In [None]:
import scanpy as sc
import squidpy as sq
import numpy as np
import pandas as pd
from anndata import AnnData
import pathlib
import matplotlib.pyplot as plt
import matplotlib as mpl
#import skimage
import seaborn as sns
import tangram as tg
#from scipy.spatial import KDTree

sc.logging.print_header()
print(f"squidpy=={sq.__version__}")

%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
adata_spatial = sq.read.visium('Spatial_data/HYPOMAPErno/B2/',  counts_file='filtered_feature_bc_matrix.h5', library_id=None, load_images=True)
adata_spatial

In [None]:
spatial_meta = pd.read_csv('Data/hypomap_spatial_metadata.csv', index_col=0)
spatial_meta.index = [barcode.rsplit('_', 1)[0] for barcode in spatial_meta.index]
temp_df = spatial_meta[spatial_meta.captureArea == 'b2']

adata_spatial = adata_spatial[temp_df.index]
adata_spatial.obs = pd.concat([adata_spatial.obs, temp_df], axis=1) 

adata_spatial.var_names_make_unique()

adata_spatial.obsm['spatial'] = np.array([[y*-1, x] for x, y in adata_spatial.obsm['spatial']])

sc.pp.normalize_total(adata_spatial)
sc.pp.log1p(adata_spatial)

adata_spatial

In [None]:
adata_bmp = sc.read('../Data/SC/parse_annotated_late_stage.h5ad')

adata_spatial.obs['Cell_types'] = adata_spatial.obs['regional_clusters_grouped']
adata_spatial.obs['bmp_treatment'] = '_'

adata_spatial.obs['source'] = 'Hypomap'
adata_bmp.obs['source'] = 'Zehra'
adata_bmp = adata_bmp[adata_bmp.obs.Cell_types != 'Telencephalic neurons']
print(adata_bmp)


In [None]:
gene_df = pd.read_excel("DE_genes/parse_late_deg_annotations.xlsx", index_col=0)
gene_df = gene_df[gene_df.cluster != 'Telencephalic neurons']

genes = gene_df.gene.values

gene_df

In [None]:
tg.pp_adatas(adata_bmp, adata_spatial, genes=genes)

ad_map = tg.map_cells_to_space(adata_bmp, adata_spatial,
    mode="clusters",
    cluster_label='Cell_types',
    density_prior='rna_count_based',
    num_epochs=500,
    device='cuda',
)

tg.project_cell_annotations(ad_map, adata_spatial, annotation='Cell_types')

In [None]:
pred_matrix = adata_spatial.obsm['tangram_ct_pred'].values

thresholds = np.percentile(pred_matrix, 80, axis=0)


mask = pred_matrix >= thresholds


filtered_matrix = np.where(mask, pred_matrix, 0)

adata_spatial.obsm['filtered_ct_pred'] = pd.DataFrame(filtered_matrix, index = adata_spatial.obsm['tangram_ct_pred'].index, columns= adata_spatial.obsm['tangram_ct_pred'].columns)

                                        

In [None]:
from sklearn.preprocessing import minmax_scale
scaled_values = minmax_scale(adata_spatial.obsm['filtered_ct_pred'].values)

# Put the scaled values back into a DataFrame with original column names and index
adata_spatial.obsm['filtered_ct_pred'] = pd.DataFrame(scaled_values, index=adata_spatial.obs.index, columns=adata_spatial.obsm['filtered_ct_pred'].columns)

adata_spatial.obs = pd.concat([adata_spatial.obs, adata_spatial.obsm['filtered_ct_pred']], axis=1)

In [None]:
with plt.rc_context({"figure.dpi": 300}):


    # Create 4x4 subplot (i.e., up to 16 cell types)
    for cell_type in list(pd.unique(adata_bmp.obs['Cell_types'])):
        sq.pl.spatial_scatter(
                adata_spatial,
                shape=None,
                frameon=False,
                color=cell_type,
                size=6,
                cmap='jet',
                #ax = axes[i]
                colorbar=False,
                #save=f'{cell_type}_spatial.png'
            )
        plt.tight_layout()
        plt.savefig(f'figures/Spatial/{cell_type.replace("/","-")}_spatial.pdf')
        plt.show()
    


In [None]:
import numpy as np
import matplotlib.pyplot as plt
with plt.rc_context({"figure.dpi": 300}):

    a = np.array([[0, 1]])
    plt.figure(figsize=(0.3, 2))
    
    # Show the image
    img = plt.imshow(a, cmap="jet")
    
    # Hide the main axis
    plt.gca().set_visible(False)
    
    # Add colorbar with ticks at 0 and 1
    cax = plt.axes([0.1, 0.2, 0.8, 0.6])
    cbar = plt.colorbar(img, orientation="vertical", cax=cax)
    cbar.set_ticks([0, 1])
    cbar.set_ticklabels(["0", "1"], size=15)
    
    # Save figure
    plt.savefig("figures/jet_colorbar.pdf", dpi=450, bbox_inches='tight')