# scDoRI Model Training

In [2]:
import logging
import torch
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from pathlib import Path
from sklearn.preprocessing import OneHotEncoder

from scdori import (
    trainConfig,
    load_scdori_inputs,
    save_model_weights,
    set_seed,
    scDoRI,
    train_scdori_phases,
    train_model_grn,
    initialize_scdori_parameters,
    load_best_model,
)

logger = logging.getLogger(__name__)
logging.basicConfig(level=trainConfig.logging_level)

#### Loading and preparing data for training and model initialisation

In [3]:
logger.info("Starting scDoRI pipeline with integrated GRN.")
set_seed(trainConfig.random_seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

INFO:__main__:Starting scDoRI pipeline with integrated GRN.
INFO:scdori._core.utils:Random seed set to 200.
INFO:__main__:Using device: cuda:0


## 1. load data

uses the path specified in config file to load processed RNA and ATAC anndata as well as precomputed insilico-chipseq matrix and peak-gene distances

In [4]:
rna_metacell, atac_metacell, gene_peak_dist, insilico_act, insilico_rep = (
    load_scdori_inputs(trainConfig)
)
gene_peak_fixed = gene_peak_dist.clone()
gene_peak_fixed[gene_peak_fixed > 0] = 1  # mask for peak-gene links based on distance

INFO:scdori._core.data_io:Loading RNA from /data/m015k/new_metacells/data_gastrulation_single_cell/generated/rna_processed.h5ad


INFO:scdori._core.data_io:Loading ATAC from /data/m015k/new_metacells/data_gastrulation_single_cell/generated/atac_processed.h5ad
INFO:scdori._core.data_io:Loading gene-peak dist from /data/m015k/new_metacells/data_gastrulation_single_cell/generated/gene_peak_distance_exp.npy
INFO:scdori._core.data_io:Loading insilico embeddings from /data/m015k/new_metacells/data_gastrulation_single_cell/generated/insilico_chipseq_act.npy & /data/m015k/new_metacells/data_gastrulation_single_cell/generated/insilico_chipseq_rep.npy


## 2. computing indices of genes which are TFs and setting number of cells per metacell ( set to 1 for single cell data)

In [5]:
# computing indices of genes which are TFs and setting number of cells per metacell ( set to 1 for single cell data)
rna_metacell.obs["num_cells"] = 1
rna_metacell.var["index_int"] = range(rna_metacell.shape[1])
tf_indices = rna_metacell.var[rna_metacell.var.gene_type == "TF"].index_int.values
num_cells = rna_metacell.obs.num_cells.values.reshape((-1, 1))

## 3. onehot encoding the batch column for entire dataset

In [6]:
batch_col = trainConfig.batch_col
rna_metacell.obs["batch"] = rna_metacell.obs[batch_col].values
atac_metacell.obs["batch"] = atac_metacell.obs[batch_col].values
# obtaining onehot encoding for technical batch,

enc = OneHotEncoder(handle_unknown="ignore")
enc.fit(rna_metacell.obs["batch"].values.reshape(-1, 1))

onehot_batch = enc.transform(rna_metacell.obs["batch"].values.reshape(-1, 1)).toarray()
enc.categories_

[array(['E7.5_rep1', 'E7.5_rep2', 'E7.75_rep1', 'E8.0_rep1', 'E8.0_rep2',
        'E8.5_CRISPR_T_KO', 'E8.5_CRISPR_T_WT', 'E8.5_rep1', 'E8.5_rep2',
        'E8.75_rep1', 'E8.75_rep2'], dtype=object)]

## 4. making train and evaluation datasets

In [7]:
# 2) Make small train/test sets
n_cells = rna_metacell.n_obs
indices = np.arange(n_cells)
train_idx, eval_idx = train_test_split(indices, test_size=0.2, random_state=42)
train_dataset = TensorDataset(torch.from_numpy(train_idx))
train_loader = DataLoader(
    train_dataset, batch_size=trainConfig.batch_size_cell, shuffle=True
)

eval_dataset = TensorDataset(torch.from_numpy(eval_idx))
eval_loader = DataLoader(
    eval_dataset, batch_size=trainConfig.batch_size_cell, shuffle=False
)

## 5. Build scDoRI model using parameters from config file

In [8]:
num_genes = rna_metacell.n_vars
num_peaks = atac_metacell.n_vars

num_tfs = insilico_act.shape[1]

num_batches = onehot_batch.shape[1]
model = scDoRI(
    device=device,
    num_genes=num_genes,
    num_peaks=num_peaks,
    num_tfs=num_tfs,
    num_topics=trainConfig.num_topics,
    num_batches=num_batches,
    dim_encoder1=trainConfig.dim_encoder1,
    dim_encoder2=trainConfig.dim_encoder2,
).to(device)

## 6. initialising scDoRI model with precomputed matrices and setting gradients 

initialising with precomputed insilico-chipseq matrices and distance dependent peak-gene links

also setting corresponding gradients for TF-gene links to False as they are not updated in Phase 1 of training

In [9]:
initialize_scdori_parameters(
    model,
    gene_peak_dist.to(device),
    gene_peak_fixed.to(device),
    insilico_act=insilico_act.to(device),
    insilico_rep=insilico_rep.to(device),
    phase="warmup",
)

scDoRI parameters (peak-gene distance & TF binding) initialized and relevant parameters frozen.


#### Train Phase 1 of scDoRI model 

here topics are inferred using reconstruction of ATAC peaks (module 1), reconstruction of RNA from predicted ATAC (module 2) and reconstruction of TF expression (module 3)

Warmup start is used where only module 1 and module 3 are trained for some initial epochs before adding module 2 

In [None]:
model = train_scdori_phases(
    model,
    device,
    train_loader,
    eval_loader,
    rna_metacell,
    atac_metacell,
    num_cells,
    tf_indices,
    onehot_batch,
    trainConfig,
)

INFO:scdori._core.train_scdori:Starting scDoRI phase 1 training (module 1,2,3) with validation + early stopping.
Epoch 0 [warmup_1]: 100%|██████████| 356/356 [04:01<00:00,  1.48it/s]
INFO:scdori._core.train_scdori:[Train] Epoch=0, Phase=warmup_1, Loss=3021143.0140, Atac=2970267.8581, TF=245.0728, RNA=21643.1397
INFO:scdori._core.train_scdori:[Eval ] Epoch=0, Phase=warmup_1, EvalLoss=3185437.8933, EvalAtac=2724118.4635, EvalTF=170.4680, EvalRNA=21266.9779
INFO:scdori._core.data_io:Saved model weights => /data/m015k/weights/weights_directory_scdori/best_scdori_best_eval.pth
Epoch 1 [warmup_1]: 100%|██████████| 356/356 [04:01<00:00,  1.47it/s]
INFO:scdori._core.train_scdori:[Train] Epoch=1, Phase=warmup_1, Loss=2716803.4270, Atac=2682611.8954, TF=161.4618, RNA=21252.9512
INFO:scdori._core.train_scdori:[Eval ] Epoch=1, Phase=warmup_1, EvalLoss=3119237.4663, EvalAtac=2664196.4073, EvalTF=155.1917, EvalRNA=21104.8987
INFO:scdori._core.data_io:Saved model weights => /data/m015k/weights/weight

In [None]:
# saving the model weight correspoinding to final epoch where model stopped training
save_model_weights(model, Path(trainConfig.weights_folder_scdori), "scdori_final")

#### Train Phase 2 of scDoRI model 
here activator and repressor TF-gene links per topic are inferred (module 4)

optionally the encoder and other model parameters from module 1,2,3 are frozen for stability

## 7. Load best checkpoint from Phase 1

In [10]:
# loading the best checkpoint from Phase 1
model = load_best_model(
    model, Path(trainConfig.weights_folder_scdori) / "best_scdori_best_eval.pth", device
)

INFO:scdori._core.downstream:Loaded best model weights from /data/m015k/weights/weights_directory_scdori/best_scdori_best_eval.pth


## 8. Set gradients for Phase 2 training

TF-gene links are learnt in this step

In [11]:
initialize_scdori_parameters(
    model,
    gene_peak_dist,
    gene_peak_fixed,
    insilico_act=insilico_act,
    insilico_rep=insilico_rep,
    phase="grn",
)

scDoRI parameters (peak-gene distance & TF binding) initialized and relevant parameters frozen.


## 9. Phase 2 training and saving model weights


In [None]:
# train Phase 2 of scDoRI model, TF-gene links are learnt in this phase and used to reconstruct gene-expression profiles
model = train_model_grn(
    model,
    device,
    train_loader,
    eval_loader,
    rna_metacell,
    atac_metacell,
    num_cells,
    tf_indices,
    onehot_batch,
    trainConfig,
)

In [17]:
# saving the model weight correspoinding to final epoch where model stopped training
save_model_weights(model, Path(trainConfig.weights_folder_grn), "scdori_final")

INFO:scdori._core.data_io:Saved model weights => /data/m015k/weights/weights_directory_grn/best_scdori_final.pth
