In [378]:
import numpy as np
from anndata import AnnData
import torch
from torch import nn
from torch.distributions import Normal, NegativeBinomial
from torch.distributions import kl_divergence as kl

import scvi
from scvi.data import AnnDataManager
from scvi import REGISTRY_KEYS
from scvi.module.base import (
    BaseModuleClass,
    LossOutput,
    auto_move_data,
)
from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin, VAEMixin
from scvi.module import VAE
from scvi.data.fields import (
    CategoricalJointObsField,
    CategoricalObsField,
    LayerField,
    NumericalJointObsField,
    NumericalObsField,
)

from pytorch_lightning.loggers import WandbLogger

from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score

In [379]:
wandb_logger = WandbLogger(
    project='PGM-single-cell', 
)

# Loading data

In [380]:
# adata = scvi.data.mouse_ob_dataset()
adata = scvi.data.purified_pbmc_dataset()

adata = adata[np.random.choice(adata.shape[0], size=adata.shape[0] // 50, replace=False)].copy()
adata

# adata = scvi.data.synthetic_iid()

print(adata.X.shape)

[34mINFO    [0m File data/PurifiedPBMCDataset.h5ad already downloaded                                                     
(2117, 21932)


# Modeling

## Simple VAE

### Define model

In [381]:
class SimpleDecoder(nn.Module):
    def __init__(self, n_latent: int, n_output: int):
        super().__init__()
        self.fc1 = nn.Linear(n_latent, 128)  
        self.fc2 = nn.Linear(128, 128)      
        self.output_mean = nn.Linear(128, n_output)  
        self.output_disp = nn.Linear(128, n_output)  

    def forward(self, z: torch.Tensor):
        """
        Reconstruit les données depuis l'espace latent.
        - z : Tensor de taille (batch_size, n_latent)
        Retourne :
        - mean : Moyenne de la distribution reconstruite
        - disp : Dispersion (ou variance) de la distribution reconstruite
        """
        h = torch.relu(self.fc1(z))
        h = torch.relu(self.fc2(h))
        mean = torch.nn.functional.softplus(self.output_mean(h))  
        disp = torch.nn.functional.softplus(self.output_disp(h)) 
        return mean, disp

In [382]:
class SimpleEncoder(nn.Module):
    def __init__(self, n_input: int, n_latent: int, n_hidden: int = 128):
        super().__init__()
        self.fc1 = nn.Linear(n_input, n_hidden)  
        self.fc2 = nn.Linear(n_hidden, n_hidden) 
        self.mean_layer = nn.Linear(n_hidden, n_latent) 
        self.var_layer = nn.Linear(n_hidden, n_latent)

    def forward(self, x: torch.Tensor):
        """
        Encode les données d'entrée dans l'espace latent.
        - x : Tensor des données (batch_size, n_input)
        Retourne :
        - mean : Moyenne latente
        - log_var : Log-variance latente
        """
        h = torch.relu(self.fc1(x))
        h = torch.relu(self.fc2(h))
        mean = self.mean_layer(h)
        log_var = self.var_layer(h)
        return mean, log_var

In [383]:
class SimpleVAEModule(BaseModuleClass):
    """Simple Variational auto-encoder model.

    Here we implement a basic version of scVI's underlying VAE [Lopez18]_.
    This implementation is for instructional purposes only.

    Parameters
    ----------
    n_input
        Number of input genes.
    n_latent
        Dimensionality of the latent space.
    """

    def __init__(
        self,
        n_input: int,
        n_latent: int = 10,
    ):
        super().__init__()
        self.encoder = SimpleEncoder(n_input, n_latent)
        self.decoder = SimpleDecoder(n_latent, n_input)


    def _get_inference_input(self, tensors: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        """Parse the dictionary to get appropriate args"""
        # let us fetch the raw counts, and add them to the dictionary
        return {"x": tensors[REGISTRY_KEYS.X_KEY]}

    @auto_move_data
    def inference(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
        """
        High level inference method.

        Runs the inference (encoder) model.
        """
        x_ = torch.log1p(x)
        qz_m, qz_v_log = self.encoder(x_)
        qz_v = qz_v_log.exp()
        z = Normal(qz_m, torch.sqrt(qz_v)).rsample()

        return {"qzm": qz_m, "qzv": qz_v, "z": z}

    def _get_generative_input(
        self, tensors: dict[str, torch.Tensor], inference_outputs: dict[str, torch.Tensor]
    ) -> dict[str, torch.Tensor]:
        return {
            "z": inference_outputs["z"],
            # "library": torch.sum(tensors[REGISTRY_KEYS.X_KEY], dim=1, keepdim=True),
        }

    @auto_move_data
    def generative(self, z: torch.Tensor) -> dict[str, torch.Tensor]:
        """Runs the generative model."""
        nb_mean, nb_disp = self.decoder(z)
        return {
            "nb_mean":nb_mean,
            "nb_disp":nb_disp,
        }

    def loss(
        self,
        tensors: dict[str, torch.Tensor],
        inference_outputs: dict[str, torch.Tensor],
        generative_outputs: dict[str, torch.Tensor],
    ) -> LossOutput:
        x = tensors[REGISTRY_KEYS.X_KEY]
        nb_mean = generative_outputs["nb_mean"]
        nb_disp = generative_outputs["nb_disp"]
        qz_m = inference_outputs["qzm"]
        qz_v = inference_outputs["qzv"]

        log_likelihood = NegativeBinomial(total_count=nb_disp, logits=torch.log(nb_mean+1e-4)).log_prob(x).sum(dim=-1)

        prior_dist = Normal(torch.zeros_like(qz_m), torch.ones_like(qz_v))
        var_post_dist = Normal(qz_m, torch.sqrt(qz_v))
        kl_divergence = kl(var_post_dist, prior_dist).sum(dim=1)
        
        elbo = log_likelihood - kl_divergence
        loss = torch.mean(-elbo)
        return LossOutput(
            loss=loss,
            reconstruction_loss=-log_likelihood,
            kl_local=kl_divergence,
            kl_global=0.0,
        )

In [384]:
class VAEModel(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass):
    """single-cell Variational Inference [Lopez18]_."""

    def __init__(
        self,
        adata: AnnData,
        module: BaseModuleClass,
        n_latent: int = 10,
        **model_kwargs,
    ):
        super().__init__(adata)

        self.module = SimpleVAEModule(
            n_input=self.summary_stats["n_vars"],
            # n_batch=self.summary_stats["n_batch"],
            n_latent=n_latent,
            **model_kwargs,
        )
        self._model_summary_string = (
            f"SCVI Model with the following params: \nn_latent: {n_latent}"
        )
        self.init_params_ = self._get_init_params(locals())

    @classmethod
    def setup_anndata(
        cls,
        adata: AnnData,
        batch_key: str | None = None,
        layer: str | None = None,
        **kwargs,
    ) -> AnnData | None:
        setup_method_args = cls._get_setup_method_args(**locals())
        anndata_fields = [
            LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
            CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
            # Dummy fields required for VAE class.
            CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, None),
            NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, None, required=False),
            CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, None),
            NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, None),
        ]
        adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
        adata_manager.register_fields(adata, **kwargs)
        cls.register_manager(adata_manager)

### Run model

In [385]:
VAEModel.setup_anndata(adata, batch_key="batch")
simple_vae = VAEModel(adata, module=SimpleVAEModule, n_latent=10)
simple_vae



In [386]:
# logger = wandb_logger
logger = None

simple_vae.train(max_epochs=40, logger=logger)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/marc/mambaforge/envs/PGM/lib/python3.12/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Users/marc/mambaforge/envs/PGM/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training:   0%|          | 0/40 [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=40` reached.


## GM-VAE

### Define model

In [387]:
class GMVAEModule(BaseModuleClass):
    """GM Variational auto-encoder model.

    Here we implement a basic version of scVI's underlying VAE [Lopez18]_.
    This implementation is for instructional purposes only.

    Parameters
    ----------
    n_input
        Number of input genes.
    n_latent
        Dimensionality of the latent space.
    """

    def __init__(
        self,
        n_input: int,
        n_latent: int = 10,
    ):
        super().__init__()
        self.encoder = SimpleEncoder(n_input, n_latent)
        self.decoder = SimpleDecoder(n_latent, n_input)


    def _get_inference_input(self, tensors: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        """Parse the dictionary to get appropriate args"""
        # let us fetch the raw counts, and add them to the dictionary
        return {"x": tensors[REGISTRY_KEYS.X_KEY]}

    @auto_move_data
    def inference(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
        """
        High level inference method.

        Runs the inference (encoder) model.
        """
        x_ = torch.log1p(x)
        qz_m, qz_v_log = self.encoder(x_)
        qz_v = qz_v_log.exp()
        z = Normal(qz_m, torch.sqrt(qz_v)).rsample()

        return {"qzm": qz_m, "qzv": qz_v, "z": z}

    def _get_generative_input(
        self, tensors: dict[str, torch.Tensor], inference_outputs: dict[str, torch.Tensor]
    ) -> dict[str, torch.Tensor]:
        return {
            "z": inference_outputs["z"],
            # "library": torch.sum(tensors[REGISTRY_KEYS.X_KEY], dim=1, keepdim=True),
        }

    @auto_move_data
    def generative(self, z: torch.Tensor) -> dict[str, torch.Tensor]:
        """Runs the generative model."""
        nb_mean, nb_disp = self.decoder(z)
        return {
            "nb_mean":nb_mean,
            "nb_disp":nb_disp,
        }

    def loss(
        self,
        tensors: dict[str, torch.Tensor],
        inference_outputs: dict[str, torch.Tensor],
        generative_outputs: dict[str, torch.Tensor],
    ) -> LossOutput:
        x = tensors[REGISTRY_KEYS.X_KEY]
        nb_mean = generative_outputs["nb_mean"]
        nb_disp = generative_outputs["nb_disp"]
        qz_m = inference_outputs["qzm"]
        qz_v = inference_outputs["qzv"]

        log_likelihood = NegativeBinomial(total_count=nb_disp, logits=torch.log(nb_mean+1e-4)).log_prob(x).sum(dim=-1)

        prior_dist = Normal(torch.zeros_like(qz_m), torch.ones_like(qz_v))
        var_post_dist = Normal(qz_m, torch.sqrt(qz_v))
        kl_divergence = kl(var_post_dist, prior_dist).sum(dim=1)
        
        elbo = log_likelihood - kl_divergence
        loss = torch.mean(-elbo)
        return LossOutput(
            loss=loss,
            reconstruction_loss=-log_likelihood,
            kl_local=kl_divergence,
            kl_global=0.0,
        )

In [388]:
# VAEModel.setup_anndata(adata, batch_key="batch")

# gm_vae = VAEModel(adata, module=GMVAEModule, n_latent=10)
# gm_vae

### Run model

In [389]:
# logger = wandb_logger
# logger = None

# gm_vae.train(max_epochs=40, logger=logger)

# VAE from SCVI

In [390]:
scvi.model.SCVI.setup_anndata(adata, batch_key="batch")

# Initialiser le modèle scVI
vae = scvi.model.SCVI(adata)

# Entraîner le modèle
vae.train(max_epochs=40)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/marc/mambaforge/envs/PGM/lib/python3.12/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Users/marc/mambaforge/envs/PGM/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training:   0%|          | 0/40 [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=40` reached.


# Evaluation

In [393]:
# --- Affichage des scores ---
print("=== Scores de Clustering ===\n")

# Récupération des vrais labels
true_labels = adata.obs['labels']

# --- Clustering et évaluation pour le VAE scVI  ---
latent_rep_vae = vae.get_latent_representation(adata)
kmeans_vae = KMeans(n_clusters=len(set(true_labels)), n_init=200, random_state=42)
predicted_labels_vae = kmeans_vae.fit_predict(latent_rep_vae)
ari_vae = adjusted_rand_score(true_labels, predicted_labels_vae)
nmi_vae = normalized_mutual_info_score(true_labels, predicted_labels_vae)
print("VAE de scvi tools :")
print(f"  Adjusted Rand Index (ARI): {ari_vae:.4f}")
print(f"  Normalized Mutual Information (NMI): {nmi_vae:.4f}\n")

# --- Clustering et évaluation pour le VAE personnalisé  ---
latent_rep_svae = simple_vae.get_latent_representation(adata)
kmeans_svae = KMeans(n_clusters=len(set(true_labels)), n_init=200, random_state=42)
predicted_labels_svae = kmeans_svae.fit_predict(latent_rep_svae)
ari_svae = adjusted_rand_score(true_labels, predicted_labels_svae)
nmi_svae = normalized_mutual_info_score(true_labels, predicted_labels_svae)
print("Simple VAE :")
print(f"  Adjusted Rand Index (ARI): {ari_svae:.4f}")
print(f"  Normalized Mutual Information (NMI): {nmi_svae:.4f}\n")


# # --- Clustering et évaluation pour le VAE personnalisé avec GM ---
# latent_rep_gmvae = gm_vae.get_latent_representation(adata)
# kmeans_gmvae = KMeans(n_clusters=len(set(true_labels)), n_init=200, random_state=42)
# predicted_labels_gmvae = kmeans_gmvae.fit_predict(latent_rep_gmvae)
# ari_gmvae = adjusted_rand_score(true_labels, predicted_labels_gmvae)
# nmi_gmvae = normalized_mutual_info_score(true_labels, predicted_labels_gmvae)
# print("GM VAE :")
# print(f"  Adjusted Rand Index (ARI): {ari_gmvae:.4f}")
# print(f"  Normalized Mutual Information (NMI): {nmi_gmvae:.4f}\n")


=== Scores de Clustering ===

VAE de scvi tools :
  Adjusted Rand Index (ARI): 0.3263
  Normalized Mutual Information (NMI): 0.5337

Simple VAE :
  Adjusted Rand Index (ARI): 0.0609
  Normalized Mutual Information (NMI): 0.1624

