In [None]:
# %pip install GraphST
# %pip install POT

# GraphST Cell Deconvolution

Inspired by this pipeline:
https://deepst-tutorials.readthedocs.io/en/latest/Tutorial%202_scRNA%20and%20ST%20data%20integration.html

In [None]:
import os
# import squidpy as sq
import tangram as tg
import matplotlib.pyplot as plt
import anndata as ad
import scanpy as sc
import pandas as pd
import numpy as np
import json
from typing import Dict, List, Optional, Union
from matplotlib.pyplot import imread
import liana as li
import decoupler as dc
import omnipath
from tqdm import tqdm
import torch
from GraphST import GraphST
from GraphST.preprocess import filter_with_overlap_gene
from GraphST.utils import project_cell_to_spot



from gbmhackathon.utils.visium_functions import (
    normalize_anndata_wrapper,
    convert_obsm_to_adata
)
from gbmhackathon.viz.visium_functions import (
    plot_spatial_expression,
    plot_obsm
)
from gbmhackathon.stats.visium_functions import (
    perform_multi_clustering,
    quantify_cell_population_activity
)
from gbmhackathon import MosaicDataset

import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 1200

In [None]:
visium_dict = MosaicDataset.load_visium(
    sample_list=["HK_G_022a_vis", "HK_G_024a_vis", "HK_G_030a_vis"],
    resolution="hires"
)

In [None]:
output_file = "data/mosaic_dataset/single_cell/preprocessed/sc_merged_annotated.h5ad"

if os.path.exists(output_file):
    # If the file exists, read the AnnData object from the file
    adata_sc = sc.read_h5ad(output_file)
    print(f"Loaded AnnData from {output_file}")
else:
    # If the file does not exist, load the data and write it locally
    adata_sc = MosaicDataset.load_singlecell()
    adata_sc.write(output_file)
    print(f"Single-cell AnnData successfully saved to {output_file}")

In [None]:
# Normalize ST data, e.g., to 1e6 counts
visium_obj = normalize_anndata_wrapper(visium_dict, target_sum=1e6)

sample_key = "HK_G_022a_vis"
adata_st = visium_obj[sample_key]

adata_st.var_names_make_unique()

In [None]:
if hasattr(adata_st, 'raw'):
    del adata_st.raw  # remove entire .raw attribute (if it exists)

# Remove unneeded layers
for layer_key in ['raw', 'CPM', 'log_CPM']:
    if layer_key in adata_st.layers:
        del adata_st.layers[layer_key]

for obsm_key in ['distance_matrix', 'graph_neigh']:
    if obsm_key in adata_st.obsm:
        del adata_st.obsm[obsm_key]

if 'spatial' in adata_st.uns:
    del adata_st.uns['spatial']

if hasattr(adata_sc, 'raw'):
    del adata_sc.raw

for layer_key in ['LogNormalize', 'ambient_rna_removed']:
    if layer_key in adata_sc.layers:
        del adata_sc.layers[layer_key]

for obsm_key in ['X_pca', 'X_scanvi', 'X_scvi', 'X_umap']:
    if obsm_key in adata_sc.obsm:
        del adata_sc.obsm[obsm_key]

# Remove neighbor graphs or other large data in .obsp or .uns
if 'distances' in adata_sc.obsp:
    del adata_sc.obsp['distances']
if 'connectivities' in adata_sc.obsp:
    del adata_sc.obsp['connectivities']
for uns_key in ['neighbors', 'umap', 'scanvi_probs']:
    if uns_key in adata_sc.uns:
        del adata_sc.uns[uns_key]

adata_st.X = adata_st.X.astype(np.float32)
adata_sc.X = adata_sc.X.astype(np.float32)

In [None]:
# adata_st.write("vdata/adata_st_cleaned.h5ad")
# adata_sc.write("vdata/adata_sc_cleaned.h5ad")
# adata_st = sc.read_h5ad("vdata/adata_st_cleaned.h5ad")
adata_sc = sc.read_h5ad("vdata/adata_sc_cleaned.h5ad")

In [None]:
obs_columns_to_remove = ["in_tissue", "array_row", "array_col"]
var_columns_to_remove = ["feature_types", "genome"]

# Remove columns from adata_st.obs
adata_st.obs.drop(columns=obs_columns_to_remove, inplace=True, errors="ignore")

# Remove columns from adata_st.var
adata_st.var.drop(columns=var_columns_to_remove, inplace=True, errors="ignore")

# Same idea applies to adata_sc:
adata_sc.obs.drop(columns=["doublet_score_scdblfinder", "doublet_predicted_scdblfinder"], 
                  inplace=True, 
                  errors="ignore")

In [None]:
n_cells_to_keep = 250000  # keep 50k cells for instance
if adata_sc.n_obs > n_cells_to_keep:
    idx = np.random.choice(adata_sc.n_obs, n_cells_to_keep, replace=False)
    adata_sc = adata_sc[idx, :].copy()

In [None]:
GraphST.preprocess(adata_st)
GraphST.construct_interaction(adata_st)
GraphST.add_contrastive_label(adata_st)

GraphST.preprocess(adata_sc)
adata_st, adata_sc = filter_with_overlap_gene(adata_st, adata_sc)
GraphST.get_feature(adata_st)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = GraphST.GraphST(
    adata_st,
    adata_sc,
    epochs=700,
    random_seed=46,
    device=device,
    deconvolution=True
)
adata_st, adata_sc = model.train_map()

In [None]:
adata_sc.obs['cell_type'] = adata_sc.obs['celltype_level2_scanvi'].copy()
project_cell_to_spot(adata_st, adata_sc, retain_percent=0.15)

print(adata_st)

In [None]:
with mpl.rc_context({'axes.facecolor': 'black', 'figure.figsize': [5, 5]}):
    sc.pl.spatial(
    adata_st,
    color=['Immune', 'Neuroglia', 'Neuron', 'Stromal', 'Endothelial', 'Granulocyte', 'Malignant_gbm', 'MoMac', 'Oligodendrocyte','T_NK'],
    cmap='magma', 
    ncols=5,
    size=1.1,
    img_key="hires",
    vmin=0,
    vmax="p99.2",    
    show=True
)