In [1]:
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint, StochasticWeightAveraging, EarlyStopping, LearningRateMonitor, LearningRateFinder

seed_everything(42, workers=True)

from scprint import scPrint
from scprint.trainer import TrainingMode
from scdataloader import DataModule 
import pandas as pd
from scdataloader.utils import load_genes
import lamindb as ln

import torch
torch.set_float32_matmul_precision('medium')

%load_ext autoreload
%autoreload 2

Seed set to 42


[92m→[0m connected lamindb: jkobject/scprint


In [2]:
# TODO: drop tissue & dev stage until part or is taken in account

hierarchical_clss = [
    "cell_type_ontology_term_id",  # 1
    #"tissue_ontology_term_id",
    "disease_ontology_term_id",  # 2
    #"simplified_dev_stage",
    "assay_ontology_term_id",  # 3
    'self_reported_ethnicity_ontology_term_id',  # 4
]
clss_to_predict = hierarchical_clss+[
    'sex_ontology_term_id',  # 5
    "organism_ontology_term_id",  # 6
    #"cell_culture"
]
clss_to_weight = clss_to_predict+[
    # "tissue_ontology_term_id",
    # "disease_ontology_term_id",
    #"simplified_dev_stage",
    # "assay_ontology_term_id",
    # "organism_ontology_term_id",
    #"clust_cell_type",
    # 'dataset_id',
    # 'cell_culture',
    #  "heat_diff",
    #  "total_counts",
   # "nnz",
    #  "dpt_group",
]

gene_emb = '../data/main/gene_embeddings.parquet'
d_model = 128

In [5]:

! lamin load jkobject/scprint

[92m→[0m connected lamindb: jkobject/scprint


In [3]:

ln.Collection.filter().df()

Unnamed: 0_level_0,uid,version,is_latest,name,description,hash,reference,reference_type,visibility,transform_id,meta_artifact_id,run_id,created_at,created_by_id
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
13,4PJVEHXDgr489o5afYxr,,True,all,all files that I could keep,fYW4GdioNfHJg10RSAd6,,,1,,,,2024-03-27 09:57:23.613992+00:00,1
14,jpNqg9b6M1R62CilaabE,,True,some,,60iAa88eD-1yKrLRTdaQ,,,1,,,,2024-03-30 11:18:42.369420+00:00,1
16,PGeDYjindWhOFQ5c0000,,True,test dataset,,pduTXjpr-nsuJ8BPE_K-fA,,,1,,,,2024-11-18 10:39:34.089779+00:00,1


In [7]:
datamodule = DataModule(
    collection_name="some", #some, all, preprocessed dataset, all no zhang, 
    gene_embeddings=gene_emb,
    clss_to_weight=clss_to_weight,
    metacell_mode=False,
    clss_to_predict=clss_to_predict,
    hierarchical_clss=hierarchical_clss,
    organisms=["NCBITaxon:9606", "NCBITaxon:10090"],
    how="most expr",
    max_len=1200,
    add_zero_genes=0,
    # how much more you will see the most present vs less present category
    weight_scaler=100,
    batch_size=10,
    num_workers=12,
    # train_oversampling=2,
    validation_split=0.05,
    do_gene_pos='../data/main/biomart_pos.parquet',
    test_split=0.05)
testfiles = datamodule.setup()

[93m![0m no run & transform got linked, call `ln.track()` & re-run
[93m![0m run input wasn't tracked, call `ln.track()` and re-run
[93m![0m run input wasn't tracked, call `ln.track()` and re-run
won't do any check but we recommend to have your dataset coming from local storage
0.0% are aligned
seeing a string: loading gene positions as biomart parquet file


In [8]:
model = scPrint(
    genes=datamodule.genes,
    d_model=d_model*2,
    nhead=2*2,
    #num_heads_kv=2,
    nlayers=8,
    layers_cls = [d_model],
    classes = datamodule.classes,
    labels_hierarchy = datamodule.labels_hierarchy,
    dropout=0,
    transformer="flash",
    precpt_gene_emb=gene_emb,
    gene_pos_enc=datamodule.gene_pos,
    mvc_decoder="inner product",
    label_decoders = datamodule.decoders,
    fused_dropout_add_ln=False,
    num_batch_labels = datamodule.num_datasets,
    checkpointing=False,
    prenorm=True,
    #weight_decay=0.01,
    #zinb=False
)

In [None]:
# from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.loggers import WandbLogger

wandb_logger = WandbLogger(project="scprint_test",
                           save_dir="../data/tensorboard")
wandb_logger.watch(model, log='all', log_freq=50, log_graph=True)

# tlogger = TensorBoardLogger(save_dir="../data/tensorboard")
# tlogger.log_graph(model)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mjkobject[0m ([33mml4ig[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


In [9]:
chckp = ModelCheckpoint(monitor="val_loss", save_top_k=-1)
trainingmode = TrainingMode(
    do_denoise=True,
    noise=[0.7],
    do_cce=False,
    cce_sim=0.6,
    do_ecs=False,
    ecs_threshold=0.4,
    ecs_scale=0.05,
    class_scale=0.08,
    do_cls=False,
    do_mvc=False,
    do_adv_cls=False,
    do_next_tp=False,
    mask_ratio=[],
    warmup_duration=100,
    fused_adam=True,
    lr_reduce_patience=200,
)
trainer = Trainer(precision="16-mixed", gradient_clip_val=500, max_time={"hours": 2}, limit_val_batches=1, callbacks=[
                  trainingmode], accumulate_grad_batches=1, check_val_every_n_epoch=1, reload_dataloaders_every_n_epochs=1000000, 
                  #logger=wandb_logger
                  )

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.


In [10]:
trainer.fit(model, datamodule=datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                            | Type                         | Params | Mode 
-----------------------------------------------------------------------------------------
0 | gene_encoder                    | GeneEncoder                  | 11.5 M | train
1 | expr_encoder                    | ContinuousValueEncoder       | 66.8 K | train
2 | pos_encoder                     | PositionalEncoding           | 0      | train
3 | class_encoder                   | CategoryValueEncoder         | 1.8 K  | train
4 | depth_encoder                   | ContinuousValueEncoder       | 66.8 K | train
5 | transformer                     | FlashTransformer             | 6.3 M  | train
6 | expr_decoder                    | ExprDecoder                  | 133 K  | train
7 | cls_decoders                    | ModuleDict                   | 234 K  | train
8 | grad_reverse_discriminator_loss | AdversarialDiscriminatorLoss | 155 K  | train
9 | mvc_decoder            

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

IndexError: Caught IndexError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "/home/ml4ig1/Documents code/scDataLoader/scdataloader/collator.py", line 144, in __call__
    expr = expr[self.accepted_genes[organism_id]]
IndexError: boolean index did not match indexed array along dimension 0; dimension is 70263 but corresponding boolean dimension is 70786
