In [1]:
import numpy as np
import os
from anndata import AnnData
import torch
from torch import nn
from torch import Tensor
from torch.distributions import Normal, NegativeBinomial, Distribution
from torch.distributions import kl_divergence as kl

import scvi
from scvi.model import SCVI
from scvi.data import AnnDataManager
from scvi import REGISTRY_KEYS
from scvi.module.base import (
    BaseModuleClass,
    LossOutput,
    auto_move_data,
)
from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin, VAEMixin
from scvi.module import VAE
from scvi.data.fields import (
    CategoricalJointObsField,
    CategoricalObsField,
    LayerField,
    NumericalJointObsField,
    NumericalObsField,
)

from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from collections.abc import Iterator, Sequence
import numpy.typing as npt
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score



# Loading data

In [2]:
# adata = scvi.data.mouse_ob_dataset()
# adata = scvi.data.purified_pbmc_dataset()

# adata = adata[np.random.choice(adata.shape[0], size=adata.shape[0]//50, replace=False)].copy()
# n_genes_to_keep = adata.shape[1] // 4  # Conserver 1/4 des gènes
# genes_indices = np.random.choice(adata.shape[1], size=n_genes_to_keep, replace=False)
# adata = adata[:, genes_indices].copy()

# adata = scvi.data.synthetic_iid()

adata = scvi.data.cortex()

print(adata.X.shape)
print(adata)
print(adata.obs["labels"].unique())
print(adata.obs["precise_labels"].unique())
print(adata.obs["cell_type"].unique())

