In [2]:
import sys
import importlib
from tqdm import tqdm, trange
sys.path.append("../src")

import pseudobulk
importlib.reload(pseudobulk)
import models
importlib.reload(models)
import utils
importlib.reload(utils)

<module 'utils' from '/home/jhaberbe/Projects/Personal/bulk-deconvolution/notebook/../src/utils.py'>

In [3]:
import scanpy as sc

pbulk = sc.read_h5ad(
    "/home/jhaberbe/Projects/Personal/bulk-deconvolution/data/pbulk.h5ad"
)

# Maybe exclude subset
sc.pp.highly_variable_genes(pbulk, flavor="seurat_v3", subset=True)

# Training the mixture model

In [4]:
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader, random_split

X_raw = torch.tensor(pbulk.X, dtype=torch.float32)  # (n, F), convert from sparse
y = torch.tensor(pbulk.obs.values, dtype=torch.float32)
y = y / y.sum(dim=1, keepdim=True).clamp(min=1e-8)  # Normalize rows

X = utils.scanpy_log_normalize(X_raw)

# --- Train/test split ---
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64)

mixture_model = models.MixturePrediction(X, y)

In [24]:
import torch.nn.functional as F

optimizer = torch.optim.Adam(mixture_model.parameters(), lr=1e-3)

# --- Training loop ---
epochs = 50
for epoch in range(epochs):
    mixture_model.train()
    total_loss = 0
    for xb, yb in train_loader:
        optimizer.zero_grad()
        preds = mixture_model(xb)
        loss = F.kl_div(preds.log(), yb, reduction='batchmean')
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    # Evaluate on test
    mixture_model.eval()
    with torch.no_grad():
        test_loss = 0
        for xb, yb in test_loader:
            preds = mixture_model(xb)
            test_loss += F.kl_div(preds.log(), yb, reduction='batchmean').item()

    print(f"Epoch {epoch:02d} | Train Loss: {total_loss:.4f} | Test Loss: {test_loss:.4f}")

Epoch 00 | Train Loss: 39.8297 | Test Loss: 1.7247
Epoch 01 | Train Loss: 11.4133 | Test Loss: 1.2601
Epoch 02 | Train Loss: 8.2495 | Test Loss: 1.0629
Epoch 03 | Train Loss: 7.2971 | Test Loss: 0.8544
Epoch 04 | Train Loss: 6.2426 | Test Loss: 0.7956
Epoch 05 | Train Loss: 5.4839 | Test Loss: 0.6319
Epoch 06 | Train Loss: 4.8126 | Test Loss: 0.6221
Epoch 07 | Train Loss: 4.5027 | Test Loss: 0.5177
Epoch 08 | Train Loss: 3.9968 | Test Loss: 0.5271
Epoch 09 | Train Loss: 3.7191 | Test Loss: 0.4592
Epoch 10 | Train Loss: 3.4107 | Test Loss: 0.4445
Epoch 11 | Train Loss: 3.3774 | Test Loss: 0.4286
Epoch 12 | Train Loss: 3.0473 | Test Loss: 0.4444
Epoch 13 | Train Loss: 3.0061 | Test Loss: 0.3754
Epoch 14 | Train Loss: 2.8171 | Test Loss: 0.3558
Epoch 15 | Train Loss: 2.8839 | Test Loss: 0.3505
Epoch 16 | Train Loss: 2.8389 | Test Loss: 0.4110
Epoch 17 | Train Loss: 2.7281 | Test Loss: 0.3439
Epoch 18 | Train Loss: 2.6644 | Test Loss: 0.3423
Epoch 19 | Train Loss: 2.7263 | Test Loss: 0.442

In [5]:
# torch.save(mixture_model, "../models/mixture_model.pt")
mixture_model = torch.load("../models/mixture_model.pt", weights_only=False)

# Training Dirichlet Model

In [6]:
mixture_weights = mixture_model(X)
mixture_weights = mixture_weights.detach()

In [7]:
dirichlet_model = models.MixtureToDirichlet(num_components=8, num_features=X.shape[1]).to("cuda")
optimizer = torch.optim.Adam(dirichlet_model.parameters(), lr=1e-3)

In [8]:
import numpy as np
counts = torch.tensor(np.stack([pbulk.layers[layer] for layer in pbulk.layers])).permute(1, 0, 2)

In [None]:
from pyro.distributions import DirichletMultinomial
from torch.utils.data import TensorDataset, DataLoader

batch_size = pbulk.shape[0]

dataset = TensorDataset(mixture_weights, counts)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

for epoch in range(50):
    dirichlet_model.train()
    epoch_loss = 0.0

    for batch_mixture, batch_counts in loader:
        batch_mixture = batch_mixture.to("cuda")
        batch_counts = batch_counts.to("cuda")

        optimizer.zero_grad()
        alpha_pred = dirichlet_model(batch_mixture)  # [batch_size, C, F]
        loss = dirichlet_model.dirichlet_multinomial_loss(alpha_pred, batch_counts)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        print(f"Epoch {epoch} | Loss: {loss:.4f}")


    avg_loss = epoch_loss / len(dataset)
    print(f"Epoch {epoch} | Loss: {avg_loss:.4f}")

In [None]:
torch.save(model, "../models/dirichlet_model.pt")
dirichlet_model = torch.load("../models/dirichlet_model.pt", weights_only=False)

AttributeError: Can't get attribute 'MixtureToDirichlet' on <module '__main__'>