In [3]:
import scanpy as sc
from scprint import scPrint
from scdataloader import Preprocessor
from scdataloader.utils import load_genes
import numpy as np
import anndata as ad
from huggingface_hub import hf_hub_download
import lamindb as ln

from scprint.tasks import Embedder
from scprint.tasks.cell_emb import display_confusion_matrix
import pandas as pd

from scib_metrics.benchmark import Benchmarker, BioConservation, BatchCorrection
from anndata import AnnData
from scdataloader.utils import translate
import bionty as bt
from scprint.tasks.cell_emb import compute_classification

from lightning.pytorch import Trainer
from lightning.pytorch.loggers import TensorBoardLogger
from scdataloader import SimpleAnnDataset, Collator, DataModule
from torch.utils.data import DataLoader

import lamindb as ln

%load_ext autoreload
%autoreload 2

import torch
import scipy.sparse

torch.set_float32_matmul_precision("medium")

  from pkg_resources import get_distribution, DistributionNotFound


[92m→[0m connected lamindb: jkobject/scprint_v2


  @custom_fwd
  @custom_bwd


In [None]:
# model_checkpoint_file = hf_hub_download(
#    repo_id="jkobject/scPRINT", filename=f"v2-medium.ckpt"
# )
# model_checkpoint_file = ../data/
model_checkpoint_file = "../../../1lzuxvg0.ckpt"
# w937u4o1.ckpt'
# da6ao55o.ckpt # 649
# 1lzuxvg0.ckpt # 677

In [None]:
model = scPrint.load_from_checkpoint(
    model_checkpoint_file, precpt_gene_emb=None, attention="normal"
)
model = model.to("cuda")

Gene position encoding has changed in the dataloader compared to last time, trying to revert
FYI: scPrint is not attached to a `Trainer`.


In [None]:
da = sc.read("./data/task_3_embed.h5ad")

In [None]:
map_to_val = {n: i for i, n in enumerate(set(da.obs["batch"].unique()))}
da.obs["batch"] = da.obs["batch"].map(map_to_val)

In [None]:
# Prepare data for fine-tuning (using the cat/tiger dataset from above)
# Split data into train/val
n_train = int(0.8 * len(da))
train_idx = np.random.choice(len(da), n_train, replace=False)
val_idx = np.setdiff1d(np.arange(len(da)), train_idx)

train_data = da[train_idx].copy()
val_data = da[val_idx].copy()

print(f"Training data: {train_data.shape}")
print(f"Validation data: {val_data.shape}")

mencoders = {}
for k, v in model.label_decoders.items():
    mencoders[k] = {va: ke for ke, va in v.items()}
# this needs to remain its original name as it is expect like that by collator, otherwise need to send org_to_id as params
mencoders.pop("organism_ontology_term_id")

# Create datasets
train_dataset = SimpleAnnDataset(
    train_data,
    obs_to_output=["cell_type_ontology_term_id", "batch", "organism_ontology_term_id"],
    get_knn_cells=model.expr_emb_style == "metacell",
    encoder=mencoders,
)

val_dataset = SimpleAnnDataset(
    val_data,
    obs_to_output=["cell_type_ontology_term_id", "batch", "organism_ontology_term_id"],
    get_knn_cells=model.expr_emb_style == "metacell",
    encoder=mencoders,
)

# Create collator
collator = Collator(
    organisms=model.organisms,
    valid_genes=model.genes,
    class_names=["cell_type_ontology_term_id", "batch"],
    how="random expr",  # or "all expr" for full expression
    max_len=3000,
    add_zero_genes=0,
)

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    collate_fn=collator,
    batch_size=32,  # Adjust based on GPU memory
    num_workers=4,
    shuffle=True,
    pin_memory=True,
)

val_loader = DataLoader(
    val_dataset,
    collate_fn=collator,
    batch_size=32,
    num_workers=4,
    shuffle=False,
    pin_memory=True,
)

Training data: (21760, 57186)
Validation data: (5440, 57186)


In [None]:
d_model_ct = 64  # dimension for cell type classification

batch_cls = torch.nn.Sequential(
    torch.nn.Linear(d_model_ct, d_model_ct),
    torch.nn.ReLU(),
    torch.nn.Linear(d_model_ct, len(set(da.obs["batch"].unique()))),
)
batch_cls = batch_cls.to(model.device)

# Create a learnable embedding vector of size d_model
batch_vector = torch.nn.Embedding(
    num_embeddings=2, embedding_dim=model.class_encoder.embedding.weight.shape[1]
).to(model.device)

In [None]:
for val in model.parameters():
    val.requires_grad = False
for val in model.cell_transformer.parameters():
    val.requires_grad = True
# for val in model.transformer.blocks[7].parameters():
#    val.requires_grad = True
# model.expr_decoder
for i in model.transformer.blocks:
    i.cross_attn.requires_grad = True
for val in model.compressor.parameters():
    val.requires_grad = True
for val in model.cls_decoders["cell_type_ontology_term_id"].parameters():
    val.requires_grad = True

In [None]:
# Alternative: Manual Training Loop (for more control)
# If you prefer to have more control over the training process
from tqdm import tqdm
import torch.nn.functional as F
from scprint.model import loss

num_epochs = 40
lr = 0.0002

# Setup optimizer
optimizer = torch.optim.AdamW(
    model.parameters(), lr=lr, weight_decay=0.01, betas=(0.9, 0.999), eps=1e-8
)

# Setup scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=2
)

