Author: Erno Hänninen

Created: 03.02.2023

Title: d16_prediction.ipynb

Description: 
- Predict cell types of data sequenced from our in-house hypothalamic differentiation protocols at day 16

Procedure
- Read the day 16 data
- Perform additional filtering
- Re-train the scANVI reference model using day 16 query data
- Predict the cell types from day 16 data
- Assign cell type with highest probability and annotate cell with prediction score less than 0.6 as unknown
- Plot result and marker gene expression on umap


List of non-standard modules:
- scanpy, scvi, matplotlib, pandas, seaborn

Conda environment used:
- PYenv

Usage:
- The script was executed using Jupyter Notebook web interface. All the dependencies required by Jupyter are installed to PYenv Conda environment. See README file for further details

In [13]:
# Python packages
import scanpy as sc
import scvi
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

import os
os.environ["MKL_NUM_THREADS"] = "30"
os.environ["NUMEXPR_NUM_THREADS"] = "30"
os.environ["OMP_NUM_THREADS"] = "30"

# Reading d16 data and quality filtering

In [14]:
adata_d16 = sc.read("Data/d16.h5ad")

In [15]:
# Compute qc metrix
adata_d16.var['mt'] = adata_d16.var_names.str.startswith('MT-')  # annotate the group of mitochondrial genes as 'mt'
sc.pp.calculate_qc_metrics(adata_d16, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)
fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(1, 5,  figsize=(20,4), gridspec_kw={'wspace':0.9})
ax1_dict = sc.pl.violin(adata_d16,['n_genes_by_counts'], jitter=0.4, show=False, ax = ax1, stripplot=False) 
ax2_dict = sc.pl.violin(adata_d16,['total_counts'], jitter=0.4, show=False, ax = ax2, stripplot=False)
ax3_dict = sns.histplot(adata_d16.obs["n_genes_by_counts"],  ax = ax3)
ax4_dict = sns.histplot(adata_d16.obs["total_counts"], ax = ax4)
ax5_dict = sc.pl.scatter(adata_d16, x='total_counts', y='n_genes_by_counts', show=False, ax=ax5)

In [16]:
# perform fitering
sc.pp.filter_cells(adata_d16, min_counts=100)
sc.pp.filter_cells(adata_d16, max_counts=35000)
sc.pp.filter_cells(adata_d16, min_genes=1450)

In [17]:
# Recompute qc metrics to see the filtering was enough
adata_d16.var['mt'] = adata_d16.var_names.str.startswith('MT-')  # annotate the group of mitochondrial genes as 'mt'
sc.pp.calculate_qc_metrics(adata_d16, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)
# Plot to see that the data quality is okay
fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(1, 5,  figsize=(20,4), gridspec_kw={'wspace':0.9})
ax1_dict = sc.pl.violin(adata_d16,['n_genes_by_counts'], jitter=0.4, show=False, ax = ax1, stripplot=False) 
ax2_dict = sc.pl.violin(adata_d16,['total_counts'], jitter=0.4, show=False, ax = ax2, stripplot=False)
ax3_dict = sns.histplot(adata_d16.obs["n_genes_by_counts"],  ax = ax3)
ax4_dict = sns.histplot(adata_d16.obs["total_counts"], ax = ax4)
ax5_dict = sc.pl.scatter(adata_d16, x='total_counts', y='n_genes_by_counts', show=False, ax=ax5)

In [18]:
# Prepare data for scANVI and plotting
adata_d16.var = adata_d16.var.set_index('_index')
adata_d16.obs["sample"] = "d16"
adata_d16.layers["counts"] = adata_d16.X # Raw counts stored in count layer

# Normalize the adata.X
sc.pp.normalize_total(adata_d16, target_sum=1e4)
sc.pp.log1p(adata_d16)
adata_d16_original = adata_d16.copy()

# Query mapping

In [None]:
# Using a pretrained scanvi model
scvi.model.SCANVI.prepare_query_anndata(adata_d16, "scanvi_model")
# Initialize model
vae_q = scvi.model.SCANVI.load_query_data(adata_d16,"scanvi_model")
vae_q

In [None]:
# Train scANVI model
vae_q.train(max_epochs=60,plan_kwargs=dict(weight_decay=0.0), early_stopping=True, train_size=0.7, batch_size=502, early_stopping_monitor='elbo_train')

# Predicting d16 data and visualization

In [None]:
# Run prediction
df = vae_q.predict(soft=True) # Returns a dataframe
# Predict function returns a dataframe
# From row get the cell type with the highest probability and the probability value
data = {"Cell_type":list(df.idxmax(axis=1)), "Probability":list(df.max(axis="columns"))}
# Create a dataframe from this data, cells with prediction score less than 0.6 are filtered away
cell_prob = pd.DataFrame(data) 
adata_d16.obs["Predictions"] = "Unknown"
cell_prob.loc[cell_prob.Probability < 0.6, ['Cell_type']] = 'Unknown'
adata_d16.obs.loc[adata_d16.obs["Predictions"] == "Unknown", "Predictions"] = list(cell_prob["Cell_type"])
adata_d16.raw = adata_d16_original.copy()
adata_d16.obs["Predictions"].value_counts()

In [23]:
# Removing astrocyte, opc and mural cells from day 16 as there were only one cell each
adata_d16 = adata_d16[adata_d16.obs["Predictions"].isin(["NP", "Neuron","Unknown", "ARC"])]

In [None]:
# plotting prediction results
with plt.rc_context({"figure.dpi": (400)}):
    sc.pl.umap(adata_d16, color=["Predictions"], legend_fontsize="large", frameon=False, save="_d16_pred.png")

# Plotting marker gene expresssion    
title_list = ["STMN2 - Neuron","SOX2 - NP", "POMC - ARC"]
gene_list = ["STMN2","SOX2", "POMC"]
for i, gene in enumerate(gene_list):
    if gene != "POMC":
        with plt.rc_context({"figure.dpi": (400)}):
            sc.pl.umap(adata_d16, color=gene, legend_fontsize="small", frameon=False, use_raw=True,  save="_d16_"+gene+".png", colorbar_loc=None, title = title_list[i])
    else:
        with plt.rc_context({"figure.dpi": (400)}):
            sc.pl.umap(adata_d16[adata_d16.obs["Predictions"].isin(["ARC", "Neuron"])], color=gene, legend_fontsize="small", frameon=False, use_raw=True,  save="_d16_"+gene+".png", colorbar_loc=None, title = title_list[i])
        


In [25]:
# Store resulting adata for later use
if "mt" in adata_d16.var.columns:
    del adata_d16.var["mt"]
adata_d16.write("Data/d16_predicted.h5ad")