[34mINFO    [0m File [35m/Users/marc/Desktop/MVA/cours/[0m[95mIntroduction[0m to graphical                                              
         models/projet/PGM-single-cell/data/expression.bin already downloaded                                      
[34mINFO    [0m Loading Cortex data from [35m/Users/marc/Desktop/MVA/cours/[0m[95mIntroduction[0m to graphical                          
         models/projet/PGM-single-cell/data/expression.bin                                                         
[34mINFO    [0m Finished loading Cortex data                                                                              
(3005, 19972)
AnnData object with n_obs × n_vars = 3005 × 19972
    obs: 'labels', 'precise_labels', 'cell_type'
[2 6 5 4 3 1 0]
['1' '2' '3' '4' '5' '6' '7' '8' '9']
['interneurons' 'pyramidal SS' 'pyramidal CA1' 'oligodendrocytes'
 'microglia' 'endothelial-mural' 'astrocytes_ependymal']




In [3]:
n_clusters = len(adata.obs["precise_labels"].unique())
n_clusters

9

# Modeling

## Simple VAE

### Define model

In [4]:
class SimpleDecoder(nn.Module):
    def __init__(self, n_latent: int, n_output: int, n_hidden: int = 128):
        super().__init__()
        self.fc1 = nn.Linear(n_latent, n_hidden)
        self.bn1 = nn.BatchNorm1d(n_hidden)
        self.fc2 = nn.Linear(n_hidden, n_hidden)
        self.bn2 = nn.BatchNorm1d(n_hidden)
        self.output_mean = nn.Linear(n_hidden, n_output)
        self.output_disp = nn.Linear(n_hidden, n_output)
        self.dropout = nn.Dropout(p=0.1)

    def forward(self, z: torch.Tensor):
        h = self.dropout(torch.relu(self.bn1(self.fc1(z))))
        h = self.dropout(torch.relu(self.bn2(self.fc2(h))))
        mean = torch.nn.functional.softplus(self.output_mean(h))
        disp = torch.nn.functional.softplus(self.output_disp(h))
        return mean, disp

In [5]:
class SimpleEncoder(nn.Module):
    def __init__(self, n_input: int, n_latent: int, n_hidden: int = 128):
        super().__init__()
        self.fc1 = nn.Linear(n_input, n_hidden)
        self.bn1 = nn.BatchNorm1d(n_hidden)
        self.fc2 = nn.Linear(n_hidden, n_hidden)
        self.bn2 = nn.BatchNorm1d(n_hidden)
        self.mean_layer = nn.Linear(n_hidden, n_latent)
        self.var_layer = nn.Linear(n_hidden, n_latent)
        self.dropout = nn.Dropout(p=0.1)

    def forward(self, x: torch.Tensor):
        h = self.dropout(torch.relu(self.fc1(x)))
        h = self.dropout(torch.relu(self.fc2(h)))
        mean = self.mean_layer(h)
        log_var = self.var_layer(h)
        return mean, log_var

In [6]:
class SimpleVAEModule(BaseModuleClass):
    """Simple Variational auto-encoder model.

    Here we implement a basic version of scVI's underlying VAE [Lopez18]_.
    This implementation is for instructional purposes only.

    Parameters
    ----------
    n_input
        Number of input genes.
    n_latent
        Dimensionality of the latent space.
    """

    def __init__(
        self,
        n_input: int,
        n_latent: int = 10,
    ):
        super().__init__()
        self.encoder = SimpleEncoder(n_input, n_latent)
        self.decoder = SimpleDecoder(n_latent, n_input)


    def _get_inference_input(self, tensors: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        """Parse the dictionary to get appropriate args"""
        # let us fetch the raw counts, and add them to the dictionary
        return {"x": tensors[REGISTRY_KEYS.X_KEY]}

    @auto_move_data
    def inference(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
        """
        High level inference method.

        Runs the inference (encoder) model.
        """
        x_ = torch.log1p(x)
        qz_m, qz_v_log = self.encoder(x_)
        qz_v = qz_v_log.exp()
        z = Normal(qz_m, torch.sqrt(qz_v)).rsample()

        return {"qzm": qz_m, "qzv": qz_v, "z": z}

    def _get_generative_input(
        self, tensors: dict[str, torch.Tensor], inference_outputs: dict[str, torch.Tensor]
    ) -> dict[str, torch.Tensor]:
        return {
            "z": inference_outputs["z"],
            # "library": torch.sum(tensors[REGISTRY_KEYS.X_KEY], dim=1, keepdim=True),
        }

    @auto_move_data
    def generative(self, z: torch.Tensor) -> dict[str, torch.Tensor]:
        """Runs the generative model."""
        nb_mean, nb_disp = self.decoder(z)
        return {
            "nb_mean":nb_mean,
            "nb_disp":nb_disp,
        }

    def loss(
        self,
        tensors: dict[str, torch.Tensor],
        inference_outputs: dict[str, torch.Tensor],
        generative_outputs: dict[str, torch.Tensor],
    ) -> LossOutput:
        x = tensors[REGISTRY_KEYS.X_KEY]
        nb_mean = generative_outputs["nb_mean"]
        nb_disp = generative_outputs["nb_disp"]
        qz_m = inference_outputs["qzm"]
        qz_v = inference_outputs["qzv"]

        log_likelihood = NegativeBinomial(total_count=nb_disp, logits=torch.log(nb_mean+1e-4)).log_prob(x).sum(dim=-1)

        prior_dist = Normal(torch.zeros_like(qz_m), torch.ones_like(qz_v))
        var_post_dist = Normal(qz_m, torch.sqrt(qz_v))
        kl_divergence = kl(var_post_dist, prior_dist).sum(dim=-1)

        elbo = log_likelihood - kl_divergence
        loss = torch.mean(-elbo)
        return LossOutput(
            loss=loss,
            reconstruction_loss=-log_likelihood,
            # kl_local=kl_divergence,
        )

In [7]:
class SimpleVAEModel(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass):
    """single-cell Variational Inference [Lopez18]_."""

    def __init__(
        self,
        adata: AnnData,
        n_latent: int = 10,
        **model_kwargs,
    ):
        super().__init__(adata)

        self.module = SimpleVAEModule(
            n_input=self.summary_stats["n_vars"],
            # n_batch=self.summary_stats["n_batch"],
            n_latent=n_latent,
            **model_kwargs,
        )
        self._model_summary_string = (
            f"SCVI Model with the following params: \nn_latent: {n_latent}"
        )
        self.init_params_ = self._get_init_params(locals())

    @classmethod
    def setup_anndata(
        cls,
        adata: AnnData,
        batch_key: str | None = None,
        layer: str | None = None,
        **kwargs,
    ) -> AnnData | None:
        setup_method_args = cls._get_setup_method_args(**locals())
        anndata_fields = [
            LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
            CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
            # Dummy fields required for VAE class.
            CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, None),
            NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, None, required=False),
            CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, None),
            NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, None),
        ]
        adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
        adata_manager.register_fields(adata, **kwargs)
        cls.register_manager(adata_manager)

### Run model

In [8]:
SimpleVAEModel.setup_anndata(adata)
simple_vae = SimpleVAEModel(adata, n_latent=10)
simple_vae



In [9]:
# logger = WandbLogger(
#     project='PGM-single-cell',
#     name="simple vae"
# )
logger = None
# logger = TensorBoardLogger(save_dir="logs/", name="simple_vae")

simple_vae.train(
    max_epochs=3, 
    logger=logger, 
    accelerator="gpu",
    train_size=0.85,
    validation_size=0.1,
    early_stopping = True,
    )

  accelerator, lightning_devices, device = parse_device_args(
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/marc/mambaforge/envs/PGM/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/marc/mambaforge/envs/PGM/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training:   0%|          | 0/3 [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.


In [138]:
save_dir = "saved_model_dir"
os.makedirs(save_dir, exist_ok=True)
model_dir = os.path.join(save_dir, "simple_vae_model")

In [139]:
simple_vae.save(model_dir, save_anndata=True, overwrite=True) 

In [140]:
# simple_vae = SimpleVAEModel.load(model_dir)

## GM-VAE

### Define model

In [141]:
class EncoderXtoY(nn.Module):
    def __init__(self, n_input: int, n_clusters: int, n_hidden: int = 128):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(n_input, n_hidden),
            nn.BatchNorm1d(n_hidden), 
            nn.ReLU(),
            nn.Dropout(p=0.1),     
            nn.Linear(n_hidden, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Dropout(p=0.1),
            nn.Linear(n_hidden, n_clusters),
            nn.Softmax(dim=-1),       
        )

    def forward(self, x: torch.Tensor):
        probs_y = self.mlp(x)
        return probs_y

In [142]:
class EncoderXYtoZ(nn.Module):
    def __init__(self, n_input: int, n_clusters: int, n_latent: int, n_hidden: int = 128):
        super().__init__()
        self.proj_y = nn.Sequential(
            nn.Linear(n_clusters, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Dropout(p=0.1),
            nn.Linear(n_hidden, n_hidden),
            nn.BatchNorm1d(n_hidden)
        )
        self.proj_x = nn.Sequential(
            nn.Linear(n_input, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Dropout(p=0.1),
            nn.Linear(n_hidden, n_hidden),
            nn.BatchNorm1d(n_hidden)
        )
        self.commonlayer = nn.Sequential(
            nn.Linear(n_hidden * 2, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Dropout(p=0.1),
            nn.Linear(n_hidden, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Dropout(p=0.1),
        )
        self.output_mean = nn.Linear(n_hidden, n_latent)
        self.output_logvar = nn.Linear(n_hidden, n_latent)

    def forward(self, x: torch.Tensor, y: torch.Tensor):
        proj_x = self.proj_x(x)
        proj_y = self.proj_y(y)
        xy = torch.cat((proj_x,proj_y), dim=-1)
        h = self.commonlayer(xy)
        mean_n = self.output_mean(h)
        logvar_n = self.output_logvar(h)
        return mean_n, logvar_n

In [143]:
class DecoderYtoZ(nn.Module):
    def __init__(self, n_clusters: int, n_latent: int, n_hidden: int = 128):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(n_clusters, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Dropout(p=0.1),   
            nn.Linear(n_hidden, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Dropout(p=0.1)
        )
        self.output_mean_n = nn.Linear(n_hidden, n_latent)
        self.output_logvar_n = nn.Linear(n_hidden, n_latent) 

    def forward(self, probs_y: torch.Tensor):
        h = self.mlp(probs_y)
        mean_n = self.output_mean_n(h)
        logvar_n = self.output_logvar_n(h)
        return mean_n, logvar_n

In [144]:
class DecoderZtoX(nn.Module):
    def __init__(self, n_output: int, n_latent: int, n_hidden: int = 128):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(n_latent, n_hidden),
            nn.BatchNorm1d(n_hidden), 
            nn.ReLU(),
            nn.Dropout(p=0.1),        
            nn.Linear(n_hidden, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Dropout(p=0.1),
        )     
        self.output_mean = nn.Linear(n_hidden, n_output)
        self.output_disp = nn.Linear(n_hidden, n_output)

    def forward(self, z: torch.Tensor):
        h = self.mlp(z)
        mean = torch.nn.functional.softplus(self.output_mean(h))  
        disp = torch.nn.functional.softplus(self.output_disp(h)) 
        return mean, disp

In [145]:
class GMVAEModule(BaseModuleClass):
    """GM Variational auto-encoder model.

    Here we implement a basic version of scVI's underlying VAE [Lopez18]_.
    This implementation is for instructional purposes only.

    Parameters
    ----------
    n_input
        Number of input genes.
    n_latent
        Dimensionality of the latent space.
    """

    def __init__(
        self,
        n_input: int,
        n_clusters: int,
        n_latent: int = 10,
    ):
        super().__init__()
        self.encoderxtoy = EncoderXtoY(n_input=n_input, n_clusters=n_clusters)
        self.encoderxytoz = EncoderXYtoZ(n_clusters=n_clusters, n_input=n_input, n_latent=n_latent)
        self.mu_y = nn.Parameter(torch.randn(n_clusters, n_latent))
        self.logvar_y = nn.Parameter(torch.zeros(n_clusters, n_latent)) 
        self.decoderztox = DecoderZtoX(n_output=n_input, n_latent=n_latent)
        self.n_clusters = n_clusters
        self.n_latent = n_latent


    def _get_inference_input(self, tensors: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        return {"x": tensors[REGISTRY_KEYS.X_KEY]}

    @auto_move_data
    def inference(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
        x_ = torch.log1p(x) 
        probs_y = self.encoderxtoy(x_)
        y_one_hot = torch.eye(self.n_clusters, device=x.device).unsqueeze(0).repeat(x_.size(0), 1, 1)  # (batch_size, n_clusters, n_clusters)
        x_expanded = x_.unsqueeze(1).repeat(1, self.n_clusters, 1)  # (batch_size, n_clusters, n_input)
        mean_n, logvar_n = self.encoderxytoz(
            x=x_expanded.view(-1, x_.size(-1)),  # Fusion des dimensions pour le traitement batch
            y=y_one_hot.view(-1, self.n_clusters),  # Idem pour y
        )
        mean_n = mean_n.view(x_.size(0), self.n_clusters, -1)  # (batch_size, n_clusters, n_latent)
        logvar_n = logvar_n.view(x_.size(0), self.n_clusters, -1)  # (batch_size, n_clusters, n_latent)
        var_n = logvar_n.exp()
        z_normales = Normal(mean_n, torch.sqrt(var_n)).rsample()  # (batch_size, n_clusters, n_latent)

        return {
            "qzm": mean_n,
            "qzv": var_n,
            "z": z_normales,
            "probs_y": probs_y,
            "z_normales": z_normales
        }

    def _get_generative_input(
        self, tensors: dict[str, torch.Tensor], inference_outputs: dict[str, torch.Tensor]
    ) -> dict[str, torch.Tensor]:
        return {
            "z_normales": inference_outputs["z_normales"],
        }

    @auto_move_data
    def generative(self, z_normales: torch.Tensor) -> dict[str, torch.Tensor]:
        z_flat = z_normales.view(-1, z_normales.size(-1))  # (batch_size * n_clusters, n_latent)
        nb_mean, nb_disp = self.decoderztox(z_flat)  # (batch_size * n_clusters, n_output)
        nb_mean = nb_mean.view(z_normales.size(0), z_normales.size(1), -1)
        nb_disp = nb_disp.view(z_normales.size(0), z_normales.size(1), -1)

        return {
            "nb_mean": nb_mean,  # (batch_size, n_clusters, n_output)
            "nb_disp": nb_disp,  # (batch_size, n_clusters, n_output)
        }

    def loss(
        self,
        tensors: dict[str, torch.Tensor],
        inference_outputs: dict[str, torch.Tensor],
        generative_outputs: dict[str, torch.Tensor],
    ) -> LossOutput:
        x = tensors[REGISTRY_KEYS.X_KEY]
        batch_size = x.shape[0]
        nb_mean = generative_outputs["nb_mean"] # (batch_size, n_clusters, n_output)
        nb_disp = generative_outputs["nb_disp"] # (batch_size, n_clusters, n_output)

        qz_m = inference_outputs["qzm"] # (batch_size, n_clusters, n_latent)
        qz_v = inference_outputs["qzv"] # (batch_size, n_clusters, n_latent)
        z_normales = inference_outputs["z"]  # (batch_size, n_clusters, n_latent)
        probs_y = inference_outputs["probs_y"] # (batch_size, n_clusters)

        x_expanded = x.unsqueeze(1)
        log_likelihood = NegativeBinomial(total_count=nb_disp, logits=torch.log(nb_mean+1e-4)).log_prob(x_expanded).sum(dim=-1) # (batch_size, n_clusters)
        mu_y_expanded = self.mu_y.unsqueeze(0).expand(batch_size, -1, -1)  # (n_batch, n_clusters, n_latent)
        var_y_expanded = self.logvar_y.unsqueeze(0).expand(batch_size, -1, -1).exp()  # (n_batch, n_clusters, n_latent)
        priors_z_y_distributions = Normal(mu_y_expanded, torch.sqrt(var_y_expanded)) # (batch_size, n_clusters)
        var_post_dist = Normal(qz_m, torch.sqrt(qz_v)) # (batch_size, n_clusters)
        kl_div_1 = kl(var_post_dist, priors_z_y_distributions).sum(dim=-1) # (batch_size, n_clusters)

        avg_cat_ll_kl = ((log_likelihood - kl_div_1) * probs_y).sum(dim=1) # (batch_size)

        q_y_x = torch.distributions.Categorical(probs_y)
        probs_uniform = torch.ones_like(probs_y)/self.n_clusters
        unif_pi = torch.distributions.Categorical(probs_uniform)
        kl_div_2 = kl(q_y_x, unif_pi).sum(dim=-1) # (batch_size)

        elbo = avg_cat_ll_kl - kl_div_2 # (batch_size)
        loss = torch.mean(-elbo)
        return LossOutput(
            loss=loss,
            reconstruction_loss=-avg_cat_ll_kl,
        )

In [146]:
class GMVAEModel(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass):
    """single-cell Variational Inference [Lopez18]_."""

    def __init__(
        self,
        adata: AnnData,
        n_clusters: int,
        n_latent: int = 10,
        **model_kwargs,
    ):
        super().__init__(adata)

        self.module = GMVAEModule(
            n_input=self.summary_stats["n_vars"],
            # n_batch=self.summary_stats["n_batch"],
            n_clusters=n_clusters,
            n_latent=n_latent,
            **model_kwargs,
        )
        self._model_summary_string = (
            f"SCVI Model with the following params: \nn_latent: {n_latent}"
        )
        self.init_params_ = self._get_init_params(locals())


    def get_latent_representation(
        self,
        adata: AnnData | None = None,
        indices: Sequence[int] | None = None,
        batch_size: int | None = None,
        dataloader: Iterator[dict[str, Tensor | None]] = None,
    ):

        self._check_if_trained(warn=False)
        if adata is not None and dataloader is not None:
            raise ValueError("Only one of `adata` or `dataloader` can be provided.")

        if dataloader is None:
            adata = self._validate_anndata(adata)
            dataloader = self._make_data_loader(
                adata=adata, indices=indices, batch_size=batch_size
            )
        latent = {}
        latent_rep = []
        latent_cat = []
        for tensors in dataloader:
            inference_inputs = self.module._get_inference_input(tensors)
            outputs = self.module.inference(**inference_inputs)
            qz_m = outputs["qzm"]
            probs_y = outputs["probs_y"]
            latent_rep += [qz_m.cpu()]
            latent_cat += [probs_y.cpu()]
        latent["latent_rep"] = torch.cat(latent_rep).detach().numpy()
        latent["latent_cat"] = torch.cat(latent_cat).detach().numpy()
        return latent

    @classmethod
    def setup_anndata(
        cls,
        adata: AnnData,
        batch_key: str | None = None,
        layer: str | None = None,
        **kwargs,
    ) -> AnnData | None:
        setup_method_args = cls._get_setup_method_args(**locals())
        anndata_fields = [
            LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
            CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
            # Dummy fields required for VAE class.
            CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, None),
            NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, None, required=False),
            CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, None),
            NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, None),
        ]
        adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
        adata_manager.register_fields(adata, **kwargs)
        cls.register_manager(adata_manager)

In [147]:
GMVAEModel.setup_anndata(adata)
print(n_clusters)
gm_vae = GMVAEModel(adata, n_clusters=n_clusters, n_latent=10)
gm_vae

9




### Run model

In [148]:
# logger = WandbLogger(
#     project='PGM-single-cell', 
#     name="GM VAE",
#     id="gm_vae",
# )
# logger = None
logger = TensorBoardLogger(save_dir="logs/", name="gm_vae")

gm_vae.train(
    max_epochs=50, 
    logger=logger, 
    accelerator="gpu",
    train_size=0.85,
    validation_size=0.1,
    early_stopping = True,
    )

  accelerator, lightning_devices, device = parse_device_args(
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/marc/mambaforge/envs/PGM/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/marc/mambaforge/envs/PGM/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training:   0%|          | 0/50 [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.


In [149]:
print(np.max(gm_vae.get_latent_representation(adata)["latent_cat"]))

0.5572785


In [150]:
save_dir = "saved_model_dir"
os.makedirs(save_dir, exist_ok=True)
model_dir_gm = os.path.join(save_dir, "gm_vae_model")

In [151]:
gm_vae.save(model_dir_gm, save_anndata=True, overwrite=True) 

In [152]:
# gm_vae = GMVAEModel.load(model_dir_gm)

# VAE from SCVI

In [153]:
# logger = WandbLogger(
#     project='PGM-single-cell', 
#     name="scvi VAE",
# )
# logger = None
logger = TensorBoardLogger(save_dir="logs/", name="scvi_vae")

scvi.model.SCVI.setup_anndata(adata)

# Initialiser le modèle scVI
vae = scvi.model.SCVI(adata)

# Entraîner le modèle
vae.train(
    max_epochs=50, 
    logger=logger, 
    train_size=0.85,
    validation_size=0.1,
    early_stopping = True,
    )

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/marc/mambaforge/envs/PGM/lib/python3.12/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Users/marc/mambaforge/envs/PGM/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/marc/mambaforge/envs/PGM/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training:   0%|          | 0/50 [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.


In [154]:
save_dir = "saved_model_dir"
os.makedirs(save_dir, exist_ok=True)
model_dir_scvi = os.path.join(save_dir, "scvi_vae_model")

In [155]:
vae.save(model_dir_scvi, save_anndata=True, overwrite=True) 

In [156]:
# vae = SCVI.load(model_dir_scvi)

# Evaluation

In [157]:
# --- Affichage des scores ---
print("=== Scores de Clustering ===\n")

# Récupération des vrais labels
true_labels = adata.obs['precise_labels']

# --- Clustering et évaluation pour le VAE scVI  ---
latent_rep_vae = vae.get_latent_representation(adata)
kmeans_vae = KMeans(n_clusters=len(set(true_labels)), n_init=200, random_state=42)
predicted_labels_vae = kmeans_vae.fit_predict(latent_rep_vae)
ari_vae = adjusted_rand_score(true_labels, predicted_labels_vae)
nmi_vae = normalized_mutual_info_score(true_labels, predicted_labels_vae)
print("VAE de scvi tools :")
print(f"  Adjusted Rand Index (ARI): {ari_vae:.4f}")
print(f"  Normalized Mutual Information (NMI): {nmi_vae:.4f}\n")

# --- Clustering et évaluation pour le VAE personnalisé  ---
latent_rep_svae = simple_vae.get_latent_representation(adata)
kmeans_svae = KMeans(n_clusters=len(set(true_labels)), n_init=200, random_state=42)
predicted_labels_svae = kmeans_svae.fit_predict(latent_rep_svae)
ari_svae = adjusted_rand_score(true_labels, predicted_labels_svae)
nmi_svae = normalized_mutual_info_score(true_labels, predicted_labels_svae)
print("Simple VAE :")
print(f"  Adjusted Rand Index (ARI): {ari_svae:.4f}")
print(f"  Normalized Mutual Information (NMI): {nmi_svae:.4f}\n")


# --- Clustering et évaluation pour le VAE personnalisé avec GM ---
latent_cat_gmvae = gm_vae.get_latent_representation(adata)["latent_cat"]
predicted_labels_gmvae = np.argmax(latent_cat_gmvae, axis=-1)
ari_gmvae = adjusted_rand_score(true_labels, predicted_labels_gmvae)
nmi_gmvae = normalized_mutual_info_score(true_labels, predicted_labels_gmvae)
print("GM VAE :")
print(f"  Adjusted Rand Index (ARI): {ari_gmvae:.4f}")
print(f"  Normalized Mutual Information (NMI): {nmi_gmvae:.4f}\n")


=== Scores de Clustering ===

VAE de scvi tools :
  Adjusted Rand Index (ARI): 0.5614
  Normalized Mutual Information (NMI): 0.6957

Simple VAE :
  Adjusted Rand Index (ARI): 0.5193
  Normalized Mutual Information (NMI): 0.6538

GM VAE :
  Adjusted Rand Index (ARI): 0.2446
  Normalized Mutual Information (NMI): 0.3208

