# Solo

Publication: https://www.cell.com/cell-systems/fulltext/S2405-4712(20)30195-2  
GitHub: https://github.com/calico/solo/blob/master/README.md  

In [3]:
import scanpy as sc

import pandas as pd
import numpy as np

import os
import gc

import seaborn as sns
import matplotlib.pyplot as plt

In [5]:
import scvi

In [6]:
scvi.__version__

'1.1.6.post2'

In [None]:
sc.set_figure_params(figsize=(5, 5))

In [None]:
os.chdir('/research/peer/fdeckert/FD20200109SPLENO')

## Settup rpy2 

In [None]:
os.environ['R_HOME'] = '/nobackup/peer/fdeckert/miniconda3/envs/r.4.1.0/lib/R/'

In [None]:
import rpy2.rinterface_lib.callbacks
import logging

from rpy2.robjects import pandas2ri
import anndata2ri

In [None]:
%load_ext rpy2.ipython

# Figures

In [None]:
sc.set_figure_params(figsize=(7, 7), transparent=False)

# Parameter 

In [None]:
cache_scvi = True
cache_solo = True

# Files and directories 

In [None]:
adata_file = 'data/BSA_0355_SM01_10x_SPLENO/ANALYSIS/raw.h5ad'

# Import anndata 

In [None]:
adata = sc.read_h5ad(adata_file)

# Subset data 

In [None]:
adata = adata[adata.obs['nCount_RNA']>=1000]
adata = adata[adata.obs['pMt_RNA']<=15]
adata = adata[:, (adata.X>=1).sum(axis=0)>=3]

# Run SCVI 

In [None]:
def scvi_sample(adata, sample_name): 
    
    print(sample_name)
    
    # Import modules 
    import scvi
    
    # Set SCVI threads
    scvi.settings.num_threads=38
    
    # Subset data 
    adata_i = adata[adata.obs['sample_name']==sample_name]

    # Subset genes
    adata_i = adata_i[:, (adata_i.X>=1).sum(axis=0)>=3]
        
    # Setup AnnData 
    adata_i = adata_i.copy()
    scvi.model.SCVI.setup_anndata(adata_i)
        
    # Model 
    model_i=scvi.model.SCVI(adata_i, gene_likelihood='nb')

    # Max epochs heuristic 
    max_epochs = int(np.min([round((20000 / adata_i.n_obs) * 400), 400]))
    print(type(max_epochs))

    # Train 
    model_i.train(max_epochs=max_epochs, check_val_every_n_epoch=1)

    # Update adata with scvi results  
    adata_i.obsm['latent']=model_i.get_latent_representation()

    # Dim reduction and clustering 
    sc.pp.neighbors(adata_i, n_neighbors=10, use_rep='latent')
    sc.tl.leiden(adata_i, resolution=2, flavor="igraph", n_iterations=2)
    sc.tl.umap(adata_i)

    # Plot model history 
    model_history(adata_i, model_i)

    # Save 
    model_i.save('data/BSA_0355_SM01_10x_SPLENO/ANALYSIS/solo/'+sample_name+'/'+'model/', overwrite=True)
    adata_i.write_h5ad('data/BSA_0355_SM01_10x_SPLENO/ANALYSIS/solo/'+sample_name+'/'+'adata.h5ad')
        
    # Clean up environment     
    del adata_i, model_i
    gc.collect()

In [None]:
def model_history(adata, model): 
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    axes[0].plot(model.history['reconstruction_loss_train']['reconstruction_loss_train'], label='train')
    axes[0].plot(model.history['reconstruction_loss_validation']['reconstruction_loss_validation'], label='validation')
    axes[0].set_title("Reconstruction Loss")
    axes[0].set_xlabel("Epoch")
    axes[0].set_ylabel("Loss")
    axes[0].legend()

    axes[1].plot(model.history['elbo_train']['elbo_train'], label='train')
    axes[1].plot(model.history['elbo_validation']['elbo_validation'], label='validation')
    axes[1].set_title("ELBO")
    axes[1].set_xlabel("Epoch")
    axes[1].set_ylabel("Loss")
    axes[1].legend()
    
    if 'solo_label' in adata.obs.columns:
        sc.pl.umap(adata, color=['solo_label'], ax=axes[2], show=False)
        axes[2].set_title("UMAP")
        
    else: 
        sc.pl.umap(adata, color=['leiden'], ax=axes[2], show=False)
        axes[2].set_title("UMAP")
    
    plt.tight_layout()
    plt.show()

