In [None]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import anndata
import scarches

from anndata import AnnData
from scarches.models.scpoli import scPoli

import umap

In [None]:
import torch
print(torch.cuda.is_available())

## Specify working directory

In [None]:
WD = "/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93vel/atlas_building"
VERSION = 46

In [None]:
out_folder = f"{WD}/q2r_models/"
model_path = f"{out_folder}/model_v{VERSION}"

In [None]:
# reference column containing the cell type labels of interest
level = 'level_3'

## Functions

In [None]:
def read_samples(project_dir, sample_sheet, annot_dir):
    adata_list=[]
    for sample in sample_sheet.sample_id:
        sample_file=f"{project_dir}/h5ad_pyraw/cp_h5ad/{sample}_pyraw.h5ad"
        adata = sc.read_h5ad(sample_file)
        adata.obs['sample_id'] = sample
        origin = sample_sheet[sample_sheet['sample_id'] == sample]['derive'].values[0]
        adata.obs['derive'] = origin
        adata.var = adata.var.drop(adata.var.columns, axis=1)
        meta_file = f"{annot_dir}/{sample}_annotation.txt"
        if os.path.exists(meta_file) and os.path.getsize(meta_file)>1:
            meta_data = pd.read_csv(meta_file, sep="\t", index_col=0)
            adata.obs = pd.merge(adata.obs, meta_data, left_index=True, right_index=True)
        else:
            adata.obs['level_1'] = 'epithelial'
            adata.obs['level_2'] = 'na'

        adata_list.append(adata)
    adata_concat = anndata.concat(adata_list, join='outer', fill_value=0)
    adata_concat.obs['publication'] = ['_'.join(i.split('_')[:3]) for i in adata_concat.obs.sample_id.tolist()]

    adata_concat.obs.index.name = "cells"

    return adata_concat

def get_counts(adata):
    t = adata.X.toarray()
    data_df = pd.DataFrame(data=t, index=adata.obs_names, columns=adata.var_names)
    data_df = np.expm1(data_df)
    counts_df = data_df.T.mul(adata.obs.n_counts).div(10000)
    counts_df = counts_df.T.iloc[:,:]
    counts_df = counts_df.round(0).astype(np.float32)
    return counts_df
    
def clear_genes(project_dir, adata):
    # clear_genes = pd.read_csv("/home/xuq44/refgenomes/hg38/hg_genes_clear_nocc.txt", header=None)[0].tolist()
    clear_genes = pd.read_csv(os.path.join(project_dir, "hg_genes_clear.txt"), header=None)[0].tolist()
    sub_clear_genes = [i for i in clear_genes if i in adata.var.index.tolist()]
    adata = adata[:, sub_clear_genes]
    
    return adata

def pre_inti0(adata):
    adata.layers['counts']=adata.X
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    adata.raw = adata

    return adata

def get_kwargs():
    early_stopping_kwargs = {
        "early_stopping_metric": "val_prototype_loss",
        "mode": "min",
        "threshold": 0,
        "patience": 20,
        "reduce_lr": True,
        "lr_patience": 13,
        "lr_factor": 0.1,
    }
    return early_stopping_kwargs

def train_scpoli(adata):
    early_stopping_kwargs = get_kwargs()

    scpoli_model = scPoli(
        adata=adata,
        unknown_ct_names=['na'],
        condition_keys=['sample_id'],
        cell_type_keys=level,
        embedding_dims=10,  # default
        latent_dim=10,  # default
        hidden_layer_sizes=[512, 512],
        recon_loss='nb',
    )

    scpoli_model.train(
        n_epochs=50,
        pretraining_epochs=4,
        early_stopping_kwargs=early_stopping_kwargs,
        eta=10,
        alpha_epoch_anneal=100,
    )

    return scpoli_model

def map_query(adata, model):
    early_stopping_kwargs = get_kwargs()

    scpoli_query = scPoli.load_query_data(
        adata=adata,
        reference_model=model,
        labeled_indices=[],
    )

    scpoli_query.train(
        n_epochs=200,
        pretraining_epochs=80,
        early_stopping_kwargs=early_stopping_kwargs,
        eta=10,
        alpha_epoch_anneal=100,
    )

    return scpoli_query

