In [1]:
import os
import torch
import warnings
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import typing as t

# To suppress the stupid AnnData warning ...
warnings.filterwarnings("ignore", category=UserWarning, message="Transforming to str index.")

from cellarium.ml.utilities.inference.cellarium_gpt_inference import \
    CellariumGPTInferenceContext

In [2]:
# arguments
cuda_device_index = 0
val_adata_index = 1

checkpoint_path = "/home/mehrtash/data/100M_long_run/run_001/lightning_logs/version_3/checkpoints/epoch=5-step=504000.ckpt"
ref_adata_path = "/home/mehrtash/data/data/extract_0.h5ad"
gene_info_path = "/home/mehrtash/data/gene_info/gene_info.tsv"

In [3]:
cuda_device_index = 0
device = torch.device(f"cuda:{cuda_device_index}")

ctx = CellariumGPTInferenceContext(
    cellarium_gpt_ckpt_path=checkpoint_path,
    ref_adata_path=ref_adata_path,
    gene_info_tsv_path=gene_info_path,
    device=device,
    attention_backend="mem_efficient",
    verbose=False
)

In [4]:
ctx.model_var_names

array(['ENSG00000187642', 'ENSG00000078808', 'ENSG00000272106', ...,
       'ENSG00000228836', 'ENSG00000231937', 'ENSG00000268916'],
      dtype=object)

In [5]:
all_genes_with_info_set = set(ctx.gene_info_df['ENSEMBL Gene ID'].values)
gene_id_to_biotype_map = dict(zip(ctx.gene_info_df['ENSEMBL Gene ID'], ctx.gene_info_df['Gene Biotype']))

In [6]:
# what is the breakdown by type?

from collections import defaultdict

counts_dict = defaultdict(int)
for var_name in ctx.model_var_names:
    if var_name not in all_genes_with_info_set:
        counts_dict["missing"] += 1
    else:
        biotype = gene_id_to_biotype_map[var_name]
        counts_dict[biotype] += 1

counts_dict

defaultdict(int,
            {'protein_coding': 19250,
             'lncRNA': 16411,
             'IG_V_gene': 145,
             'IG_V_pseudogene': 187,
             'missing': 224,
             'TR_V_gene': 106,
             'TR_V_pseudogene': 33,
             'TR_J_gene': 79,
             'IG_D_gene': 37,
             'IG_J_gene': 18,
             'TR_J_pseudogene': 4,
             'IG_J_pseudogene': 3,
             'IG_C_pseudogene': 9,
             'transcribed_unitary_pseudogene': 17,
             'artifact': 17,
             'transcribed_unprocessed_pseudogene': 29,
             'processed_pseudogene': 2,
             'IG_C_gene': 14,
             'TR_C_gene': 6,
             'TR_D_gene': 4,
             'TEC': 1,
             'transcribed_processed_pseudogene': 2,
             'unitary_pseudogene': 1,
             'unprocessed_pseudogene': 2})

In [7]:
from pybiomart import Dataset
import pandas as pd

# Connect to the human gene dataset
dataset = Dataset(name='hsapiens_gene_ensembl', host='www.ensembl.org')

# Query the dataset for all gene IDs and their chromosome names
result = dataset.query(attributes=['ensembl_gene_id', 'chromosome_name'])

# Optionally, rename the chromosome column for clarity
result.rename(columns={'Chromosome/scaffold name': 'contig'}, inplace=True)

print(result)


        Gene stable ID contig
0      ENSG00000210049     MT
1      ENSG00000211459     MT
2      ENSG00000210077     MT
3      ENSG00000210082     MT
4      ENSG00000209082     MT
...                ...    ...
86397  ENSG00000235358      1
86398  ENSG00000228067      1
86399  ENSG00000293271      1
86400  ENSG00000310526      1
86401  ENSG00000241860      1

[86402 rows x 2 columns]


In [8]:
result['contig'].value_counts()

contig
1                  7095
2                  5685
6                  4230
11                 4170
3                  4161
                   ... 
KI270720.1            1
KI270718.1            1
GL000216.2            1
HSCHR17_12_CTG4       1
HSCHR4_4_CTG12        1
Name: count, Length: 528, dtype: int64

In [9]:
gene_id_to_contig_map = dict(zip(result['Gene stable ID'], result['contig']))
model_genes_df = pd.DataFrame(ctx.model_var_names, columns=['ensembl_gene_id'])
model_genes_df['contig'] = model_genes_df['ensembl_gene_id'].map(gene_id_to_contig_map)

In [10]:
model_genes_df['contig'].value_counts(dropna=False)

contig
1             3327
2             2450
11            2025
19            1995
17            1938
12            1857
3             1820
6             1764
5             1736
7             1633
16            1605
4             1477
8             1435
14            1428
10            1351
9             1268
15            1225
NaN           1135
X             1128
20             941
22             880
13             761
18             740
21             538
Y              107
MT              13
KI270728.1       4
KI270727.1       3
KI270734.1       3
GL000194.1       2
KI270713.1       2
KI270726.1       2
GL000219.1       1
GL000195.1       1
KI270721.1       1
KI270731.1       1
KI270711.1       1
GL000009.2       1
GL000213.1       1
GL000218.1       1
Name: count, dtype: int64

In [11]:
autosomal_gene_ids = model_genes_df[
    model_genes_df['contig'].isin({str(i) for i in range(1, 23)})]['ensembl_gene_id'].values.tolist()
sex_gene_ids = model_genes_df[
    model_genes_df['contig'].isin({'X', 'Y'})]['ensembl_gene_id'].values.tolist()
x_gene_ids = model_genes_df[
    model_genes_df['contig'].isin({'X'})]['ensembl_gene_id'].values.tolist()
y_gene_ids = model_genes_df[
    model_genes_df['contig'].isin({'Y'})]['ensembl_gene_id'].values.tolist()

print(f"Number of autosomal genes: {len(autosomal_gene_ids)}")
print(f"Number of sex genes: {len(sex_gene_ids)}")
print(f"Number of X genes: {len(x_gene_ids)}")
print(f"Number of Y genes: {len(y_gene_ids)}")

assert set(autosomal_gene_ids).intersection(sex_gene_ids) == set()

Number of autosomal genes: 34194
Number of sex genes: 1235
Number of X genes: 1128
Number of Y genes: 107


In [12]:
# write the gene lists to file
with open("/home/mehrtash/data/data/cellariumgpt_artifacts/autosomal_gene_ids.txt", 'w') as f:
    for gene_id in autosomal_gene_ids:
        f.write(f"{gene_id}\n")

with open("/home/mehrtash/data/data/cellariumgpt_artifacts/sex_gene_ids.txt", 'w') as f:
    for gene_id in sex_gene_ids:
        f.write(f"{gene_id}\n")

with open("/home/mehrtash/data/data/cellariumgpt_artifacts/x_gene_ids.txt", 'w') as f:
    for gene_id in x_gene_ids:
        f.write(f"{gene_id}\n")

with open("/home/mehrtash/data/data/cellariumgpt_artifacts/y_gene_ids.txt", 'w') as f:
    for gene_id in y_gene_ids:
        f.write(f"{gene_id}\n")