In [None]:
if not cache_scvi: 
    
    from multiprocessing import Process
    for sample_name in adata.obs['sample_name'].cat.categories: 
        p = Process(target=scvi_sample, args=(adata, sample_name,))
        p.start()
        p.join()
        
else: 
    
    for sample_name in adata.obs['sample_name'].cat.categories:
        
        import scvi
        
        # Load data 
        adata_i = sc.read_h5ad('data/BSA_0355_SM01_10x_SPLENO/ANALYSIS/solo/'+sample_name+'/'+'adata.h5ad')
        model_i = scvi.model.SCVI.load('data/BSA_0355_SM01_10x_SPLENO/ANALYSIS/solo/'+sample_name+'/'+'model/', adata=adata_i)
        
        # Plot model history
        model_history(adata_i, model_i)
        
        # Clean up environment 
        del adata_i, model_i
        gc.collect()

# Run solo 

In [None]:
def solo_sample(sample_name):
    
    print(sample_name)
    
    # Import modules 
    import scvi
    
    # Set SCVI threads
    scvi.settings.num_threads=38
    
    # Load and prepare adata
    adata_i = sc.read_h5ad('data/BSA_0355_SM01_10x_SPLENO/ANALYSIS/solo/'+sample_name+'/'+'adata.h5ad')
    
    # Load scvi model  
    model_i = scvi.model.SCVI.load('data/BSA_0355_SM01_10x_SPLENO/ANALYSIS/solo/'+sample_name+'/'+'model/', adata_i)
    
    # Register solo model 
    solo_i = scvi.external.SOLO.from_scvi_model(model_i)
    
    # Max epochs heuristic 
    max_epochs = int(np.min([round((20000 / adata_i.n_obs) * 400), 400]))
    print(type(max_epochs))
    
    # Train solo
    solo_i.train(max_epochs=max_epochs, check_val_every_n_epoch=1)

    # Get doublet score and labels 
    prob = solo_i.predict(soft=True, include_simulated_doublets=False)
    label = solo_i.predict(soft=False, include_simulated_doublets=False)

    adata_i.obs['solo_pred_doublet'] = prob['doublet']
    adata_i.obs['solo_pred_singlet'] = prob['singlet']
    adata_i.obs['solo_label'] = label
    
    # Plot model results 
    sc.pl.umap(adata_i, color=['sample_name', 'leiden', 'solo_label'])

    # Store results 
    solo_i.save('data/BSA_0355_SM01_10x_SPLENO/ANALYSIS/solo/'+sample_name+'/'+'model_solo/', overwrite=True, save_anndata=True)
    adata_i.write('data/BSA_0355_SM01_10x_SPLENO/ANALYSIS/solo/'+sample_name+'/'+'adata.h5ad')
    
    # Clear environment 
    del adata_i, model_i
    gc.collect()

In [None]:
if not cache_solo: 
    
    from multiprocessing import Process
    for sample_name in adata.obs['sample_name'].cat.categories: 
        p = Process(target=solo_sample, args=(sample_name,))
        p.start()
        p.join()
        
                
else: 
    
    for sample_name in adata.obs['sample_name'].cat.categories:
        
        import scvi
        
        # Load data 
        adata_i = sc.read_h5ad('data/BSA_0355_SM01_10x_SPLENO/ANALYSIS/solo/'+sample_name+'/'+'adata.h5ad')
        
        # Plot results
        sc.pl.umap(adata_i, color=['sample_name', 'leiden', 'solo_label'])
        
        # Clean up environment 
        del adata_i
        gc.collect()

# Combine SOLO results

In [None]:
solo_pred = []

for sample_name in adata.obs['sample_name'].cat.categories:
    adata_i = sc.read_h5ad('data/BSA_0355_SM01_10x_SPLENO/ANALYSIS/solo/'+sample_name+'/'+'adata.h5ad')
    solo_pred.append(adata_i.obs[['solo_label', 'solo_pred_doublet', 'solo_pred_singlet']])

solo_pred = pd.concat(solo_pred, axis=0, ignore_index=False)

In [None]:
adata.obs = adata.obs.merge(solo_pred, left_index=True, right_index=True, how='left')

# Save results

In [None]:
adata.obs[['solo_label', 'solo_pred_doublet', 'solo_pred_singlet']].rename(columns={'leiden': 'solo_leiden'}).to_csv('result/solo/solo.csv')