For testing purposes

In [2]:
# Load the autoreload extension
%load_ext autoreload

# Enable automatic reloading of modules
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Directory management

In [3]:
import sys
sys.path.append('/Users/olav/Documents/PhD/scvi-tools/src')

In [4]:
import anndata
from scvi.data import AnnDataManager
from scvi.dataloaders import AnnDataLoader
import numpy as np
import torch.nn

import scvi_local

from scvi_local.nn import DecoderSCVI, Encoder, FCLayers

from scvi.data import AnnDataManager
from scvi.data.fields import LayerField, CategoricalObsField, NumericalObsField
from src._multivae import MULTIVAE 

import gzip
import os
import tempfile
from pathlib import Path

import numpy as np
import pooch
import scanpy as sc
import seaborn as sns
import torch

import numpy as np
import scanpy as sc
import scvi

  from .autonotebook import tqdm as notebook_tqdm


Data loading

In [62]:
adata = anndata.read_h5ad("/Users/olav/Documents/PhD/multiVI/data/mixed_source_adata.h5ad.gz")

In [63]:
adata.var_names_make_unique()

In [64]:
# split to three datasets by modality (RNA, ATAC, Multiome), and corrupt data
# by remove some data to create single-modality data
n = 4004
adata_rna = adata[:n, adata.var.modality == "Gene Expression"].copy()
adata_paired = adata[n : 2 * n].copy()
adata_atac = adata[2 * n :, adata.var.modality == "Peaks"].copy()

In [65]:
adata_mvi = scvi_local.data.organize_multiome_anndatas(adata_paired, adata_rna, adata_atac)

  return multi_anndata.concatenate(other, join="outer", batch_key=modality_key)
  return multi_anndata.concatenate(other, join="outer", batch_key=modality_key)


In [66]:
adata_mvi = adata_mvi[:, adata_mvi.var["modality"].argsort()].copy()
adata_mvi.var

Unnamed: 0,feature,modality,chr,start,end,n_cells
A1BG,A1BG,Gene Expression,,,,2244
RAI1,RAI1,Gene Expression,,,,2053
RAF1,RAF1,Gene Expression,,,,8408
RAE1,RAE1,Gene Expression,,,,2274
RAD9A,RAD9A,Gene Expression,,,,2844
...,...,...,...,...,...,...
chr17:79840351-79840851,chr17:79840351-79840851,Peaks,chr17,79840351,79840851,2844
chr17:79838639-79839139,chr17:79838639-79839139,Peaks,chr17,79838639,79839139,8532
chr17:79837846-79838346,chr17:79837846-79838346,Peaks,chr17,79837846,79838346,3831
chr17:79836519-79837019,chr17:79836519-79837019,Peaks,chr17,79836519,79837019,3576


In [67]:
adata_mvi

AnnData object with n_obs × n_vars = 89655 × 94507
    obs: 'barcode', 'source', 'rep', 'tech', 'celltype', '_scvi_batch', '_scvi_labels', '_scvi_local_l_mean', '_scvi_local_l_var', 'modality'
    var: 'feature', 'modality', 'chr', 'start', 'end', 'n_cells'
    obsm: 'X_multiVI', 'X_multiVI_nbc', 'X_umap', '_scvi_extra_categoricals'

In [68]:
adata_mvi = adata_mvi[adata_mvi.obs.modality == "accessibility", :]

In [69]:
adata_mvi.obs["_indices"] = np.arange(adata_mvi.n_obs)

  adata_mvi.obs["_indices"] = np.arange(adata_mvi.n_obs)


In [70]:
anndata_fields = [
    LayerField(registry_key="x", layer=None, is_count_data=True),
    CategoricalObsField(registry_key="modality", attr_key="modality"),
    NumericalObsField(registry_key  = "cell_idx", attr_key = "_indices" )
]
adata_manager = AnnDataManager(fields=anndata_fields)
adata_manager.register_fields(adata_mvi)
print(
    adata_manager.registry.keys()
)  # There is additionally a _scvi_uuid key which is used to uniquely identify AnnData objects for subsequent retrieval.

dict_keys(['scvi_version', 'model_name', 'setup_args', 'field_registries', '_scvi_uuid'])


In [71]:
adl = AnnDataLoader(adata_manager, shuffle=False, batch_size=16)

Arguments

In [72]:
n_epochs = 10
n_genes = sum(adata_mvi.var.modality == "Gene Expression")
n_regions =  sum(adata_mvi.var.modality == "Peaks")
n_hidden = 128
n_latent = 10
n_epochs_kl_warmup = 50

