# Train a `bioLORD` model with `developing human immune across tissue` for `bioLORD` (B-cells)

The data was generated by Suo et al.[[1]](https://www.science.org/doi/full/10.1126/science.abo0510) and downloaded from [Lymphoid cells](https://cellgeni.cog.sanger.ac.uk/developmentcellatlas/fetal-immune/PAN.A01.v01.raw_count.20210429.LYMPHOID.embedding.h5ad). <br>
The complete dataset contains a cross-tissue single-cell atlas of developing human immune cells across prenatal hematopoietic, lymphoid, and nonlymphoid peripheral organs. This includes over 900,000 cells from which we identified over 100 cell states.

[[1] Suo, Chenqu, Emma Dann, Issac Goh, Laura Jardine, Vitalii Kleshchevnikov, Jong-Eun Park, Rachel A. Botting et al. "Mapping the developing human immune system across organs." Science (2022): eabo0510.](https://www.science.org/doi/full/10.1126/science.abo0510)


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
sys.path.append("/cs/usr/bar246802/bar246802/SandBox2023/biolord_immune_bcells/utils") # add utils
sys.path.append("/cs/usr/bar246802/bar246802/SandBox2023/biolord") # set path)

In [3]:
import biolord
import scanpy as sc
import anndata
import numpy as np
import pandas as pd
import torch
import umap.plot
import seaborn as sns
import itertools
import matplotlib.pyplot as plt
from cluster_analysis import *
from formatters import *

[rank: 0] Global seed set to 0


In [4]:
print(f"PyTorch version: {torch.__version__}")
# Set the device      
device = "gpu" if torch.backends.cuda.is_built() else "cpu"
print(f"Using device: {device}")

PyTorch version: 1.11.0
Using device: gpu


In [5]:
from tqdm import tqdm
tqdm(disable=True, total=0)  # initialise internal lock

<tqdm.std.tqdm at 0x7f63fcb33f40>

In [6]:
import mplscience
mplscience.set_style()

plt.rcParams['legend.scatterpoints'] = 1

## Set parameters

In [None]:
DIR = "5/"
ID_ = "51"
DATA_DIR = "../data/"
SAVE_DIR = "../output/" + DIR
FIG_DIR = "../figures/" + DIR

## Import processed data

In [None]:
adata = sc.read(DATA_DIR + "2_biolord_immune_bcells_bm.h5ad")

In [None]:
adata.obs["split"].value_counts()

In [None]:
biolord.Biolord.setup_anndata(
    adata,
    categorical_attributes_keys=["celltype", "organ", "age"],
    retrieval_attribute_key="sex",
)

## Train model

In [None]:
# reconstruction_penalty
N_LATENT_ATTRIBUTE_CATEGORICAL = 4
module_params = {
    "autoencoder_width": 128,
    "autoencoder_depth": 2,
    "attribute_nn_width": 256,
    "attribute_nn_depth": 2,
    "n_latent_attribute_categorical": N_LATENT_ATTRIBUTE_CATEGORICAL,
    "loss_ae": "gauss",
    "loss_ordered_attribute": "gauss",
    "reconstruction_penalty": 1e2,
    "unknown_attribute_penalty": 1e1,
    "unknown_attribute_noise_param": 1e-1,
    "attribute_dropout_rate": 0.1,
    "use_batch_norm": False,
    "use_layer_norm": False,
    "seed": 42,
}


trainer_params = {
    "n_epochs_warmup": 0,
    "autoencoder_lr": 1e-4,
    "autoencoder_wd": 1e-4,
    "attribute_nn_lr": 1e-2,
    "attribute_nn_wd": 4e-8,
    "step_size_lr": 45,
    "cosine_scheduler": True,
    "scheduler_final_lr": 1e-5,
}

In [None]:
model = biolord.Biolord(
    adata=adata,
    n_latent=32,
    model_name="immune_bcells",
    module_params=module_params,
    train_classifiers=False,
    split_key="split",
)

In [None]:
model.train(max_epochs=1000,
            batch_size=512,
            plan_kwargs=trainer_params,
            early_stopping=True,
            early_stopping_patience=20,            
            check_val_every_n_epoch=10,
            num_workers=1)

## Evaluate the trained model

In [None]:
size = N_LATENT_ATTRIBUTE_CATEGORICAL
vals = [
    "generative_mean_accuracy", "generative_var_accuracy", "biolord_metric"
]
fig, axs = plt.subplots(nrows=1,
                        ncols=len(vals),
                        figsize=(size * len(vals), size))

model.epoch_history = pd.DataFrame().from_dict(
    model.training_plan.epoch_history)
for i, val in enumerate(vals):
    sns.lineplot(
        x="epoch",
        y=val,
        hue="mode",
        data=model.epoch_history[model.epoch_history["mode"] == "valid"],
        ax=axs[i],
    )

plt.tight_layout()
plt.show()ID_

## Save the trained model

In [None]:
model.save(SAVE_DIR + "trained_model")

## Load saved model

In [None]:
ID_ = "254"
model = biolord.Biolord.load(dir_path=SAVE_DIR + "trained_model_" + ID_, adata=adata, use_gpu=True)

## Assess performance

In [None]:
idx_train = np.where(
    (adata.obs["split"] == "train") 
)[0]

adata_train= adata[idx_train].copy()


idx_test = np.where(
    (adata.obs["split"] == "test")
)[0]

adata_test = adata[idx_test].copy()

dataset_train = model.get_dataset(adata_train)
dataset_test = model.get_dataset(adata_test)

## Cluster embeddings

In [None]:
attributes_map = {
    "celltype": model.categorical_attributes_map["celltype"],
    "organ": model.categorical_attributes_map["organ"]
}

In [None]:
transf_embeddings_attributes = {
    attribute_:
    model.get_categorical_attribute_embeddings(attribute_key=attribute_)
    for attribute_ in model.categorical_attributes_map
}

keys = list(
    itertools.product(*[
        list(model.categorical_attributes_map[attribute_].keys())
        for attribute_ in model.categorical_attributes_map
    ]))

transf_embeddings_attributes_dict = {
    "_".join([str(k) for k in key_]): np.concatenate(([
        transf_embeddings_attributes[map_[0]][map_[1][key_[ci]], :]
        for ci, map_ in enumerate(model.categorical_attributes_map.items())
    ]), 0)
    for key_ in keys
}

transf_embeddings_attributes = [
    np.concatenate(([
        transf_embeddings_attributes[map_[0]][map_[1][key_[ci]], :]
        for ci, map_ in enumerate(model.categorical_attributes_map.items())
    ]), 0) for key_ in keys
]

In [None]:
transf_embeddings_attributes_ind = {
    attribute_:
    model.get_categorical_attribute_embeddings(attribute_key=attribute_)
    for attribute_ in attributes_map
}

keys = list(
    itertools.product(*[
        list(model.categorical_attributes_map[attribute_].keys())
        for attribute_ in attributes_map
    ]))

transf_embeddings_attributes_dict = {
    "_".join([str(k) for k in key_]): np.concatenate(([
        transf_embeddings_attributes_ind[map_[0]][map_[1][key_[ci]], :]
        for ci, map_ in enumerate(attributes_map.items())
    ]), 0)
    for key_ in keys
}

transf_embeddings_attributes = [
    np.concatenate(([
        transf_embeddings_attributes_ind[map_[0]][map_[1][key_[ci]], :]
        for ci, map_ in enumerate(attributes_map.items())
    ]), 0) for key_ in keys
]

In [None]:
transf_embeddings_attributes = np.asarray(transf_embeddings_attributes)

In [None]:
pca = sc.tl.pca(transf_embeddings_attributes)

In [None]:
mapper_latent = umap.UMAP().fit_transform(transf_embeddings_attributes)

In [None]:
cols = {
    attribute_ :  [key_[ci] for key_ in keys]
    for ci, attribute_ in enumerate(attributes_map)
}

In [None]:
adata.uns["organ_colors"] = [
    "#029e73", 
    "#949494",
    "#ece133",
    "#de8f05",
    "#ca9161",
    "#fbafe4",
    "#cc78bc",
    "#d55e00",
    "#0173b2",
]

In [None]:
df = pd.DataFrame(mapper_latent, columns=["umap1", "umap2"])
for i in range(pca.shape[1]):
    df[f"pc{i+1}"] = pca[:, i]
for col_, map_ in cols.items():
    df[col_] = map_
    

dfs = {}
n_latent_attribute_categorical = transf_embeddings_attributes_ind[attribute_].shape[1]
columns_latent = [f"latent{i}" for i in range(1, n_latent_attribute_categorical+1)]
for attribute_ in transf_embeddings_attributes_ind:
    dfs[attribute_] = pd.DataFrame(
        transf_embeddings_attributes_ind[attribute_],
        columns=columns_latent)
    dfs[attribute_][attribute_] = list(attributes_map[attribute_].keys())
    dfs[attribute_][attribute_ + '_key'] = list(attributes_map[attribute_].values())
    df[attribute_ + "_key"] = df[attribute_].map(attributes_map[attribute_])

In [None]:
df["celltype"] = df["celltype"].replace(
    {
        "B1": "B1", 
        "CYCLING_B": "cycling B", 
        "IMMATURE_B": "immature B", 
        "LARGE_PRE_B": "large pre B", 
        "LATE_PRO_B": "late pro B", 
        "MATURE_B": "mature B", 
        "PLASMA_B": "palsma B", 
        "PRE_PRO_B": "pre pro B", 
        "PRO_B": "pro B", 
        "SMALL_PRE_B": "small pre B", 
    }
)

In [None]:
df["organ"] = df["organ"].replace(
    {
        "BM": "Bone Marrow", 
        "GU": "Gut", 
        "KI": "Kidney", 
        "LI": "Liver", 
        "MLN": "Lymph Node", 
        "SK": "Skin", 
        "SP": "Spleen", 
        "TH": "Thymus", 
        "YS": "Yolk Sac",
    }
)

In [None]:
fig, axs = plt.subplots(1,1, figsize=(8,4))


sns.scatterplot(
    data=df, 
    x="umap1", 
    y="umap2", 
    hue="celltype",
    style="organ",
    ax=axs,  
    alpha=.8,
    s=60,
    palette="deep"
)

axs.set_title("cell type")
axs.set(xticklabels=[], yticklabels=[])
axs.set_xlabel("UMAP1")
axs.set_ylabel("UMAP2")
axs.grid(False)
axs.legend(loc="upper left", bbox_to_anchor=(1, 1), ncols=2)

plt.tight_layout()
plt.savefig(FIG_DIR + f"cell_type_{ID_}.png", format="png", dpi=300)
plt.show()

In [None]:
SAVEFIG = False
fig, axs = plt.subplots(1,1, figsize=(8,4))

sns.scatterplot(
    data=df, 
    x="umap1", 
    y="umap2", 
    hue="organ",
    style="celltype",
    ax=axs,  
    alpha=.8,
    s=60,
    palette=[
    "#029e73", 
    "#949494",
    "#ece133",
    "#de8f05",
    "#ca9161",
    "#fbafe4",
    "#cc78bc",
    "#d55e00",
    "#0173b2",
]
)


axs.set_title("organ")
axs.set(xticklabels=[], yticklabels=[])
axs.set_xlabel("UMAP1")
axs.set_ylabel("UMAP2")
axs.grid(False)
axs.legend(loc="upper left", bbox_to_anchor=(1, 1), ncols=2)

plt.tight_layout()
plt.savefig(FIG_DIR + f"organ_{ID_}.png", format="png", dpi=300)

plt.show()


In [None]:
attributes_map_rev = {}
for key in attributes_map:
    attributes_map_rev[key] = {v: k for k, v in attributes_map[key].items()}

In [None]:
attributes_map_rev

In [None]:
def cluster_evaluate_figures(attribute_):
    ground_truth_labels = np.array(df[attribute_ + '_key'])
    print("Number of samples:", ground_truth_labels.size)
    title = "Attribute: " + attribute_ 
    path = FIG_DIR + attribute_ + "_"
    scores = matrices_figures(transf_embeddings_attributes, ground_truth_labels, df,
                        attributes_map_rev, attribute_, title, path)

In [None]:
cluster_evaluate_figures("celltype")

In [None]:
cluster_evaluate_figures("organ")

In [None]:
def cluster_evaluate(model, attributes = ['celltype', 'organ']):
    transf_embeddings_attributes = get_transf_embeddings_attributes(model)
    all_scores = None
    for attribute in attributes:
        ground_truth_labels = np.array(df[attribute + '_key'])
        ground_truth_unique_labels = list(set(ground_truth_labels))
        print(f'For attribute {attribute} the # of unique true labels is: {len(ground_truth_unique_labels)}')
        scores = get_kmeans_score(transf_embeddings_attributes, ground_truth_labels)
        scores['attribute'] = attribute
        if all_scores is not None:
            all_scores = pd.concat([all_scores, scores], ignore_index=True)
        else:
            all_scores = scores
    cols = ['attribute', 'score_name', 'score', 'n_clusters']
    all_scores = all_scores[cols]
    print(all_scores)
    return all_scores

In [None]:
cluster_evaluate(model)

In [None]:
def train_model(module_params, trainer_params):   
    model = biolord.Biolord(
        adata=adata,
        n_latent=32,
        model_name="immune_bcells",
        module_params=module_params,
        train_classifiers=False,
        split_key="split",
    )
    
    model.train(max_epochs=1000,
            batch_size=512,
            plan_kwargs=trainer_params,
            early_stopping=True,
            early_stopping_patience=20,            
            check_val_every_n_epoch=10,
            num_workers=1)
    return model

In [None]:
def model_training_iterations():
    arr_n_latent_attribute_categorical = np.arange(3, 6)
    arr_reconstruction_penalty = [1e2, 1e3, 1e4]
    arr_unknown_attribute_penalty = [1e-2, 1e-1, 1e1, 1e2]
    arr_unknown_attribute_noise_param = [1e-2, 1e-1, 1e1, 1e2]
    id_ = 1
    full_scores = None
    for n_latent_attribute_categorical, reconstruction_penalty, unknown_attribute_penalty, unknown_attribute_noise_param in itertools.product(
            arr_n_latent_attribute_categorical, arr_reconstruction_penalty,
            arr_unknown_attribute_penalty, arr_unknown_attribute_noise_param):
        print(f"loop index is {id_}")
        module_params = {
            "autoencoder_width": 128,
            "autoencoder_depth": 2,
            "attribute_nn_width": 256,
            "attribute_nn_depth": 2,
            "n_latent_attribute_categorical": n_latent_attribute_categorical,
            "loss_ae": "gauss",
            "loss_ordered_attribute": "gauss",
            "reconstruction_penalty": reconstruction_penalty,
            "unknown_attribute_penalty": unknown_attribute_penalty,
            "unknown_attribute_noise_param": unknown_attribute_noise_param,
            "attribute_dropout_rate": 0.1,
            "use_batch_norm": False,
            "use_layer_norm": False,
            "seed": 42,
        }

        trainer_params = {
            "n_epochs_warmup": 0,
            "autoencoder_lr": 1e-4,
            "autoencoder_wd": 1e-4,
            "attribute_nn_lr": 1e-2,
            "attribute_nn_wd": 4e-8,
            "step_size_lr": 45,
            "cosine_scheduler": True,
            "scheduler_final_lr": 1e-5,
        }
        model = train_model(module_params, trainer_params)
        scores = cluster_evaluate(model)
        scores['n_latent_attribute_categorical'] = n_latent_attribute_categorical
        scores['reconstruction_penalty'] = reconstruction_penalty
        scores['unknown_attribute_penalty'] = unknown_attribute_penalty
        scores['unknown_attribute_noise_param'] = unknown_attribute_noise_param
        if full_scores is not None:
            full_scores = pd.concat([full_scores, scores], ignore_index=True)
        else:
            full_scores = scores

        model.save(SAVE_DIR + "trained_model_" + str(id_), overwrite=True)

        id_ += 1

        full_scores.to_csv(SAVE_DIR + "trained_models_scores.csv")

In [None]:
model_training_iterations()

In [None]:
model_training_iterations()