In [None]:
# %%


# %%


# %%


# %%


# %%
# Find all checkpoints in the checkpoints directory


# %%


# %%


# %%

# %%

In [1]:
import os
import glob
import torch
import numpy as np
import pandas as pd
import anndata as ad
from tqdm import tqdm

import matplotlib.pyplot as plt
import seaborn as sns
import umap

import pytorch_lightning as pl

from wcd_vae.data import get_dataloader_from_adata
from wcd_vae.model import VAE, VAEConfig, Discriminator, VAEDiscriminator, VAEWasserstein, VAE_OT
from wcd_vae.metrics import compute_metrics

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
pl.seed_everything(42, workers=True)

Seed set to 42


42

In [2]:
# Load the anndata object (adjust path if needed)
anndata_path = "../data/vu_2022_ay_wh.h5ad"
anndata = ad.read_h5ad(anndata_path)
anndata.layers["normalized"] = anndata.X

# Find/subset HVGs & swap to raw counts
import scanpy as sc
sc.pp.highly_variable_genes(anndata, n_top_genes=3000, batch_key="sample")
anndata = anndata[:, anndata.var["highly_variable"]]
anndata.X = anndata.layers["counts"]

  anndata.X = anndata.layers["counts"]


In [3]:
# VAE config (should match training)
config = VAEConfig(
    input_dim=anndata.shape[1],
    latent_dim=128,
    encoder_hidden_dims=[512, 256],
    decoder_hidden_dims=[256, 512],
    dropout=0.2,
    batchsize=128,
    num_epochs=100_000,
    lr=1e-4,
    weight_decay=1e-5,
    kl_anneal_start=0,
    kl_anneal_end=100,
    kl_anneal_max=1,
)

In [8]:
# Get test loader
_, test_loader, domain_encoder, cell_encoder = get_dataloader_from_adata(
    anndata, by="age", batch_size=config.batchsize, num_workers=0
)

In [12]:
ckpt_dir = "/workspaces/wasserstein-critic-deconfounding/notebooks/checkpoints"
ckpt_files = sorted(glob.glob(os.path.join(ckpt_dir, "*.ckpt")))
print("Found checkpoints:", ckpt_files)

Found checkpoints: ['/workspaces/wasserstein-critic-deconfounding/notebooks/checkpoints/train_vae-epoch=20043-val_loss=42.99.ckpt', '/workspaces/wasserstein-critic-deconfounding/notebooks/checkpoints/train_vae_c-epoch=50248-val_loss=46.09.ckpt', '/workspaces/wasserstein-critic-deconfounding/notebooks/checkpoints/train_vae_d-epoch=99500-val_loss=40.09.ckpt', '/workspaces/wasserstein-critic-deconfounding/notebooks/checkpoints/train_vae_uot-epoch=21285-val_loss=48.59.ckpt']


