In [None]:
import scanpy as sc,anndata as ad
import pandas as pd
from matplotlib import pyplot as plt
import os
import numpy as np
from matplotlib import rcParams
import seaborn as sns
from anndata import AnnData
os.chdir(os.getcwd())
import scarches as sca
from scarches.models.scpoli import scPoli

import scvi
import itertools
import torch
import time
#import ray
#from ray import tune
#from scvi import autotune
import pickle
from itables import init_notebook_mode
init_notebook_mode(all_interactive=True)

# Setup functions

In [None]:
## Setup function for calculating elapsed time
def print_elapsed_time(start,stop):
    # Calculate the elapsed time in seconds
    elapsed_seconds = stop - start
    
    # Convert elapsed time to hours and minutes
    elapsed_minutes, elapsed_seconds = divmod(int(elapsed_seconds), 60)
    elapsed_hours, elapsed_minutes = divmod(elapsed_minutes, 60)
    
    # Print the result in the desired format
    print(f"Elapsed time:{elapsed_hours} hours:{elapsed_minutes} minutes")

# LOAD DATASETS

## Load adata and Tabula Sapiens vascular + blood dataset

In [None]:
from scipy import sparse
adata_ts=sc.read_h5ad('../../data/ts_blood_vasc.h5ad')
adata_ts.obs['scanvi_batch']=adata_ts.obs['donor'].astype(str) +'_'+ adata_ts.obs['method'].astype(str)
adata_ts.obs['scanvi_batch']=adata_ts.obs['scanvi_batch'].astype('category')
adata_ts.X=sparse.csr_matrix(adata_ts.X)

## Load preprocessed Xenium anndata files + subset TS data to common genes with Xenium panels

In [None]:
proc_dir='/data/gpfs/projects/punim2121/Atherosclerosis/xenium_data/processed_data/baysor_processed_output'

adata_dict={}
adata_ts_dict={}
adata_merged_dict={}
for panel in ['Panel1','Panel2'][:]:
    fn=os.path.join(proc_dir,f'filtered_{panel}_cells.h5ad')
    adata=sc.read_h5ad(fn)


    ## Add some columnms (originating from the reference TS data) to xenium data obs, which will be covariate keys in the scVI model
    adata.obs['donor']=adata.obs['patient'].values
    adata.obs['method']='xenium'
    adata.obs['free_annotation']='unknown'
    adata.obs['organ_tissue']='Vasculature'

    for coln in ['free_annotation','donor','organ_tissue','method']:
        adata.obs[coln]=adata.obs[coln].astype('category')
        
    

    common_genes=list(set(adata_ts.var_names)&set(adata.var_names))
    print(f'Number of common genes between TS and {panel} data: {len(common_genes)} out of {len(adata.var_names)}')
    
    adata_dict[panel]=adata[:,common_genes].copy()
    adata_ts_dict[panel]=adata_ts[:,common_genes].copy()

    adata_merged_dict[panel]=ad.concat([adata_ts[:,common_genes].copy(),adata[:,common_genes].copy()],
                                       label='dataset_origin',keys=['TS','xenium'],join='outer',merge='unique')


## SCPOLI TRAINING ON REFERENCE DATA

In [None]:
###====== DEFINE TRAINIG KWARGS
train_kwargs = {
    "early_stopping": True,
    "early_stopping_monitor": "elbo_validation",
    "early_stopping_patience": 30,
    "early_stopping_min_delta": 0.00001,
    "check_val_every_n_epoch":1,
    "plan_kwargs": {"weight_decay": 1e-6},
    'accelerator':'auto'}


###====== DEFINE HYPERPARAMETERS FOR MODELS TO TRAIN
search_space={
        "n_hidden":[128],
        "n_latent":[10],
        "n_layers": [1],
        'gene_likelihood':['nb']}

# Create a list of dictionaries where each dictionary represents a set of hyperparameters
parameter_combinations = list(itertools.product(*search_space.values()))
hyperparameter_sets = [{key: value for key, value in zip(search_space.keys(), combination)} for combination in parameter_combinations]

start=time.time()

torch.set_float32_matmul_precision('high')

###====== RUN TRAINING
## Concatenate the 2 batch effect columns into one, as scNVI can only handle one batch effect covariate
for panel in ['Panel1','Panel2']:
    print(panel)
    adata_ts_panel=adata_ts_dict[panel]

    adata_ts_panel.obs['scanvi_batch']=adata_ts_panel.obs['donor'].astype(str) +'_'+ adata_ts_panel.obs['method'].astype(str)
    adata_ts_panel.obs['scanvi_batch']=adata_ts_panel.obs['scanvi_batch'].astype('category')
    adata_ts_panel.obs['donor']=adata_ts_panel.obs['donor'].astype('category')
    #adata_ts_panel=adata_ts_panel.copy()


    adata_merged=adata_merged_dict[panel]

    for n,hyperparameters in enumerate(hyperparameter_sets):
        loop_time=time.time()
    
        print_elapsed_time(start,loop_time)
        print('set:',str(n+1),'/',str(len(hyperparameter_sets))) 
        print(hyperparameters)
        
        ## Create string of the parameters
        pairs=[f"{key}:{value}" for key, value in hyperparameters.items()]
        param_set_string='-'.join(pairs)
    
        
        ## Setup model    
        sca.models.SCVI.setup_anndata(adata_merged,
                                     layer="raw_counts",
                                     categorical_covariate_keys=['donor'],
                                     batch_key='method',
                                     labels_key='free_annotation')
        
        vae=sca.models.SCVI(adata_merged,
                            n_layers=hyperparameters['n_layers'],
                            n_hidden=hyperparameters['n_hidden'],
                            n_latent=hyperparameters['n_latent'],
                            gene_likelihood=hyperparameters['gene_likelihood']
                            )
    
    
        vae.train(max_epochs=500,**train_kwargs)
    
        ## Plot training results
        fig,ax=plt.subplots(1,2,figsize=(9,3))
        fig.suptitle('-'.join([panel,param_set_string]))
        d=vae.history
        ax[0].plot(d['reconstruction_loss_train']['reconstruction_loss_train'], label='reconstruction_loss_train')
        ax[0].plot(d['reconstruction_loss_validation']['reconstruction_loss_validation'], label='reconstruction_loss_validation')
        ax[1].plot(d['elbo_validation']['elbo_validation'], label='elbo_validation')
        ax[1].plot(d['elbo_train']['elbo_train'], label='elbo_train')    
        ax[0].legend()
        ax[1].legend()
    
        ## Create string from parameters==> add to the name of the model when saving
        pairs=[f"{key}:{value}" for key, value in hyperparameters.items()]
        param_set_string='-'.join(pairs)
        
        ## Save model
        path=os.path.join(proc_dir,"scvi_models",panel+"_ts_model_"+param_set_string)
        os.makedirs(path, exist_ok=True) 
        vae.save(path,overwrite=True)
    
    