# Celltype Classification with Transcriptformer

This notebooks showcase how the Transcriptformer embeddings can be used to train a classifier for celltype classification.

In [1]:
%load_ext autoreload
%autoreload 2

import json
import logging
import os

import hydra
from omegaconf import DictConfig, OmegaConf


from transcriptformer.model.inference import run_inference
from transcriptformer.datasets import tabula_sapiens
import yaml

In [2]:
# !python ./../download_artifacts.py tf-sapiens

In [3]:
adata = tabula_sapiens(tissue="ear", version="v2")
cfg = OmegaConf.load("./../conf/inference_config.yaml")
logging.debug(OmegaConf.to_yaml(cfg))

cfg.model.checkpoint_path = "./checkpoints/tf_sapiens"

config_path = os.path.join(cfg.model.checkpoint_path, "config.json")
with open(config_path) as f:
    config_dict = json.load(f)
mlflow_cfg = OmegaConf.create(config_dict)

# Merge the MLflow config with the main config
cfg = OmegaConf.merge(mlflow_cfg, cfg)

# Set the checkpoint paths based on the unified checkpoint_path
cfg.model.inference_config.load_checkpoint = os.path.join(cfg.model.checkpoint_path, "model_weights.pt")
cfg.model.data_config.aux_vocab_path = os.path.join(cfg.model.checkpoint_path, "vocabs")
cfg.model.data_config.esm2_mappings_path = os.path.join(cfg.model.checkpoint_path, "vocabs")


In [4]:
adata.var["ensembl_id"] = adata.var_names.str.split('.').str[0]
adata.X = adata.layers["decontXcounts"]

In [5]:
# Set logging level to ERROR to reduce verbosity
logging.getLogger().setLevel(logging.ERROR)

adata_output = run_inference(cfg, data_files=[adata])

Using 16bit Automatic Mixed Precision (AMP)
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/work/venv/tf311/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 382/382 [00:25<00:00, 15.13it/s]




In [7]:
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedKFold, cross_validate
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    make_scorer,
    precision_score,
    recall_score,
)
pipeline = Pipeline(
                [
                    ("scaler", StandardScaler()),
                    ("lr", LogisticRegression()),
                ]
            )
average_type = "macro"
SEED = 42
scorers = {
            "accuracy": make_scorer(accuracy_score),
            "f1": make_scorer(f1_score, average=average_type),
            "precision": make_scorer(precision_score, average=average_type),
            "recall": make_scorer(recall_score, average=average_type),
        }

skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=SEED)

In [12]:
embeddings, labels = adata_output.obsm["embeddings"], adata_output.obs["cell_type"].values

In [13]:
labels = pd.Categorical(labels.astype(str))
cv_results = cross_validate(
    pipeline,
    embeddings,
    labels.codes,
    cv=skf,
    scoring=scorers,
    return_train_score=False,
)

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [14]:
cv_results

{'fit_time': array([5.32113767, 3.82658339, 3.82428098]),
 'score_time': array([0.04094672, 0.03930044, 0.02056813]),
 'test_accuracy': array([0.96859666, 0.98035363, 0.97347741]),
 'test_f1': array([0.76635337, 0.97750495, 0.75698548]),
 'test_precision': array([0.77295502, 0.9744704 , 0.7483731 ]),
 'test_recall': array([0.76083898, 0.98095984, 0.77046934])}

In [None]:
adata_map = tabula_sapiens(tissue="ear", version="v1")
adata_map.var["ensembl_id"] = adata_map.var_names.str.split('.').str[0]
adata_map.X = adata_map.layers["decontXcounts"]