In [1]:
import scanpy as sc
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl
import os
import math
import seaborn as sns
import anndata as ad
import hdf5plugin
import pandas as pd
import sys
import time

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
hahn_spatial_assay_anndata = ad.read_h5ad("hahn_spatial_assay_anndata")

# Filter for specific Age and Slice
six_month_anndata = hahn_spatial_assay_anndata[hahn_spatial_assay_anndata.obs['age'] == 'M6']
six_month_slice1_3_3_anndata = six_month_anndata[six_month_anndata.obs['slice'] == 'slice1_3_3']

In [35]:
# define the size of each pixel
pixel_dim = 10

# create new columns with pixel coordinates
six_month_slice1_3_3_anndata.obs['pixelrow'] = pd.cut(six_month_slice1_3_3_anndata.obs['imagerow'], 
                                                bins=int(six_month_slice1_3_3_anndata.obs['imagerow'].max() / pixel_dim) + 1, 
                                                labels=False) * pixel_dim

six_month_slice1_3_3_anndata.obs['pixelcol'] = pd.cut(six_month_slice1_3_3_anndata.obs['imagecol'], 
                                                bins=int(six_month_slice1_3_3_anndata.obs['imagecol'].max() / pixel_dim) + 1, 
                                                labels=False) * pixel_dim

# Correction 
six_month_slice1_3_3_anndata.obs['pixelcol'] /= pixel_dim
six_month_slice1_3_3_anndata.obs['pixelrow'] /= pixel_dim
six_month_slice1_3_3_anndata.obs['pixelcol'] = six_month_slice1_3_3_anndata.obs['pixelcol'].astype(int)
six_month_slice1_3_3_anndata.obs['pixelrow'] = six_month_slice1_3_3_anndata.obs['pixelrow'].astype(int)

unique_row_pixel_vals = six_month_slice1_3_3_anndata.obs['pixelrow'].unique()
unique_row_pixel_vals.sort()
num_pixel_row = len(unique_row_pixel_vals)

unique_col_pixel_vals = six_month_slice1_3_3_anndata.obs['pixelcol'].unique()
unique_col_pixel_vals.sort()
num_pixel_height = len(unique_col_pixel_vals)

print("# Pixel for Rows:", num_pixel_row)
print("# Pixel for Heights:", num_pixel_height)
print("# Pixels in total:", num_pixel_row*num_pixel_height)
six_month_slice1_3_3_anndata.obs.head()

# Pixel for Rows: 48
# Pixel for Heights: 45
# Pixels in total: 2160


Unnamed: 0_level_0,orig.ident,nCount_Spatial,nFeature_Spatial,nCount_SCT,nFeature_SCT,sampleID,age,integrated_snn_res.0.8,seurat_clusters,clusterLevel,regionLevel,imagerow,imagecol,slice,imagerow_rotated_v2,imagecol_rotated_v2,pixelrow,pixelcol
Cell,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
AAACAAGTATCTCCCA-1_2_1,SeuratProject,60077,8570,25034,6630,Visium_Young_R02_S1,M6,7,7,Layer II,Cortex,360.498059,429.638687,slice1_3_3,172.851568,373.754896,32,42
AAACACCAATAACTGC-1_2_1,SeuratProject,20316,5556,23429,5556,Visium_Young_R02_S1,M6,5,5,Thalamus 2,Thalamus,413.964858,156.665044,slice1_3_3,251.879891,373.168958,39,8
AAACAGCTTTCAGAAG-1_2_1,SeuratProject,11740,3779,23189,4597,Visium_Young_R02_S1,M6,0,0,White matter,White matter,322.485363,123.120121,slice1_3_3,148.388677,173.364264,26,3
AAACAGGGTCTATATT-1_2_1,SeuratProject,12775,3995,23573,4387,Visium_Young_R02_S1,M6,6,6,Globus pallidus,Globus pallidus,345.336926,136.4502,slice1_3_3,326.660167,229.394539,30,5
AAACATGGTGAGAGGA-1_2_1,SeuratProject,23045,6225,23741,6225,Visium_Young_R02_S1,M6,2,2,Hypothalamus,Hypothalamus,431.616226,94.189456,slice1_3_3,326.806652,252.319345,42,0


In [36]:
"""
code should create a pandas DataFrame df with three-level row indices (corresponding to the gene name, col1, and col2) and 
columns corresponding to the gene names. The values in the DataFrame are the corresponding X values for each gene and each 
combination of col1 and col2.
"""

# assume that your AnnData object is called `adata`
# first, extract the gene names and observation column values you're interested in
np.where(six_month_slice1_3_3_anndata.X.any(axis=0))
gene_col_index_lst = np.where(six_month_slice1_3_3_anndata.X.any(axis=0))[0]
gene_col_index_lst.sort()

# create an empty 3D numpy array to store the data
gene_expression_2D_array = np.empty((len(gene_col_index_lst), num_pixel_row, num_pixel_height))
gene_expression_2D_array.shape

(21364, 48, 45)

In [33]:
start_time = time.time()

# define the range of values for each column
col1_range = range(0, num_pixel_row)
col2_range = range(0, num_pixel_height)

# create a new dataframe with all possible combinations of values for the two columns
all_combin_row_and_col_coord_df = pd.DataFrame([(i, j) for i in col1_range for j in col2_range], columns=['pixelrow', 'pixelcol'])

index_to_gene_index_dict = {}

