In [177]:
from __future__ import annotations
import chex
from dataclasses import field
from typing import Dict, Iterable
from scvi._types import LossRecord, Tensor
import pandas as pd
from typing import Optional, List, Union

def prepare_metadata(meta_data: pd.DataFrame,
                     cov_cat_keys: Optional[list] = None,
                     cov_cat_embed_keys: Optional[list] = None,
                     cov_cont_keys: Optional[list] = None,
                     orders=None):
    """

    :param meta_data: Dataframe containing species and covariate info, e.g. from non-registered adata.obs
    :param cov_cat_keys: List of categorical covariates column names.
    :param cov_cat_embed_keys: List of categorical covariates column names to be encoded via embedding
    rather than one-hot encoding.
    :param cov_cont_keys: List of continuous covariates column names.
    :param orders: Defined orders for species or categorical covariates. Dict with keys being
    'species' or categorical covariates names and values being lists of categories. May contain more/less
    categories than data.
    :return: covariate data, dict with order of categories per covariate, dict with keys categorical and continuous
    specifying lists of covariates
    """
    if cov_cat_keys is None:
        cov_cat_keys = []
    if cov_cat_embed_keys is None:
        cov_cat_embed_keys = []
    if cov_cont_keys is None:
        cov_cont_keys = []

    def order_categories(values: pd.Series, categories: Union[List, None] = None):
        if categories is None:
            categories = pd.Categorical(values).categories.values
        else:
            missing = set(values.unique()) - set(categories)
            if len(missing) > 0:
                raise ValueError(f'Some values of {values.name} are not in the specified categories order: {missing}')
        return list(categories)

    def dummies_categories(values: pd.Series, categories: Union[List, None] = None):
        """
        Make dummies of categorical covariates. Use specified order of categories.
        :param values: Categories for each observation.
        :param categories: Order of categories to use.
        :return: dummies, categories. Dummies - one-hot encoding of categories in same order as categories.
        """
        categories = order_categories(values=values, categories=categories)

        # Get dummies
        # Ensure ordering
        values = pd.Series(pd.Categorical(values=values, categories=categories, ordered=True),
                           index=values.index, name=values.name)
        # This is problematic if many covariates
        dummies = pd.get_dummies(values, prefix=values.name)

        return dummies, categories

    # Covariate encoding
    # Save order of covariates and categories
    cov_dict = {'categorical': cov_cat_keys, 'categorical_embed': cov_cat_embed_keys, 'continuous': cov_cont_keys}
    # One-hot encoding of categorical covariates
    orders_dict = {}

    if len(cov_cat_keys) > 0 or len(cov_cont_keys) > 0:
        cov_cat_data = []
        for cov_cat_key in cov_cat_keys:
            cat_dummies, cat_order = dummies_categories(
                values=meta_data[cov_cat_key], categories=orders.get(cov_cat_key, None))
            cov_cat_data.append(cat_dummies)
            orders_dict[cov_cat_key] = cat_order
        # Prepare single cov array for all covariates
        cov_data_parsed = pd.concat(cov_cat_data + [meta_data[cov_cont_keys]], axis=1)
    else:
        cov_data_parsed = None

    if len(cov_cat_embed_keys) > 0:
        cov_embed_data = []
        for cov_cat_embed_key in cov_cat_embed_keys:
            cat_order = order_categories(values=meta_data[cov_cat_embed_key],
                                         categories=orders.get(cov_cat_embed_key, None))
            cat_map = dict(zip(cat_order, range(len(cat_order))))
            cov_embed_data.append(meta_data[cov_cat_embed_key].map(cat_map))
            orders_dict[cov_cat_embed_key] = cat_order
        cov_embed_data = pd.concat(cov_embed_data, axis=1)
    else:
        cov_embed_data = None

    return cov_data_parsed, cov_embed_data, orders_dict, cov_dict

from itertools import combinations

def check_adatas_var_index(*adatas):
    """Check if the variable indices of all provided AnnData objects are the same and in the same order."""
    for i, j in combinations(range(len(adatas)), 2):
        if not all(adatas[i].var.index == adatas[j].var.index):
            raise ValueError(f"The variable indices of the AnnData objects at positions {i} and {j} do not match!")
    print("Everything ok!")
    

from sklearn.metrics import r2_score

def compute_r2_score(preds, ground_truth):
    # Convert to densdataset adata.X is sparse
    ground_truth = ground_truth.X.toarray()
    ground_truth = ground_truth.mean(axis = 0)
    preds = preds.X.mean(axis = 0)
    # Compute R2 score
    r2 = r2_score(ground_truth, preds)
    return r2

import numpy as np
import anndata as ad
import pandas as pd
import scanpy as sc

def compare_de(X: np.ndarray, Y: np.ndarray, C: np.ndarray, shared_top: int = 100, **kwargs) -> dict:
    """Compare DEG across real and simulated perturbations.

    Computes DEG for real and simulated perturbations vs. control and calculates
    metrics to evaluate similarity of the results.

    Args:
        X: Real perturbed data.
        Y: Simulated perturbed data.
        C: Control data
        shared_top: The number of top DEG to compute the proportion of their intersection.
        **kwargs: arguments for `scanpy.tl.rank_genes_groups`.
    """
    
    n_vars = X.shape[1]
    assert n_vars == Y.shape[1] == C.shape[1]

    shared_top = min(shared_top, n_vars)
    vars_ranks = np.arange(1, n_vars + 1)

    adatas_xy = {}
    adatas_xy["x"] = ad.AnnData(X, obs={"label": "comp"})
    adatas_xy["y"] = ad.AnnData(Y, obs={"label": "comp"})
    adata_c = ad.AnnData(C, obs={"label": "ctrl"})

    results = pd.DataFrame(index=adata_c.var_names)
    top_names = []
    for group in ("x", "y"):
        adata_joint = ad.concat((adatas_xy[group], adata_c), index_unique="-")

        sc.tl.rank_genes_groups(adata_joint, groupby="label", reference="ctrl", key_added="de", **kwargs)

        srt_idx = np.argsort(adata_joint.uns["de"]["names"]["comp"])
        results[f"scores_{group}"] = adata_joint.uns["de"]["scores"]["comp"][srt_idx]
        results[f"pvals_adj_{group}"] = adata_joint.uns["de"]["pvals_adj"]["comp"][srt_idx]
        # needed to avoid checking rankby_abs
        results[f"ranks_{group}"] = vars_ranks[srt_idx]

        top_names.append(adata_joint.uns["de"]["names"]["comp"][:shared_top])

    metrics = {}
    metrics["shared_top_genes"] = len(set(top_names[0]).intersection(top_names[1])) / shared_top
    metrics["scores_corr"] = results["scores_x"].corr(results["scores_y"], method="pearson")
    metrics["pvals_adj_corr"] = results["pvals_adj_x"].corr(results["pvals_adj_y"], method="pearson")
    metrics["scores_ranks_corr"] = results["ranks_x"].corr(results["ranks_y"], method="spearman")

    return metrics

