# For testing purposes

In [1]:
# Load the autoreload extension
%load_ext autoreload

# Enable automatic reloading of module
%autoreload 2

# Directory management

# Imports

In [2]:
import anndata
import scvi
from src.data import AnnDataManager
from src.dataloaders import AnnDataLoader
import numpy as np
import torch.nn
from src.data.fields import LayerField, CategoricalObsField, NumericalObsField
from src._multivae import MULTIVAE 
import math

import gzip
import os
import tempfile
from pathlib import Path

import numpy as np
import pooch
import scanpy as sc
import seaborn as sns
import torch

import numpy as np
import scanpy as sc
import wandb

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__)

Seed set to 0


Last run with scvi-tools version: 1.2.2.post2


In [4]:
sc.set_figure_params(figsize=(6, 6), frameon=False)
sns.set_theme()
torch.set_float32_matmul_precision("high")
save_dir = tempfile.TemporaryDirectory()

%config InlineBackend.print_figure_kwargs={"facecolor": "w"}
%config InlineBackend.figure_format="retina"

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Data loading

In [5]:
def download_data(save_path: str, fname: str = "pbmc_10k"):
    data_paths = pooch.retrieve(
        url="https://cf.10xgenomics.com/samples/cell-arc/2.0.0/pbmc_unsorted_10k/pbmc_unsorted_10k_filtered_feature_bc_matrix.tar.gz",
        known_hash="872b0dba467d972aa498812a857677ca7cf69050d4f9762b2cd4753b2be694a1",
        fname=fname,
        path=save_path,
        processor=pooch.Untar(),
        progressbar=True,
    )
    data_paths.sort()

    for path in data_paths:
        with gzip.open(path, "rb") as f_in:
            with open(path.replace(".gz", ""), "wb") as f_out:
                f_out.write(f_in.read())

    return str(Path(data_paths[0]).parent)

In [6]:
data_path = download_data(save_dir.name)

Downloading data from 'https://cf.10xgenomics.com/samples/cell-arc/2.0.0/pbmc_unsorted_10k/pbmc_unsorted_10k_filtered_feature_bc_matrix.tar.gz' to file '/tmp/tmpf2dcccei/pbmc_10k'.
100%|████████████████████████████████████████| 375M/375M [00:00<00:00, 718GB/s]
Untarring contents of '/tmp/tmpf2dcccei/pbmc_10k' to '/tmp/tmpf2dcccei/pbmc_10k.untar'


# Data processing

In [7]:
from src.data import read_10x_multiome
# read multiomic data
adata = read_10x_multiome(data_path)
adata.var_names_make_unique()

  utils.warn_names_duplicates("var")


In [8]:
# split to three datasets by modality (RNA, ATAC, Multiome), and corrupt data
# by remove some data to create single-modality data
n = 4004
adata_rna = adata[:n, adata.var.modality == "Gene Expression"].copy()
adata_paired = adata[n : 2 * n].copy()
adata_atac = adata[2 * n :, adata.var.modality == "Peaks"].copy()

In [9]:
from src.data import organize_multiome_anndatas
# We can now use the organizing method from scvi to concatenate these anndata
adata_mvi = organize_multiome_anndatas(adata_paired, adata_rna, adata_atac)

  return multi_anndata.concatenate(other, join="outer", batch_key=modality_key)
  return multi_anndata.concatenate(other, join="outer", batch_key=modality_key)


In [10]:
adata_mvi = adata_mvi[:, adata_mvi.var["modality"].argsort()].copy()
adata_mvi.var

Unnamed: 0,ID,modality,chr,start,end
MIR1302-2HG,ENSG00000243485,Gene Expression,chr1,29553,30267
AL391261.2,ENSG00000258847,Gene Expression,chr14,66004522,66004523
FUT8-AS1,ENSG00000276116,Gene Expression,chr14,65412689,65412690
FUT8,ENSG00000033170,Gene Expression,chr14,65410591,65413008
AL355076.2,ENSG00000258760,Gene Expression,chr14,65302679,65318790
...,...,...,...,...,...
chr15:101277030-101277907,chr15:101277030-101277907,Peaks,chr15,101277030,101277907
chr15:101257856-101258771,chr15:101257856-101258771,Peaks,chr15,101257856,101258771
chr15:101251516-101252373,chr15:101251516-101252373,Peaks,chr15,101251516,101252373
chr15:101397608-101398445,chr15:101397608-101398445,Peaks,chr15,101397608,101398445


In [11]:
print(adata_mvi.shape)
sc.pp.filter_genes(adata_mvi, min_cells=int(adata_mvi.shape[0] * 0.01))
print(adata_mvi.shape)

(12012, 148458)
(12012, 80878)


In [12]:
adata_mvi.obs["_indices"] = np.arange(adata_mvi.n_obs)

