# `cellarium-ml` highly variable genes

2024.12.02

Stephen Fleming

Pre-requisite:

Run `onepass_mean_var_std` by using a command like

```bash
(cellarium) $ cellarium-ml onepass_mean_var_std fit -c onepass_config.yaml
```

In [1]:
import os
from string import Template

In [2]:
working_dir = "./tmp"

In [4]:
batch_size = 4096
shard_size = 50000
last_shard_size = 25023
config_file_template = "../examples/cli_workflow/onepass_train_config.yaml"

In [5]:
local_config_yaml = os.path.join(working_dir, "config.yaml")

with open(config_file_template, "r") as file:
    yaml_text = file.read()

substitutions = {
    "root_dir": working_dir,
    "shard_size": shard_size,
    "batch_size": batch_size,
    "last_shard_size": last_shard_size,
}

template = Template(yaml_text)
customized_yaml = template.substitute(substitutions)

# write the customized YAML to a local file in the working directory
with open(local_config_yaml, "w") as file:
    file.write(customized_yaml)

print(f"Config YAML written to: {local_config_yaml}")

Config YAML written to: ./tmp/config.yaml


In [6]:
!cat $local_config_yaml

# lightning.pytorch==2.2.1
seed_everything: true
trainer:
  accelerator: auto
  strategy:
    class_path: lightning.pytorch.strategies.DDPStrategy
    init_args:
      accelerator: null
      parallel_devices: null
      cluster_environment: null
      checkpoint_io: null
      precision_plugin: null
      ddp_comm_state: null
      ddp_comm_hook: null
      ddp_comm_wrapper: null
      model_averaging_period: null
      process_group_backend: null
      timeout: 0:30:00
      start_method: popen
    dict_kwargs:
      broadcast_buffers: false
  devices: 1
  num_nodes: 1
  precision: null
  logger: null
  callbacks: null
  fast_dev_run: false
  max_epochs: 1
  min_epochs: null
  max_steps: -1
  min_steps: null
  max_time: null
  limit_train_batches: null
  limit_val_batches: null
  limit_test_batches: null
  limit_predict_batches: null
  overfit_batches: 0.0
  val_check_interval: null
  check_val_every_n_epoch: 1
  num_sanity_val_steps: null
  log_every_n_steps: null
  enable_checkpoin

In [7]:
!cellarium-ml onepass_mean_var_std fit -c $local_config_yaml

Seed set to 0
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[rank: 0] Seed set to 0
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type              | Params
-----------------------------------------------
0 | pipeline | CellariumPipeline | 1     
-----------------------------------------------
1         Trainable params
0         Non-trainable params
1         Total params
0.000     Total estimated model params size (MB)
Epoch 0: 100%|███████████████████████| 470/470 [06:04<00:00,  1.29it/s, v_num=1]`Trainer.fit` stopped: `max_epo

In [8]:
# fill this in

onepass_checkpoint = "../notebooks/tmp/lightning_logs/version_1/checkpoints/epoch=0-step=470.ckpt"

In [9]:
from cellarium.ml import CellariumModule
from cellarium.ml.preprocessing import get_highly_variable_genes

In [10]:
module = CellariumModule.load_from_checkpoint(onepass_checkpoint)
module

CellariumModule(pipeline = CellariumPipeline(
  (0): NormalizeTotal(target_count=10000, eps=1e-06)
  (1): Log1p()
  (2): OnePassMeanVarStd()
))

In [11]:
module.model.var_names_g

array(['TSPAN6', 'TNMD', 'DPM1', ..., 'LY6S', 'ENSG00000291316',
       'TMEM276'], dtype=object)

In [12]:
module.model.mean_g

tensor([7.0748e-04, 5.8727e-05, 1.7840e-01,  ..., 4.8401e-04, 3.4831e-04,
        2.0109e-02], device='cuda:0')

In [13]:
module.model.var_g

tensor([4.7239e-04, 5.4425e-05, 1.1273e-01,  ..., 3.5554e-04, 2.6977e-04,
        1.1562e-02], device='cuda:0')

In [14]:
hvg_df = get_highly_variable_genes(
    gene_names=module.model.var_names_g.astype(str),
    mean=module.model.mean_g.detach().cpu(),
    var=module.model.var_g.detach().cpu(),
    n_top_genes=2000,
)

In [15]:
hvg_df[hvg_df["highly_variable"]]

Unnamed: 0,means,dispersions,mean_bin,dispersions_norm,highly_variable
CFTR,0.285716,0.075575,"(0.282, 0.376]",2.161184,True
KLHL13,0.117839,0.217548,"(0.0941, 0.188]",3.035444,True
HOXA11,0.000020,0.320333,"(-0.00188, 0.0941]",3.073698,True
LGALS14,0.000009,0.145802,"(-0.00188, 0.0941]",2.258085,True
ST3GAL1,0.047569,0.060156,"(-0.00188, 0.0941]",1.857845,True
...,...,...,...,...,...
ENSG00000291017,0.000014,0.023587,"(-0.00188, 0.0941]",1.686951,True
ENSG00000291077,0.000053,0.096751,"(-0.00188, 0.0941]",2.028859,True
LGALS17A,0.000479,0.609067,"(-0.00188, 0.0941]",4.423005,True
ENSG00000291095,0.000032,0.012063,"(-0.00188, 0.0941]",1.633098,True


In [16]:
hvg_df["ensembl_id"] = hvg_df.index.copy()

In [17]:
hvg_df[hvg_df["highly_variable"]]

Unnamed: 0,means,dispersions,mean_bin,dispersions_norm,highly_variable,ensembl_id
CFTR,0.285716,0.075575,"(0.282, 0.376]",2.161184,True,CFTR
KLHL13,0.117839,0.217548,"(0.0941, 0.188]",3.035444,True,KLHL13
HOXA11,0.000020,0.320333,"(-0.00188, 0.0941]",3.073698,True,HOXA11
LGALS14,0.000009,0.145802,"(-0.00188, 0.0941]",2.258085,True,LGALS14
ST3GAL1,0.047569,0.060156,"(-0.00188, 0.0941]",1.857845,True,ST3GAL1
...,...,...,...,...,...,...
ENSG00000291017,0.000014,0.023587,"(-0.00188, 0.0941]",1.686951,True,ENSG00000291017
ENSG00000291077,0.000053,0.096751,"(-0.00188, 0.0941]",2.028859,True,ENSG00000291077
LGALS17A,0.000479,0.609067,"(-0.00188, 0.0941]",4.423005,True,LGALS17A
ENSG00000291095,0.000032,0.012063,"(-0.00188, 0.0941]",1.633098,True,ENSG00000291095


In [20]:
hvg_df[hvg_df["highly_variable"]]["ensembl_id"].to_csv("tmp/hvg.csv", index=False)

In [21]:
!head hvg.csv

ensembl_id
CFTR
KLHL13
HOXA11
LGALS14
ST3GAL1
TENM1
GPRC5A
CNTN1
HGF
