In [None]:
import scarches as sca
#import torch
import scanpy as sc
import pandas as pd
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
import os
import time
import itertools

In [None]:
## Setup function for calculating elapsed time
def print_elapsed_time(start,stop):
    # Calculate the elapsed time in seconds
    elapsed_seconds = stop - start
    
    # Convert elapsed time to hours and minutes
    elapsed_minutes, elapsed_seconds = divmod(int(elapsed_seconds), 60)
    elapsed_hours, elapsed_minutes = divmod(elapsed_minutes, 60)
    
    # Print the result in the desired format
    print(f"Elapsed time:{elapsed_hours} hours:{elapsed_minutes} minutes")

# LOAD DATA & SUBSET TO PATIENTS 1-4

In [None]:
proc_dir='/data/gpfs/projects/punim2121/Atherosclerosis/xenium_data/processed_data/baysor_processed_output'
scale_param='10'#'5'#'5|10|15'
avg_assignment_conf_thr=0.75

adata_dict={}
for panel in ['Panel1','Panel2'][:]:

    fn=os.path.join(proc_dir,f'filtered_{panel}_cells_scale_{scale_param}_asg_conf_{avg_assignment_conf_thr}.h5ad')
    adata=sc.read_h5ad(fn)
    
    adata=adata[adata.obs['patient'].isin(['P1','P2','P3','P4']),:].copy()
    print(adata.shape)
    adata_dict[panel]=adata.copy()
    

# Add ENSEMBL IDs of gene names to adata.var

In [None]:

## THE EXACT ENSEMBL IDS -HGNCE NAME PAIRS ARE STORED IN THE GENE_PANEL.JSON FILES (THE SAME FOR ALL SAMPLES WITHIN RESPECTIVE GENE PANEL)
#  ==> EXTRACT THE ENSEMBL IDS;HGNC NAME PAIRS AND ADD THEM TO ADATA.VAR, TO MAKE SURE UNIQUE GENE NAMES ARE USED
def extract_ensembl_hgnc_name_df():
    
    panel_sample_dirs=["20230808__140639__2311_Sachs_Panel1/output-XETG00050__0003370__P3_D__20230808__140759",
                       "20230817__135824__2311_Sachs_Panel2/output-XETG00050__0003341__P3_D__20230817__135943"]
    data_dir="/data/gpfs/projects/punim2121/Atherosclerosis/xenium_data/"
    import json
    
    df_list=[]
    for panel,sample_dir in zip(['Panel1','Panel2'],panel_sample_dirs):
   
        with open(os.path.join(data_dir,sample_dir,'gene_panel.json')) as f:
            gene_panel=json.load(f)
        
        ensembl_id_list,hgnc_list=[],[]
        
        for n in range(len(gene_panel['payload']['targets'])):
            try:
                ensembl_id_list.append(gene_panel['payload']['targets'][n]['type']['data']['id'])
                hgnc_list.append(gene_panel['payload']['targets'][n]['type']['data']['name'])
            except KeyError:
                pass

        gene_info_df=pd.DataFrame.from_dict({'ENSEMBL_ID':ensembl_id_list,'HGNC_NAME':hgnc_list})
        gene_info_df[panel]=True
        df_list.append(gene_info_df)

    gene_info_merged=pd.concat(df_list,axis=0)
    gene_info_merged=gene_info_merged.drop_duplicates('ENSEMBL_ID').reset_index().set_index('HGNC_NAME')
    
    return gene_info_merged


gene_info_merged=extract_ensembl_hgnc_name_df()


for panel in ['Panel1','Panel2'][:]:

    adata=adata_dict[panel]
    adata.var['ENSEMBL_ID']=gene_info_merged.loc[adata.var.index.tolist(),'ENSEMBL_ID'].values
    adata.var=adata.var.reset_index().set_index('ENSEMBL_ID',drop=True).rename(columns={'index':'HGNC'})
    adata_dict[panel]=adata



# BATCH EFFECT CORRECTION WITH SCVI

## Run scVI model an both Panel adatas

In [None]:
###====== DEFINE TRAINIG KWARGS

train_kwargs = {
    "early_stopping": True,
    "early_stopping_monitor": "elbo_validation",
    "early_stopping_patience": 30,
    "early_stopping_min_delta": 0.01,
    "check_val_every_n_epoch":1,
    "plan_kwargs": {"weight_decay": 1e-6},
    'accelerator':'auto'}