In [None]:
def load_and_eval(model_class, ckpt_path, config, test_loader, device="cuda" if torch.cuda.is_available() else "cpu"):
    # Instantiate and load model
    if model_class == VAE:
        model = VAE.load_from_checkpoint(ckpt_path, config=config, linear_decoder=True)
        eval_model = model
    elif model_class == VAEDiscriminator:
        vae = VAE(config, linear_decoder=True)
        critic = Discriminator(config.latent_dim, critic=False)
        model = VAEDiscriminator.load_from_checkpoint(ckpt_path, vae=vae, critic=critic)
        eval_model = model.vae
    elif model_class == VAEWasserstein:
        vae = VAE(config, linear_decoder=True)
        critic = Discriminator(config.latent_dim, critic=True)
        model = VAEWasserstein.load_from_checkpoint(ckpt_path, vae=vae, critic=critic)
        eval_model = model.vae
    elif model_class == VAE_OT:
        model = VAE_OT.load_from_checkpoint(ckpt_path, config=config, linear_decoder=True)
        eval_model = model
    else:
        raise ValueError("Unknown model class")

    model.eval()
    model = model.to(device)
    eval_model = eval_model.to(device)
    eval_model.eval()

    # Compute embeddings
    embeddings = []
    batches = []
    cell_type = []
    with torch.no_grad():
        for batch in tqdm(test_loader, desc=f"Evaluating {os.path.basename(ckpt_path)}"):
            x, batch_label, cell_label = batch
            x, batch_label, cell_label = x.to(device), batch_label.to(device), cell_label.to(device)
            mu, logvar = eval_model.encode(x)
            z = eval_model.reparameterize(mu, logvar)
            embeddings.append(z.cpu())
            batches.append(batch_label.cpu())
            cell_type.append(cell_label.cpu())
    embeddings = torch.cat(embeddings, dim=0)
    batches = torch.cat(batches, dim=0)
    cell_type = torch.cat(cell_type, dim=0)

    # Compute metrics
    metrics_dict = compute_metrics(
        embeddings=embeddings,
        batch_labels=batches,
        cell_type_labels=cell_type,
    )
    return embeddings, batches, cell_type, metrics_dict

In [17]:
def load_and_eval(model_class, ckpt_path, config, test_loader, device="cuda" if torch.cuda.is_available() else "cpu"):
    # Instantiate and load model
    if model_class == VAE:
        model = VAE.load_from_checkpoint(ckpt_path, config=config, linear_decoder=True)
        eval_model = model
    elif model_class == VAEDiscriminator:
        vae = VAE(config, linear_decoder=True)
        critic = Discriminator(config.latent_dim, critic=False)
        model = VAEDiscriminator.load_from_checkpoint(ckpt_path, vae=vae, critic=critic)
        eval_model = model.vae
    elif model_class == VAEWasserstein:
        vae = VAE(config, linear_decoder=True)
        critic = Discriminator(config.latent_dim, critic=True)
        model = VAEWasserstein.load_from_checkpoint(ckpt_path, vae=vae, critic=critic)
        eval_model = model.vae
    elif model_class == VAE_OT:
        model = VAE_OT.load_from_checkpoint(ckpt_path, config=config, linear_decoder=True)
        eval_model = model
    else:
        raise ValueError("Unknown model class")

    model.eval()
    model = model.to(device)
    eval_model = eval_model.to(device)
    eval_model.eval()

    # Compute embeddings
    embeddings = []
    batches = []
    cell_type = []
    with torch.no_grad():
        for batch in tqdm(test_loader, desc=f"Evaluating {os.path.basename(ckpt_path)}"):
            x, batch_label, cell_label = batch
            x, batch_label, cell_label = x.to(device), batch_label.to(device), cell_label.to(device)
            mu, logvar = eval_model.encode(x)
            z = eval_model.reparameterize(mu, logvar)
            embeddings.append(z.cpu())
            batches.append(batch_label.cpu())
            cell_type.append(cell_label.cpu())
    embeddings = torch.cat(embeddings, dim=0)
    batches = torch.cat(batches, dim=0)
    cell_type = torch.cat(cell_type, dim=0)

    # Compute metrics
    metrics_dict = compute_metrics(
        embeddings=embeddings,
        batch_labels=batches,
        cell_type_labels=cell_type,
    )
    return embeddings, batches, cell_type, metrics_dict

In [18]:
# Map checkpoint names to model classes (adjust as needed)
model_map = {
    "vae-": VAE,
    "vae_uot": VAE_OT,
    "vae_d": VAEDiscriminator,
    "vae_c": VAEWasserstein,
}

results = {}

