In [1]:
import scanpy as sc
from scprint2 import scPRINT2
from scdataloader import Preprocessor
from scdataloader.utils import load_genes
import numpy as np
import anndata as ad
from huggingface_hub import hf_hub_download
import lamindb as ln

from scprint2.tasks import Embedder
from scprint2.tasks.cell_emb import display_confusion_matrix
import pandas as pd

from scib_metrics.benchmark import Benchmarker, BioConservation, BatchCorrection
from anndata import AnnData
from scdataloader.utils import translate
import bionty as bt
from scprint2.tasks.cell_emb import compute_classification

from lightning.pytorch import Trainer
from lightning.pytorch.loggers import TensorBoardLogger
from scdataloader import SimpleAnnDataset, Collator, DataModule
from torch.utils.data import DataLoader

import lamindb as ln

%load_ext autoreload
%autoreload 2

import torch
import scipy.sparse

torch.set_float32_matmul_precision("medium")

[92m→[0m connected lamindb: jkobject/scprint2


  @custom_fwd
  @custom_bwd


In [2]:
LOC = "/pasteur/appa/scratch/jkalfon/data/spcrint_data/temp/" #"../../data/temp/"  #

In [3]:
adata = sc.read(
    LOC + "glio_smart_cort_area" + ".h5ad",
    backup_url="https://datasets.cellxgene.cziscience.com/a1d40c84-c81c-406f-bef4-e25edeb651e5.h5ad",
)

In [4]:
preprocessor = Preprocessor(
    force_preprocess=True,
    skip_validate=True,
    # drop_non_primary=False,
    do_postp=False,
)
print("")
print(adata.X.sum(1).mean())
adata = preprocessor(adata)


1353725.5
Dropping layers:  KeysView(Layers with keys: exon, intron)
checking raw counts
removed 0 non primary cells, 49417 renamining
filtered out 0 cells, 49417 renamining
Removed 0 genes not known to the ontology
Removed 0 duplicate genes
Added 34323 genes in the ontology but not present in the dataset
starting QC
Seeing 10180 outliers (20.60% of total dataset):
done
AnnData object with n_obs × n_vars = 49417 × 70116
    obs: 'suspension_type', 'cluster', 'class', 'subclass', 'sex_ontology_term_id', 'region', 'cortical_layer', 'cell_type_accession', 'cell_type_alias', 'cell_type_alt_alias', 'cell_type_designation', 'donor_id', 'outlier_call', 'outlier_type', 'tissue_ontology_term_id', 'disease_ontology_term_id', 'assay_ontology_term_id', 'is_primary_data', 'cell_type_ontology_term_id', 'Specimen ID', 'sample_tissue_type', 'development_stage_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'cause_of_death', 'PMI (hr)', 'Tissue (RIN)', 'Hemisphere Sampled', 'tissue_type

In [5]:
# model_checkpoint_file = hf_hub_download(
#    repo_id="jkobject/scPRINT", filename=f"v2-medium.ckpt"
# )
# model_checkpoint_file = ../data/
model_checkpoint_file = "../models/ji9krimq.ckpt"
# w937u4o1
# ji9krimq

In [28]:
model = scPRINT2.load_from_checkpoint(
    model_checkpoint_file, precpt_gene_emb=None, gene_pos_file=None, max_cont_len=None,
)
model = model.to("cuda")

FYI: scPrint is not attached to a `Trainer`.


In [7]:
adata.obs.cell_type.value_counts()

cell_type
glutamatergic neuron                                              21841
L2/3-6 intratelencephalic projecting glutamatergic neuron          4778
VIP GABAergic cortical interneuron                                 3533
pvalb GABAergic cortical interneuron                               2800
L6 corticothalamic-projecting glutamatergic cortical neuron        2556
lamp5 GABAergic cortical interneuron                               2434
sst GABAergic cortical interneuron                                 2358
unknown                                                            1985
oligodendrocyte                                                    1930
astrocyte                                                          1187
L6b glutamatergic cortical neuron                                  1080
near-projecting glutamatergic cortical neuron                       816
oligodendrocyte precursor cell                                      773
microglial cell                                       

In [29]:
# Prepare data for fine-tuning (using the cat/tiger dataset from above)
# Split data into train/val

train_ct = ["oligodendrocyte",
            "microglial cell",
            "pericyte",
            "lamp5 GABAergic cortical interneuron",
            "L6b glutamatergic cortical neuron",
            "astrocyte",
            "VIP GABAergic cortical interneuron",
            "glutamatergic neuron",
            "pvalb GABAergic cortical interneuron"]
train_data = adata[adata.obs.cell_type.isin(train_ct)].copy()
val_data = adata[~adata.obs.cell_type.isin(train_ct)].copy()

print(f"Training data: {train_data.shape}")
print(f"Validation data: {val_data.shape}")

mencoders = {}
for k, v in model.label_decoders.items():
    mencoders[k] = {va: ke for ke, va in v.items()}
# this needs to remain its original name as it is expect like that by collator, otherwise need to send org_to_id as params

# Create datasets
train_dataset = SimpleAnnDataset(
    train_data,
    obs_to_output=["cell_type_ontology_term_id", "organism_ontology_term_id"],
    get_knn_cells=model.expr_emb_style == "metacell",
    encoder=mencoders,
)

# Create collator
collator = Collator(
    organisms=model.organisms,
    valid_genes=model.genes,
    class_names=["cell_type_ontology_term_id", 'organism_ontology_term_id'],
    how="random expr",  # or "all expr" for full expression
    max_len=6000,
    add_zero_genes=0,
    org_to_id=mencoders["organism_ontology_term_id"],
)

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    collate_fn=collator,
    batch_size=4,  # Adjust based on GPU memory
    num_workers=4,
    shuffle=True,
    pin_memory=True,
)

Training data: (35587, 70116)
Validation data: (13830, 70116)


  organismdf = pd.concat(organismdf)


In [30]:
for val in model.transformer.blocks:
    val.mixer.Wqkv.requires_grad = True
for val in model.cell_transformer.blocks:
    for i in val.cross_attn.parameters():
        i.requires_grad = True
#for val in model.gene_encoder[3].parameters():
#    val.requires_grad = True
#for val in model.metacell_encoder.parameters():
#    val.requires_grad = True

In [31]:
def classif(batch, do_cls=False):
    gene_pos = batch["genes"].to(model.device)
    expression = batch["x"].to(model.device)
    depth = batch["depth"].to(model.device)
    class_elem = batch["class"].long().to(model.device)
    total_loss = 0

    # Forward pass with automatic mixed precisio^n
    with torch.cuda.amp.autocast():
        # Forward pass
        output = model.forward(
            gene_pos,
            expression,
            req_depth=depth,
            depth_mult=expression.sum(1),
            do_class=True,
            metacell_token=torch.zeros_like(depth),
        )
        ## generate expr loss
        output_gen = model._generate(
            cell_embs=output["output_cell_embs"],
            gene_pos=gene_pos,
            depth_mult=expression.sum(1),
            req_depth=depth,
        )
        if "zero_logits" in output_gen:
            loss_expr = loss.zinb(
                theta=output_gen["disp"],
                pi=output_gen["zero_logits"],
                mu=output_gen["mean"],
                target=expression,
            )
            if model.zinb_and_mse:
                loss_expr += (
                    loss.mse(
                        input=torch.log(output_gen["mean"] + 1)
                        * (1 - torch.sigmoid(output_gen["zero_logits"])),
                        target=torch.log(expression + 1),
                    )
                    / 10  # scale to make it more similar to the zinb
                )
        else:
            loss_expr = loss.mse(
                input=torch.log(output_gen["mean"] + 1),
                target=torch.log(expression + 1),
            )
        # Add expression loss to total
        total_loss += loss_expr

        # ct clss
        if do_cls:
            for i, k in enumerate(["cell_type_ontology_term_id", "organism_ontology_term_id"]):
                cls_loss = loss.hierarchical_classification(
                    pred=output.get("cls_output_" + k),
                    cl=class_elem[:, i],
                    labels_hierarchy=model.mat_labels_hierarchy.get(
                        k
                ).to("cuda") if k in model.mat_labels_hierarchy else None,
            )
            total_loss += cls_loss
            return total_loss, loss_expr, cls_loss
        # total_loss += output["vae_kl_loss"] * 0.001
    return total_loss, loss_expr

In [32]:
# Alternative: Manual Training Loop (for more control)
# If you prefer to have more control over the training process
from tqdm import tqdm
import torch.nn.functional as F
from scprint2.model import loss

num_epochs = 1
lr = 0.00002

# Setup optimizer
all_params = (
    list(model.parameters())
)
optimizer = torch.optim.AdamW(
    all_params, lr=lr, weight_decay=0.01, betas=(0.9, 0.999), eps=1e-8
)

# Setup automatic mixed precision
scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None

for k, i in model.mat_labels_hierarchy.items():
    model.mat_labels_hierarchy[k] = i.to(model.device)

  scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None


In [None]:
metrics = {}
for n in [400, 2000, 6000]:
    ad = train_data[(train_data.X > 0).sum(1) >= n]
    # Get 5,000 random cells
    n_cells = min(5000, ad.n_obs)  # In case you have fewer than 5,000 cells
    random_indices = np.random.choice(ad.n_obs, n_cells, replace=False)
    adata_subset = ad[random_indices, :].copy()
    embed = Embedder(
        how="random expr",
        max_len=n,
        num_workers=8,
        pred_embedding=["cell_type_ontology_term_id"],
        doplot=False,
    )
    _, metrics[n] = embed(model, adata_subset.copy())
    
metrics

  organismdf = pd.concat(organismdf)


not on wandb, could not set name


  0%|          | 0/79 [00:00<?, ?it/s]

> [32m<string>[39m([92m29[39m)[36m_encoder[39m[34m()[39m



In [None]:
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")
    print(f"Current learning rate: {optimizer.param_groups[0]['lr']:.2e}")
    ## Training phase
    train_loss = 0.0
    train_steps = 0
    avg_cls = 0
    avg_expr = 0
    model.train()

    pbar = tqdm(train_loader, desc="Training")
    for batch_idx, batch in enumerate(pbar):
        optimizer.zero_grad()
        total_loss, loss_expr = classif(batch, do_cls=False)
        # Backward pass
        scaler.scale(total_loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()

        train_loss += total_loss.item()
        train_steps += 1
        avg_expr += loss_expr.item()
        #avg_cls += loss_cls.item()
        # print(loss_cls.item(), loss_expr.item())
        if batch_idx > 2400:
            break
    # Validation phase
    model.eval()

metrics = {}
for n in [400, 2000, 6000]:
    ad = train_data[(train_data.X > 0).sum(1) >= n]
    # Get 5,000 random cells
    n_cells = min(5000, ad.n_obs)  # In case you have fewer than 5,000 cells
    random_indices = np.random.choice(ad.n_obs, n_cells, replace=False)
    adata_subset = ad[random_indices, :].copy()
    embed = Embedder(
        how="random expr",
        max_len=n,
        num_workers=8,
        pred_embedding=["cell_type_ontology_term_id"],
        doplot=False,
    )
    _, metrics[n] = embed(model, adata_subset.copy())
    
metrics


Epoch 1/1
Current learning rate: 2.00e-05


  with torch.cuda.amp.autocast():
Training:  27%|██▋       | 2401/8897 [25:29<1:08:57,  1.57it/s]
  organismdf = pd.concat(organismdf)


not on wandb, could not set name


100%|██████████| 79/79 [00:11<00:00,  6.60it/s]


logging the anndata
AnnData object with n_obs × n_vars = 5000 × 17917
    obs: 'pred_cell_type_ontology_term_id', 'pred_tissue_ontology_term_id', 'pred_disease_ontology_term_id', 'pred_age_group', 'pred_assay_ontology_term_id', 'pred_self_reported_ethnicity_ontology_term_id', 'pred_sex_ontology_term_id', 'pred_organism_ontology_term_id', 'pred_cell_culture', 'conv_pred_cell_type_ontology_term_id', 'conv_pred_tissue_ontology_term_id', 'conv_pred_disease_ontology_term_id', 'conv_pred_age_group', 'conv_pred_assay_ontology_term_id', 'conv_pred_self_reported_ethnicity_ontology_term_id'
    obsm: 'scprint_emb_cell_type_ontology_term_id'
    layers: 'scprint_mu', 'scprint_theta', 'scprint_pi'
too few cells to embed into a umap
too few cells to compute a clustering


  organismdf = pd.concat(organismdf)


not on wandb, could not set name


100%|██████████| 79/79 [00:37<00:00,  2.08it/s]


logging the anndata
AnnData object with n_obs × n_vars = 5000 × 18604
    obs: 'pred_cell_type_ontology_term_id', 'pred_tissue_ontology_term_id', 'pred_disease_ontology_term_id', 'pred_age_group', 'pred_assay_ontology_term_id', 'pred_self_reported_ethnicity_ontology_term_id', 'pred_sex_ontology_term_id', 'pred_organism_ontology_term_id', 'pred_cell_culture', 'conv_pred_cell_type_ontology_term_id', 'conv_pred_tissue_ontology_term_id', 'conv_pred_disease_ontology_term_id', 'conv_pred_age_group', 'conv_pred_assay_ontology_term_id', 'conv_pred_self_reported_ethnicity_ontology_term_id'
    obsm: 'scprint_emb_cell_type_ontology_term_id'
    layers: 'scprint_mu', 'scprint_theta', 'scprint_pi'
too few cells to embed into a umap
too few cells to compute a clustering


  organismdf = pd.concat(organismdf)


not on wandb, could not set name


100%|██████████| 79/79 [02:23<00:00,  1.81s/it]


logging the anndata
AnnData object with n_obs × n_vars = 5000 × 20004
    obs: 'pred_cell_type_ontology_term_id', 'pred_tissue_ontology_term_id', 'pred_disease_ontology_term_id', 'pred_age_group', 'pred_assay_ontology_term_id', 'pred_self_reported_ethnicity_ontology_term_id', 'pred_sex_ontology_term_id', 'pred_organism_ontology_term_id', 'pred_cell_culture', 'conv_pred_cell_type_ontology_term_id', 'conv_pred_tissue_ontology_term_id', 'conv_pred_disease_ontology_term_id', 'conv_pred_age_group', 'conv_pred_assay_ontology_term_id', 'conv_pred_self_reported_ethnicity_ontology_term_id'
    obsm: 'scprint_emb_cell_type_ontology_term_id'
    layers: 'scprint_mu', 'scprint_theta', 'scprint_pi'
too few cells to embed into a umap
too few cells to compute a clustering


{400: {'cell_type_ontology_term_id_accuracy': 0.36425075739920765,
  'tissue_ontology_term_id_accuracy': 0.0604,
  'disease_ontology_term_id_accuracy': 0.933,
  'assay_ontology_term_id_accuracy': 0.037,
  'self_reported_ethnicity_ontology_term_id_accuracy': 0.6556,
  'sex_ontology_term_id_accuracy': 0.5738,
  'organism_ontology_term_id_accuracy': 1.0},
 2000: {'cell_type_ontology_term_id_accuracy': 0.39084753677328976,
  'tissue_ontology_term_id_accuracy': 0.065,
  'disease_ontology_term_id_accuracy': 0.9432,
  'assay_ontology_term_id_accuracy': 0.04,
  'self_reported_ethnicity_ontology_term_id_accuracy': 0.6598,
  'sex_ontology_term_id_accuracy': 0.825,
  'organism_ontology_term_id_accuracy': 1.0},
 6000: {'cell_type_ontology_term_id_accuracy': 0.4239738805970149,
  'tissue_ontology_term_id_accuracy': 0.0696,
  'disease_ontology_term_id_accuracy': 0.9924,
  'assay_ontology_term_id_accuracy': 0.0038,
  'self_reported_ethnicity_ontology_term_id_accuracy': 0.6494,
  'sex_ontology_term_id

In [13]:
metrics = {}
for n in [400, 2000, 8000]:
    ad = adata[(adata.X > 0).sum(1) >= n]
    # Get 5,000 random cells
    n_cells = min(5000, ad.n_obs)  # In case you have fewer than 5,000 cells
    random_indices = np.random.choice(ad.n_obs, n_cells, replace=False)
    adata_subset = ad[random_indices, :].copy()
    embed = Embedder(
        how="random expr",
        max_len=n,
        num_workers=8,
        pred_embedding=["cell_type_ontology_term_id"],
        doplot=False,
    )
    _, metrics[n] = embed(model, adata_subset.copy())
    
metrics

  organismdf = pd.concat(organismdf)


not on wandb, could not set name


100%|██████████| 79/79 [00:08<00:00,  9.19it/s]


logging the anndata
AnnData object with n_obs × n_vars = 5000 × 17889
    obs: 'pred_cell_type_ontology_term_id', 'pred_tissue_ontology_term_id', 'pred_disease_ontology_term_id', 'pred_age_group', 'pred_assay_ontology_term_id', 'pred_self_reported_ethnicity_ontology_term_id', 'pred_sex_ontology_term_id', 'pred_organism_ontology_term_id', 'pred_cell_culture', 'conv_pred_cell_type_ontology_term_id', 'conv_pred_tissue_ontology_term_id', 'conv_pred_disease_ontology_term_id', 'conv_pred_age_group', 'conv_pred_assay_ontology_term_id', 'conv_pred_self_reported_ethnicity_ontology_term_id'
    obsm: 'scprint_emb_cell_type_ontology_term_id'
    layers: 'scprint_mu', 'scprint_theta', 'scprint_pi'
too few cells to embed into a umap
too few cells to compute a clustering


  organismdf = pd.concat(organismdf)


not on wandb, could not set name


100%|██████████| 79/79 [00:26<00:00,  2.99it/s]


logging the anndata
AnnData object with n_obs × n_vars = 5000 × 18581
    obs: 'pred_cell_type_ontology_term_id', 'pred_tissue_ontology_term_id', 'pred_disease_ontology_term_id', 'pred_age_group', 'pred_assay_ontology_term_id', 'pred_self_reported_ethnicity_ontology_term_id', 'pred_sex_ontology_term_id', 'pred_organism_ontology_term_id', 'pred_cell_culture', 'conv_pred_cell_type_ontology_term_id', 'conv_pred_tissue_ontology_term_id', 'conv_pred_disease_ontology_term_id', 'conv_pred_age_group', 'conv_pred_assay_ontology_term_id', 'conv_pred_self_reported_ethnicity_ontology_term_id'
    obsm: 'scprint_emb_cell_type_ontology_term_id'
    layers: 'scprint_mu', 'scprint_theta', 'scprint_pi'
too few cells to embed into a umap
too few cells to compute a clustering


  organismdf = pd.concat(organismdf)


not on wandb, could not set name


100%|██████████| 79/79 [02:09<00:00,  1.64s/it]


logging the anndata
AnnData object with n_obs × n_vars = 5000 × 20004
    obs: 'pred_cell_type_ontology_term_id', 'pred_tissue_ontology_term_id', 'pred_disease_ontology_term_id', 'pred_age_group', 'pred_assay_ontology_term_id', 'pred_self_reported_ethnicity_ontology_term_id', 'pred_sex_ontology_term_id', 'pred_organism_ontology_term_id', 'pred_cell_culture', 'conv_pred_cell_type_ontology_term_id', 'conv_pred_tissue_ontology_term_id', 'conv_pred_disease_ontology_term_id', 'conv_pred_age_group', 'conv_pred_assay_ontology_term_id', 'conv_pred_self_reported_ethnicity_ontology_term_id'
    obsm: 'scprint_emb_cell_type_ontology_term_id'
    layers: 'scprint_mu', 'scprint_theta', 'scprint_pi'
too few cells to embed into a umap
too few cells to compute a clustering


{400: {'cell_type_ontology_term_id_accuracy': 0.6121981681931724,
  'tissue_ontology_term_id_accuracy': 0.2196,
  'disease_ontology_term_id_accuracy': 0.991,
  'assay_ontology_term_id_accuracy': 0.0278,
  'self_reported_ethnicity_ontology_term_id_accuracy': 0.6266,
  'sex_ontology_term_id_accuracy': 0.585,
  'organism_ontology_term_id_accuracy': 1.0},
 2000: {'cell_type_ontology_term_id_accuracy': 0.650836820083682,
  'tissue_ontology_term_id_accuracy': 0.23,
  'disease_ontology_term_id_accuracy': 0.9982,
  'assay_ontology_term_id_accuracy': 0.03,
  'self_reported_ethnicity_ontology_term_id_accuracy': 0.626,
  'sex_ontology_term_id_accuracy': 0.7358,
  'organism_ontology_term_id_accuracy': 1.0},
 8000: {'cell_type_ontology_term_id_accuracy': 0.7434017595307918,
  'tissue_ontology_term_id_accuracy': 0.4376,
  'disease_ontology_term_id_accuracy': 1.0,
  'assay_ontology_term_id_accuracy': 0.0002,
  'self_reported_ethnicity_ontology_term_id_accuracy': 0.5972,
  'sex_ontology_term_id_accura

 8000: {'cell_type_ontology_term_id_accuracy': 0.7407175636211931,
  'tissue_ontology_term_id_accuracy': 0.1876,
  'disease_ontology_term_id_accuracy': 0.9948,
  'assay_ontology_term_id_accuracy': 0.897, #0.9274
  'self_reported_ethnicity_ontology_term_id_accuracy': 0.599, #0.6376
  'sex_ontology_term_id_accuracy': 0.9828,
  'organism_ontology_term_id_accuracy': 1.0}