In [None]:
! lamin load scprint

In [1]:
import lamindb as ln
import lnschema_bionty as lb

import pandas as pd
import scanpy as sc 

from lightning.pytorch import Trainer, seed_everything
seed_everything(42, workers=True)

from scprint import scPrint
from scprint.utils import getBiomartTable

from scdataloader import Dataset
from scdataloader import DataModule
from scprint.dataloader import embed
from scdataloader.utils import load_genes
from scprint.dataloader import Collator

import torch 
torch.set_float32_matmul_precision('medium')

lb.settings.organism = "human"

%load_ext autoreload
%autoreload 2

💡 lamindb instance: jkobject/scprint


INFO: Seed set to 42
2024-01-23 14:01:57,152:INFO - Seed set to 42
2024-01-23 14:01:57,152:INFO - Seed set to 42


In [2]:
## Gene embeddings
# embeddings = embed(genedf=genedf,
#     organism="homo_sapiens",
#     cache=True,
#     fasta_path="/tmp/data/fasta/",
#     embedding_size=1024,)
# embeddings.to_parquet('../data/temp/embeddings.parquet')
embeddings = pd.read_parquet('../data/temp/embeddings.parquet')
embeddings.columns = ['emb_'+str(i) for i in embeddings.columns]
# and annotations
biomart = getBiomartTable(attributes=['start_position', 'chromosome_name']).set_index('ensembl_gene_id')
biomart = biomart.loc[~biomart.index.duplicated(keep='first')]
biomart = biomart.sort_values(by=['chromosome_name', 'start_position'])
# and location
c = []
i = 0
prev_position = -100000
prev_chromosome = None
for _, r in biomart.iterrows():
    if r['chromosome_name'] != prev_chromosome or r['start_position'] - prev_position > 10_000:
        i += 1
    c.append(i)
    prev_position = r['start_position']
    prev_chromosome = r['chromosome_name']
print(f'reduced the size to {len(set(c))/len(biomart)}')
biomart['pos'] = c

downloading gene names from biomart
['ensembl_gene_id', 'hgnc_symbol', 'gene_biotype', 'entrezgene_id', 'start_position', 'chromosome_name']

['ensembl_gene_id', 'hgnc_symbol', 'gene_biotype', 'entrezgene_id', 'start_position', 'chromosome_name']
reduced the size to 0.6722574020195106


In [3]:
# OR directly load the dataset
name="preprocessed dataset"
dataset = ln.Collection.filter(name=name).first()
dataset.artifacts.count()

