In [None]:
from celltype_infer import *
from genexp_infer import *
from superres_deconv_vit import *
import scanpy as sc
import seaborn as sns


sample = 'mouse_bulb'
num_classes=18
Experimental_path = 'dataset/mouse_bulb/sr'
img_dir = 'dataset/mouse_bulb/Visium_Mouse_Olfactory_Bulb_image.tif'

deconv_adata = sc.read('dataset/mouse_bulb/adata/EnDecon_adata.h5ad')
segment_adata = sc.read('dataset/mouse_bulb/adata/img_adata_sc.h5ad')

sr_inferencer=DINOv2_superres_deconv(deconv_adata,
                                segment_adata,
                                img_dir,
                                Experimental_path,
                                neighb=3,
                                radius=deconv_adata.uns['radius'],
                                num_classes=num_classes)

sr_inferencer.run_train(epoch=200, batch_size=32, num_workers=4, accelerator='gpu')


sr_inferencer.run_superres()

feature_pred = torch.load('dataset/mouse_bulb/sr/sr/features_pred.pt')

feature_train = torch.load('dataset/mouse_bulb/sr/sr/features_train.pt')

feature_pred.shape
model = DINOv2NeighborClassifier.load_from_checkpoint('dataset/mouse_bulb/sr/sr/superres_model.ckpt',num_classes=num_classes)

with torch.no_grad():
    # result = model(feature_train)
    result = model(feature_pred)

dataset = DINOv2NeighborDataset(centers=sr_inferencer.sr_adata.obsm['spatial'],
                                img_path=img_dir,
                                label_frame=None,
                                train=False,
                                radius=sr_inferencer.radius, 
                                neighb=sr_inferencer.neighb,
                                path=sr_inferencer.path)

dataloader = DataLoader(dataset, batch_size=256, num_workers=4)
predict = sr_inferencer.pred(dataloader)
sr_inferencer.sr_adata.obs[sr_inferencer.cell_type_name] = predict
sr_inferencer.sr_adata.write(os.path.join(sr_inferencer.path,'adata/sr_adata.h5ad'))


sr_deconv_adata = sc.read_h5ad('dataset/mouse_bulb/sr/adata/sr_adata.h5ad')


# alpha=0.3
inferencer = CellTypeAnnotator(experimental_path=Experimental_path,
                            img_dir=img_dir,
                            num_classes=18,
                            deconv_adata=deconv_adata,
                            sr_deconv_adata=sr_deconv_adata,
                            segment_adata=segment_adata)

inferencer.filter_segmentation()
inferencer.calculate_cell_count()
# inferencer.calculate_imgtype_ratio()
inferencer.calculate_celltype_ratio()
# inferencer.calculate_type_transfer_matrix()
inferencer.infer_cell_types()

inferencer.segment_cp.write('dataset/mouse_bulb/sr/adata/adata_pred_celltype.h5ad')


adata = sc.read_h5ad('dataset/mouse_bulb/sr/adata/adata_pred_celltype.h5ad')

celltype_list = adata.obs['pred_cell_type'].unique().tolist()
celltype_list.sort()

for ct in celltype_list:
    fig, ax = plt.subplots(1, 3, figsize=(15,5))
    # xm_adata.obs[ct] = (xm_adata.obs['subclass_label'] == ct).values.astype(int)
    # sc.pl.spatial(xm_adata, color=ct, spot_size=55, cmap='Reds', ax=ax[0])
    sc.pl.spatial(adata, color=ct, spot_size=30, cmap='Reds', ax=ax[0])
    sc.pl.spatial(deconv_adata, color=ct, spot_size=100, cmap='Reds', ax=ax[1])
    sc.pl.spatial(sr_deconv_adata, color=ct, spot_size=50, cmap='Reds', ax=ax[2])
    plt.savefig(f'dataset/mouse_bulb/sr/celltype_infer/fig/pred/{ct}.png', dpi=500)
    plt.close()
    # break


#################
sc_adata = sc.read('dataset/allen/expression_matrix_sub.h5ad')

sc_adata.obs['subclass_label'] = (
    sc_adata.obs['subclass_label'].astype(str)
    .str.replace('-', '_')
    .str.replace('/', '_')
    .str.replace(' ', '_')
)

# alpha=0.3
infered_adata = sc.read('dataset/Xenium_brain/sr/adata/adata_pred_celltype.h5ad')
celltype_list = infered_adata.obs['pred_cell_type'].unique().tolist()
celltype_list.sort()
gene_inferencer = GeneExpPredictor(sc_adata=sc_adata,
                                   spot_adata=deconv_adata,
                                   infered_adata=infered_adata)

gene_inferencer.ctspecific_spot_gene_exp(celltype_list, celltype_column='subclass_label')


genemap = gene_inferencer.do_geneinfer(
    gamma_param=0.001,
    graph_mode='delaunay',
    weight_mode='inverse',
    # k=k,
    # sigma=30.0
)
genemap.write('dataset/Xenium_brain/adata/genemap_.h5ad')