Model

In [73]:
multivi = MULTIVAE(
    n_input_genes = n_genes,
    n_input_regions = n_regions,
    n_hidden = n_hidden,
    n_latent = n_latent,
    deeply_inject_covariates = True,
    n_batch = adata_manager.summary_stats.n_modality,
    modality_weights = "universal"
)

In [74]:
import torch.optim as optim

# extract parameters of the model and adversial classifier
model_params = [param for name, param in multivi.named_parameters() if "adversarial_classifier" not in name]
adversarial_params = [param for name, param in multivi.named_parameters() if "adversarial_classifier" in name]

model_optimizer = optim.Adam(model_params, lr=0.0001,weight_decay=0.001)
adversarial_optimizer = optim.Adam(adversarial_params, lr=0.0001,weight_decay=1e-6, eps=0.01)

In [75]:
def _compute_kl_weight(
    epoch: int,
    n_epochs_kl_warmup: int | None,
    max_kl_weight: float = 1.0,
    min_kl_weight: float = 0.0,
) -> float | torch.Tensor:
    """Computes the kl weight for the current step or epoch.

    If both `n_epochs_kl_warmup` and `n_steps_kl_warmup` are None `max_kl_weight` is returned.

    Parameters
    ----------
    epoch
        Current epoch.
    step
        Current step.
    n_epochs_kl_warmup
        Number of training epochs to scale weight on KL divergences from
        `min_kl_weight` to `max_kl_weight`
    n_steps_kl_warmup
        Number of training steps (minibatches) to scale weight on KL divergences from
        `min_kl_weight` to `max_kl_weight`
    max_kl_weight
        Maximum scaling factor on KL divergence during training.
    min_kl_weight
        Minimum scaling factor on KL divergence during training.
    """
    if min_kl_weight > max_kl_weight:
        raise ValueError(
            f"min_kl_weight={min_kl_weight} is larger than max_kl_weight={max_kl_weight}."
        )

    slope = max_kl_weight - min_kl_weight
    if n_epochs_kl_warmup:
        if epoch < n_epochs_kl_warmup:
            return slope * (epoch / n_epochs_kl_warmup) + min_kl_weight
    elif n_steps_kl_warmup:
        if step < n_steps_kl_warmup:
            return slope * (step / n_steps_kl_warmup) + min_kl_weight
    return max_kl_weight

In [79]:
epoch_losses = []
for epoch in range(n_epochs):

    multivi.train()
    batch_losses = []

    for i, batch in enumerate(adl):

        x = batch["x"]
        modality = batch["modality"]
        cell_idx = batch["cell_idx"]       
        # y = torch.zeros(x.shape[0], 1, device=x.device, requires_grad=False) # in case of no usage of porteins 

        # inference
        inference_outputs = multivi.inference(x, modality, cell_idx)

        latent = inference_outputs["z"]
        libsize_expr = inference_outputs["libsize_expr"]
        libsize_acc = inference_outputs["libsize_acc"]

        # generation
        generative_outputs = multivi.generative(latent,modality,libsize_expr)

        # kl_weight & kappa
        klw = _compute_kl_weight(
            epoch = epoch,
            n_epochs_kl_warmup = n_epochs_kl_warmup            
        )
        kappa = 1 - klw

        # loss
        loss = multivi.loss(
            inference_outputs,
            generative_outputs,
            klw
        )

        # fool classifier by modifying z
        fool_loss = multivi.loss_adversarial_classifier(latent, modality, False)
        model_loss = loss + (fool_loss * kappa)
        model_optimizer.zero_grad()        
        model_loss.backward()
        model_optimizer.step()

        # train classifier
        adv_loss = multivi.loss_adversarial_classifier(latent.detach(), modality, True)
        adv_loss *= kappa
        adversarial_optimizer.zero_grad()
        adv_loss.backward()
        adversarial_optimizer.step()
        
        batch_losses.append(loss.item())

        if i % 100 == 0:
            print(f"batch nr {i} loss : {loss}")
            
    epoch_loss = np.mean(batch_losses)
    epoch_losses.append(epoch_loss)
    print(f"epoch loss : {epoch_loss}")

ValueError: Expected parameter loc (Tensor of shape (16, 10)) of distribution Normal(loc: torch.Size([16, 10]), scale: torch.Size([16, 10])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]],
       grad_fn=<AddmmBackward0>)