In [13]:
anndata_fields = [
    LayerField(registry_key="x", layer=None, is_count_data=True),
    CategoricalObsField(registry_key="modality", attr_key="modality"),
    NumericalObsField(registry_key  = "cell_idx", attr_key = "_indices" )
]
adata_manager = AnnDataManager(fields=anndata_fields)
adata_manager.register_fields(adata_mvi)
print(
    adata_manager.registry.keys()
)  # There is additionally a _scvi_uuid key which is used to uniquely identify AnnData objects for subsequent retrieval.

dict_keys(['scvi_version', 'model_name', 'setup_args', 'field_registries', '_scvi_uuid'])


# Splitting training and validation set

In [14]:
n_samples = adata_mvi.shape[0]
train_size = 0.9
validation_size = 0.1

n_train = math.ceil(train_size * n_samples)
n_val = n_samples - n_train

random_state = np.random.RandomState(seed=0)
indices = np.arange(adata_manager.adata.n_obs)
indices = random_state.permutation(indices)

val_idx = indices[:n_val]
train_idx = indices[n_val : (n_val + n_train)]

# Dataloader 

In [15]:
adata_mvi.obs

Unnamed: 0_level_0,batch_id,modality,_indices,_scvi_modality
barcode,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
CCGCTAAAGGGCCATC-0-0,1,paired,0,2
CCGCTAAAGTCTTGAA-0-0,1,paired,1,2
CCGCTAAAGTTAGACC-0-0,1,paired,2,2
CCGCTAAAGTTCCCAC-0-0,1,paired,3,2
CCGCTAAAGTTTGCGG-0-0,1,paired,4,2
...,...,...,...,...
TTTGTTGGTACGCGCA-1,1,accessibility,12007,0
TTTGTTGGTATTTGCC-1,1,accessibility,12008,0
TTTGTTGGTGATTACG-1,1,accessibility,12009,0
TTTGTTGGTTTCAGGA-1,1,accessibility,12010,0


In [16]:
train_adl = AnnDataLoader(adata_manager, indices = train_idx, shuffle= False, drop_last= False, batch_size=128)
val_adl = AnnDataLoader(adata_manager, indices = val_idx, shuffle= False, drop_last= False, batch_size=128)

1500223.0
1641645.0


# Training

In [17]:
n_epochs = 300
n_genes = sum(adata_mvi.var.modality == "Gene Expression")
n_regions =  sum(adata_mvi.var.modality == "Peaks")
n_hidden = 128
n_latent = 11
n_epochs_kl_warmup = 50

Model

In [18]:
multivi = MULTIVAE(
    n_input_genes = n_genes,
    n_input_regions = n_regions,
    n_hidden = n_hidden,
    n_latent = n_latent,
    deeply_inject_covariates = True,
    n_batch = adata_manager.summary_stats.n_modality,
    modality_weights = "universal"
)

multivi.to(device)



