In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [2]:
n_cells = 1000
n_genes = 10
p_zero = 0.8 # probability that a gene expression value is zero


rng = np.random.default_rng(seed=32)  # Create a generator with a seed
x_pre = rng.integers(1, 10, size=(n_cells,n_genes))  # Generate 10 random integers between 0 and 10
mask = rng.binomial(n=1, p=1-p_zero, size=(n_cells,n_genes))

x_pre, mask

(array([[8, 2, 8, ..., 3, 6, 7],
        [4, 9, 8, ..., 8, 9, 5],
        [6, 3, 8, ..., 9, 4, 4],
        ...,
        [5, 5, 7, ..., 9, 1, 2],
        [1, 5, 9, ..., 9, 8, 1],
        [2, 7, 7, ..., 8, 6, 9]]),
 array([[0, 0, 1, ..., 0, 1, 0],
        [0, 1, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 1, ..., 0, 0, 0],
        [1, 0, 0, ..., 1, 0, 0],
        [0, 0, 1, ..., 0, 0, 0]]))

In [3]:
x = x_pre * mask
x

array([[0, 0, 8, ..., 0, 6, 0],
       [0, 9, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 7, ..., 0, 0, 0],
       [1, 0, 0, ..., 9, 0, 0],
       [0, 0, 7, ..., 0, 0, 0]])

In [None]:
x[:n_cells//2,:] += 10
x *= mask
x

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 3, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 7, ..., 0, 0, 0],
       [1, 0, 0, ..., 9, 0, 0],
       [0, 0, 7, ..., 0, 0, 0]])

In [29]:
class ZINBLoss(nn.Module):
    def __init__(self, ridge_lambda=0.0):
        super(ZINBLoss, self).__init__()
        self.eps = 1e-10
        self.ridge_lambda = ridge_lambda

    def forward(self, x, mean, dispersion, pi, scale_factor=1.0):
        x = x.float()
        mean = mean * scale_factor

        nb_case = (
            torch.lgamma(dispersion + self.eps)
            + torch.lgamma(x + 1.0)
            - torch.lgamma(x + dispersion + self.eps)
            - dispersion * torch.log(dispersion + self.eps)
            - x * torch.log(mean + self.eps)
            + (dispersion + x) * torch.log(dispersion + mean + self.eps)
        )

        zero_case = -torch.log(pi + ((1.0 - pi) * torch.exp(-nb_case)) + self.eps)

        result = torch.where(torch.lt(x, 1e-8), zero_case, -torch.log(1.0 - pi + self.eps) + nb_case)
        ridge = self.ridge_lambda * (pi ** 2).sum()

        return result.mean() + ridge


class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar
    

class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(latent_dim, hidden_dim)
        self.fc_mean = nn.Linear(hidden_dim, output_dim)      # Mean of NB
        self.fc_disp = nn.Linear(hidden_dim, output_dim)      # Dispersion
        self.fc_pi = nn.Linear(hidden_dim, output_dim)        # Zero-inflation

    def forward(self, z):
        h = F.relu(self.fc1(z))
        mean = torch.exp(self.fc_mean(h))                     # Mean > 0
        dispersion = torch.exp(self.fc_disp(h))               # Dispersion > 0
        pi = torch.sigmoid(self.fc_pi(h))                     # pi in [0, 1]
        return mean, dispersion, pi


class ZINBVAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(ZINBVAE, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
        self.decoder = Decoder(latent_dim, hidden_dim, input_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        mean, disp, pi = self.decoder(z)
        return mean, disp, pi, mu, logvar, z

    def loss_function(self, x, mean, disp, pi, mu, logvar):
        zinb_loss = ZINBLoss()(x, mean, disp, pi)
        kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)
        return zinb_loss + kl_div


In [43]:
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # For deterministic behavior
    #torch.backends.cudnn.deterministic = True
    #torch.backends.cudnn.benchmark = False



In [46]:
# Set the seed
set_seed(11)
vae = ZINBVAE(input_dim=n_genes, hidden_dim=20, latent_dim=5)
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
x = torch.tensor(x, dtype=torch.float32)

vae.train()
for epoch in range(10):
    optimizer.zero_grad()
    mean, disp, pi, mu, logvar, z = vae(x)
    loss = vae.loss_function(x, mean, disp, pi, mu, logvar)
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

Epoch 0, Loss: 2.0776
Epoch 1, Loss: 2.0428
Epoch 2, Loss: 2.0035
Epoch 3, Loss: 1.9675
Epoch 4, Loss: 1.9350
Epoch 5, Loss: 1.9070
Epoch 6, Loss: 1.8676
Epoch 7, Loss: 1.8429
Epoch 8, Loss: 1.8177
Epoch 9, Loss: 1.7863


  x = torch.tensor(x, dtype=torch.float32)


torch.Size([1000, 5])