In [1]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import anndata
import scarches

from anndata import AnnData
from scarches.models.scpoli import scPoli

import umap

INFO:pytorch_lightning.utilities.seed:Global seed set to 0
  new_rank_zero_deprecation(
  return new_rank_zero_deprecation(*args, **kwargs)
 captum (see https://github.com/pytorch/captum).


In [2]:
import torch
print(torch.cuda.is_available())

False


## Specify working directory

In [3]:
WD = "/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93vel/atlas_building"
VERSION = 17

## Functions

In [4]:
def read_samples(project_dir, sample_sheet, annot_dir):
    adata_list=[]
    for sample in sample_sheet.sample_id:
        sample_file=f"{project_dir}/h5ad_pyraw/cp_h5ad/{sample}_pyraw.h5ad"
        adata = sc.read_h5ad(sample_file)
        adata.obs['sample_id'] = sample
        origin = sample_sheet[sample_sheet['sample_id'] == sample]['derive'].values[0]
        adata.obs['derive'] = origin
        adata.var = adata.var.drop(adata.var.columns, axis=1)
        meta_file = f"{annot_dir}/{sample}_annotation.txt"
        if os.path.exists(meta_file) and os.path.getsize(meta_file)>1:
            meta_data = pd.read_csv(meta_file, sep="\t", index_col=0)
            adata.obs = pd.merge(adata.obs, meta_data, left_index=True, right_index=True)
        else:
            adata.obs['level_1'] = 'epithelial'
            adata.obs['level_2'] = 'na'

        adata_list.append(adata)
    adata_concat = anndata.concat(adata_list, join='outer', fill_value=0)
    adata_concat.obs['publication'] = ['_'.join(i.split('_')[:3]) for i in adata_concat.obs.sample_id.tolist()]

    adata_concat.obs.index.name = "cells"

    return adata_concat
    
def clear_genes(project_dir, adata):
    # clear_genes = pd.read_csv("/home/xuq44/refgenomes/hg38/hg_genes_clear_nocc.txt", header=None)[0].tolist()
    clear_genes = pd.read_csv(os.path.join(project_dir, "hg_genes_clear.txt"), header=None)[0].tolist()
    sub_clear_genes = [i for i in clear_genes if i in adata.var.index.tolist()]
    adata = adata[:, sub_clear_genes]
    
    return adata

def pre_inti0(adata):
    adata.layers['counts']=adata.X
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    adata.raw = adata

    return adata

def get_kwargs():
    early_stopping_kwargs = {
        "early_stopping_metric": "val_prototype_loss",
        "mode": "min",
        "threshold": 0,
        "patience": 20,
        "reduce_lr": True,
        "lr_patience": 13,
        "lr_factor": 0.1,
    }
    return early_stopping_kwargs

def train_scpoli(adata):
    early_stopping_kwargs = get_kwargs()

    scpoli_model = scPoli(
        adata=adata,
        unknown_ct_names=['na'],
        condition_keys='sample_id',
        cell_type_keys=['level_2'],
        embedding_dims=10,  # default
        latent_dim=10,  # default
        hidden_layer_sizes=[512, 512],
        recon_loss='mse',
    )

    scpoli_model.train(
        n_epochs=50,
        pretraining_epochs=4,
        early_stopping_kwargs=early_stopping_kwargs,
        eta=10,
        alpha_epoch_anneal=100,
    )

    return scpoli_model

def map_query(adata, model):
    early_stopping_kwargs = get_kwargs()

    scpoli_query = scPoli.load_query_data(
        adata=adata,
        reference_model=model,
        labeled_indices=[],
    )

    scpoli_query.train(
        n_epochs=100,
        pretraining_epochs=4,
        early_stopping_kwargs=early_stopping_kwargs,
        eta=10,
        alpha_epoch_anneal=100,
    )

    return scpoli_query

def umap_transform(adata_ref, adata_que, cell_type_col):
    model = umap.UMAP(n_neighbors=5, random_state=42, min_dist=0.5).fit(adata_ref.X)
    adata_ref.obsm['X_umap'] = model.transform(adata_ref.X)
    adata_que.obsm['X_umap'] = model.transform(adata_que.X)
    
    adata_ref.obs['query'] = 0
    adata_que.obs['query'] = 1
    
    adata_ref.obs['cell_type_pred'] = np.nan
    
    # harmonized annotation: reference labels transferred to query cells
    adata_ref.obs['cell_type'] = adata_ref.obs['level_2'].copy()
    adata_que.obs['cell_type'] = adata_que.obs['cell_type_pred'].copy()
    
    # original reference cell annotation
    adata_ref.obs['cell_type_ref'] = adata_ref.obs['level_2'].copy()
    adata_que.obs['cell_type_ref'] = np.nan
    
    # original query cell annotation
    # adata_ref.obs['cell_type_que'] = np.nan
    # adata_que.obs['cell_type_que'] = adata_que.obs['level_2'].copy()
    
    adata_all = anndata.concat([adata_ref, adata_que])

    return adata_all


## Load samples

In [5]:
project_dir = os.path.join(WD, 'data')
annot_dir = os.path.join(project_dir, 'sample_annot6')
sample_sheet = pd.read_csv(os.path.join(project_dir, 'all_samples_sheets.txt'), sep='\t')
sample_sheet = sample_sheet[sample_sheet.tissue=='lung']

### Load reference data

In [6]:
# FETAL
cell_type_col = 'new_celltype'
fetal = sc.read(os.path.join(project_dir, 'reference/Assembled10DomainsEpithelial.h5ad'))
fetal = clear_genes(project_dir, fetal)
fetal.obs['level_1'] = 'epithelial'
fetal.obs['level_2'] = fetal.obs[cell_type_col].copy()
fetal.obs['sample_id'] = fetal.obs['donor'].copy()
fetal.obs['maturity'] = 'fetal'

  fetal.obs['level_1'] = 'epithelial'


In [7]:
print(np.max(fetal.X))
print(np.min(fetal.X))

8.5433
0.0


In [8]:
# MATURE
cell_type_col = 'predicted_labels'
mature = sc.read(os.path.join(project_dir, 'reference/Barbry_Leroy_2020_epithelial_annot.h5ad'))
mature = clear_genes(project_dir, mature)
mature.obs['level_1'] = 'epithelial'
mature.obs['level_2'] = mature.obs[cell_type_col].copy()
mature.obs['sample_id'] = mature.obs['sample'].copy()
mature.obs['maturity'] = 'mature'
mature.uns['log1p'] = {'base': None}

  mature.obs['level_1'] = 'epithelial'


In [9]:
print(np.max(mature.X))
print(np.min(mature.X))

9.126299
0.0


In [10]:
reference = anndata.concat([fetal, mature])

  warn(


In [11]:
reference.obs['derive'] = np.nan
sc.pp.highly_variable_genes(reference, n_top_genes=3000, batch_key='sample_id')
reference = reference[:,reference.var.highly_variable]

### Load organoid data

In [12]:
organoid = read_samples(project_dir, sample_sheet, annot_dir)

  warn(
  utils.warn_names_duplicates("obs")


In [13]:
print(np.max(organoid.X))
print(np.min(organoid.X))

25001.0
0.0


In [14]:
organoid.obs.level_1.value_counts()

epithelial    225487
Name: level_1, dtype: int64

In [15]:
organoid.obs.level_2.value_counts()

alveolar type 2 (AT2) cells    61221
basal cells                    32813
club cells                     30613
goblet cells                   27694
stem cells                     19224
na                             16034
alveolar type 1 (AT1) cells    15337
neuroendocrine cells           14272
airway secretory cells          5626
ciliated cells                  2653
Name: level_2, dtype: int64

In [16]:
organoid.obs['orig_cell_types'] = organoid.obs['level_2'].copy()
organoid.obs['level_2'] = 'na'

In [17]:
organoid.obs

Unnamed: 0_level_0,initial_size_spliced,initial_size_unspliced,initial_size,n_genes_by_counts,total_counts,total_counts_mt,pct_counts_mt,total_counts_ribo,pct_counts_ribo,total_counts_hb,...,n_genes,sample_id,derive,level_1,level_2,level_3,Cell_type,n_counts,publication,orig_cell_types
cells,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACCCACAAAGCTAA,31397.0,14368.0,31397.0,6808,39508.0,4868.0,12.321555,7983.0,20.206034,1.0,...,6808.0,Chan_NatCommun_2022_bronchial_organoids,ASC,epithelial,na,basal cells,,,Chan_NatCommun_2022,basal cells
AAACCCACACTGTGTA,24573.0,8025.0,24573.0,4675,31483.0,3729.0,11.844487,7484.0,23.771559,0.0,...,4675.0,Chan_NatCommun_2022_bronchial_organoids,ASC,epithelial,na,club cells,,,Chan_NatCommun_2022,club cells
AAACCCATCCGCAGTG,21147.0,6246.0,21147.0,4822,26328.0,2810.0,10.673048,5713.0,21.699331,2.0,...,4822.0,Chan_NatCommun_2022_bronchial_organoids,ASC,epithelial,na,club cells,,,Chan_NatCommun_2022,club cells
AAACCCATCTAGTGTG,8687.0,3785.0,8687.0,2726,11182.0,632.0,5.651941,3580.0,32.015739,0.0,...,2726.0,Chan_NatCommun_2022_bronchial_organoids,ASC,epithelial,na,basal cells,,,Chan_NatCommun_2022,basal cells
AAACGAACAACTCCCT,11375.0,5798.0,11375.0,3956,14544.0,1113.0,7.652640,4326.0,29.744226,0.0,...,3956.0,Chan_NatCommun_2022_bronchial_organoids,ASC,epithelial,na,basal cells,,,Chan_NatCommun_2022,basal cells
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGATCAGTCTAGCT,9504.0,21130.0,9504.0,4459,11821.0,2224.0,18.813974,739.0,6.251586,0.0,...,4459.0,Hein_Dev_2022_Spheroids,IPS,epithelial,na,,,,Hein_Dev_2022,na
TTTGATCCAACATACC,50587.0,52023.0,50587.0,8910,61103.0,1360.0,2.225750,11083.0,18.138226,5.0,...,8910.0,Hein_Dev_2022_Spheroids,IPS,epithelial,na,,,,Hein_Dev_2022,na
TTTGGAGAGGAGAGGC,9358.0,9525.0,9358.0,3935,10915.0,2025.0,18.552450,1227.0,11.241410,1.0,...,3935.0,Hein_Dev_2022_Spheroids,IPS,epithelial,na,,,,Hein_Dev_2022,na
TTTGGTTAGTGCAGCA,11309.0,14530.0,11309.0,4664,13709.0,1197.0,8.731490,2147.0,15.661244,1.0,...,4664.0,Hein_Dev_2022_Spheroids,IPS,epithelial,na,,,,Hein_Dev_2022,na


In [18]:
organoid = pre_inti0(organoid)
organoid.X = organoid.X.astype(np.float32)

In [19]:
# select only those highly variable genes that occur in both the reference and organoid data
try:
    organoid = organoid[:,reference.var.index]
except:
    a = set(reference.var.index)
    b = set(organoid.var.index)
    overlap = a.intersection(b)
    reference = reference[:,pd.Index(overlap)]
    organoid = organoid[:,pd.Index(overlap)]

In [20]:
out_folder = f"{WD}/q2r_models/"
model_path = f"{out_folder}/model_v{VERSION}"

## Train Q2R model

In [None]:
# DO NOT RUN IF YOU HAVE TRAINED A MODEL BEFORE!
scpoli_model = train_scpoli(reference)

  self.adata.obs['conditions_combined'] = adata.obs[condition_keys].apply(lambda x: '_'.join(x))


In [None]:
reference.obsm['X_scPoli'] = scpoli_model.get_latent(
    reference,
    mean=True,
)

In [None]:
# visualize the latent representation of reference cells computed by scPoli
sc.pp.neighbors(reference, use_rep='X_scPoli')
sc.tl.umap(reference)
sc.pl.umap(
    reference, 
    color=['maturity', 'level_2'],
    show=True,
    frameon=False,
    save=f'_scPoli_latent_v{VERSION}.png',
)

In [None]:
scpoli_model.save(model_path)

## Load Q2R model

In [None]:
# load a Q2R model if you've trained one before
# scpoli_model = scarches.models.scpoli.scPoli.load(model_path, reference)

## Map organoid to reference cells

In [None]:
scpoli_query = map_query(organoid, scpoli_model)

# get latent representation of reference data
scpoli_query.model.eval()
data_latent_source = scpoli_query.get_latent(
    reference,
    mean=True,
)

# get latent representation of query data
data_latent_target = scpoli_query.get_latent(
    organoid,
    mean=True,
)

adata_latent_source = AnnData(data_latent_source)
adata_latent_source.obs = reference.obs.copy()

adata_latent_target = AnnData(data_latent_target)
adata_latent_target.obs = organoid.obs.copy()

In [None]:
# get label annotations
results_dict = scpoli_query.classify(organoid)
adata_latent_target.obs['cell_type_pred'] = results_dict['level_2']['preds'].tolist()
adata_latent_target.obs['cell_type_uncert'] = results_dict['level_2']['uncert'].tolist()
adata_latent_target.obs['classifier_outcome'] = (
    adata_latent_target.obs['cell_type_pred'] == adata_latent_target.obs['level_2']
)

In [None]:
# the representation is chosen automatically: For .n_vars < 50, .X is used, otherwise ‘X_pca’ is used
sc.pp.neighbors(adata_latent_target, n_neighbors=15)
sc.tl.umap(adata_latent_target)

adata_latent = umap_transform(adata_latent_source, adata_latent_target, cell_type_col=cell_type_col)

## Save latent representation

In [None]:
adata_latent.write(os.path.join(WD, f"data/q2r_fetal_adata_latent_v{VERSION}.h5ad"))

## Plots

In [None]:
sc.pl.umap(
    adata_latent,
    color='cell_type',
    show=True,
    frameon=False,
    save=f'_new_celltype_v{VERSION}.png',
)

In [None]:
sc.pl.umap(
    adata_latent,
    color='cell_type_ref',
    show=True,
    frameon=False,
)

In [None]:
sc.pl.umap(
    adata_latent,
    color='query',
    show=True,
    frameon=False,
    save=f'_query_v{VERSION}.png',
)

In [None]:
sc.pl.umap(
    adata_latent,
    color='cell_type_pred',
    show=True,
    frameon=False,
    save=f'_cell_type_pred_v{VERSION}.png',
)

In [None]:
sc.pl.umap(
    adata_latent,
    color='derive',
    show=True,
    frameon=False,
    save=f'_derive_v{VERSION}.png',
)