# 🧬 scIDiff: Single-cell Inverse Diffusion

A demo of using DDPMs for denoising and inverse design of scRNA-seq profiles.

In [None]:
# Install required packages (uncomment if running in Colab)
# !pip install torch torchvision torchaudio
# !pip install scanpy

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import scanpy as sc


In [None]:
# Load example single-cell data
adata = sc.datasets.pbmc3k()
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=1000)
adata = adata[:, adata.var['highly_variable']]
data = adata.X.toarray().astype(np.float32)
print("Shape:", data.shape)


In [None]:
T = 1000
beta = np.linspace(1e-4, 0.02, T)
alpha = 1 - beta
alpha_hat = np.cumprod(alpha)

plt.plot(alpha_hat)
plt.title("Cumulative Alpha Schedule")
plt.xlabel("Timestep")
plt.ylabel("Alpha Hat")
plt.show()


In [None]:
class MLP(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, 512),
            nn.ReLU(),
            nn.Linear(512, dim)
        )
    def forward(self, x, t):
        return self.net(x)

model = MLP(data.shape[1])
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


In [None]:
def q_sample(x_0, t, noise):
    alpha_t = torch.tensor(alpha_hat[t], dtype=torch.float32).unsqueeze(1)
    return torch.sqrt(alpha_t) * x_0 + torch.sqrt(1 - alpha_t) * noise

# Convert data
x_0 = torch.tensor(data[:512])
for step in range(1000):
    t = torch.randint(0, T, (x_0.size(0),))
    noise = torch.randn_like(x_0)
    x_t = q_sample(x_0, t, noise)
    noise_pred = model(x_t, t)
    loss = F.mse_loss(noise_pred, noise)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if step % 100 == 0:
        print(f"Step {step}, Loss: {loss.item():.4f}")