MULTIVAE(
  (z_encoder_expression): Encoder(
    (encoder): FCLayers(
      (fc_layers): Sequential(
        (Layer 0): Sequential(
          (0): Linear(in_features=12446, out_features=128, bias=True)
          (1): None
          (2): LayerNorm((128,), eps=1e-05, elementwise_affine=False)
          (3): LeakyReLU(negative_slope=0.01)
          (4): Dropout(p=0.1, inplace=False)
        )
        (Layer 1): Sequential(
          (0): Linear(in_features=128, out_features=128, bias=True)
          (1): None
          (2): LayerNorm((128,), eps=1e-05, elementwise_affine=False)
          (3): LeakyReLU(negative_slope=0.01)
          (4): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (mean_encoder): Linear(in_features=128, out_features=11, bias=True)
    (var_encoder): Linear(in_features=128, out_features=11, bias=True)
  )
  (l_encoder_expression): ExprLibrarySizeEncoder(
    (px_decoder): FCLayers(
      (fc_layers): Sequential(
        (Layer 0): Sequential(
          (0): L

In [19]:
import torch.optim as optim

# extract parameters of the model and adversial classifier
model_params = [param for name, param in multivi.named_parameters() if "adversarial_classifier" not in name]
adversarial_params = [param for name, param in multivi.named_parameters() if "adversarial_classifier" in name]

model_optimizer = optim.AdamW(model_params, lr=1e-4, weight_decay=1e-3, eps = 1e-08)
# adversarial_optimizer = optim.Adam(adversarial_params, lr=1e-3, eps=0.01, weight_decay=self.weight_decay)

In [20]:
def _compute_kl_weight(
    epoch: int,
    n_epochs_kl_warmup: int | None,
    max_kl_weight: float = 1.0,
    min_kl_weight: float = 0.0,
) -> float | torch.Tensor:
    """Computes the kl weight for the current step or epoch.

    If both `n_epochs_kl_warmup` and `n_steps_kl_warmup` are None `max_kl_weight` is returned.

    Parameters
    ----------
    epoch
        Current epoch.
    step
        Current step.
    n_epochs_kl_warmup
        Number of training epochs to scale weight on KL divergences from
        `min_kl_weight` to `max_kl_weight`
    n_steps_kl_warmup
        Number of training steps (minibatches) to scale weight on KL divergences from
        `min_kl_weight` to `max_kl_weight`
    max_kl_weight
        Maximum scaling factor on KL divergence during training.
    min_kl_weight
        Minimum scaling factor on KL divergence during training.
    """
    if min_kl_weight > max_kl_weight:
        raise ValueError(
            f"min_kl_weight={min_kl_weight} is larger than max_kl_weight={max_kl_weight}."
        )

    slope = max_kl_weight - min_kl_weight
    if n_epochs_kl_warmup:
        if epoch < n_epochs_kl_warmup:
            return slope * (epoch / n_epochs_kl_warmup) + min_kl_weight
    elif n_steps_kl_warmup:
        if step < n_steps_kl_warmup:
            return slope * (step / n_steps_kl_warmup) + min_kl_weight
    return max_kl_weight

In [21]:
# wandb.init(project="multiVI-training")

epoch_model_losses = []
epoch_adv_losses = []
epoch_kl_local_train = []

for epoch in range(n_epochs):
    # Training 

    multivi.train()
    batch_model_losses = []
    batch_adv_losses = []
    batch_kl_local_losses = []
    batch_recon_losses = []

    for i, batch in enumerate(train_adl):

        x = batch["x"].to(device)
        modality = batch["modality"].to(device)
        cell_idx = batch["cell_idx"].to(device)       
        y = torch.zeros(x.shape[0], 1, device=x.device, requires_grad=False).to(device) # in case of no usage of porteins 

        print(modality)

        # print(f"x batch {i} : {x.sum()}")
            
        # inference
        inference_outputs = multivi.inference(x, y, modality, cell_idx)

        latent = inference_outputs["z"]
        libsize_expr = inference_outputs["libsize_expr"]
        libsize_acc = inference_outputs["libsize_acc"]
    
        # generation
        generative_outputs = multivi.generative(latent,modality,libsize_expr)

        # kl_weight & kappa
        klw = _compute_kl_weight(
            epoch = epoch,
            n_epochs_kl_warmup = n_epochs_kl_warmup,
            min_kl_weight=1e-3            
        )

        kappa = 1 

        # loss
        loss, kl_local_loss, recon_loss = multivi.loss(
            inference_outputs,
            generative_outputs,
            klw
        )
        
        n_obs_in_batch = len(kl_local_loss["kl_divergence_z"])
        kl_local_loss = (sum(kl_local_loss.values()).sum())/n_obs_in_batch
        recon_loss = (sum(recon_loss.values()).sum())/n_obs_in_batch

        # # fool classifier by modifying z
        # fool_loss = multivi.loss_adversarial_classifier(latent, modality, False)
        # model_loss = loss + (fool_loss * kappa)
        # model_optimizer.zero_grad()        
        # model_loss.backward()
        # model_optimizer.step()

        # # train classifier
        # adv_loss = multivi.loss_adversarial_classifier(latent.detach(), modality, True)
        # adv_loss *= kappa
        # adversarial_optimizer.zero_grad()
        # adv_loss.backward()
        # adversarial_optimizer.step()

        model_optimizer.zero_grad()        
        loss.backward()
        model_optimizer.step()

        wandb.log({
            "train_loss_step": loss.detach(),
            # "adv_train_loss": adv_loss.detach(),
            # "kl_local_train" : kl_local_loss.detach(),
            "epoch": epoch
        })
        
        batch_model_losses.append(loss.detach().cpu())
        batch_kl_local_losses.append(kl_local_loss.detach().cpu())
        batch_recon_losses.append(recon_loss.detach().cpu())
    
    epoch_model_loss = np.mean(batch_model_losses)
    epoch_kl_local_train = np.mean(batch_kl_local_losses)
    epoch_recon_loss = np.mean(batch_recon_losses)

    wandb.log({
    "train_loss_epoch": epoch_model_loss,
    "kl_local_train" : epoch_kl_local_train,
    "reconstruction_loss_train": epoch_recon_loss, 
    "epoch": epoch
})


tensor([[1.],
        [1.],
        [1.],
        [0.],
        [0.],
        [0.],
        [1.],
        [2.],
        [0.],
        [1.],
        [0.],
        [1.],
        [0.],
        [2.],
        [2.],
        [0.],
        [2.],
        [1.],
        [1.],
        [2.],
        [0.],
        [0.],
        [2.],
        [2.],
        [2.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [2.],
        [0.],
        [1.],
        [2.],
        [0.],
        [1.],
        [0.],
        [1.],
        [2.],
        [2.],
        [2.],
        [2.],
        [2.],
        [1.],
        [0.],
        [2.],
        [1.],
        [2.],
        [2.],
        [0.],
        [1.],
        [2.],
        [2.],
        [2.],
        [0.],
        [1.],
        [1.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [1.],
        [1.],
        [2.],
        [0.],
        [1.],
        [0.],
        [1.],
        [2.],
        [2.],
      

Error: You must call wandb.init() before wandb.log()