Training of the Autoencoder
Training on the pbmc3k train dataset. 

In [1]:
#Import packages


import anndata as ad
import torch
from torch.utils.data import DataLoader, TensorDataset, Dataset

import numpy as np
from typing import List, Optional, Callable
import torch.nn as nn
import torch.nn.functional as F
from scvi.distributions import NegativeBinomial
import torch.nn.functional as F
import scanpy as sc
import matplotlib.pyplot as plt
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


Define the function for an MLP to be used in our autoencoder

In [2]:
# Here we define an MLP as the basis of our encoder and decoder
class MLP(nn.Module):
    def __init__(self, 
                 dims: List[int],
                 batch_norm: bool, 
                 dropout: bool, 
                 dropout_p: float, 
                 activation: Optional[Callable] = nn.ELU, 
                 final_activation: Optional[str] = None):
        super().__init__()
        self.dims = dims
        self.batch_norm = batch_norm
        self.activation = activation
        layers = []
        for i in range(len(dims[:-2])):
            block = [nn.Linear(dims[i], dims[i+1])]
            if batch_norm:
                block.append(nn.BatchNorm1d(dims[i+1]))
            block.append(activation())
            if dropout:
                block.append(nn.Dropout(dropout_p))
            layers.append(nn.Sequential(*block))
        layers.append(nn.Linear(dims[-2], dims[-1]))
        self.net = nn.Sequential(*layers)
        if final_activation == "tanh":
            self.final_activation = nn.Tanh()
        elif final_activation == "sigmoid":
            self.final_activation = nn.Sigmoid()
        else:
            self.final_activation = None

    def forward(self, x):
        x = self.net(x)
        return x if self.final_activation is None else self.final_activation(x)

In [14]:
from scipy.stats import beta

def size_factor_distribution(adata_train, n_samples):
    # Parameters
    """ Here we have a function that gets a distribution of the library size in the corresponding data set.
    It then samples from this distribution, so we get a realistic distribution of library size in each decoded cell
    """
    min_val = adata_train.obs['n_counts'].min() 
    max_val = adata_train.obs['n_counts'].max()
    mean = adata_train.obs['n_counts'].mean()
    std = adata_train.obs['n_counts'].std()
    
    # Beta distribution sampling
    m = (mean - min_val) / (max_val - min_val)
    v = (std**2) / ((max_val - min_val)**2)
    temp = m*(1-m)/v - 1
    a_beta = m * temp
    b_beta = (1-m) * temp
    
    samples_beta = beta.rvs(a_beta, b_beta, size=n_samples)
    samples_beta = samples_beta * (max_val - min_val) + min_val

    return samples_beta

Load the negative binomial autoencoder model