search_space={
        "n_hidden":[128],
        "n_latent":[10], #20
        "n_layers": [1], #2
        'gene_likelihood':['zinb'],
        'max_epochs':[400]}

# Create a list of dictionaries where each dictionary represents a set of hyperparameters
parameter_combinations = list(itertools.product(*search_space.values()))
hyperparameter_sets = [{key: value for key, value in zip(search_space.keys(), combination)} for combination in parameter_combinations]

start=time.time()

#torch.set_float32_matmul_precision('high')

###====== RUN TRAINING
for panel in ['Panel1','Panel2']:
    print(panel)

    adata=adata_dict[panel]

    for n,hyperparameters in enumerate(hyperparameter_sets):
        loop_time=time.time()
    
        print_elapsed_time(start,loop_time)
        print('set:',str(n+1),'/',str(len(hyperparameter_sets))) 
        print(hyperparameters)
        
        ## Create string of the parameters
        pairs=[f"{key}:{value}" for key, value in hyperparameters.items()]
        param_set_string='-'.join(pairs)
    
        
        ## Setup model    
        sca.models.SCVI.setup_anndata(adata,
                                     layer="raw_counts",
                                     #categorical_covariate_keys=['condition'],
                                     batch_key='original_sample',
                                     #labels_key='free_annotation'
                                     )
        
        vae=sca.models.SCVI(adata,
                            n_layers=hyperparameters['n_layers'],
                            n_hidden=hyperparameters['n_hidden'],
                            n_latent=hyperparameters['n_latent'],
                            gene_likelihood=hyperparameters['gene_likelihood']
                            )
    
    
        vae.train(max_epochs=hyperparameters['max_epochs'],**train_kwargs)
    
        ## Plot training results
        fig,ax=plt.subplots(1,2,figsize=(9,3))
        fig.suptitle('-'.join([panel,param_set_string]))
        d=vae.history
        ax[0].plot(d['reconstruction_loss_train']['reconstruction_loss_train'], label='reconstruction_loss_train')
        ax[0].plot(d['reconstruction_loss_validation']['reconstruction_loss_validation'], label='reconstruction_loss_validation')
        ax[1].plot(d['elbo_validation']['elbo_validation'], label='elbo_validation')
        ax[1].plot(d['elbo_train']['elbo_train'], label='elbo_train')    
        ax[0].legend()
        ax[1].legend()
    
        ## Create string from parameters==> add to the name of the model when saving
        pairs=[f"{key}:{value}" for key, value in hyperparameters.items()]
        param_set_string='-'.join(pairs)
        
        ## Save model
        #path=os.path.join(proc_dir,"scvi_models",panel+"_ts_model_"+param_set_string)
        path=os.path.join(proc_dir,"scvi_models",f'{panel}_batch_corr_model_{param_set_string}_data_cells_scale_{scale_param}_asg_conf_{avg_assignment_conf_thr}')
        os.makedirs(path, exist_ok=True) 
        vae.save(path,overwrite=True)

## Load the trained scVI models and latent representation

In [None]:
import scvi
###====== DEFINE HYPERPARAMETERS FOR MODELS TO TRAIN
search_space={
        "n_hidden":[128],
        "n_latent":[10], #10
        "n_layers": [1], #1
        'gene_likelihood':['zinb'],
        'max_epochs':[400]}

# Create a list of dictionaries where each dictionary represents a set of hyperparameters
parameter_combinations = list(itertools.product(*search_space.values()))
hyperparameter_sets = [{key: value for key, value in zip(search_space.keys(), combination)} for combination in parameter_combinations]


adata_vae_dict={}

for panel in ['Panel1','Panel2'][:]:
    adata=adata_dict[panel]

    adata_vae_dict[panel]={}
    for hyperparameters in hyperparameter_sets:

        ## Create string from parameters==> add to the name of the model when saving
        pairs=[f"{key}:{value}" for key, value in hyperparameters.items()]
        param_set_string='-'.join(pairs)
        print(panel,param_set_string)
        
        ## LOAD MODEL
        #path=os.path.join(proc_dir,"scvi_models",panel+"_batch_corr_model_"+param_set_string)
        path=os.path.join(proc_dir,"scvi_models",f'{panel}_batch_corr_model_{param_set_string}_data_cells_scale_{scale_param}_asg_conf_{avg_assignment_conf_thr}')
        try:
            vae=scvi.model.SCVI.load(path, adata=adata)
        except ValueError:
            print(f'Model not found: {param_set_string}\n')
            continue
        adata_vae_dict[panel][param_set_string]=vae