# Setup automatic mixed precision
scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None

_ = model.train()

for k, i in model.mat_labels_hierarchy.items():
    model.mat_labels_hierarchy[k] = i.to(model.device)

  scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None


In [None]:
def batch_corr_pass(batch):
    gene_pos = batch["genes"].to(model.device)
    expression = batch["x"].to(model.device)
    depth = batch["depth"].to(model.device)
    class_elem = batch["class"].long().to(model.device)

    # Forward pass with automatic mixed precisio^n
    with torch.cuda.amp.autocast():
        # Forward pass
        output = model.forward(
            gene_pos,
            expression,
            req_depth=depth,
            depth_mult=expression.sum(1),
            do_class=True,
            metacell_token=torch.zeros_like(depth),
        )

        output["output_cell_embs"][
            :, model.classes.index("organism_ontology_term_id") + 1, :
        ] = batch_vector(class_elem[:, 1])
        output_gen = model._generate(
            cell_embs=output["output_cell_embs"],
            gene_pos=gene_pos,
            depth_mult=expression.sum(1),
            req_depth=depth,
        )
        # model.qkv # use it to fine tune on the gene interactions
        # predict something like known PPI matrices, cell specific GRNs from atac-seq data

        # model.gene_output_embeddings
        # use it to train a classifier on top to predict other modalities from gene embeddings given an additional anndata
        # could be protein expression, ATAC-seq gene activity, transcript dynamics

        # model.gen_output["expression"] # modify the loss so that the model learns to predict KO given additional gene + with learnt KO representation token
        # or expression temporal change given learnt temporal token

        # for batch correction and classification
        # Compute losses
        total_loss = 0

        if "zero_logits" in output_gen:
            loss_expr = loss.zinb(
                theta=output_gen["disp"],
                pi=output_gen["zero_logits"],
                mu=output_gen["mean"],
                target=expression,
            )
            if model.zinb_and_mse:
                loss_expr += (
                    loss.mse(
                        input=torch.log(output_gen["mean"] + 1)
                        * (1 - torch.sigmoid(output_gen["zero_logits"])),
                        target=torch.log(expression + 1),
                    )
                    / 10  # scale to make it more similar to the zinb
                )
        else:
            loss_expr = loss.mse(
                input=torch.log(output_gen["mean"] + 1),
                target=torch.log(expression + 1),
            )

        # Add expression loss to total
        total_loss += loss_expr

        # Classification loss
        cls_output = output.get("cls_output_cell_type_ontology_term_id")
        if cls_output is not None:
            cls_loss = loss.hierarchical_classification(
                pred=cls_output,
                cl=batch["class"][:, 0].to(model.device),
                labels_hierarchy=model.mat_labels_hierarchy.get(
                    "cell_type_ontology_term_id"
                ).to("cuda"),
            )
            total_loss += cls_loss

        pos = model.classes.index("cell_type_ontology_term_id") + 1
        # Apply gradient reversal to the input embedding
        selected_emb = (
            output["compressed_cell_embs"][pos]
            if model.compressor is not None
            else output["input_cell_embs"][:, pos, :]
        )
        adv_input_emb = loss.grad_reverse(selected_emb.clone(), lambd=1.0)
        # Get predictions from the adversarial decoder
        adv_pred = batch_cls(adv_input_emb)
        # do dissim

        # Compute the adversarial loss - Fix: Convert target to long type
        current_adv_loss = torch.nn.functional.cross_entropy(
            input=adv_pred,
            target=class_elem[:, 1],  # Convert to long type
        )

        # Add adversarial loss to total loss
        total_loss += current_adv_loss * 1
    return total_loss, cls_loss, current_adv_loss, loss_expr