In [15]:
# Here we incorporate our MLP into a bigger class, so that we can train an autoencoder. We train it independetly of the 
# flow matching model.
class NB_Autoencoder(nn.Module):
    def __init__(self,
                 num_features: int,
                 latent_dim: int = 50,
                 hidden_dims: List[int] = [512, 256],
                 dropout_p: float = 0.1,
                 l2_reg: float = 1e-5,
                 kl_reg: float = 0):
        super().__init__()
        self.num_features = num_features
        self.latent_dim = latent_dim
        self.l2_reg = l2_reg
        self.kl_reg = kl_reg

        self.hidden_encoder = MLP(
        dims=[num_features, *hidden_dims, latent_dim],
        batch_norm=True,
        dropout=False,
        dropout_p=dropout_p
        )
        #self.latent_layer = nn.Linear(hidden_dims[-1], latent_dim)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(self.device)
        self.decoder = MLP(
            dims=[latent_dim, *hidden_dims[::-1], num_features],
            batch_norm=True,
            dropout=False,
            dropout_p=dropout_p
        )

        #self.log_theta = nn.Parameter(torch.randn(num_features) * 0.01)
        self.theta = torch.nn.Parameter(torch.randn(num_features), requires_grad=True)
    def forward(self, x, library_size = None):
        """ forward function that encodes and decodes"""
        z = self.hidden_encoder(x["X_norm"])
        
        #z = self.latent_layer(h)
        # Raw decoded logits
        logits = self.decoder(z)  
        
        # Softmax over genes → normalized probabilities
        gene_probs = F.softmax(logits, dim=1)

        if library_size is None:
            # Use average library size 1.0 if not provided
            # Sample size factors from your custom distribution
            lib = size_factor_distribution(adata, z.size(0))   # returns numpy array or list

            # Convert to torch tensor, match shape, move to correct device
            library_size = torch.tensor(lib, dtype=torch.float32, device=z.device).unsqueeze(1)

            #library_size = torch.ones(z.size(0), 1, device=z.device)
        # Library size of each cell (sum of counts)
        #library_size = x["X"].sum(1).unsqueeze(1).to(self.device)  
        
        # Scale probabilities by library size → mean parameter μ
        mu = gene_probs * library_size
 

        #theta = torch.exp(self.log_theta).unsqueeze(0).expand_as(mu)
        return {"z": z, "mu": mu, "theta": self.theta}

    def encode(self, x):
        """ decoding function"""
        z = self.hidden_encoder(x)
        return z

    def size_factor_distribution(self, adata_train, n_samples):
        """ same function as in block before """
        min_val = adata_train.obs['n_counts'].min() 
        max_val = adata_train.obs['n_counts'].max()
        mean = adata_train.obs['n_counts'].mean()
        std = adata_train.obs['n_counts'].std()
        
        # Beta distribution sampling
        m = (mean - min_val) / (max_val - min_val)
        v = (std**2) / ((max_val - min_val)**2)
        temp = m*(1-m)/v - 1
        a_beta = m * temp
        b_beta = (1-m) * temp
        
        samples_beta = beta.rvs(a_beta, b_beta, size=n_samples)
        samples_beta = samples_beta * (max_val - min_val) + min_val
    
        return samples_beta

    
        
    def decode(self, z, library_size=None):
        """
        Decode latent vectors z to NB parameters mu, theta.
        z: (batch, latent_dim)
        library_size: (batch, 1) sum of counts per cell; if None, use 1.0
        """
        logits = self.decoder(z)  # (batch, num_genes)
        gene_probs = F.softmax(logits, dim=1)  # softmax over genes
        # if library size isnt specified:
        if library_size is None:
            # Sample size factors from your custom distribution
            lib = size_factor_distribution(adata, z.size(0))   # returns numpy array or list

            # Convert to torch tensor, match shape, move to correct device
            library_size = torch.tensor(lib, dtype=torch.float32, device=z.device).unsqueeze(1)

            #library_size = torch.ones(z.size(0), 1, device=z.device)
    
        mu = gene_probs * library_size  # scale by library size
        #theta = torch.exp(self.log_theta).unsqueeze(0).expand_as(mu)
        return {"mu": mu, "theta": self.theta}
   

    def loss_function(self, x, outputs):
        """
        Compute loss using scvi NegativeBinomial.
        """
        mu = outputs["mu"]          # (batch, n_genes)
        theta = outputs["theta"]    # (batch, n_genes)
        z = outputs["z"]            # latent
    
        # scvi NegativeBinomial expects mu and theta
        nb_dist = NegativeBinomial(mu=mu, theta=torch.exp(self.theta))
        nll = -nb_dist.log_prob(x).sum(dim=1).mean()  # mean over batch
        
        # Optional regularization
        l2_loss = sum((p**2).sum() for p in self.parameters()) * self.l2_reg
        kl_loss = (z**2).mean() * self.kl_reg
    
        loss = nll + l2_loss + kl_loss
        return {"loss": loss, "nll": nll}


Define a dataloader that both gives the raw counts X, and the log1p normalized counts X_norm

In [16]:
# dataloader

class CountsDataset(Dataset):
    def __init__(self, X, y=None):
        """
        function that returns a dict containing both the original raw counts and the log1p normalized counts.
        X: raw counts tensor (num_cells, num_genes)
        y: optional labels tensor (num_cells,)
        """
        if hasattr(X, "toarray"):
            X = X.toarray()
        self.X = torch.tensor(X, dtype=torch.float32)
        self.X_norm = torch.log1p(self.X)  # log1p = log(1 + x)
        self.y = torch.tensor(y, dtype=torch.long) if y is not None else None
        self.n_samples = self.X.shape[0]

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        sample = dict(
            X=self.X[idx],
            X_norm=self.X_norm[idx]
        )
        if self.y is not None:
            sample["y"] = self.y[idx]
        return sample