### Check raw UMAP for batch effects

In [None]:
for panel in ['Panel1','Panel2']:
    print(panel)
    adata=adata_dict[panel]
    
    adata.obsm["raw_counts_pca"]=sc.tl.pca(adata.layers["raw_counts"])
    sc.pp.neighbors(adata, n_neighbors=15,use_rep='raw_counts_pca',key_added='raw_counts_neigh',metric='cosine')
    sc.tl.umap(adata,n_components=2,neighbors_key='raw_counts_neigh')
    
    color_cols=['patient','condition','original_sample','sample_region']
    adata_dict[panel]=adata
    ncols=2
    nrows=int(np.ceil(len(color_cols)/ncols))
    fig=plt.figure(figsize=(ncols*7,nrows*5))
    fig.suptitle(f'{panel}_cells_scale_{scale_param}_asg_conf_{avg_assignment_conf_thr}',fontweight='bold',y=1.02)
    
    for n,col in enumerate(color_cols):
        ax=fig.add_subplot(nrows,ncols,n+1) 
                
        leg_loc='right margin'
    
        if col in ['fastcluster','leiden_groups']:
            leg_loc="on data"
        sc.pl.umap(adata, color=col,show=False,ax=ax,size=3,legend_loc=leg_loc)

## Plot scVI training results + Leiden-cluster batch corrected data

In [None]:
import scvi

###====== DEFINE TRAINIG KWARGS

search_space={
        "n_hidden":[128],
        "n_latent":[10], #20
        "n_layers": [1], #2
        'gene_likelihood':['zinb'],
        'max_epochs':[400]}

# Create a list of dictionaries where each dictionary represents a set of hyperparameters
parameter_combinations = list(itertools.product(*search_space.values()))
hyperparameter_sets = [{key: value for key, value in zip(search_space.keys(), combination)} for combination in parameter_combinations]


import time    
start=time.time()

for panel in ['Panel1','Panel2']:
    print(panel)
    #adata_ts_panel=adata_ts_dict[panel]
    adata=adata_dict[panel]
    
    # Print the list of hyperparameter sets
    for n,hyperparameters in enumerate(hyperparameter_sets):
        ## Create string of the parameters
        pairs=[f"{key}:{value}" for key, value in hyperparameters.items()]
        param_set_string='-'.join(pairs)
        
        ## Load trained scVI model
        vae=adata_vae_dict[panel][param_set_string]


        adata.obsm["X_scVI_"+param_set_string]=vae.get_latent_representation()

        sc.pp.neighbors(adata, use_rep="X_scVI_"+param_set_string,n_neighbors=15,key_added="X_scVI_"+param_set_string,metric='cosine')
        #sc.tl.leiden(adata, key_added='scVI_leiden', resolution=0.6, neighbors_key="X_scVI_"+param_set_string)
        sc.tl.umap(adata,n_components=2,neighbors_key="X_scVI_"+param_set_string,random_state=0)
        #print('calc done')
        
        adata_dict[panel]=adata

        ## PLOT UMAPS
        
        color_cols=['patient', 'condition', 'original_sample','sample_region','n_counts', 'n_genes']
        ncols=2
        nrows=int(np.ceil(len(color_cols)/ncols))
        fig=plt.figure(figsize=(ncols*6,nrows*4))

        for n,col in enumerate(color_cols):
            ax=fig.add_subplot(nrows,ncols,n+1) 
            fig.suptitle('-'.join([panel,param_set_string,f'\ncells_scale_{scale_param}_asg_conf_{avg_assignment_conf_thr}']),
                         fontweight='bold',fontsize=15,y=0.99)
            leg_loc='right margin'

            if col in ['fastcluster','scVI_leiden','free_annotation']:
                leg_loc="on data"
            sc.pl.umap(adata[(adata.obs['n_counts']>0),:], color=col,show=False,ax=ax,size=4,
                       vmin='p1',vmax='p99',legend_fontsize=7,legend_loc=leg_loc)
        

## Save batch corrected data

In [None]:
for panel in [*adata_dict][:]:
    adata=adata_dict[panel]

    ## Save batch corrected data
    fn=os.path.join(proc_dir,f'filtered_batch_corr_{panel}_cells_scale_{scale_param}_asg_conf_{avg_assignment_conf_thr}.h5ad')
    adata.write_h5ad(fn,compression='gzip')

In [None]:
adata.obsm.keys()