In [None]:
from typing import Any
from anndata import AnnData

# Load packages and classes
import anndata as ad
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tiffslide
import seaborn as sns
import gget
import tifffile
import zarr
import scanpy as sc
import tempfile

# MosaicDataset and BruceDataset classes allow loading and visualisation of the different data sources
from gbmhackathon import MosaicDataset

import scvi

# Load Visium data
visium_obj = MosaicDataset.load_visium(
#     sample_list=["HK_G_022a_vis", "HK_G_024a_vis", "HK_G_030a_vis"],  # remove this argument to load all available samples
    resolution="lowres"
)

# Create anndata 
adata = ad.concat(visium_obj.values(), label="batch", keys=visium_obj.keys())
adata.write("adata_spatial_raw.h5ad")

# Load all adata
adata = sc.read_h5ad("adata_spatial_raw.h5ad", backed="r")

train_split = ['HK_G_002a',
 'HK_G_003a',
 'HK_G_004a',
 'HK_G_005a',
 'HK_G_006a',
 'HK_G_007a',
 'HK_G_008a',
 'HK_G_009a',
 'HK_G_010a',
 'HK_G_012a',
 'HK_G_013a',
 'HK_G_015a',
 'HK_G_016a',
 'HK_G_017b',
 'HK_G_018a',
 'HK_G_020a',
 'HK_G_022a',
 'HK_G_023a',
 'HK_G_024a',
 'HK_G_025a',
 'HK_G_026a',
 'HK_G_028a',
 'HK_G_029b',
 'HK_G_030a',
 'HK_G_031a',
 'HK_G_032a',
 'HK_G_033a',
 'HK_G_034a',
 'HK_G_035a',
 'HK_G_036b',
 'HK_G_037a',
 'HK_G_038a',
 'HK_G_039a',
 'HK_G_040a',
 'HK_G_043a',
 'HK_G_044b',
 'HK_G_046a',
 'HK_G_047a',
 'HK_G_049a',
 'HK_G_050a',
 'HK_G_051a',
 'HK_G_052a',
 'HK_G_053a',
 'HK_G_054a',
 'HK_G_058a',
 'HK_G_059b',
 'HK_G_060a',
 'HK_G_061b',
 'HK_G_062a',
 'HK_G_063a',
 'HK_G_064a',
 'HK_G_065a',
 'HK_G_067a',
 'HK_G_068a',
 'HK_G_069a',
 'HK_G_070a',
 'HK_G_071a',
 'HK_G_072a',
 'HK_G_073a',
 'HK_G_074a',
 'HK_G_075a',
 'HK_G_078a',
 'HK_G_079b',
 'HK_G_080a',
 'HK_G_081a',
 'HK_G_082b',
 'HK_G_083a',
 'HK_G_084b',
 'HK_G_085a',
 'HK_G_086b',
 'HK_G_087a',
 'HK_G_088a',
 'HK_G_089a',
 'HK_G_090b',
 'HK_G_091a',
 'HK_G_092b',
 'HK_G_093a',
 'HK_G_095a',
 'HK_G_096b',
 'HK_G_099a',
 'HK_G_100b',
 'HK_G_101a',
 'HK_G_102a',
 'HK_G_104a',
 'HK_G_105b',
 'HK_G_106a',
 'HK_G_108a',
 'HK_G_109b',
 'HK_G_110a',
 'HK_G_111b',
 'HK_G_112a',
 'HK_G_113b',
 'HK_G_114a',
 'HK_G_115b']

test_split = ['HK_G_001a',
 'HK_G_011a',
 'HK_G_014a',
 'HK_G_019a',
 'HK_G_021a',
 'HK_G_027a',
 'HK_G_041a',
 'HK_G_042a',
 'HK_G_045a',
 'HK_G_048a',
 'HK_G_055a',
 'HK_G_056a',
 'HK_G_057a',
 'HK_G_066a',
 'HK_G_076a',
 'HK_G_077a',
 'HK_G_094a',
 'HK_G_098b',
 'HK_G_103a',
 'HK_G_107a']

def preprocess_adata(adata: AnnData) -> AnnData:
    adata.var_names_make_unique()
    adata.var["mt"] = adata.var_names.str.startswith("MT-")
    sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], inplace=True)
    
    # Filtering
    sc.pp.filter_genes(adata, min_cells=10)
    adata.layers["counts"] = adata.X.copy()  # preserve counts

    # Normalize data
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    adata.raw = adata 

    # Select HVG
    sc.pp.highly_variable_genes(
        adata,
        n_top_genes=2000,
        subset=True,
        layer="counts",
        flavor="seurat_v3",
        batch_key="batch",
    )

    return adata

adata = preprocess_adata(adata)

def train_and_save(adata: AnnData) -> Any:
    scvi.model.SCVI.setup_anndata(
        adata,
        layer="counts",
        categorical_covariate_keys=["batch"],
    )
    model = scvi.model.SCVI(adata)
    model.train(epoch=1)
    return model

adata.obs["batch"]

adata_train = adata[adata.obs["batch"].isin(train_split)].copy()
model = train_and_save(adata_train)

SCVI_LATENT_KEY = "X_scVI"

latent = model.get_latent_representation()
adata.obsm[SCVI_LATENT_KEY] = latent

adata_sub = sc.pp.subsample(adata, n_obs=50_000, copy=True)

sc.pp.pca(adata_sub)
sc.pp.neighbors(adata_sub)
sc.tl.umap(adata_sub)
adata_sub.obsm["X_umap_raw"] = adata_sub.obsm["X_umap"]

sc.pp.neighbors(adata_sub, use_rep=SCVI_LATENT_KEY)
sc.tl.umap(adata_sub, min_dist=0.3)

fig, axes = plt.subplots(1, 2, figsize=(12, 6))
sc.pl.embedding(adata_sub, basis="X_umap_raw", color="batch", show=False, ax=axes[0])
sc.pl.embedding(adata_sub, basis="X_umap", color="batch", show=False, ax=axes[1])
plt.show()

dict_pathway = {
    "ECM_remodeling": ["COL1A1", "COL3A1", "FN1", "MMP2"],
    "Lymphocytes": ["CD3E", "CD8A", "CD4", "CD19", "CD20", "CD79A"],
    "TAMs": ["CCL4","ADRB2","NAV3","ADORA3","SIGLEC8","SPRY1","TAL1","RHOB","BIN1","SALL1","KLF2","BHLHE41","SLC1A3","P2RY12"],
    "Tumor_cells": ["EGFR", "HER2", "MKI67", "VEGFA", "CD44", "GPC3"],
}
for key, value in dict_pathway.items():
    sc.tl.score_genes(adata_sub, value, score_name=key)

sc.pl.embedding(adata_sub, basis="X_umap_raw", color=list(dict_pathway.keys()) + ["batch"])
sc.pl.embedding(adata_sub, basis="X_umap", color=list(dict_pathway.keys()) + ["batch"])