[autoreload of scprint.model.model failed: Traceback (most recent call last):
  File "/home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/IPython/extensions/autoreload.py", line 276, in check
    superreload(m, reload, self.old_objects)
  File "/home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/IPython/extensions/autoreload.py", line 475, in superreload
    module = reload(module)
  File "/home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/importlib/__init__.py", line 169, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 619, in _exec
  File "<frozen importlib._bootstrap_external>", line 879, in exec_module
  File "<frozen importlib._bootstrap_external>", line 1017, in get_code
  File "<frozen importlib._bootstrap_external>", line 947, in source_to_code
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/home/ml4ig1/Documents code/scPRINT/scprint/model/model.py", line 593
    drop 

91

In [4]:
# TODO: drop tissue & dev stage until part or is taken in account

hierarchical_labels = [
    "cell_type_ontology_term_id",
    #"tissue_ontology_term_id",
    "disease_ontology_term_id",
    #"development_stage_ontology_term_id",
    "assay_ontology_term_id",
    'self_reported_ethnicity_ontology_term_id',
]

labels_weighted_sampling = hierarchical_labels+[
    'sex_ontology_term_id',
    "organism_ontology_term_id",
]

all_labels = labels_weighted_sampling+[
    #'dataset_id',
    #'cell_culture',
    "heat_diff",
    "total_counts",
    "nnz",
    "dpt_group",
]

In [5]:
mdataset = Dataset(dataset, organisms=["NCBITaxon:9606"], obs=all_labels, clss_to_pred=labels_weighted_sampling, hierarchical_clss=hierarchical_labels, )
print(mdataset)

won't do any check but we recommend to have your dataset coming from local storage

82.41758241758242% are aligned
total dataset size is 106.584138411 Gb
---
dataset contains:
     5567614 cells
     70116 genes
     10 labels
     1 organisms
dataset contains 232 classes to predict



In [6]:
# we might want not to order the genes by expression (or do it?)
# we might want to not introduce zeros and 

[autoreload of scprint.model.model failed: Traceback (most recent call last):
  File "/home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/IPython/extensions/autoreload.py", line 276, in check
    superreload(m, reload, self.old_objects)
  File "/home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/IPython/extensions/autoreload.py", line 475, in superreload
    module = reload(module)
  File "/home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/importlib/__init__.py", line 169, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 619, in _exec
  File "<frozen importlib._bootstrap_external>", line 879, in exec_module
  File "<frozen importlib._bootstrap_external>", line 1017, in get_code
  File "<frozen importlib._bootstrap_external>", line 947, in source_to_code
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/home/ml4ig1/Documents code/scPRINT/scprint/model/model.py", line 593
    drop 

In [7]:
col = Collator(organisms=["NCBITaxon:9606",], labels=all_labels, genelist=embeddings.index.tolist(), max_len=2000, add_zero_genes=200, org_to_id={'NCBITaxon:9606': mdataset.encoder['organism_ontology_term_id']['NCBITaxon:9606']})#mdataset.encoder['organism_ontology_term_id'])

In [8]:
datamodule = DataModule(mdataset, label_to_weight=labels_weighted_sampling, collate_fn=col, batch_size=16, num_workers=4)
datamodule.setup()

In [9]:
for i in datamodule.train_dataloader():
    break

In [10]:
labels = {k: len(v) for k, v in mdataset.class_topred.items()}

cls_hierarchies = {}
for k, dic in mdataset.class_groupings.items():
    rdic = {}
    for sk, v in dic.items():
        rdic[mdataset.encoder[k][sk]] = [mdataset.encoder[k][i] for i in list(v)]
    cls_hierarchies[k] = rdic

df = embeddings.join(biomart,how="inner")

genedf = load_genes(['NCBITaxon:9606'])
df = df.loc[genedf[genedf.index.isin(df.index)].index]

In [11]:
labels

{'cell_type_ontology_term_id': 190,
 'disease_ontology_term_id': 18,
 'assay_ontology_term_id': 11,
 'self_reported_ethnicity_ontology_term_id': 8,
 'sex_ontology_term_id': 3,
 'organism_ontology_term_id': 2}

In [21]:
max(cls_hierarchies['assay_ontology_term_id'].keys())

13

In [12]:
model = scPrint(
    genes = df.index.tolist(),
    d_model = 64,
    nhead = 2,
    d_hid = 64,
    nlayers = 2,
    layers_cls = [],
    labels = labels,
    cls_hierarchy = cls_hierarchies,
    dropout= 0.2,
    transformer = "fast",
    use_precpt_gene_emb = df.values[:, :64].astype(float),
    gene_pos_enc = df['pos'].tolist(),
    mvc_decoder = "inner product",
)

TypeError: scPrint.__init__() got an unexpected keyword argument 'do_adv_cls'

In [59]:
model.training_step(i, 0)

encoding

torch.Size([14, 7, 64])
> /home/ml4ig1/Documents code/scPRINT/scprint/model/model.py(589)_compute_loss()
    587 
    588             pdb.set_trace()
--> 589             for labelname, cl in zip(self.labels, clss.T):
    590                 drop = cl == -1  # unknown label
    591                 cl = cl[~drop]

> /home/ml4ig1/Documents code/scPRINT/scprint/model/model.py(590)_compute_loss()
    588             pdb.set_trace()
    589             for labelname, cl in zip(self.labels, clss.T):
--> 590                 drop = cl == -1  # unknown label
    591                 cl = cl[~drop]
    592                 pred = output["cls_output_" + labelname]

> /home/ml4ig1/Documents code/scPRINT/scprint/model/model.py(591)_compute_loss()
    589             for labelname, cl in zip(self.labels, clss.T):
    590                 drop = cl == -1  # unknown label
--> 591                 cl = cl[~drop]
    592                 pred = output["cls_output_" + labelname]
    593              

In [15]:
model

scPrint(
  (gene_encoder): GeneEncoder(
    (embedding): Embedding(33890, 64)
    (enc_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (expr_encoder): ContinuousValueEncoder(
    (linear1): Linear(in_features=1, out_features=64, bias=True)
    (activation): ReLU()
    (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (label_encoder): BatchLabelEncoder(
    (embedding): Embedding(9, 64)
    (enc_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  )
  (time_encoder): ContinuousValueEncoder(
    (linear1): Linear(in_features=1, out_features=64, bias=True)
    (activation): ReLU()
    (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (depth_encoder): ContinuousValueEncoder(
    (linear1): Linear(in_features=1, o

In [23]:
# sets seeds for numpy, torch and python.random.
trainer = Trainer(deterministic=True, fast_dev_run=True)

INFO: GPU available: True (cuda), used: True
2024-01-19 10:24:00,318:INFO - GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
2024-01-19 10:24:00,321:INFO - TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
2024-01-19 10:24:00,324:INFO - IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
2024-01-19 10:24:00,328:INFO - HPU available: False, using: 0 HPUs
INFO: Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.
2024-01-19 10:24:00,373:INFO - Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.
2024-01-19 10:24:00,318:INFO - GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
2024-01-19 10:24:00,321:INFO - TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
2024-01-19 10:24:00,324:INFO - IPU avail

In [46]:
trainloader = datamodule.train_dataloader()

In [None]:
trainer.fit(model, trainloader)

In [None]:
# TODO: test unseen genes (do we see much being kept after filtering and stuff)
# TODO: debug the timepoint problem
# TODO: find the neighboors and next time point cells
# TODO: create a version with next time point and neighboors task
# TODO: add KO & drug datasets
# TODO: create a version with KO and drug effect prediction