def compare_logfold(X: np.ndarray, Y: np.ndarray, C: np.ndarray, **kwargs) -> dict:
    """Compare DEG across real and simulated perturbations.

    Computes DEG for real and simulated perturbations vs. control and calculates
    metrics to evaluate similarity of the results.

    Args:
        X: Real perturbed data.
        Y: Simulated perturbed data.
        C: Control data
        shared_top: The number of top DEG to compute the proportion of their intersection.
        **kwargs: arguments for `scanpy.tl.rank_genes_groups`.
    """
    n_vars = X.shape[1]
    assert n_vars == Y.shape[1] == C.shape[1]

    prop_of_genes_set_to_0 = np.mean(Y < 0)
    Y[Y < 0] = 0

    adatas_xy = {}
    adatas_xy["x"] = ad.AnnData(X, obs={"label": "comp"})
    adatas_xy["y"] = ad.AnnData(Y, obs={"label": "comp"})
    adata_c = ad.AnnData(C, obs={"label": "ctrl"})

    results = pd.DataFrame(index=adata_c.var_names)
    top_names = []
    for group in ("x", "y"):
        adata_joint = ad.concat((adatas_xy[group], adata_c), index_unique="-")

        sc.tl.rank_genes_groups(adata_joint, groupby="label", reference="ctrl", key_added="de", **kwargs)
        results[f"logfold_{group}"] = [elm[0] for elm in adata_joint.uns["de"]["logfoldchanges"].tolist()]

    metics = {}
    metics["logfold_corr"] = results["logfold_x"].corr(results["logfold_y"], method="pearson")
    metics["prop_of_genes_set_to_0"] = prop_of_genes_set_to_0

    return metics

In [345]:
from collections import OrderedDict
from inspect import signature
from typing import Any, Callable, Dict, Iterable, Literal, Optional, Union

import lightning.pytorch as pl
import numpy as np
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import Accuracy

from scvi import METRIC_KEYS, REGISTRY_KEYS
from scvi.autotune._types import Tunable, TunableMixin
from scvi.module import Classifier
from scvi.module.base import (
    BaseModuleClass,
    LossOutput,
)

from scvi.train._metrics import ElboMetric

TorchOptimizerCreator = Callable[[Iterable[torch.Tensor]], torch.optim.Optimizer]

def _compute_kl_weight(
    epoch: int,
    step: int,
    n_epochs_kl_warmup: Optional[int],
    n_steps_kl_warmup: Optional[int],
    max_kl_weight: float = 1.0,
    min_kl_weight: float = 0.0,
) -> float:
    """Computes the kl weight for the current step or epoch.

    If both `n_epochs_kl_warmup` and `n_steps_kl_warmup` are None `max_kl_weight` is returned.

    Parameters
    ----------
    epoch
        Current epoch.
    step
        Current step.
    n_epochs_kl_warmup
        Number of training epochs to scale weight on KL divergences from
        `min_kl_weight` to `max_kl_weight`
    n_steps_kl_warmup
        Number of training steps (minibatches) to scale weight on KL divergences from
        `min_kl_weight` to `max_kl_weight`
    max_kl_weight
        Maximum scaling factor on KL divergence during training.
    min_kl_weight
        Minimum scaling factor on KL divergence during training.
    """
    if min_kl_weight > max_kl_weight:
        raise ValueError(
            f"min_kl_weight={min_kl_weight} is larger than max_kl_weight={max_kl_weight}."
        )

    slope = max_kl_weight - min_kl_weight
    if n_epochs_kl_warmup:
        if epoch < n_epochs_kl_warmup:
            return slope * (epoch / n_epochs_kl_warmup) + min_kl_weight
    elif n_steps_kl_warmup:
        if step < n_steps_kl_warmup:
            return slope * (step / n_steps_kl_warmup) + min_kl_weight
    return max_kl_weight