In [None]:
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")

    # Training phase
    train_loss = 0.0
    train_steps = 0
    avg_adv = 0
    avg_expr = 0
    avg_cls = 0

    pbar = tqdm(train_loader, desc="Training")
    for batch_idx, batch in enumerate(pbar):
        # if epoch == 0:
        #    break
        # Move batch to device
        optimizer.zero_grad()
        total_loss, cls_loss, current_adv_loss, loss_expr = batch_corr_pass(batch)
        # Backward pass
        scaler.scale(total_loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()

        train_loss += total_loss.item() if not torch.isnan(total_loss) else 0
        train_steps += 1
        avg_cls += cls_loss.item() if not torch.isnan(cls_loss) else 0
        avg_expr += loss_expr.item() if not torch.isnan(loss_expr) else 0
        avg_adv += current_adv_loss.item() if not torch.isnan(current_adv_loss) else 0
        # Update progress bar
        # if batch_idx % 35 == 0:
        # print(
        #    f"avg_loss {train_loss / train_steps:.4f}, avg_cls {avg_cls / train_steps:.4f}, avg_expr {avg_expr / train_steps:.4f}, avg_adv {avg_adv / train_steps:.4f}"
        # )
        pbar.set_postfix(
            {
                "loss": f"{total_loss.item():.4f}",
                "avg_loss": f"{train_loss / train_steps:.4f}",
                "lr": f"{optimizer.param_groups[0]['lr']:.2e}",
                "cls_loss": f"{cls_loss.item():.4f}",
                "adv_loss": f"{current_adv_loss.item():.4f}",
                "expr_loss": f"{loss_expr.item():.4f}",
            }
        )

    # Validation phase
    model.eval()
    val_loss = 0.0
    val_steps = 0
    val_loss_to_prt = 0.0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            loss_val, cls_loss, current_adv_loss, loss_expr = batch_corr_pass(batch)
            val_loss_to_prt += loss_val.item() - (2 * current_adv_loss.item())
            val_loss += loss_val.item()
            val_steps += 1
    try:
        avg_val_loss = val_loss_to_prt / val_steps
        avg_train_loss = train_loss / train_steps
    except ZeroDivisionError:
        print("Error: Division by zero occurred while calculating average losses.")
        avg_train_loss = 0
    print(
        "cls_loss: {:.4f}, adv_loss: {:.4f}, expr_loss: {:.4f}".format(
            cls_loss.item(), current_adv_loss.item(), loss_expr.item()
        )
    )
    print(f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

    # Update learning rate
    scheduler.step(avg_val_loss)

    # Early stopping check (simple implementation)
    if epoch > 10 and val_loss / val_steps > 1.3 * avg_train_loss:
        print("Early stopping due to overfitting")
        break

print("Manual fine-tuning completed!")


Epoch 1/10


  with torch.cuda.amp.autocast():
Training: 100%|██████████| 680/680 [04:54<00:00,  2.31it/s, loss=2.4818, avg_loss=2.3124, lr=1.00e-04, cls_loss=0.0195, adv_loss=1.2941, expr_loss=1.1682]
Validation: 100%|██████████| 170/170 [00:36<00:00,  4.60it/s]


cls_loss: 0.0081, adv_loss: 1.2733, expr_loss: 1.1698
Train Loss: 2.3124, Val Loss: -0.2513

Epoch 2/10


Training: 100%|██████████| 680/680 [04:56<00:00,  2.29it/s, loss=2.4958, avg_loss=2.3977, lr=1.00e-04, cls_loss=0.0126, adv_loss=1.3896, expr_loss=1.0935]
Validation: 100%|██████████| 170/170 [00:36<00:00,  4.61it/s]


cls_loss: 0.0066, adv_loss: 1.3480, expr_loss: 1.1643
Train Loss: 2.3977, Val Loss: -0.3468

Epoch 3/10


Training: 100%|██████████| 680/680 [04:56<00:00,  2.29it/s, loss=2.5501, avg_loss=2.4850, lr=1.00e-04, cls_loss=0.0124, adv_loss=1.4689, expr_loss=1.0688]
Validation: 100%|██████████| 170/170 [00:36<00:00,  4.60it/s]


cls_loss: 0.0042, adv_loss: 1.4240, expr_loss: 1.1613
Train Loss: 2.4850, Val Loss: -0.4467

Epoch 4/10


Training: 100%|██████████| 680/680 [04:56<00:00,  2.29it/s, loss=2.6959, avg_loss=2.5747, lr=1.00e-04, cls_loss=0.0127, adv_loss=1.5359, expr_loss=1.1473]
Validation: 100%|██████████| 170/170 [00:36<00:00,  4.60it/s]


cls_loss: 0.0036, adv_loss: 1.5008, expr_loss: 1.1604
Train Loss: 2.5747, Val Loss: -0.5240

Epoch 5/10


Training: 100%|██████████| 680/680 [04:57<00:00,  2.29it/s, loss=2.8078, avg_loss=2.6665, lr=1.00e-04, cls_loss=0.0093, adv_loss=1.7007, expr_loss=1.0978]
Validation: 100%|██████████| 170/170 [00:36<00:00,  4.60it/s]


cls_loss: 0.0022, adv_loss: 1.5781, expr_loss: 1.1566
Train Loss: 2.6665, Val Loss: -0.6413

Epoch 6/10


Training: 100%|██████████| 680/680 [04:56<00:00,  2.29it/s, loss=2.7283, avg_loss=2.7588, lr=1.00e-04, cls_loss=0.0031, adv_loss=1.7664, expr_loss=0.9588]
Validation: 100%|██████████| 170/170 [00:36<00:00,  4.60it/s]


cls_loss: 0.0014, adv_loss: 1.6556, expr_loss: 1.1564
Train Loss: 2.7588, Val Loss: -0.7373

Epoch 7/10


Training: 100%|██████████| 680/680 [04:57<00:00,  2.29it/s, loss=2.7686, avg_loss=2.8550, lr=1.00e-04, cls_loss=0.0045, adv_loss=1.8250, expr_loss=0.9391]
Validation: 100%|██████████| 170/170 [00:36<00:00,  4.60it/s]


cls_loss: 0.0029, adv_loss: 1.7354, expr_loss: 1.1564
Train Loss: 2.8550, Val Loss: -0.8359

Epoch 8/10


Training: 100%|██████████| 680/680 [04:57<00:00,  2.29it/s, loss=3.0392, avg_loss=2.9520, lr=1.00e-04, cls_loss=0.0015, adv_loss=1.9639, expr_loss=1.0739]
Validation: 100%|██████████| 170/170 [00:36<00:00,  4.60it/s]


cls_loss: 0.0010, adv_loss: 1.8161, expr_loss: 1.1549
Train Loss: 2.9520, Val Loss: -0.9370

Epoch 9/10


Training: 100%|██████████| 680/680 [04:57<00:00,  2.29it/s, loss=3.1418, avg_loss=3.0504, lr=1.00e-04, cls_loss=0.0015, adv_loss=2.0767, expr_loss=1.0637]
Validation: 100%|██████████| 170/170 [00:37<00:00,  4.59it/s]


cls_loss: 0.0017, adv_loss: 1.8970, expr_loss: 1.1541
Train Loss: 3.0504, Val Loss: -1.0383

Epoch 10/10


Training: 100%|██████████| 680/680 [04:57<00:00,  2.29it/s, loss=3.1932, avg_loss=3.1500, lr=1.00e-04, cls_loss=0.0050, adv_loss=2.1660, expr_loss=1.0221]
Validation: 100%|██████████| 170/170 [00:36<00:00,  4.60it/s]

cls_loss: 0.0012, adv_loss: 1.9801, expr_loss: 1.1514
Train Loss: 3.1500, Val Loss: -1.1424
Manual fine-tuning completed!





In [None]:
import lightning as L

checkpoint = {
    "epoch": epoch,
    "global_step": (1 + epoch) * batch_idx,
    "pytorch-lightning_version": L.__version__,
    "state_dict": model.state_dict(),
    "optimizer_states": [optimizer.state_dict()],
    "lr_schedulers": [scheduler.state_dict()],
}

In [None]:
torch.save(model.state_dict(), "fit_2.ckpt")

In [None]:
# model.load_state_dict(torch.load("fit_2.ckpt"))

<All keys matched successfully>

In [None]:
# model = scPrint.load_from_checkpoint(
#    "fit.ckpt", precpt_gene_emb=None, attention="normal"
# )
# model = model.to("cuda")

KeyError: 'pytorch-lightning_version'

In [None]:
da.obs = da.obs.iloc[:, :-20]
for i in [
    "scprint_emb",
    "scprint_emb_age_group",
    "scprint_emb_assay_ontology_term_id",
    "scprint_emb_cell_culture",
    "scprint_emb_cell_type_ontology_term_id",
    "scprint_emb_disease_ontology_term_id",
    "scprint_emb_organism_ontology_term_id",
    "scprint_emb_other",
    "scprint_emb_self_reported_ethnicity_ontology_term_id",
    "scprint_emb_sex_ontology_term_id",
    "scprint_emb_tissue_ontology_term_id",
]:
    da.obsm.pop(i)

In [None]:
embed = Embedder(
    how="random expr",
    max_len=2600,
    num_workers=8,
    pred_embedding=["all"],
    doplot=False,
)

In [None]:
n_adata, metrics = embed(model, da.copy())

not on wandb, could not set name


100%|██████████| 425/425 [01:16<00:00,  5.54it/s]


logging the anndata
AnnData object with n_obs × n_vars = 27200 × 21550
    obs: 'pred_cell_type_ontology_term_id', 'pred_tissue_ontology_term_id', 'pred_disease_ontology_term_id', 'pred_age_group', 'pred_assay_ontology_term_id', 'pred_self_reported_ethnicity_ontology_term_id', 'pred_sex_ontology_term_id', 'pred_organism_ontology_term_id', 'pred_cell_culture', 'conv_pred_cell_type_ontology_term_id', 'conv_pred_tissue_ontology_term_id', 'conv_pred_disease_ontology_term_id', 'conv_pred_age_group', 'conv_pred_assay_ontology_term_id', 'conv_pred_self_reported_ethnicity_ontology_term_id'
    obsm: 'scprint_emb_other', 'scprint_emb_cell_type_ontology_term_id', 'scprint_emb_tissue_ontology_term_id', 'scprint_emb_disease_ontology_term_id', 'scprint_emb_age_group', 'scprint_emb_assay_ontology_term_id', 'scprint_emb_self_reported_ethnicity_ontology_term_id', 'scprint_emb_sex_ontology_term_id', 'scprint_emb_organism_ontology_term_id', 'scprint_emb_cell_culture'
    layers: 'scprint_mu', 'scprint_t