def umap_transform(adata_ref, adata_que, cell_type_col):
    model = umap.UMAP(n_neighbors=5, random_state=42, min_dist=0.5).fit(adata_ref.X)
    adata_ref.obsm['X_umap'] = model.transform(adata_ref.X)
    adata_que.obsm['X_umap'] = model.transform(adata_que.X)
    
    adata_ref.obs['query'] = 0
    adata_que.obs['query'] = 1
    
    adata_ref.obs['cell_type_pred'] = np.nan
    adata_ref.obs['cell_type_uncert'] = np.nan
    adata_que.obs['maturity'] = np.nan
    
    # harmonized annotation: reference labels transferred to query cells
    adata_ref.obs['cell_type'] = adata_ref.obs[level].copy()
    adata_que.obs['cell_type'] = adata_que.obs['cell_type_pred'].copy()
    
    # original reference cell annotation
    adata_ref.obs['cell_type_ref'] = adata_ref.obs[level].copy()
    adata_que.obs['cell_type_ref'] = np.nan
    
    adata_all = anndata.concat([adata_ref, adata_que])

    return adata_all


## Load samples

In [None]:
project_dir = os.path.join(WD, 'data')
annot_dir = os.path.join(project_dir, 'sample_annot6')
sample_sheet = pd.read_csv(os.path.join(project_dir, 'all_samples_sheets.txt'), sep='\t')
sample_sheet = sample_sheet[sample_sheet.tissue=='lung']

### Load reference data

In [None]:
# FETAL
fetal = sc.read(os.path.join(project_dir, 'reference/Assembled10DomainsEpithelialSimplifiedAnnotation.h5ad'))
fetal = clear_genes(project_dir, fetal)
fetal.obs['level_1'] = 'epithelial'
fetal.obs['level_2'] = fetal.obs['simplified_celltype'].copy()
fetal.obs['level_3'] = fetal.obs['new_celltype'].copy()
fetal.obs['sample_id'] = fetal.obs['donor'].copy()
fetal.obs['maturity'] = 'fetal'

In [None]:
fetal.obs['level_2']

In [None]:
fetal.obs['level_3']

In [None]:
# revert the normalized, transformed data back to counts
fetal_counts = get_counts(fetal)
fetal.X = fetal_counts.values.copy()

In [None]:
print(np.max(fetal.X))
print(np.min(fetal.X))
print(type(fetal.X[0,0]))

In [None]:
# MATURE
mature = sc.read(os.path.join(project_dir, 'reference/Barbry_Leroy_2020_epithelial_annot_simplified.h5ad'))
mature = clear_genes(project_dir, mature)
mature.obs['level_1'] = 'epithelial'
mature.obs['level_2'] = mature.obs['simplified_celltype'].copy()
mature.obs['level_3'] = mature.obs['predicted_labels'].copy()
mature.obs['sample_id'] = mature.obs['sample'].copy()
mature.obs['maturity'] = 'mature'
mature.uns['log1p'] = {'base': None}

In [None]:
mature.obs['level_2']

In [None]:
mature.obs['level_3']

In [None]:
mature.X = mature.layers['counts'].copy()

In [None]:
print(np.max(mature.X))
print(np.min(mature.X))
print(type(mature.X[0,0]))

In [None]:
reference = anndata.concat([fetal, mature])

In [None]:
reference.obs['derive'] = np.nan
reference.obs

In [None]:
sc.pp.highly_variable_genes(reference,
                            n_top_genes=3000, 
                            batch_key='sample_id',
                            flavor='seurat_v3')

reference = reference[:,reference.var.highly_variable]

### Load organoid data

In [None]:
organoid = read_samples(project_dir, sample_sheet, annot_dir)

In [None]:
print(np.max(organoid.X))
print(np.min(organoid.X))
print(type(organoid.X[0,0]))

In [None]:
organoid.X = organoid.X.astype(np.float32)
print(type(organoid.X[0,0]))

In [None]:
organoid.obs.level_1.value_counts()

In [None]:
organoid.obs.level_2.value_counts()

In [None]:
organoid.obs['orig_cell_types'] = organoid.obs['level_2'].copy()
organoid.obs['level_2'] = 'na'
organoid.obs['level_3'] = 'na'

In [None]:
organoid.obs

In [None]:
# note that the 'Miller' cells were incorrectly annotated; these cells are derived from FSC cells rather than ASC cells
# we correct the annotation in the post-hoc analysis notebook

In [None]:
organoid[organoid.obs.publication == 'Miller_DevCell_2020'].obs.sample_id.value_counts()

In [None]:
organoid[organoid.obs.publication == 'Miller_DevCell_2020'].obs.derive

In [None]:
# select only those highly variable genes that occur in both the reference and organoid data
try:
    organoid = organoid[:,reference.var.index]
except:
    a = set(reference.var.index)
    b = set(organoid.var.index)
    overlap = a.intersection(b)
    reference = reference[:,pd.Index(overlap)]
    organoid = organoid[:,pd.Index(overlap)]

In [None]:
print(reference.shape)
print(organoid.shape)

## Integrate reference data