# loop over the genes and fill in the corresponding X values
count = 0
for gene_index in gene_col_index_lst:
    index_to_gene_index_dict[count] = gene_index
    gene_pixel_coords_df = six_month_slice1_3_3_anndata[:, gene_index].obs[["pixelrow", "pixelcol"]].copy()
    gene_pixel_coords_df["X_val"] = six_month_slice1_3_3_anndata[gene_pixel_coords_df.index, gene_index].X.flatten().tolist().copy()
    gene_pixel_coords_mean_df = gene_pixel_coords_df.groupby(["pixelrow", "pixelcol"]).mean().reset_index()
    
    # merge the all_combin_row_and_col_coord_df with the original dataframe and fill missing values with 0 in X_val
    gene_expression_2D_array[count,:,:] = pd.merge(gene_pixel_coords_mean_df, all_combin_row_and_col_coord_df, 
                         on=['pixelrow', 'pixelcol'], how='right').fillna({'X_val': 0}).pivot(index='pixelrow', 
                                                                                              columns='pixelcol', 
                                                                                              values='X_val').copy().to_numpy()
    if count % 1000 == 0:
        print("Finished processing this many genes: ", count)
    count = count + 1
    
print()

end_time = time.time()
elapsed_time = end_time - start_time
print("Elapsed time: {:.2f} seconds".format(elapsed_time))
print("Elapsed time: {:.2f} minutes".format(elapsed_time/60))
print("Elapsed time: {:.2f} hours".format(elapsed_time/360))

Finished processing this many genes:  0
Finished processing this many genes:  1000
Finished processing this many genes:  2000
Finished processing this many genes:  3000
Finished processing this many genes:  4000
Finished processing this many genes:  5000
Finished processing this many genes:  6000
Finished processing this many genes:  7000
Finished processing this many genes:  8000
Finished processing this many genes:  9000
Finished processing this many genes:  10000
Finished processing this many genes:  11000
Finished processing this many genes:  12000
Finished processing this many genes:  13000
Finished processing this many genes:  14000
Finished processing this many genes:  15000
Finished processing this many genes:  16000
Finished processing this many genes:  17000
Finished processing this many genes:  18000
Finished processing this many genes:  19000
Finished processing this many genes:  20000
Finished processing this many genes:  21000

Elapsed time: 307.92 seconds
Elapsed time: 5

In [None]:
# Create the directory if it doesn't already exist
parent_dirname = "pixel_spatial_gene_expression_version_1"
if not os.path.exists(parent_dirname):
    os.makedirs(parent_dirname)

child_dirname = parent_dirname + '/' + "pixel_dim_" + str(pixel_dim)
if not os.path.exists(child_dirname):
    os.makedirs(child_dirname)

# save the 4-dimensional NumPy array to a NPY file
np.save(child_dirname + '/' + 'gene_expression_2D_array.npy', gene_expression_2D_array)

In [None]:
start_time = time.time()

# Get the shape of the array
shape = gene_expression_2D_array.shape

# Get the indices for each dimension of the array
idx0, idx1, idx2 = np.indices(shape)

# Reshape the arrays to 1D arrays
idx0 = idx0.ravel()
idx1 = idx1.ravel()
idx2 = idx2.ravel()
val = gene_expression_2D_array.ravel()

# Create a dataframe using the index arrays and values
gene_loc_val_df = pd.DataFrame({'2D Array Index': idx0, 'pixelrow': idx1, 'pixelcol': idx2, 'value': val})
gene_loc_val_df["AnnData Gene Index"] = gene_loc_val_df["2D Array Index"].map(index_to_gene_index_dict)
gene_loc_val_df["Gene"] = gene_loc_val_df["AnnData Gene Index"].map(dict(enumerate(six_month_slice1_3_3_anndata.var.index)))


end_time = time.time()
elapsed_time = end_time - start_time
print("Elapsed time: {:.2f} seconds".format(elapsed_time))
print("Elapsed time: {:.2f} minutes".format(elapsed_time/60))
print("Elapsed time: {:.2f} hours".format(elapsed_time/360))
display(gene_loc_val_df)

In [None]:
start_time = time.time()

# save the Pandas DataFrame to a CSV file
# gene_loc_val_df.to_csv(child_dirname + '/' + 'gene_loc_val.csv', index=False)

end_time = time.time()
elapsed_time = end_time - start_time
print("Elapsed time: {:.2f} seconds".format(elapsed_time))
print("Elapsed time: {:.2f} minutes".format(elapsed_time/60))
print("Elapsed time: {:.2f} hours".format(elapsed_time/360))

In [None]:
gene_of_interest = "C4b"
specific_gene_df = gene_loc_val_df[gene_loc_val_df["Gene"]=="C4b"]

In [None]:
specific_gene_coord_val_df = specific_gene_df.groupby(["pixelrow", "pixelcol"]).mean()
specific_gene_coord_val_df.reset_index(inplace=True)
specific_gene_coord_val_df = specific_gene_coord_val_df[["pixelrow", "pixelcol", "value"]]
specific_gene_coord_val_df

In [None]:
# Reshape the dataframe using pivot
matrix = specific_gene_coord_val_df.pivot(index='pixelcol', columns='pixelrow', values='value')

# Create the plot
fig, ax = plt.subplots()
im = ax.imshow(matrix, cmap='viridis', label='Gene Expression', origin='lower')

# Add a colorbar
cbar = ax.figure.colorbar(im, ax=ax)

graph_title = gene_of_interest + ' Gene Expression Level for 6-Month, Slice 1_3_3 \n in Pixel Coordinates' + " w/ Pixel dim = " + str(pixel_dim)

plt.xlabel('Pixel Horizontal Coordinate')
plt.ylabel('Pixel Vertical Coordinate')
plt.title(graph_title)

plt.savefig(child_dirname + "/" + gene_of_interest + ' Gene Expression Level for 6-Month, Slice 1_3_3 in Pixel Coordinates' + " with pixel dim = " + str(pixel_dim) + '.png', dpi=300, bbox_inches='tight', pad_inches=0, transparent=False)
plt.show()
