In [1]:
! lamin load jkobject/scprint2

[92m→[0m connected lamindb: jkobject/scprint2


In [2]:
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint, StochasticWeightAveraging, EarlyStopping, LearningRateMonitor, LearningRateFinder

seed_everything(42, workers=True)


from scprint import scPrint
from scprint.trainer import TrainingMode
from scdataloader import DataModule 
import pandas as pd
from scdataloader.utils import load_genes
import lamindb as ln

import torch
torch.set_float32_matmul_precision('medium')

%load_ext autoreload
%autoreload 2

Seed set to 42


In [3]:
# TODO: drop tissue & dev stage until part or is taken in account

hierarchical_clss = [
    "cell_type_ontology_term_id",  # 1
    "tissue_ontology_term_id",
    "disease_ontology_term_id",  # 2
    "simplified_dev_stage",
    "assay_ontology_term_id",  # 3
    'self_reported_ethnicity_ontology_term_id',  # 4
]
clss_to_predict = hierarchical_clss+[
    'sex_ontology_term_id',  # 5
    "organism_ontology_term_id",  # 6
    "cell_culture"
]
clss_to_weight = [
    "tissue_ontology_term_id",
    "disease_ontology_term_id",
    "simplified_dev_stage",
    "assay_ontology_term_id",
    "organism_ontology_term_id",
    "clust_cell_type",
    # 'dataset_id',
    # 'cell_culture',
    #  "heat_diff",
    #  "total_counts",
    "nnz",
    #  "dpt_group",
]

gene_emb = '../data/main/gene_embeddings.parquet'
d_model = 128

In [19]:
datamodule = DataModule(
    collection_name="scPRINT-V2 test", #some, all, preprocessed dataset, all no zhang, 
    gene_embeddings=gene_emb,
    clss_to_weight=clss_to_weight,
    metacell_mode=0.2,
    clss_to_predict=clss_to_predict,
    hierarchical_clss=hierarchical_clss,
    organisms=["NCBITaxon:9606"],#, "NCBITaxon:10090"],
    how="random expr",
    max_len=3200,
    add_zero_genes=0,
    # how much more you will see the most present vs less present category
    weight_scaler=100,
    batch_size=8,
    num_workers=20,
    prefetch_factor=3,
    train_oversampling_per_epoch=2,
    validation_split=0.05,
    do_gene_pos='../data/main/biomart_pos.parquet',
    pin_memory=True,
    test_split=0.05)
testfiles = datamodule.setup()

[93m![0m no run & transform got linked, call `ln.track()` & re-run
[93m![0m run input wasn't tracked, call `ln.track()` and re-run
[93m![0m run input wasn't tracked, call `ln.track()` and re-run
won't do any check but we recommend to have your dataset coming from local storage
100.0% are aligned
seeing a string: loading gene positions as biomart parquet file


In [17]:
model = scPrint(
    genes=datamodule.genes,
    d_model=d_model*2,
    nhead=2*2,
    num_heads_kv=2,
    nlayers=8,
    layers_cls = [d_model],
    classes = datamodule.classes,
    labels_hierarchy = datamodule.labels_hierarchy,
    dropout=0.1,
    transformer="flash",
    precpt_gene_emb=gene_emb,
    gene_pos_enc=datamodule.gene_pos,
    mvc_decoder="inner product",
    label_decoders = datamodule.decoders,
    num_batch_labels = datamodule.num_datasets,
    checkpointing=True,
    prenorm=True,
    cell_specific_blocks=True,
    #weight_decay=0.01,
    #zinb=False
)

10


In [None]:
# from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.loggers import WandbLogger

wandb_logger = WandbLogger(project="scprint_test",
                           save_dir="../data/tensorboard")
wandb_logger.watch(model, log='all', log_freq=50, log_graph=True)

# tlogger = TensorBoardLogger(save_dir="../data/tensorboard")
# tlogger.log_graph(model)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mjkobject[0m ([33mml4ig[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


In [16]:
chckp = ModelCheckpoint(monitor="val_loss", save_top_k=-1)
trainingmode = TrainingMode(
    do_denoise=True,
    noise=[0.6],
    do_mvc=False,
    do_adv_cls=False,
    run_full_forward=True,
    mask_ratio=["TF"],
    warmup_duration=100,
    fused_adam=False,
    lr_reduce_patience=200,
)
trainer = Trainer(precision="16-mixed", gradient_clip_val=30, gradient_clip_algorithm="norm", max_time={"hours": 2}, limit_val_batches=1, callbacks=[
                  trainingmode], accumulate_grad_batches=1, check_val_every_n_epoch=1, reload_dataloaders_every_n_epochs=1000000, 
                  #logger=wandb_logger
                  )

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.


In [12]:
import gc
import torch
gc.collect()
torch.cuda.empty_cache()


In [18]:
trainer.fit(model, datamodule=datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                            | Type                         | Params | Mode 
------------------------------------------------------------------------------------------
0  | gene_encoder                    | GeneEncoder                  | 5.9 M  | train
1  | expr_encoder                    | ContinuousValueEncoder       | 66.8 K | train
2  | pos_encoder                     | PositionalEncoding           | 0      | train
3  | class_encoder                   | CategoryValueEncoder         | 2.6 K  | train
4  | depth_encoder                   | ContinuousValueEncoder       | 66.8 K | train
5  | transformer                     | FlashTransformer             | 11.6 M | train
6  | cell_transformer                | FlashTransformer             | 8.7 M  | train
7  | expr_decoder                    | ExprDecoder                  | 133 K  | train
8  | cls_decoders                    | ModuleDict                   | 315 K  | train
9  | grad_revers

on_fit_start


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

new batch
full forward
encoder
transformer
cell transformer
decoder
masking
encoder
transformer
cell transformer
decoder
compute loss
denoising
encoder
transformer
cell transformer
decoder
compute loss
generate
encoder
decoder
backward
encoder
transformer
cell transformer
decoder


Training: |          | 0/? [00:00<?, ?it/s]

new batch
full forward
encoder
transformer
cell transformer
decoder
masking
encoder
transformer
cell transformer
decoder
compute loss
denoising
encoder
transformer
cell transformer
decoder
compute loss
generate
encoder
decoder
backward
backward end
new batch
full forward
encoder
transformer
cell transformer
decoder
masking
encoder
transformer
cell transformer
decoder
compute loss
denoising
encoder
transformer
cell transformer
decoder
compute loss
generate
encoder
decoder
backward
backward end
new batch
full forward
encoder
transformer
cell transformer
decoder
masking
encoder
transformer
cell transformer
decoder
compute loss
denoising
encoder
transformer
cell transformer
decoder
compute loss
generate
encoder
decoder
backward
backward end
new batch
full forward
encoder
transformer
cell transformer
decoder
masking
encoder
transformer
cell transformer
decoder
compute loss
denoising
encoder
transformer
cell transformer
decoder
compute loss
generate
encoder
decoder
backward
backward end
new 


Detected KeyboardInterrupt, attempting graceful shutdown ...


backward end
new batch
full forward
encoder
transformer
cell transformer
decoder


NameError: name 'exit' is not defined

In [14]:
trainingmode = TrainingMode(
    do_denoise=True,
    noise=[0.6],
    do_cce=True,
    cce_sim=0.6,
    do_ecs=False,
    ecs_threshold=0.4,
    ecs_scale=0.05,
    class_scale=0.08,
    do_cls=True,
    do_mvc=False,
    do_adv_cls=False,
    run_full_forward=True,
    do_next_tp=False,
    mask_ratio=["TF"],
    fused_adam=False,
    lr_reduce_monitor=None,
    
)
overfit_trainer = Trainer(precision="16-mixed", gradient_clip_val=10, max_time={"hours": 2}, limit_val_batches=1, callbacks=[
                  trainingmode], accumulate_grad_batches=1, check_val_every_n_epoch=10_000, overfit_batches=1, 
                  reload_dataloaders_every_n_epochs=1_000_000, num_sanity_val_steps=2)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.


In [7]:
overfit_trainer.fit(model, datamodule=datamodule)

NameError: name 'overfit_trainer' is not defined