class TrainingPlan(TunableMixin, pl.LightningModule):
    """Lightning module task to train scvi-tools modules.

    The training plan is a PyTorch Lightning Module that is initialized
    with a scvi-tools module object. It configures the optimizers, defines
    the training step and validation step, and computes metrics to be recorded
    during training. The training step and validation step are functions that
    take data, run it through the model and return the loss, which will then
    be used to optimize the model parameters in the Trainer. Overall, custom
    training plans can be used to develop complex inference schemes on top of
    modules.

    The following developer tutorial will familiarize you more with training plans
    and how to use them: :doc:`/tutorials/notebooks/dev/model_user_guide`.

    Parameters
    ----------
    module
        A module instance from class ``BaseModuleClass``.
    optimizer
        One of "Adam" (:class:`~torch.optim.Adam`), "AdamW" (:class:`~torch.optim.AdamW`),
        or "Custom", which requires a custom optimizer creator callable to be passed via
        `optimizer_creator`.
    optimizer_creator
        A callable taking in parameters and returning a :class:`~torch.optim.Optimizer`.
        This allows using any PyTorch optimizer with custom hyperparameters.
    lr
        Learning rate used for optimization, when `optimizer_creator` is None.
    weight_decay
        Weight decay used in optimization, when `optimizer_creator` is None.
    eps
        eps used for optimization, when `optimizer_creator` is None.
    n_steps_kl_warmup
        Number of training steps (minibatches) to scale weight on KL divergences from
        `min_kl_weight` to `max_kl_weight`. Only activated when `n_epochs_kl_warmup` is
        set to None.
    n_epochs_kl_warmup
        Number of epochs to scale weight on KL divergences from `min_kl_weight` to
        `max_kl_weight`. Overrides `n_steps_kl_warmup` when both are not `None`.
    reduce_lr_on_plateau
        Whether to monitor validation loss and reduce learning rate when validation set
        `lr_scheduler_metric` plateaus.
    lr_factor
        Factor to reduce learning rate.
    lr_patience
        Number of epochs with no improvement after which learning rate will be reduced.
    lr_threshold
        Threshold for measuring the new optimum.
    lr_scheduler_metric
        Which metric to track for learning rate reduction.
    lr_min
        Minimum learning rate allowed.
    max_kl_weight
        Maximum scaling factor on KL divergence during training.
    min_kl_weight
        Minimum scaling factor on KL divergence during training.
    **loss_kwargs
        Keyword args to pass to the loss method of the `module`.
        `kl_weight` should not be passed here and is handled automatically.
    """

    def __init__(
        self,
        module: BaseModuleClass,
        *,
        optimizer: Tunable[Literal["Adam", "AdamW", "Custom"]] = "Adam",
        optimizer_creator: Optional[TorchOptimizerCreator] = None,
        lr: Tunable[float] = 1e-3,
        weight_decay: Tunable[float] = 1e-6,
        eps: Tunable[float] = 0.01,
        n_steps_kl_warmup: Tunable[int] = None,
        n_epochs_kl_warmup: Tunable[int] = 400,
        reduce_lr_on_plateau: Tunable[bool] = False,
        lr_factor: Tunable[float] = 0.6,
        lr_patience: Tunable[int] = 30,
        lr_threshold: Tunable[float] = 0.0,
        lr_scheduler_metric: Literal[
            "elbo_validation", "reconstruction_loss_validation", "kl_local_validation"
        ] = "elbo_validation",
        lr_min: Tunable[float] = 0,
        max_kl_weight: Tunable[float] = 1.0,
        min_kl_weight: Tunable[float] = 0.0,
        **loss_kwargs,
    ):
        super().__init__()
        self.module = module
        self.lr = lr
        self.weight_decay = weight_decay
        self.eps = eps
        self.optimizer_name = optimizer
        self.n_steps_kl_warmup = n_steps_kl_warmup
        self.n_epochs_kl_warmup = n_epochs_kl_warmup
        self.reduce_lr_on_plateau = reduce_lr_on_plateau
        self.lr_factor = lr_factor
        self.lr_patience = lr_patience
        self.lr_scheduler_metric = lr_scheduler_metric
        self.lr_threshold = lr_threshold
        self.lr_min = lr_min
        self.loss_kwargs = loss_kwargs
        self.min_kl_weight = min_kl_weight
        self.max_kl_weight = max_kl_weight
        self.optimizer_creator = optimizer_creator

        if self.optimizer_name == "Custom" and self.optimizer_creator is None:
            raise ValueError(
                "If optimizer is 'Custom', `optimizer_creator` must be provided."
            )

        self._n_obs_training = None
        self._n_obs_validation = None

        # automatic handling of kl weight
        self._loss_args = set(signature(self.module.loss).parameters.keys())
        if "kl_weight" in self._loss_args:
            self.loss_kwargs.update({"kl_weight": self.kl_weight})

        self.initialize_train_metrics()
        self.initialize_val_metrics()

    @staticmethod
    def _create_elbo_metric_components(mode: str, n_total: Optional[int] = None):
        """Initialize ELBO metric and the metric collection."""
        rec_loss = ElboMetric("reconstruction_loss", mode, "obs")
        kl_local = ElboMetric("kl_local", mode, "obs")
        kl_global = ElboMetric("kl_global", mode, "batch")
        # n_total can be 0 if there is no validation set, this won't ever be used
        # in that case anyway
        n = 1 if n_total is None or n_total < 1 else n_total
        elbo = rec_loss + kl_local + (1 / n) * kl_global
        elbo.name = f"elbo_{mode}"
        collection = OrderedDict(
            [(metric.name, metric) for metric in [elbo, rec_loss, kl_local, kl_global]]
        )
        return elbo, rec_loss, kl_local, kl_global, collection

    def initialize_train_metrics(self):
        """Initialize train related metrics."""
        (
            self.elbo_train,
            self.rec_loss_train,
            self.kl_local_train,
            self.kl_global_train,
            self.train_metrics,
        ) = self._create_elbo_metric_components(
            mode="train", n_total=self.n_obs_training
        )
        self.elbo_train.reset()

    def initialize_val_metrics(self):
        """Initialize val related metrics."""
        (
            self.elbo_val,
            self.rec_loss_val,
            self.kl_local_val,
            self.kl_global_val,
            self.val_metrics,
        ) = self._create_elbo_metric_components(
            mode="validation", n_total=self.n_obs_validation
        )
        self.elbo_val.reset()

    @property
    def n_obs_training(self):
        """Number of observations in the training set.

        This will update the loss kwargs for loss rescaling.

        Notes
        -----
        This can get set after initialization
        """
        return self._n_obs_training

    @n_obs_training.setter
    def n_obs_training(self, n_obs: int):
        if "n_obs" in self._loss_args:
            self.loss_kwargs.update({"n_obs": n_obs})
        self._n_obs_training = n_obs
        self.initialize_train_metrics()

    @property
    def n_obs_validation(self):
        """Number of observations in the validation set.

        This will update the loss kwargs for loss rescaling.

        Notes
        -----
        This can get set after initialization
        """
        return self._n_obs_validation

    @n_obs_validation.setter
    def n_obs_validation(self, n_obs: int):
        self._n_obs_validation = n_obs
        self.initialize_val_metrics()

    def forward(self, *args, **kwargs):
        """Passthrough to the module's forward method."""
        return self.module(*args, **kwargs)

    @torch.inference_mode()
    def compute_and_log_metrics(
        self,
        loss_output: LossOutput,
        metrics: Dict[str, ElboMetric],
        mode: str,
        metrics_eval: Dict = None,
    ):
        """Computes and logs metrics.

        Parameters
        ----------
        loss_output
            LossOutput object from scvi-tools module
        metrics
            Dictionary of metrics to update
        mode
            Postfix string to add to the metric name of
            extra metrics
        """
        rec_loss = loss_output.reconstruction_loss_sum
        n_obs_minibatch = loss_output.n_obs_minibatch
        kl_local = loss_output.kl_local_sum
        kl_global = loss_output.kl_global_sum

        # Use the torchmetric object for the ELBO
        # We only need to update the ELBO metric
        # As it's defined as a sum of the other metrics
        metrics[f"elbo_{mode}"].update(
            reconstruction_loss=rec_loss,
            kl_local=kl_local,
            kl_global=kl_global,
            n_obs_minibatch=n_obs_minibatch,
        )
        # pytorch lightning handles everything with the torchmetric object
        self.log_dict(
            metrics,
            batch_size=n_obs_minibatch,
        )

        # accumlate extra metrics passed to loss recorder
        for key in loss_output.extra_metrics_keys:
            met = loss_output.extra_metrics[key]
            if isinstance(met, torch.Tensor):
                if met.shape != torch.Size([]):
                    raise ValueError("Extra tracked metrics should be 0-d tensors.")
                met = met.detach()
            self.log(
                f"{key}_{mode}",
                met,
                on_step=False,
                on_epoch=True,
                batch_size=n_obs_minibatch,
            )
        # accumulate extra eval metrics
        if metrics_eval is not None:
            for extra_metric, met in metrics_eval.items():
                if isinstance(met, torch.Tensor):
                    if met.shape != torch.Size([]):
                        raise ValueError("Extra tracked metrics should be 0-d tensors.")
                    met = met.detach()
                self.log(
                    f"{extra_metric}_{mode}_eval",
                    met,
                    batch_size=n_obs_minibatch,
                )

    def training_step(self, batch, batch_idx):
        """Training step for the model."""
        if "kl_weight" in self.loss_kwargs:
            kl_weight = self.kl_weight
            self.loss_kwargs.update({"kl_weight": kl_weight})
            self.log("kl_weight", kl_weight, on_step=True, on_epoch=False)
        _, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
        self.log("train_loss", scvi_loss.loss, on_epoch=True, prog_bar=True)
        self.compute_and_log_metrics(scvi_loss, self.train_metrics, "train")
        return scvi_loss.loss

    def validation_step(self, batch, batch_idx):
        # loss kwargs here contains `n_obs` equal to n_training_obs
        # so when relevant, the actual loss value is rescaled to number
        # of training examples
        _, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
        # self.log("validation_loss", scvi_loss.loss, on_epoch=True)  # Saved above via loss recorder
        metrics_eval = self.module.compute_validation_metrics()
        self.compute_and_log_metrics(scvi_loss, self.val_metrics, "validation", metrics_eval=metrics_eval)

    def _optimizer_creator_fn(
        self, optimizer_cls: Union[torch.optim.Adam, torch.optim.AdamW]
    ):
        """Create optimizer for the model.

        This type of function can be passed as the `optimizer_creator`
        """
        return lambda params: optimizer_cls(
            params, lr=self.lr, eps=self.eps, weight_decay=self.weight_decay
        )

    def get_optimizer_creator(self):
        """Get optimizer creator for the model."""
        if self.optimizer_name == "Adam":
            optim_creator = self._optimizer_creator_fn(torch.optim.Adam)
        elif self.optimizer_name == "AdamW":
            optim_creator = self._optimizer_creator_fn(torch.optim.AdamW)
        elif self.optimizer_name == "Custom":
            optim_creator = self.optimizer_creator
        else:
            raise ValueError("Optimizer not understood.")

        return optim_creator

    def configure_optimizers(self):
        """Configure optimizers for the model."""
        params = filter(lambda p: p.requires_grad, self.module.parameters())
        optimizer = self.get_optimizer_creator()(params)
        config = {"optimizer": optimizer}
        if self.reduce_lr_on_plateau:
            scheduler = ReduceLROnPlateau(
                optimizer,
                patience=self.lr_patience,
                factor=self.lr_factor,
                threshold=self.lr_threshold,
                min_lr=self.lr_min,
                threshold_mode="abs",
                verbose=True,
            )
            config.update(
                {
                    "lr_scheduler": {
                        "scheduler": scheduler,
                        "monitor": self.lr_scheduler_metric,
                    },
                },
            )
        return config

    @property
    def kl_weight(self):
        """Scaling factor on KL divergence during training."""
        return _compute_kl_weight(
            self.current_epoch,
            self.global_step,
            self.n_epochs_kl_warmup,
            self.n_steps_kl_warmup,
            self.max_kl_weight,
            self.min_kl_weight,
        )