If you've already integrated your reference data, skip ahead to the next section!

In [None]:
scpoli_model = train_scpoli(reference)

In [None]:
reference.obsm['X_scPoli'] = scpoli_model.get_latent(
    reference,
    mean=True,
)

In [None]:
# visualize the latent representation of reference cells computed by scPoli
sc.pp.neighbors(reference, use_rep='X_scPoli')
sc.tl.umap(reference)
sc.pl.umap(
    reference, 
    color=['maturity', level],
    show=True,
    frameon=False,
    save=f'_scPoli_latent_v{VERSION}.png',
)

In [None]:
# reference.obs[level] = ''
# reference.obs[level][:fetal.shape[0]] = fetal.obs['new_celltype'].copy()
# reference.obs[level][fetal.shape[0]:] = mature.obs['predicted_labels'].copy()

In [None]:
sc.pl.umap(
    reference, 
    color=[level],
    show=True,
    frameon=False,
)

In [None]:
reference

In [None]:
reference.obs.drop(columns=['conditions_combined'], inplace=True)
reference.write(os.path.join(project_dir, f'reference/integrated_reference_v{VERSION}.h5ad'))

In [None]:
scpoli_model.save(model_path)

## Load model and integrated reference data

In [None]:
# load a model if you've trained one before
reference = sc.read_h5ad(os.path.join(project_dir, f"reference/integrated_reference_v{VERSION}.h5ad"))
scpoli_model = scarches.models.scpoli.scPoli.load(f"{out_folder}/model_v{VERSION}", reference)

In [None]:
sc.pl.umap(
    reference, 
    color=['maturity', level],
    show=True,
    frameon=False,
)

## Map organoid to reference cells

In [None]:
scpoli_query = map_query(organoid, scpoli_model)

# get latent representation of reference data
scpoli_query.model.eval()
data_latent_source = scpoli_query.get_latent(
    reference,
    mean=True,
)

# get latent representation of query data
data_latent_target = scpoli_query.get_latent(
    organoid,
    mean=True,
)

adata_latent_source = AnnData(data_latent_source)
adata_latent_source.obs = reference.obs.copy()

adata_latent_target = AnnData(data_latent_target)
adata_latent_target.obs = organoid.obs.copy()

In [None]:
# get label annotations
results_dict = scpoli_query.classify(organoid, scale_uncertainties=True)
adata_latent_target.obs['cell_type_pred'] = results_dict[level]['preds'].tolist()
adata_latent_target.obs['cell_type_uncert'] = results_dict[level]['uncert'].tolist()
adata_latent_target.obs['classifier_outcome'] = (
    adata_latent_target.obs['cell_type_pred'] == adata_latent_target.obs[level]
)

In [None]:
print(adata_latent_source.X.shape)
print(adata_latent_target.X.shape)

In [None]:
adata_latent = umap_transform(adata_latent_source, adata_latent_target, cell_type_col='simplified_celltype')

In [None]:
# adata_latent_reference = AnnData(reference.obsm['X_scPoli'])
# adata_latent_reference.obs = reference.obs.copy()
# adata_latent = umap_transform(adata_latent_reference, adata_latent_target, cell_type_col=cell_type_col)

In [None]:
adata_latent

## Save latent representation

In [None]:
adata_latent.write(os.path.join(WD, f"data/q2r_fetal_adata_latent_v{VERSION}.h5ad"))

## Plots

In [None]:
sc.pl.umap(
    adata_latent,
    color='cell_type',
    show=True,
    frameon=False,
    save=f'_new_celltype_v{VERSION}.png',
)

In [None]:
sc.pl.umap(
    adata_latent,
    color='cell_type_ref',
    show=True,
    frameon=False,
    save=f'_cell_type_ref_v{VERSION}.png',
)

In [None]:
sc.pl.umap(
    adata_latent,
    color='query',
    show=True,
    frameon=False,
    save=f'_query_v{VERSION}.png',
)

In [None]:
sc.pl.umap(
    adata_latent,
    color='cell_type_pred',
    show=True,
    frameon=False,
    save=f'_cell_type_pred_v{VERSION}.png',
)

In [None]:
sc.pl.umap(
    adata_latent,
    color='derive',
    show=True,
    frameon=False,
    save=f'_derive_v{VERSION}.png',
)

In [None]:
sc.pl.umap(
    adata_latent,
    color='cell_type_uncert',
    show=True,
    frameon=False,
    cmap='magma',
    # vmax=1,
    save=f'_uncert_v{VERSION}.png',
)

In [None]:
sc.pl.umap(
    adata_latent,
    color='maturity',
    show=True,
    frameon=False,
    save=f'_maturity_v{VERSION}.png',
)