Load the data and extract raw counts X and check if it can transform correctly with log1p

In [19]:

input_file_path = "/dtu/blackhole/0e/214382/datasets/pbmc3k/pbmc3k_train.h5ad"

adata = ad.read_h5ad(input_file_path)
adata.obs.head()
# extract raw counts
X = adata.layers["X_counts"]
if hasattr(X, "toarray"):
    X = X.toarray()
X = torch.tensor(X, dtype=torch.float32)
# check shape is correct
print(X.shape)



torch.Size([2110, 8573])


Training loop
Very important that the loss is computed on the raw counts

In [17]:
# Main training + encoding

input_file = input_file_path
latent_dim = 50
hidden_dims = [512, 256]
batch_size = 512
epochs = 1000
learning_rate = 1e-3
epochs_list=[]
loss_list=[] 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#  Load data
adata = ad.read_h5ad(input_file)

# Load RAW COUNTS
X = adata.layers["X_counts"]
if hasattr(X, "toarray"):
    X = X.toarray()
X = torch.tensor(X, dtype=torch.float32)



dataset = CountsDataset(X)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

#  Initialize model 
num_genes = adata.n_vars
model = NB_Autoencoder(num_features=num_genes,
                       latent_dim=latent_dim,
                       hidden_dims=hidden_dims)
model = model.to(device)
model.train()

# we use a adamw model, inspired from similar works
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# Training loop 
for epoch in range(epochs):
    epoch_loss = 0
    for batch in dataloader:
        # Use log-transformed input for encoder
        
        batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}

        # Forward pass
        outputs = model(batch)

        # Compute loss on raw counts
        loss_dict = model.loss_function(batch["X"], outputs)
        loss = loss_dict["loss"]

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        #epoch_loss += loss.item() * x_raw.size(0)
        epoch_loss += loss.item() * batch["X"].size(0)
        

    epoch_loss /= len(dataset)
    epochs_list.append(epoch)
    loss_list.append(epoch_loss)
    
    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch+1}/{epochs} - Loss: {epoch_loss:.4f}")

# Save trained model
model_file = input_file_path.replace(".h5ad", "_nb_autoencoder.pt")
torch.save(model.state_dict(), model_file)
print(f"Trained model saved to {model_file}")



  self.X = torch.tensor(X, dtype=torch.float32)


Epoch 50/1000 - Loss: 2765.1971
Epoch 100/1000 - Loss: 2611.7810
Epoch 150/1000 - Loss: 2474.4409
Epoch 200/1000 - Loss: 2337.2923
Epoch 250/1000 - Loss: 2212.2679
Epoch 300/1000 - Loss: 2103.8249
Epoch 350/1000 - Loss: 2022.9652
Epoch 400/1000 - Loss: 1963.6307
Epoch 450/1000 - Loss: 1901.3625
Epoch 500/1000 - Loss: 1873.8554
Epoch 550/1000 - Loss: 1839.2425
Epoch 600/1000 - Loss: 1823.2363
Epoch 650/1000 - Loss: 1791.0843
Epoch 700/1000 - Loss: 1791.4225
Epoch 750/1000 - Loss: 1770.5874
Epoch 800/1000 - Loss: 1756.4842
Epoch 850/1000 - Loss: 1742.1504
Epoch 900/1000 - Loss: 1736.6346
Epoch 950/1000 - Loss: 1723.1017
Epoch 1000/1000 - Loss: 1729.5569
Trained model saved to /dtu/blackhole/0e/214382/datasets/pbmc3k/pbmc3k_train_nb_autoencoder.pt


In [18]:

# Save encoded training cells to train flow model
model.eval()
all_z = []

with torch.no_grad():
    for batch in tqdm(dataloader):
        batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}
        outputs = model.forward(batch)

        z = outputs["z"].cpu().numpy()
        all_z.append(z)

latent = np.concatenate(all_z, axis=0)

# Save to AnnData
adata.obsm["X_latent"] = latent
output_file = input_file.replace(".h5ad", "_with_latent.h5ad")
adata.write(output_file)

print(f"Latent space saved to {output_file}")

100%|██████████| 5/5 [00:00<00:00, 58.15it/s]


Latent space saved to /dtu/blackhole/0e/214382/datasets/pbmc3k/pbmc3k_train_with_latent.h5ad