In [346]:
from typing import List, Optional, Union

from scvi.dataloaders import DataSplitter
from scvi.model._utils import get_max_epochs_heuristic
from scvi.train import TrainRunner#, TrainingPlan
from scvi.utils._docstrings import devices_dsp
#from transVAE.train._training_plan import TrainingPlan

class UnsupervisedTrainingMixin:
    """General purpose unsupervised train method."""

    _data_splitter_cls = DataSplitter
    _training_plan_cls = TrainingPlan
    _train_runner_cls = TrainRunner

    @devices_dsp.dedent
    def train(
        self,
        max_epochs: Optional[int] = None,
        use_gpu: Optional[Union[str, int, bool]] = None,
        accelerator: str = "auto",
        devices: Union[int, List[int], str] = "auto",
        train_size: float = 0.9,
        validation_size: Optional[float] = None,
        shuffle_set_split: bool = True,
        batch_size: int = 128,
        early_stopping: bool = False,
        plan_kwargs: Optional[dict] = None,
        **trainer_kwargs,
    ):
        """Train the model.

        Parameters
        ----------
        max_epochs
            Number of passes through the dataset. If `None`, defaults to
            `np.min([round((20000 / n_cells) * 400), 400])`
        %(param_use_gpu)s
        %(param_accelerator)s
        %(param_devices)s
        train_size
            Size of training set in the range [0.0, 1.0].
        validation_size
            Size of the test set. If `None`, defaults to 1 - `train_size`. If
            `train_size + validation_size < 1`, the remaining cells belong to a test set.
        shuffle_set_split
            Whether to shuffle indices before splitting. If `False`, the val, train, and test set are split in the
            sequential order of the data according to `validation_size` and `train_size` percentages.
        batch_size
            Minibatch size to use during training.
        early_stopping
            Perform early stopping. Additional arguments can be passed in `**kwargs`.
            See :class:`~scvi.train.Trainer` for further options.
        plan_kwargs
            Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to
            `train()` will overwrite values present in `plan_kwargs`, when appropriate.
        **trainer_kwargs
            Other keyword args for :class:`~scvi.train.Trainer`.
        """
        if max_epochs is None:
            max_epochs = get_max_epochs_heuristic(self.adata.n_obs)

        plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else {}

        data_splitter = self._data_splitter_cls(
            self.adata_manager,
            train_size=train_size,
            validation_size=validation_size,
            batch_size=batch_size,
            shuffle_set_split=shuffle_set_split,
        )
        training_plan = self._training_plan_cls(self.module, **plan_kwargs)

        es = "early_stopping"
        trainer_kwargs[es] = (
            early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es]
        )
        runner = self._train_runner_cls(
            self,
            training_plan=training_plan,
            data_splitter=data_splitter,
            max_epochs=max_epochs,
            use_gpu=use_gpu,
            accelerator=accelerator,
            devices=devices,
            **trainer_kwargs,
        )
        return runner()

In [361]:
from typing import Optional, Dict
import numpy as np
from anndata import AnnData

import torch
from torch.distributions import Normal
from torch.distributions import kl_divergence as kl

