Author: Erno Hänninen

Created: 14.02.2023

Title: scvi_parameter_autotune.ipynb

Description: 
- Automated parameter fine-tuning for scvi integration

Procedure
- Read data to be integrated
- Prepare autotune class and define search space for it
- Run autotune
- Find the best performing parameter combination from autotune output by creating a dataframe
- Visualize result on seaborns relational plot

List of non-standard modules:
- ray, scanpy, scvi, matplotlib, seaborn, numpy, pandas

Conda environment used:
- PYenv

Usage:
- The script was executed using Jupyter Notebook web interface. All the dependencies required by Jupyter are installed to PYenv Conda environment. See README file for further details

In [None]:
# Python packages
import ray
import scanpy as sc
import scvi
from ray import tune
from scvi import autotune
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

In [2]:
# Read data
adata = sc.read("Data/adata_ready_for_scvi.h5ad")

In [None]:
# Prepare autotuner class
scvi.model.SCVI.setup_anndata(adata, layer="counts", batch_key="sample")
model_cls = scvi.model.SCVI
scvi_tuner = autotune.ModelTuner(model_cls)
scvi_tuner.info() # Prints ttunable parameters

In [6]:
# Define search space
search_space = {
    "n_hidden": tune.choice([64, 128, 256]),
    "n_layers": tune.choice([1, 2, 3]),
    "n_latent": tune.choice([10, 20, 30]),
    "dispersion":tune.choice(["gene", "gene-batch", "gene-label", "gene-cell"]),
    "gene_likelihood":tune.choice(["zinb", "nb", "poisson"]),
}

# 324 different combinations


In [7]:
# printing default number of epochs for this dataset
np.min([round((20000 / len(adata.obs_names)) * 400), 400])

55

In [None]:
# Run autotune
# Samplig the search space 200 times 
# Train max 55 epochs which is the default num of epochs for the dataset
ray.init(log_to_driver=False)
results = scvi_tuner.fit(
    adata,
    metric="validation_loss",
    search_space=search_space,
    num_samples=200,
    max_epochs=55,
    resources={"cpu": 20, "gpu": 1},
)

In [3]:
# the autotune module doesn't return df, instead it prints it
# The printed df was manually copied to text editor and here reading it back to python
autotune_result = pd.read_csv('scvi_parameters.csv')
# Print the autotune results, sort based on validation_loss (better validation_loss yields better training)
# The best performing combination was taken from this df
autotune_result.sort_values(by=["validation_loss"]).head(10)

Unnamed: 0,Trial name,status,loc,n_hidden,n_layers,n_latent,dispersion,gene_likelihood,validation_loss
57,_trainable_94f0b_00057,TERMINATED,10.84.1.4:3177749,256,2,30,gene-cell,zinb,468.95
152,_trainable_94f0b_00152,TERMINATED,10.84.1.4:3177749,256,2,30,gene-label,zinb,470.35
183,_trainable_94f0b_00183,TERMINATED,10.84.1.4:3178012,256,1,30,gene-cell,nb,470.631
80,_trainable_94f0b_00080,TERMINATED,10.84.1.4:3177749,256,1,30,gene-label,zinb,470.907
95,_trainable_94f0b_00095,TERMINATED,10.84.1.4:3177412,256,1,30,gene-label,zinb,471.018
49,_trainable_94f0b_00049,TERMINATED,10.84.1.4:3178340,256,1,30,gene,zinb,471.068
85,_trainable_94f0b_00085,TERMINATED,10.84.1.4:3177412,256,3,30,gene-label,zinb,471.432
127,_trainable_94f0b_00127,TERMINATED,10.84.1.4:3178340,256,2,20,gene-label,zinb,471.442
169,_trainable_94f0b_00169,TERMINATED,10.84.1.4:3178340,128,1,30,gene-cell,zinb,471.577
117,_trainable_94f0b_00117,TERMINATED,10.84.1.4:3177749,256,1,30,gene,zinb,471.587


In [5]:
# Plotting parameters controlling how gene expression is modelled on relational plot
sns.set(font_scale=1.3)
with plt.rc_context({"figure.dpi": (400)}):
    sns.relplot(data=autotune_result, x=autotune_result['gene_likelihood'], y= autotune_result["validation_loss"], hue=autotune_result['dispersion'])
    plt.savefig("figures/network_param.png", dpi=400, bbox_inches='tight')

In [None]:
# Plotting parameters controlling network structure on relational plot
with plt.rc_context({"figure.dpi": (400)}):
    sns.relplot(data=autotune_result, x=autotune_result['n_layers'], y= autotune_result["validation_loss"], hue=autotune_result['n_hidden'])
    plt.savefig("figures/gene_param.png", dpi=400, bbox_inches='tight')    