for ckpt_path in ckpt_files:
    # Guess model type from filename
    for key, model_class in model_map.items():
        if key in os.path.basename(ckpt_path).lower():
            print(f"Evaluating checkpoint {ckpt_path} as {key}")
            embeddings, batches, cell_type, metrics_dict = load_and_eval(
                model_class, ckpt_path, config, test_loader
            )
            results[key] = {
                "embeddings": embeddings,
                "batches": batches,
                "cell_type": cell_type,
                "metrics": metrics_dict,
            }
            print(f"Metrics for {key}:")
            for k, v in metrics_dict.items():
                print(f"  {k}: {v}")
            break

Evaluating checkpoint /workspaces/wasserstein-critic-deconfounding/notebooks/checkpoints/train_vae-epoch=20043-val_loss=42.99.ckpt as vae-


Evaluating train_vae-epoch=20043-val_loss=42.99.ckpt: 100%|██████████| 43/43 [00:00<00:00, 580.13it/s]


Metrics for vae-:
  batch_entropy: 0.6931471824645996
  ilisi_batch: 2.0
  clisi_celltype: 1.2195122241973877
  silhouette_score: -0.0214417465031147
  normalized_mutual_info: 0.00041299566030934126
Evaluating checkpoint /workspaces/wasserstein-critic-deconfounding/notebooks/checkpoints/train_vae_c-epoch=50248-val_loss=46.09.ckpt as vae_c


Evaluating train_vae_c-epoch=50248-val_loss=46.09.ckpt: 100%|██████████| 43/43 [00:00<00:00, 634.06it/s]


Metrics for vae_c:
  batch_entropy: 0.6931471824645996
  ilisi_batch: 2.0
  clisi_celltype: 1.2195122241973877
  silhouette_score: -0.023939337581396103
  normalized_mutual_info: 0.00041299566030934126
Evaluating checkpoint /workspaces/wasserstein-critic-deconfounding/notebooks/checkpoints/train_vae_d-epoch=99500-val_loss=40.09.ckpt as vae_d


Evaluating train_vae_d-epoch=99500-val_loss=40.09.ckpt: 100%|██████████| 43/43 [00:00<00:00, 638.65it/s]


Metrics for vae_d:
  batch_entropy: 0.6931471824645996
  ilisi_batch: 2.0
  clisi_celltype: 1.2195122241973877
  silhouette_score: -0.0258098766207695
  normalized_mutual_info: 0.00041299566030934126
Evaluating checkpoint /workspaces/wasserstein-critic-deconfounding/notebooks/checkpoints/train_vae_uot-epoch=21285-val_loss=48.59.ckpt as vae_uot


Evaluating train_vae_uot-epoch=21285-val_loss=48.59.ckpt: 100%|██████████| 43/43 [00:00<00:00, 613.10it/s]


Metrics for vae_uot:
  batch_entropy: 0.6931471824645996
  ilisi_batch: 2.0
  clisi_celltype: 1.2195122241973877
  silhouette_score: -0.011280431412160397
  normalized_mutual_info: 0.00041299566030934126


In [None]:
# Example: UMAP visualization for each model
for key, res in results.items():
    embeddings_np = res["embeddings"].numpy()
    batches_np = res["batches"].argmax(dim=1).numpy()
    cell_type_np = res["cell_type"].argmax(dim=1).numpy()

    umap_model = umap.UMAP(n_neighbors=15, min_dist=0.1, random_state=42)
    embedding_2d = umap_model.fit_transform(embeddings_np)

    plt.figure(figsize=(8, 6))
    sns.scatterplot(x=embedding_2d[:, 0], y=embedding_2d[:, 1], hue=batches_np, palette="tab10", s=10)
    plt.title(f"UMAP by Batch ({key})")
    plt.legend(title="Batch", bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.tight_layout()
    plt.show()

    plt.figure(figsize=(8, 6))
    sns.scatterplot(x=embedding_2d[:, 0], y=embedding_2d[:, 1], hue=cell_type_np, palette="tab20", s=10)
    plt.title(f"UMAP by Cell Type ({key})")
    plt.legend(title="Cell Type", bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.tight_layout()
    plt.show()