from scvi import REGISTRY_KEYS
from scvi.distributions import NegativeBinomial
from scvi.module.base import BaseModuleClass, auto_move_data
from scvi.module.base._base_module import LossOutput

torch.backends.cudnn.benchmark = True
from transVAE.nn._base_components import Encoder, Decoder, Embedding
from sklearn.metrics import r2_score

# Conditional VAE model
class VAEC(BaseModuleClass):
    """Conditional Variational auto-encoder model.

    This is an implementation of the CondSCVI model

    Parameters
    ----------
    n_input
        Number of input genes
    n_labels
        Number of labels
    n_hidden
        Number of nodes per hidden layer
    n_latent
        Dimensionality of the latent space
    n_layers
        Number of hidden layers used for encoder and decoder NNs
    log_variational
        Log(data+1) prior to encoding for numerical stability. Not normalization.
    dropout_rate
        Dropout rate for the encoder and decoder neural network.
    extra_encoder_kwargs
        Keyword arguments passed into :class:`~scvi.nn.Encoder`.
    extra_decoder_kwargs
        Keyword arguments passed into :class:`~scvi.nn.FCLayers`.
    """

    def __init__(
        self,
        adata: AnnData,
        n_input: int,
        n_hidden: int = 128,
        n_latent: int = 5,
        n_layers: int = 2,
        log_variational: bool = True,
        kl_weight: float = 0.005,
        dropout_rate: float = 0.05,
        px_decoder: bool = False, 
        cov_embed_dims: int = 10,
        extra_encoder_kwargs: Optional[dict] = None,
        extra_decoder_kwargs: Optional[dict] = None,
        use_exponentiation: bool = True,
        initialize_weights: bool = False,
        validation_adatas_dict: Optional[dict] = None,
        original_adata_covariates = None,
    ):
        super().__init__()
        self.dispersion = "gene"
        self.n_latent = n_latent
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.dropout_rate = dropout_rate
        self.log_variational = log_variational
        self.gene_likelihood = "nb"
        self.latent_distribution = "normal"
        self.px_decoder = px_decoder
        # Automatically deactivate if useless
        self.n_batch = 0
        self.kl_weight = kl_weight
        self.use_exponentiation = use_exponentiation
        self.val_adatas = validation_adatas_dict
        self.original_adata_covariates = original_adata_covariates
                
        # Initialize embeddings for covariates with high cardinality
        if "covariates_embed" in adata.obsm.keys():
            self.embed_cov_sizes = adata.obsm["covariates_embed"].nunique().tolist()
            self.cov_embeddings = torch.nn.ModuleList([
                Embedding(size=size, cov_embed_dims=cov_embed_dims) for size in self.embed_cov_sizes])
            self.embed_cov = True
            self.n_cov_embed = len(adata.obsm["covariates_embed"].columns)*cov_embed_dims
        else:
            self.embed_cov = False
            self.n_cov_embed = 0
        
        if "covariates" in adata.obsm.keys():
            self.n_cov = len(adata.obsm["covariates"].columns)
        else:
            self.n_cov = 0
            
        self.n_cov_total = self.n_cov_embed + self.n_cov
        if self.n_cov_total == 0:
            self.n_cov_total = None 

        self.z_encoder = Encoder(
            n_input=n_input,
            n_output=n_latent,  
            n_cat=self.n_cov_total,
            n_layers=n_layers,
            n_hidden=n_hidden,
            dropout_rate = dropout_rate,
            var_eps=1e-4
        )

        self.decoder = Decoder(
            n_input=n_latent,
            n_output=n_input,
            n_cat=self.n_cov_total,
            n_layers=n_layers,
            n_hidden=n_hidden,
            dropout_rate = dropout_rate,
            var_eps=1e-4
        )
        
        if initialize_weights:
            self.z_encoder.initialize_weights()
            self.decoder.initialize_weights()
        
    def _get_cov(self, tensors):
        
        cov_list = []

        # Check if 'covariates' key exists in tensors and its size
        if 'covariates' in tensors and tensors['covariates'].shape[1] > 0:
            cov_list.append(tensors['covariates'])
        
        # Check if 'covariates_embed' key exists in tensors
        if self.embed_cov:
            # Dynamically create embeddings if not already initialized
            if not hasattr(self, 'cov_embeddings'):
                self.embed_cov_sizes = [tensors['covariates_embed'][:, i].max().item() + 1 for i in range(tensors['covariates_embed'].shape[1])]
                self.cov_embeddings = torch.nn.ModuleList([
                    Embedding(size=size, cov_embed_dims=cov_embed_dims) for size in self.embed_cov_sizes])

            # Append the embeddings
            cov_list.extend([embedding(tensors['covariates_embed'][:, i].int()) 
                             for i, embedding in enumerate(self.cov_embeddings)])
            
        # Concatenate along dimension 1 or return None if empty
        return torch.cat(cov_list, dim=1) if cov_list else None
        
    def _get_inference_input(self, tensors, **kwargs):
        """Parse the dictionary to get appropriate args"""
        
        expr = tensors[REGISTRY_KEYS.X_KEY]
        cov = self._get_cov(tensors=tensors)
        input_dict = dict(expr=expr, cov=cov)
        
        return input_dict

    def _get_generative_input(self, tensors, inference_outputs):
        """
        Parse the dictionary to get appropriate args
        :param cov_replace: Replace cov from tensors with this covariate vector (for predict)
        """

        z = inference_outputs["z"]
        cov = self._get_cov(tensors=tensors)
        input_dict = dict(z=z, cov=cov)
        
        return input_dict

    @auto_move_data
    def inference(self, expr, cov):
        """High level inference method.

        Runs the inference (encoder) model.
        """
        q_m, q_v, latent = self.z_encoder(expr, cov)        
        outputs = {"z": latent, "q_m": q_m, "q_v": q_v}
        
        return outputs

    @auto_move_data
    def generative(self, z, cov):
        """Runs the generative model."""
        
        p = self.decoder(z, cov)
        return {"px": p}

    def loss(
        self,
        tensors,
        inference_outputs,
        generative_outputs,
        kl_weight: float = 0.0005,
    ):
        x = tensors[REGISTRY_KEYS.X_KEY]
        qz_m = inference_outputs["q_m"]
        qz_v = inference_outputs["q_v"]
        p = generative_outputs["px"]
        
        kld = kl(
            Normal(qz_m, torch.sqrt(qz_v)),
            Normal(0, 1),
        ).sum(dim=1)           
        
        rl = self.get_reconstruction_loss(p, x)
        loss = (0.5 * rl + 0.5 * (kld * self.kl_weight)).mean()

        return LossOutput(loss=loss, reconstruction_loss=rl, kl_local=kld)
        
    def get_reconstruction_loss(self, x, px) -> torch.Tensor:
        x = x[0] if isinstance(x, tuple) else x
        px = torch.tensor(px) if not isinstance(px, torch.Tensor) else px
        if self.use_exponentiation:
            x = torch.exp(x) - 1
            px = torch.exp(px) - 1
            cap_value = 5000000
            px = torch.clamp(px, min=None, max=cap_value)
            px[torch.isinf(px)] = cap_value
        loss = ((x - px) ** 2).sum(dim=1)
        return loss

    def compute_validation_metrics(self):
        # device
        
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        validation_loss = {}
        csdl = self.val_adatas["csdl_adata_to_predict"]
        
        latent = []
        qm = []
        qv = []
        expr = []
        for tensors in csdl:
            tensors = {k: v.to(device) for k, v in tensors.items()}  # Move tensors to the device
            inference_inputs = self._get_inference_input(tensors)
            outputs = self.inference(**inference_inputs)
            z = outputs["z"]
            latent += [z]

        latent = torch.cat(latent)
        for name, data in self.val_adatas.items():
            if name == "csdl_adata_to_predict":
                continue
            cov = self._get_cov(tensors=data["generative_tensors"])
            gt = data["ground_truth"]
            px = self.generative(z = latent, cov = cov)["px"].cpu().detach().numpy().mean(axis = 0)
            r2 = r2_score(gt, px)
            
            validation_loss[f"{name}_r2"] = r2
        
        if len(validation_loss) > 1:
            validation_loss["mean_r2"] = sum(validation_loss.values()) / len(validation_loss)
        
        return validation_loss
            

In [362]:
from scvi.data.fields import CategoricalObsField, LayerField, ObsmField
from scvi.data import AnnDataManager
from anndata import AnnData
from scvi.utils import setup_anndata_dsp
from scvi import REGISTRY_KEYS
import numpy as np
from typing import Optional, List, Dict, Sequence, Tuple
import torch
from scvi.module.base import auto_move_data

#from transVAE.module._base_module import VAEC
#from transVAE.module._utils import prepare_metadata
from scvi.model.base import VAEMixin, BaseModelClass#, UnsupervisedTrainingMixin
#from transVAE.model._training import UnsupervisedTrainingMixin

class transVAE(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass):
    """
    Implementation of VAE model
    Parameters
    ----------
    adata
        AnnData object that has been registered. 
    n_hidden
        Number of nodes per hidden layer.
    n_latent
        Dimensionality of the latent space.
    n_layers
        Number of hidden layers used for encoder and decoder NNs.
    dropout_rate
        Dropout rate for neural networks.
    """
    def __init__(
        self,
        adata: AnnData,
        n_labels: list = 0,
        n_hidden: int = 800,
        n_latent: int = 100,
        n_layers: int = 2,
        dropout_rate: float = 0.2,
        cov_embed_dims: int = 10,
        kl_weight: float = 0.005,
        initialize_weights: bool = False,
        use_exponentiation: bool = False,
        validation_adatas_dict: Optional[Dict] = None,
        **model_kwargs,
    ):
        super().__init__(adata)
        # assign n_input
        n_input = self.summary_stats.n_vars
        self.use_exponentiation = use_exponentiation
        if validation_adatas_dict is not None:
            validation_adatas_dict = self.prepare_validation_data(validation_adatas_dict)
        
        # cVAE
        self.module = VAEC(
            adata = adata,
            n_input=n_input,
            n_hidden=n_hidden,
            n_latent=n_latent,
            n_layers=n_layers,
            dropout_rate=dropout_rate,
            cov_embed_dims = cov_embed_dims,
            kl_weight = kl_weight,
            initialize_weights = initialize_weights,
            use_exponentiation = use_exponentiation,
            validation_adatas_dict = validation_adatas_dict,
            original_adata_covariates = self.adata.obsm,
            **model_kwargs,
        )
        self._model_summary_string = (
            "Model with the following params: n_hidden: {}, n_latent: {}, n_layers: {}, dropout_rate: {}, n_labels {}"
        ).format(
            n_hidden,
            n_latent,
            n_layers,
            dropout_rate,
            n_labels
        )

        self.init_params_ = self._get_init_params(locals())

    def prepare_validation_data(self, 
                                validation_adatas_dict: Dict,
                                batch_size = 4048) -> Dict:
        """
        Prepares validation data for models that require embedding categorical covariates, 
        and validation of input AnnData objects. This method primarily processes and 
        augments `AnnData` objects with necessary metadata, embeddings, and data loaders 
        for downstream model validation.

        The method operates in-place, modifying `validation_adatas_dict` to include 
        data loaders and processed AnnData objects. It relies on a 'control' key being 
        present within `validation_adatas_dict`, and 'translation dict' in the `obsm` 
        attribute of AnnData objects for embedding.

        Parameters:
        - validation_adatas_dict (dict): A dictionary where keys are descriptive names and 
          values are AnnData objects intended for validation. Must include keys 'adata_to_predict', 
          'translate_dicts', and 'ground_truths' for operations within the method.

        Returns:
        - dict: Updated validation_adatas_dict with data loaders and other relevant 
          data structures needed for validation.

        Raises:
        - ValueError: If 'control' key is not present in `validation_adatas_dict`.
        - KeyError: If a key from `translate_dicts` is not present in the `obs` attribute 
          of an AnnData object.
        """
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        adata_to_predict = validation_adatas_dict["adata_to_predict"]
        translate_dicts = validation_adatas_dict["translate_dicts"]
        ground_truths = validation_adatas_dict["ground_truths"]

        # Extract covariate information from self.adata
        categorical_covariate_keys = self.adata.uns["covariates_dict"]["categorical"]
        categorical_covariate_embed_keys = self.adata.uns["covariates_dict"]["categorical_embed"]
        orders = self.adata.uns["covariate_orders"]

        # Prepare metadata for validation AnnData
        covariates, covariates_embed, orders_dict, cov_dict = prepare_metadata(
            meta_data=adata_to_predict.obs,
            cov_cat_keys=categorical_covariate_keys,
            cov_cat_embed_keys=categorical_covariate_embed_keys,
            orders=orders
        )

        # Update the uns and obsm of validation AnnData with new metadata and embeddings
        adata_to_predict.uns['covariate_orders'] = orders_dict
        adata_to_predict.uns['covariates_dict'] = cov_dict
        if categorical_covariate_keys is not None:
            adata_to_predict.obsm['covariates'] = covariates
        if categorical_covariate_embed_keys is not None:
            adata_to_predict.obsm['covariates_embed'] = covariates_embed

        # Validate the AnnData object and create a DataLoader
        adata_to_predict = self._validate_anndata(adata_to_predict)
        scdl = self._make_data_loader(
            adata=adata_to_predict, indices=None, batch_size=batch_size
        )
        
        validation_dict = dict()
        validation_dict["csdl_adata_to_predict"] = scdl
        
        for ground_truth_name, translate_dict in translate_dicts.items():
            # latent variable switching
            for column in translate_dict.keys():
                if column not in list(adata_to_predict.obs.columns):
                    raise KeyError("Dict key from translate_dict not found in adata.obs.")
                adata_to_predict.obs[column] = translate_dict[column]

            covariates, covariates_embed, orders_dict, cov_dict = prepare_metadata(
                meta_data=adata_to_predict.obs,
                cov_cat_keys=categorical_covariate_keys,
                cov_cat_embed_keys=categorical_covariate_embed_keys,
                orders=self.adata.uns["covariate_orders"]
            )

            tensors = {
                "covariates": torch.Tensor(covariates.values).to(device), 
                "covariates_embed": torch.Tensor(covariates_embed.values).to(device)
            }
            
            gt = ground_truths[ground_truth_name].X.toarray().mean(axis = 0)
                
            # Update the dictionary with the DataLoader
            validation_dict[ground_truth_name] = {"ground_truth":gt, "generative_tensors": tensors}

        return validation_dict
    
    @auto_move_data
    @torch.inference_mode()
    def get_latent_representation(
        self,
        adata: Optional[AnnData] = None,
        indices: Optional[Sequence[int]] = None,
        mc_samples: int = 5000,
        batch_size: Optional[int] = None,
        
    ) -> np.ndarray:
        
        """Return the latent representation for each cell.

        This is typically denoted as :math:`z_n`.

        Parameters
        ----------
        adata
            AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
            AnnData object used to initialize the model.
        indices
            Indices of cells in adata to use. If `None`, all cells are used.
        batch_size
            Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
        Returns
        -------
        Low-dimensional representation for each cell.
        """
        
        self._check_if_trained(warn=False)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        categorical_covariate_keys = self.adata.uns["covariates_dict"]["categorical"]
        categorical_covariate_embed_keys = self.adata.uns["covariates_dict"]["categorical_embed"]
        
        covariates, covariates_embed, orders_dict, cov_dict = prepare_metadata(
            meta_data=adata.obs,
            cov_cat_keys=categorical_covariate_keys,
            cov_cat_embed_keys=categorical_covariate_embed_keys,
            orders=self.adata.uns["covariate_orders"]
        )
        
        adata.uns['covariate_orders'] = orders_dict
        adata.uns['covariates_dict'] = cov_dict
        if categorical_covariate_keys is not None:
            if 'covariates' in adata.obsm:
                del adata.obsm['covariates']
            adata.obsm['covariates'] = covariates
        if categorical_covariate_embed_keys is not None:
            if 'covariates_embed' in adata.obsm:
                del adata.obsm['covariates_embed']
            adata.obsm['covariates_embed'] = covariates_embed
        
        adata = self._validate_anndata(adata)
        scdl = self._make_data_loader(
            adata=adata, indices=indices, batch_size=batch_size
        )
        
        latent = []
        for tensors in scdl:
            tensors = {k: v.to(device) for k, v in tensors.items()}  # Move tensors to the device
            inference_inputs = self.module._get_inference_input(tensors)
            outputs = self.module.inference(**inference_inputs)
            z = outputs["z"]
            latent += [z.cpu()]
            
        return torch.cat(latent).numpy()
    
    @auto_move_data
    def translate(
        self,
        adata: AnnData,
        translate_dict: Dict,
        copy: bool = False,
        
    ) -> AnnData:
        """
        Translate the given adata based on the provided translation dictionary.

        The function goes through an inference process to obtain latent representations 
        and then uses a generative process with latent varibale switching to predict cells. 
        The results are formatted and returned as an AnnData object.

        Parameters:
        - adata (AnnData): The input AnnData object.
        - translate_dict (Dict): Dictionary specifying which column in the adata should be translated to which variable.
        - copy (bool, optional): If True, a copy of the input adata will be used for processing. Defaults to False.

        Returns:
        - AnnData: An AnnData object containing the predicted cells.
        
        Examples:
        
        >> predicted = model.translate(adata_train, translate_dict= {"dataset": "chem"})
        
        """
        # find the device
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Make sure var names are unique
        if adata.shape[1] != len(set(adata.var_names)):
            raise ValueError('Adata var_names are not unique')

        if copy:
            adata = adata.copy()
            
        ### Inference -----------
                
        latent = torch.Tensor(self.get_latent_representation(adata))
            
        ### Generative ----------
        
        categorical_covariate_keys = self.adata.uns["covariates_dict"]["categorical"]
        categorical_covariate_embed_keys = self.adata.uns["covariates_dict"]["categorical_embed"]
        
        # latent variable switching
        for column in translate_dict.keys():
            if column not in list(adata.obs.columns):
                raise KeyError("Dict key from translate_dict not found in adata.obs.")
            adata.obs[column] = translate_dict[column]
        
        covariates, covariates_embed, orders_dict, cov_dict = prepare_metadata(
            meta_data=adata.obs,
            cov_cat_keys=categorical_covariate_keys,
            cov_cat_embed_keys=categorical_covariate_embed_keys,
            orders=self.adata.uns["covariate_orders"]
        )
            
        tensors = {
            "covariates": torch.Tensor(covariates.values).to(device), 
            "covariates_embed": torch.Tensor(covariates_embed.values).to(device)
        }

        cov = self.module._get_cov(tensors=tensors)
        predicted_cells = self.module.generative(z = latent, cov = cov)["px"].cpu().detach().numpy()
        
        # make output pretty
        
        adata.uns['covariate_orders'] = orders_dict
        adata.uns['covariates_dict'] = cov_dict
        if categorical_covariate_keys is not None:
            if 'covariates' in adata.obsm:
                del adata.obsm['covariates']
            adata.obsm['covariates'] = covariates
        if categorical_covariate_embed_keys is not None:
            if 'covariates_embed' in adata.obsm:
                del adata.obsm['covariates_embed']
            adata.obsm['covariates_embed'] = covariates_embed
        
        predicted_adata = AnnData(
            X=predicted_cells,
            obs=adata.obs.copy(),
            var=adata.var.copy(),
            uns=adata.uns.copy(),
            obsm=adata.obsm.copy(),
        )
        
        return predicted_adata

    @classmethod
    @setup_anndata_dsp.dedent
    def setup_anndata(
        cls,
        adata: AnnData,
        categorical_covariate_keys: Optional[List[str]] = None,
        categorical_covariate_embed_keys: Optional[List[str]] = None,
        covariate_orders: Optional[Dict] = None,
        copy: bool = True,
        layer: Optional[str] = None,
        validation_ind_key: Optional[str] = None,
        **kwargs,
    ):
        """
        Sets up the AnnData object for subsequent analysis.

        Parameters:
        ----------
        cls : class
            The class to which this classmethod belongs.
        adata : AnnData
            The annotated data matrix.
        categorical_covariate_keys : Optional[List[str]], default=None
            List of keys for categorical covariates.
        categorical_covariate_embed_keys : Optional[List[str]], default=None
            List of keys for categorical covariates to be embedded.
        covariate_orders : Optional[Dict], default=None
            Dictionary specifying the order of covariates.
        copy : bool, default=True
            Whether to return a copy of the original adata object.
        layer : Optional[str], default=None
            Specifies which layer of the adata object to consider.
        **kwargs : Additional keyword arguments.

        Returns:
        -------
        AnnData
            The modified or copied AnnData object.

        Raises:
        ------
        ValueError
            If var_names in adata are not unique.
        """

        # Make sure var names are unique
        if adata.shape[1] != len(set(adata.var_names)):
            raise ValueError('Adata var_names are not unique')

        if copy:
            adata = adata.copy()
        
        if covariate_orders is None:
            covariate_orders = {}

        covariates, covariates_embed, orders_dict, cov_dict = prepare_metadata(
            meta_data=adata.obs,
            cov_cat_keys=categorical_covariate_keys,
            cov_cat_embed_keys=categorical_covariate_embed_keys,
            orders=covariate_orders
        )

        adata.uns['covariate_orders'] = orders_dict
        adata.uns['covariates_dict'] = cov_dict
        if categorical_covariate_keys is not None:
            if 'covariates' in adata.obsm:
                del adata.obsm['covariates']
            adata.obsm['covariates'] = covariates
        if categorical_covariate_embed_keys is not None:
            if 'covariates_embed' in adata.obsm:
                del adata.obsm['covariates_embed']
            adata.obsm['covariates_embed'] = covariates_embed

        setup_method_args = cls._get_setup_method_args(**locals())
        anndata_fields = [
            LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=False)
        ]

        if categorical_covariate_keys is not None:
            anndata_fields.append(ObsmField('covariates', 'covariates'))
        if categorical_covariate_embed_keys is not None:
            anndata_fields.append(ObsmField('covariates_embed', 'covariates_embed'))

        adata_manager = AnnDataManager(
            fields=anndata_fields, setup_method_args=setup_method_args
        )
        adata_manager.register_fields(adata, **kwargs)
        cls.register_manager(adata_manager)

        return adata


In [363]:
#import scanpy as sc
#adata = sc.read_h5ad("/d/hpc/projects/FRI/DL/mo6643/MSC/data/data_update_slack/data_splits/data_splits_train_merge/train_data/train_adata_baseline_3000hvg.h5ad")

In [364]:
adata_train = transVAE.setup_anndata(adata, categorical_covariate_embed_keys=["dataset", "cell_type"], categorical_covariate_keys=["organism"], copy = True)

[34mINFO    [0m Using column names from columns of adata.obsm[1m[[0m[32m'covariates'[0m[1m][0m                                               
[34mINFO    [0m Using column names from columns of adata.obsm[1m[[0m[32m'covariates_embed'[0m[1m][0m                                         


In [365]:
#adata_to_predict = sc.read_h5ad(f"/d/hpc/projects/FRI/DL/mo6643/MSC/data/data_update_slack/data_splits/data_splits_train_merge/data_to_predict/wang_to_predict_3000hvg.h5ad")
#dbdb_ground_truth = sc.read_h5ad(f"/d/hpc/projects/FRI/DL/mo6643/MSC/data/data_update_slack/data_splits/data_splits_train_merge/ground_truth/dbdb_ground_truth_3000hvg.h5ad")
#mSTZ_ground_truth = sc.read_h5ad(f"/d/hpc/projects/FRI/DL/mo6643/MSC/data/data_update_slack/data_splits/data_splits_train_merge/ground_truth/mSTZ_ground_truth_3000hvg.h5ad")

Must include keys 'adata_to_predict','translate_dicts', and 'ground_truths' for operations within the method.

In [366]:
validation_adatas_dict = {"adata_to_predict":adata_to_predict,
                          "ground_truths": {"dbdb": dbdb_ground_truth,
                                            "mSTZ": mSTZ_ground_truth},
                          "translate_dicts": {"dbdb":{"dataset":"db/db", "organism": "Mus musculus"},
                                              "mSTZ":{"dataset":"mSTZ", "organism": "Mus musculus"}}}

In [391]:
# make the model
model = transVAE(adata_train, 
                 n_hidden=1000, 
                 n_latent=256, 
                 n_layers=6, 
                 dropout_rate=0.3, 
                 cov_embed_dims = 10, 
                 kl_weight = 0.0005,
                 use_exponentiation=False,
                 validation_adatas_dict = validation_adatas_dict)

[34mINFO    [0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             


In [392]:
# train the model
model.train(batch_size=4096, max_epochs = 1000, train_size = 1, enable_progress_bar = False,
            early_stopping = True, early_stopping_monitor = 'mean_r2_validation_eval', early_stopping_mode = "max", 
            early_stopping_min_delta = 0.01, early_stopping_patience = 70,
            plan_kwargs = {"lr": 0.00001,
                           "weight_decay":0.000001,
                           "reduce_lr_on_plateau":True,
                           "lr_factor":0.5,
                           "lr_patience":50,
                           "lr_scheduler_metric":"reconstruction_loss_validation"})

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
SLURM auto-requeueing enabled. Setting signal handlers.


Epoch 2/100:   1%|          | 1/100 [00:07<12:37,  7.65s/it, v_num=1, train_loss_step=525, train_loss_epoch=568]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [370]:
model.history.keys()

dict_keys(['kl_weight', 'train_loss_step', 'elbo_validation', 'reconstruction_loss_validation', 'kl_local_validation', 'kl_global_validation', 'dbdb_r2_validation_eval', 'mSTZ_r2_validation_eval', 'mean_r2_validation_eval', 'train_loss_epoch'])

In [390]:
model.history["dbdb_r2_validation_eval"].iloc[-1].item()

0.25259862434618824