In [1]:
import scanpy as sc
import os
import sys
import numpy as np
from tqdm import tqdm

root_path = os.path.abspath('/work/magroup/skrieger/scMulan/Tutorials/scMulan')
sys.path.append(os.path.abspath(root_path))
from utils.hf_tokenizer import scMulanTokenizer
%load_ext autoreload
%autoreload 2
%env CUDA_LAUNCH_BLOCKING=1


  from .autonotebook import tqdm as notebook_tqdm


env: CUDA_LAUNCH_BLOCKING=1


In [2]:
from data_util import get_generation_dataloader
import torch

# load your AnnData and meta_info
# adata     = sc.read_h5ad('my_mouse_data.h5ad')
meta_info = torch.load('scMulan/utils/meta_info.pt')
# adata_sub = sc.read_h5ad('../../tissue_generator/MERFISH_aging/sc_ref_xyz_scMulan_processed.h5ad')
# adata_sub = sc.read_h5ad('../../tissue_generator/MERFISH_aging/sc_ref_xyz_scMulan_processed_n100.h5ad')


In [13]:
from scipy.stats import pearsonr
import numpy as np
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import PCA

def binary_reconstruction_metrics(predicted, ground_truth, axis=1, eps=1e-8):
    pred_bool = (predicted > 0)
    true_bool = (ground_truth > 0)

    tp = np.sum(pred_bool & true_bool, axis=axis)
    fp = np.sum(pred_bool & ~true_bool, axis=axis)
    fn = np.sum(~pred_bool & true_bool, axis=axis)

    precision = tp / (tp + fp + eps)
    recall    = tp / (tp + fn + eps)
    f1        = 2 * (precision * recall) / (precision + recall + eps)

    return {
        'precision': precision.mean(),
        'recall':    recall.mean(),
        'f1':        f1.mean()
    }

def cell_type_likelihood_knn(predicted, adata, subclass_key='subclass', n_pc=50, n_neighbors=50):
    X_true = adata.X.toarray() if hasattr(adata.X, 'toarray') else adata.X
    pca = PCA(n_components=n_pc).fit(X_true)
    X_true_pca = pca.transform(X_true)
    X_pred_pca = pca.transform(predicted)

    nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(X_true_pca)
    distances, indices = nbrs.kneighbors(X_pred_pca)

    labels = adata.obs[subclass_key].values
    p_knn = np.array([
        (labels[inds] == labels[i]).mean()
        for i, inds in enumerate(indices)
    ])
    return p_knn


def get_metrics(scm, rows, idxlist, subclass_key='subclass', n_pc=50, n_neighbors=50):
    """
    Print various reconstruction metrics:
    - Pearson correlation per gene
    - Binary reconstruction precision/recall/F1
    - Cell-type likelihood from k-NN analysis
    
    Parameters:
    -----------
    scm : object with .adata (AnnData)
    rows : list of predicted expression arrays (one per cell)
    k : number of cells to evaluate
    subclass_key : key in .obs for cell type labels
    n_pc : number of principal components for kNN
    n_neighbors : number of neighbors for kNN
    """

    ground_truth = scm.adata.X.toarray()[idxlist].T
    assert len(rows) == ground_truth.shape[1], "Mismatch in number of cells"

    # Pearson r per gene
    rs = []
    for j, pred_vec in enumerate(rows):
        true_vec = ground_truth[:, j]
        r, _ = pearsonr(pred_vec, true_vec)
        rs.append(r)
    mean_r = np.mean(rs)
    print(f"Average Pearson r over {len(rs)} pairs: {mean_r:.4f}")

    # Binary reconstruction metrics
    ground_truth_matrix = scm.adata.X.toarray()[idxlist]
    pred_matrix = np.asarray(rows)
    bin_metrics = binary_reconstruction_metrics(pred_matrix, ground_truth_matrix)
    print(f"Binary reconstruction metrics: {bin_metrics}")

    # Cell-type likelihood via kNN
    p_knn = cell_type_likelihood_knn(pred_matrix, scm.adata, subclass_key=subclass_key, n_pc=n_pc, n_neighbors=n_neighbors)
    print(f"Mean cell-type likelihood (kNN): {p_knn.mean():.4f}")
    
def run_and_metrics(scm, idxlist=range(100)):
    
    results = scm.generate_cell_genesis(
            idx=idxlist,
            max_new_tokens= 500,
            top_k= 5,
            # verbose=True,
            return_gt=True,
            batch_size=12,
        )
    rows = [r[0] for r in results]
    get_metrics(scm, rows, idxlist)
    return results


# Creation of scRNA-seq finetuning dataset

In [2]:
adata = sc.read('../../tissue_generator/MERFISH_aging/sc_ref_xyz_scMulan.h5ad')
adata



AnnData object with n_obs √ó n_vars = 25004 √ó 42117
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label', 'cell_barcoder', 'barcoded_cell_sample_label', 'library_labelr', 'feature_matrix_label', 'entity', 'brain_section_label', 'library_method', 'region_of_interest_acronym', 'donor_label', 'donor_genotype', 'donor_sex', 'dataset_label', 'x', 'y', 'cluster_alias', 'neurotransmitter', 'class', 'subclass', 'supertype', 'cluster', 'neurotransmitter_color', 'class_color', 'subclass_color', 'supertype_color', 'cluster_color', 'region_of_interest_order', 'region_of_interest_color', 'z', 'n_genes'
    uns: 'log1p'

In [3]:
# adata_sub = adata[:500].copy()
adata_sub = adata

In [5]:
adata_sub.obs['organ'] = 'Brain'

In [6]:
import numpy as np

region_map = {
    'SS-GU-VISC':  'Cerebral cortex',  # somatosensory gustatory/visceral area
    'PL-ILA-ORB':  'Cerebral cortex',  # prelimbic, infralimbic, orbital prefrontal
    'TEa-PERI-ECT':'Cerebral cortex',  # temporal association & perirhinal
    'MOp':         'Cerebral cortex',  # primary motor cortex
    'VIS':         'Cerebral cortex',  # visual cortex
    'VIS-PTLp':    'Cerebral cortex',  # posterior lateral visual
    'SSp':         'Cerebral cortex',  # primary somatosensory cortex
    'MO-FRP':      'Cerebral cortex',  # frontal pole motor
    'AI':          'Cerebral cortex',  # agranular insular cortex
    'AUD':         'Cerebral cortex',  # auditory cortex

    'ACA':         'Cingulate cortex', # anterior cingulate area
    'RSP':         'Cingulate cortex', # retrosplenial cortex

    np.nan:        'Unclassified'     # missing values
}
adata_sub.obs['region'] = adata_sub.obs['region_of_interest_acronym'].map(region_map)
adata_sub.obs['region'] = adata_sub.obs['region'].fillna('Unclassified')

In [7]:
new2existing = {
    # glutamatergic ‚Üí Excitatory neuron
    '006 L4/5 IT CTX Glut':                      'Excitatory neuron',
    '030 L6 CT CTX Glut':                        'Excitatory neuron',
    '029 L6b CTX Glut':                         'Excitatory neuron',
    '032 L5 NP CTX Glut':                       'Excitatory neuron',
    '005 L5 IT CTX Glut':                       'Excitatory neuron',
    '007 L2/3 IT CTX Glut':                     'Excitatory neuron',
    '022 L5 ET CTX Glut':                       'Excitatory neuron',
    '021 L4 RSP-ACA Glut':                      'Excitatory neuron',
    '004 L6 IT CTX Glut':                       'Excitatory neuron',
    '020 L2/3 IT RSP Glut':                     'Excitatory neuron',
    '001 CLA-EPd-CTX Car3 Glut':                'Excitatory neuron',
    '003 L5/6 IT TPE-ENT Glut':                 'Excitatory neuron',
    '028 L6b/CT ENT Glut':                      'Excitatory neuron',
    '002 IT EP-CLA Glut':                       'Excitatory neuron',
    '025 CA2-FC-IG Glut':                       'Excitatory neuron',
    '009 L2/3 IT PIR-ENTl Glut':                'Excitatory neuron',
    '010 IT AON-TT-DP Glut':                    'Excitatory neuron',
    '036 HPF CR Glut':                          'Excitatory neuron',
    '008 L2/3 IT ENT Glut':                     'Excitatory neuron',
    '027 L6b EPd Glut':                         'Excitatory neuron',
    '114 COAa-PAA-MEA Barhl2 Glut':             'Excitatory neuron',
    '018 L2 IT PPP-APr Glut':                   'Excitatory neuron',
    '035 OB Eomes Ms4a15 Glut':                 'Excitatory neuron',
    '262 Pineal Crx Glut':                      'Excitatory neuron',
    '034 NP PPP Glut':                          'Excitatory neuron',
    '019 L2/3 IT PPP Glut':                     'Excitatory neuron',
    '033 NP SUB Glut':                          'Excitatory neuron',
    '115 MS-SF Bsx Glut':                       'Excitatory neuron',
    '016 CA1-ProS Glut':                        'Excitatory neuron',

    # GABAergic ‚Üí Inhibitory neuron
    '053 Sst Gaba':                             'Inhibitory neuron',
    '050 Lamp5 Lhx6 Gaba':                     'Inhibitory neuron',
    '046 Vip Gaba':                             'Inhibitory neuron',
    '052 Pvalb Gaba':                           'Inhibitory neuron',
    '049 Lamp5 Gaba':                          'Inhibitory neuron',
    '047 Sncg Gaba':                           'Inhibitory neuron',
    '056 Sst Chodl Gaba':                      'Inhibitory neuron',
    '041 OB-in Frmd7 Gaba':                    'Inhibitory neuron',
    '066 NDB-SI-ant Prdm12 Gaba':              'Inhibitory neuron',
    '061 STR D1 Gaba':                         'Inhibitory neuron',
    '062 STR D2 Gaba':                         'Inhibitory neuron',
    '064 STR-PAL Chst9 Gaba':                  'Inhibitory neuron',
    '042 OB-out Frmd7 Gaba':                   'Inhibitory neuron',
    '039 OB Meis2 Thsd7b Gaba':                'Inhibitory neuron',
    '080 CEA-AAA-BST Six3 Sp9 Gaba':           'Inhibitory neuron',
    '051 Pvalb chandelier Gaba':               'Inhibitory neuron',
    '044 OB Dopa-Gaba':                        'Inhibitory neuron',
    '045 OB-STR-CTX Inh IMN':                  'Inhibitory neuron',
    '065 IA Mgp Gaba':                         'Inhibitory neuron',
    '063 STR D1 Sema5a Gaba':                  'Inhibitory neuron',
    '048 RHP-COA Ndnf Gaba':                   'Inhibitory neuron',

    # ‚ÄúNN‚Äù suffix ‚Üí non‚Äêneuronal
    '327 Oligo NN':                            'Oligodendrocyte',
    '326 OPC NN':                              'Oligodendrocyte precursor cell (OPC)',
    '333 Endo NN':                             'Endothelial cell',
    '332 SMC NN':                              'Smooth muscle cell',
    '330 VLMC NN':                             'Vascular smooth muscle cell',
    '319 Astro-TE NN':                         'Astrocyte',
    '318 Astro-NT NN':                         'Astrocyte',
    '338 Lymphoid NN':                         'Lymphoid cell',
    '331 Peri NN':                             'Pericyte',
    '334 Microglia NN':                        'Microglia',
    '335 BAM NN':                              'Macrophage',
    '329 ABC NN':                              'Basal cell',

    # missing / fallback
    None:                                      'Unclassified',
}

adata_sub.obs['cell_type'] = adata_sub.obs['subclass'].map(new2existing)
adata_sub.obs['cell_type'] = adata_sub.obs['cell_type'].fillna('Unclassified')

In [8]:
coord_bins = {}
n_bins = 10  # e.g. 10
for coord in ('x','y','z'):
    vals_full = adata_sub.obs[coord].values.astype(float)
    vals = adata_sub.obs[coord].dropna().values.astype(float)
    coord_bins[coord] = np.linspace(vals.min(), vals.max(), n_bins)
    edges   = coord_bins[coord]
    bin_idxs = np.digitize(vals_full, edges, right=True)
    adata_sub.obs[f'<{coord}>'] = bin_idxs
    

In [9]:
cols = ['x','y','z','cell_type','region','organ']
mask = adata_sub.obs[cols].notnull().all(axis=1)
adata_sub = adata_sub[mask].copy()
adata_sub

AnnData object with n_obs √ó n_vars = 24881 √ó 42117
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label', 'cell_barcoder', 'barcoded_cell_sample_label', 'library_labelr', 'feature_matrix_label', 'entity', 'brain_section_label', 'library_method', 'region_of_interest_acronym', 'donor_label', 'donor_genotype', 'donor_sex', 'dataset_label', 'x', 'y', 'cluster_alias', 'neurotransmitter', 'class', 'subclass', 'supertype', 'cluster', 'neurotransmitter_color', 'class_color', 'subclass_color', 'supertype_color', 'cluster_color', 'region_of_interest_order', 'region_of_interest_color', 'z', 'n_genes', 'organ', 'region', 'cell_type', '<x>', '<y>', '<z>'
    uns: 'log1p'

In [10]:
adata_sub.write('../../tissue_generator/MERFISH_aging/sc_ref_xyz_scMulan_processed.h5ad')

# Creation of scRNA-seq finetuning dataset with 4-level hierarchy

In [30]:
adata = sc.read('../../tissue_generator/MERFISH_aging/sc_ref_xyz_scMulan_processed_n100.h5ad')
adata



AnnData object with n_obs √ó n_vars = 24881 √ó 42117
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label', 'cell_barcoder', 'barcoded_cell_sample_label', 'library_labelr', 'feature_matrix_label', 'entity', 'brain_section_label', 'library_method', 'region_of_interest_acronym', 'donor_label', 'donor_genotype', 'donor_sex', 'dataset_label', 'x', 'y', 'cluster_alias', 'neurotransmitter', 'class', 'subclass', 'supertype', 'cluster', 'neurotransmitter_color', 'class_color', 'subclass_color', 'supertype_color', 'cluster_color', 'region_of_interest_order', 'region_of_interest_color', 'z', 'n_genes', 'organ', 'region', 'cell_type', '<x>', '<y>', '<z>'
    uns: 'log1p'

In [31]:
# adata_sub = adata[:500].copy()
adata_sub = adata

In [32]:
adata_sub.obs['organ'] = 'Brain'

In [33]:
meta_info['class'] = adata_sub.obs['class'].unique().tolist()
meta_info['subclass'] = adata_sub.obs['subclass'].unique().tolist()
meta_info['supertype'] = adata_sub.obs['supertype'].unique().tolist()
meta_info['cluster'] = adata_sub.obs['cluster'].unique().tolist()
meta_info['region_of_interest_acronym'] = adata_sub.obs['region_of_interest_acronym'].unique().tolist()

In [34]:
coord_bins = {}
n_bins = 100  # e.g. 10
for coord in ('x','y','z'):
    vals_full = adata_sub.obs[coord].values.astype(float)
    vals = adata_sub.obs[coord].dropna().values.astype(float)
    coord_bins[coord] = np.linspace(vals.min(), vals.max(), n_bins)
    edges   = coord_bins[coord]
    bin_idxs = np.digitize(vals_full, edges, right=True)
    adata_sub.obs[f'<{coord}>'] = bin_idxs
    

In [19]:
del meta_info['study_id']
del meta_info['donor_gender']
del meta_info['cell_type']
del meta_info['age_bin']
del meta_info['donor_age']
del meta_info['seq_tech']
del meta_info['sample_status']
del meta_info['region']
del meta_info['MCT']

In [20]:
token_set = meta_info['token_set'][:2002] + meta_info['region_of_interest_acronym']+ list(meta_info['organ'])+ meta_info[ 'class']+ meta_info['subclass']+ meta_info['supertype']+ meta_info['cluster'] + meta_info['token_set'][-16:]
meta_info['token_set'] = list(set(token_set))

In [35]:
cols = ['x','y','z','region_of_interest_acronym','organ', 'class', 'subclass','supertype','cluster']
mask = adata_sub.obs[cols].notnull().all(axis=1)
adata_sub = adata_sub[mask].copy()
adata_sub

AnnData object with n_obs √ó n_vars = 24881 √ó 42117
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label', 'cell_barcoder', 'barcoded_cell_sample_label', 'library_labelr', 'feature_matrix_label', 'entity', 'brain_section_label', 'library_method', 'region_of_interest_acronym', 'donor_label', 'donor_genotype', 'donor_sex', 'dataset_label', 'x', 'y', 'cluster_alias', 'neurotransmitter', 'class', 'subclass', 'supertype', 'cluster', 'neurotransmitter_color', 'class_color', 'subclass_color', 'supertype_color', 'cluster_color', 'region_of_interest_order', 'region_of_interest_color', 'z', 'n_genes', 'organ', 'region', 'cell_type', '<x>', '<y>', '<z>'
    uns: 'log1p'

In [36]:
adata_sub.write('../../tissue_generator/MERFISH_aging/sc_ref_xyz_scMulan_processed_4hierarchy.h5ad')

In [23]:
torch.save(meta_info,'../../tissue_generator/MERFISH_aging/4hierarchy_metainfo.pt')

# Creation of scRNA-seq finetuning dataset with 4-level hierarchy and all data

In [3]:
adata = sc.read('../../tissue_generator/sc-reference/alldata-log2-xyz.h5ad')
adata



AnnData object with n_obs √ó n_vars = 1699939 √ó 32285
    obs: 'cell_barcodel', 'library_labell', 'anatomical_division_label', 'cell_label', 'cell_barcode', 'barcoded_cell_sample_label', 'library_label', 'feature_matrix_label', 'entity', 'brain_section_label', 'library_method', 'region_of_interest_acronym', 'donor_label', 'donor_genotype', 'donor_sex', 'dataset_label', 'x', 'y', 'cluster_alias', 'neurotransmitter', 'class', 'subclass', 'supertype', 'cluster', 'neurotransmitter_color', 'class_color', 'subclass_color', 'supertype_color', 'cluster_color', 'region_of_interest_order', 'region_of_interest_color', 'z'

In [4]:
# adata_sub = adata[:500].copy()
adata_sub = adata

In [5]:
adata_sub.obs['organ'] = 'Brain'

In [6]:
meta_info['class'] = adata_sub.obs['class'].unique().tolist()
meta_info['subclass'] = adata_sub.obs['subclass'].unique().tolist()
meta_info['supertype'] = adata_sub.obs['supertype'].unique().tolist()
meta_info['cluster'] = adata_sub.obs['cluster'].unique().tolist()
meta_info['region_of_interest_acronym'] = adata_sub.obs['region_of_interest_acronym'].unique().tolist()

In [7]:
coord_bins = {}
n_bins = 100  # e.g. 10
for coord in ('x','y','z'):
    vals_full = adata_sub.obs[coord].values.astype(float)
    vals = adata_sub.obs[coord].dropna().values.astype(float)
    coord_bins[coord] = np.linspace(vals.min(), vals.max(), n_bins)
    edges   = coord_bins[coord]
    bin_idxs = np.digitize(vals_full, edges, right=True)
    adata_sub.obs[f'<{coord}>'] = bin_idxs
    

In [8]:
cols = ['x','y','z','region_of_interest_acronym','organ', 'class', 'subclass','supertype','cluster']
mask = adata_sub.obs[cols].notnull().all(axis=1)
adata_sub = adata_sub[mask].copy()
adata_sub

AnnData object with n_obs √ó n_vars = 1699939 √ó 32285
    obs: 'cell_barcodel', 'library_labell', 'anatomical_division_label', 'cell_label', 'cell_barcode', 'barcoded_cell_sample_label', 'library_label', 'feature_matrix_label', 'entity', 'brain_section_label', 'library_method', 'region_of_interest_acronym', 'donor_label', 'donor_genotype', 'donor_sex', 'dataset_label', 'x', 'y', 'cluster_alias', 'neurotransmitter', 'class', 'subclass', 'supertype', 'cluster', 'neurotransmitter_color', 'class_color', 'subclass_color', 'supertype_color', 'cluster_color', 'region_of_interest_order', 'region_of_interest_color', 'z', 'organ', '<x>', '<y>', '<z>'

In [9]:
adata_sub.write('../../tissue_generator/sc-reference/alldata-log2-xyz_4hierarchy.h5ad')

In [None]:
adata_sub = sc.read('../../tissue_generator/sc-reference/alldata-log2-xyz_4hierarchy.h5ad')



In [None]:
adata_sub.var

# Creating training data with 100 bins for expression data

In [3]:
adata = sc.read('../../tissue_generator/MERFISH_aging/sc_ref_xyz_scMulan.h5ad')
adata



AnnData object with n_obs √ó n_vars = 25004 √ó 42117
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label', 'cell_barcoder', 'barcoded_cell_sample_label', 'library_labelr', 'feature_matrix_label', 'entity', 'brain_section_label', 'library_method', 'region_of_interest_acronym', 'donor_label', 'donor_genotype', 'donor_sex', 'dataset_label', 'x', 'y', 'cluster_alias', 'neurotransmitter', 'class', 'subclass', 'supertype', 'cluster', 'neurotransmitter_color', 'class_color', 'subclass_color', 'supertype_color', 'cluster_color', 'region_of_interest_order', 'region_of_interest_color', 'z', 'n_genes'
    uns: 'log1p'

In [4]:
# adata_sub = adata[:500].copy()
adata_sub = adata

In [5]:
adata_sub.obs['organ'] = 'Brain'

In [6]:
import numpy as np

region_map = {
    'SS-GU-VISC':  'Cerebral cortex',  # somatosensory gustatory/visceral area
    'PL-ILA-ORB':  'Cerebral cortex',  # prelimbic, infralimbic, orbital prefrontal
    'TEa-PERI-ECT':'Cerebral cortex',  # temporal association & perirhinal
    'MOp':         'Cerebral cortex',  # primary motor cortex
    'VIS':         'Cerebral cortex',  # visual cortex
    'VIS-PTLp':    'Cerebral cortex',  # posterior lateral visual
    'SSp':         'Cerebral cortex',  # primary somatosensory cortex
    'MO-FRP':      'Cerebral cortex',  # frontal pole motor
    'AI':          'Cerebral cortex',  # agranular insular cortex
    'AUD':         'Cerebral cortex',  # auditory cortex

    'ACA':         'Cingulate cortex', # anterior cingulate area
    'RSP':         'Cingulate cortex', # retrosplenial cortex

    np.nan:        'Unclassified'     # missing values
}
adata_sub.obs['region'] = adata_sub.obs['region_of_interest_acronym'].map(region_map)
adata_sub.obs['region'] = adata_sub.obs['region'].fillna('Unclassified')

In [7]:
new2existing = {
    # glutamatergic ‚Üí Excitatory neuron
    '006 L4/5 IT CTX Glut':                      'Excitatory neuron',
    '030 L6 CT CTX Glut':                        'Excitatory neuron',
    '029 L6b CTX Glut':                         'Excitatory neuron',
    '032 L5 NP CTX Glut':                       'Excitatory neuron',
    '005 L5 IT CTX Glut':                       'Excitatory neuron',
    '007 L2/3 IT CTX Glut':                     'Excitatory neuron',
    '022 L5 ET CTX Glut':                       'Excitatory neuron',
    '021 L4 RSP-ACA Glut':                      'Excitatory neuron',
    '004 L6 IT CTX Glut':                       'Excitatory neuron',
    '020 L2/3 IT RSP Glut':                     'Excitatory neuron',
    '001 CLA-EPd-CTX Car3 Glut':                'Excitatory neuron',
    '003 L5/6 IT TPE-ENT Glut':                 'Excitatory neuron',
    '028 L6b/CT ENT Glut':                      'Excitatory neuron',
    '002 IT EP-CLA Glut':                       'Excitatory neuron',
    '025 CA2-FC-IG Glut':                       'Excitatory neuron',
    '009 L2/3 IT PIR-ENTl Glut':                'Excitatory neuron',
    '010 IT AON-TT-DP Glut':                    'Excitatory neuron',
    '036 HPF CR Glut':                          'Excitatory neuron',
    '008 L2/3 IT ENT Glut':                     'Excitatory neuron',
    '027 L6b EPd Glut':                         'Excitatory neuron',
    '114 COAa-PAA-MEA Barhl2 Glut':             'Excitatory neuron',
    '018 L2 IT PPP-APr Glut':                   'Excitatory neuron',
    '035 OB Eomes Ms4a15 Glut':                 'Excitatory neuron',
    '262 Pineal Crx Glut':                      'Excitatory neuron',
    '034 NP PPP Glut':                          'Excitatory neuron',
    '019 L2/3 IT PPP Glut':                     'Excitatory neuron',
    '033 NP SUB Glut':                          'Excitatory neuron',
    '115 MS-SF Bsx Glut':                       'Excitatory neuron',
    '016 CA1-ProS Glut':                        'Excitatory neuron',

    # GABAergic ‚Üí Inhibitory neuron
    '053 Sst Gaba':                             'Inhibitory neuron',
    '050 Lamp5 Lhx6 Gaba':                     'Inhibitory neuron',
    '046 Vip Gaba':                             'Inhibitory neuron',
    '052 Pvalb Gaba':                           'Inhibitory neuron',
    '049 Lamp5 Gaba':                          'Inhibitory neuron',
    '047 Sncg Gaba':                           'Inhibitory neuron',
    '056 Sst Chodl Gaba':                      'Inhibitory neuron',
    '041 OB-in Frmd7 Gaba':                    'Inhibitory neuron',
    '066 NDB-SI-ant Prdm12 Gaba':              'Inhibitory neuron',
    '061 STR D1 Gaba':                         'Inhibitory neuron',
    '062 STR D2 Gaba':                         'Inhibitory neuron',
    '064 STR-PAL Chst9 Gaba':                  'Inhibitory neuron',
    '042 OB-out Frmd7 Gaba':                   'Inhibitory neuron',
    '039 OB Meis2 Thsd7b Gaba':                'Inhibitory neuron',
    '080 CEA-AAA-BST Six3 Sp9 Gaba':           'Inhibitory neuron',
    '051 Pvalb chandelier Gaba':               'Inhibitory neuron',
    '044 OB Dopa-Gaba':                        'Inhibitory neuron',
    '045 OB-STR-CTX Inh IMN':                  'Inhibitory neuron',
    '065 IA Mgp Gaba':                         'Inhibitory neuron',
    '063 STR D1 Sema5a Gaba':                  'Inhibitory neuron',
    '048 RHP-COA Ndnf Gaba':                   'Inhibitory neuron',

    # ‚ÄúNN‚Äù suffix ‚Üí non‚Äêneuronal
    '327 Oligo NN':                            'Oligodendrocyte',
    '326 OPC NN':                              'Oligodendrocyte precursor cell (OPC)',
    '333 Endo NN':                             'Endothelial cell',
    '332 SMC NN':                              'Smooth muscle cell',
    '330 VLMC NN':                             'Vascular smooth muscle cell',
    '319 Astro-TE NN':                         'Astrocyte',
    '318 Astro-NT NN':                         'Astrocyte',
    '338 Lymphoid NN':                         'Lymphoid cell',
    '331 Peri NN':                             'Pericyte',
    '334 Microglia NN':                        'Microglia',
    '335 BAM NN':                              'Macrophage',
    '329 ABC NN':                              'Basal cell',

    # missing / fallback
    None:                                      'Unclassified',
}

adata_sub.obs['cell_type'] = adata_sub.obs['subclass'].map(new2existing)
adata_sub.obs['cell_type'] = adata_sub.obs['cell_type'].fillna('Unclassified')

In [8]:
coord_bins = {}
n_bins = 100  # e.g. 10
for coord in ('x','y','z'):
    vals_full = adata_sub.obs[coord].values.astype(float)
    vals = adata_sub.obs[coord].dropna().values.astype(float)
    coord_bins[coord] = np.linspace(vals.min(), vals.max(), n_bins)
    edges   = coord_bins[coord]
    bin_idxs = np.digitize(vals_full, edges, right=True)
    adata_sub.obs[f'<{coord}>'] = bin_idxs
    

In [9]:
cols = ['x','y','z','cell_type','region','organ']
mask = adata_sub.obs[cols].notnull().all(axis=1)
adata_sub = adata_sub[mask].copy()
adata_sub

AnnData object with n_obs √ó n_vars = 24881 √ó 42117
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label', 'cell_barcoder', 'barcoded_cell_sample_label', 'library_labelr', 'feature_matrix_label', 'entity', 'brain_section_label', 'library_method', 'region_of_interest_acronym', 'donor_label', 'donor_genotype', 'donor_sex', 'dataset_label', 'x', 'y', 'cluster_alias', 'neurotransmitter', 'class', 'subclass', 'supertype', 'cluster', 'neurotransmitter_color', 'class_color', 'subclass_color', 'supertype_color', 'cluster_color', 'region_of_interest_order', 'region_of_interest_color', 'z', 'n_genes', 'organ', 'region', 'cell_type', '<x>', '<y>', '<z>'
    uns: 'log1p'

In [10]:
adata_sub.write('../../tissue_generator/MERFISH_aging/sc_ref_xyz_scMulan_processed_n100.h5ad')

# Testing data_loader, finetuning

In [24]:
from data_util import get_generation_dataloader
import torch

# load your AnnData and meta_info
# adata     = sc.read_h5ad('my_mouse_data.h5ad')
meta_info = torch.load('../../tissue_generator/MERFISH_aging/4hierarchy_metainfo.pt')
new_tokens = ["<x>", "<y>", "<z>"]
meta_info["token_set"].extend(new_tokens)
tok = scMulanTokenizer(meta_info['token_set'])



In [25]:
import os
import argparse

import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader
import scanpy as sc
from tqdm import tqdm

# Model, config, and tokenizer
from model.model import MulanConfig, scMulanModel
from utils.hf_tokenizer import scMulanTokenizer
# Data loader for generation task
from data_util import get_generation_dataloader

ckp_path = 'ckpt/ckpt_scMulan.pt'
device = 'cuda'
meta_info = '../../tissue_generator/MERFISH_aging/4hierarchy_metainfo.pt'

ckp = torch.load(ckp_path, map_location='cpu')
gptconf = MulanConfig(**ckp['model_args'])
ModelClass = scMulanModel
model = ModelClass(gptconf)
device = torch.device(device)
model.to(device)
model.load_state_dict(ckp['model'], strict=False)
model.eval()
model.hidden_dim = ckp['model_args']['n_embd']
n_express_level = ckp['model_args'].get('expression_level', None)

# 2) Load and extend meta_info with coordinate tokens
meta_info = torch.load(meta_info)
new_tokens = ["<x>", "<y>", "<z>"]
meta_info['token_set'].extend(new_tokens)
print(len(meta_info['token_set']))

# 3) Initialize tokenizer and resize model embeddings/output
tokenizer = scMulanTokenizer(meta_info['token_set'])
tokenizer.add_special_tokens({'sep_token': meta_info.get('sep_token', '<SPToken1>')})
# Grow embedding & lm_head to accomodate new tokens
model.resize_token_embeddings(len(tokenizer))
model.resize_expression_embeddings(100)
model.config.expression_level = 100
model.to(device)

[32m2025-06-30 09:32:53.040[0m | [1mINFO    [0m | [36mmodel.model[0m:[36m__init__[0m:[36m129[0m - [1mnumber of parameters: 371.29M[0m
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.


2806


scMulanModel(
  (transformer): ModuleDict(
    (wte): Embedding(2806, 1120)
    (wee): Embedding(101, 1120)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-23): 24 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=1120, out_features=3360, bias=False)
          (c_proj): Linear(in_features=1120, out_features=1120, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=1120, out_features=4480, bias=False)
          (c_proj): Linear(in_features=4480, out_features=1120, bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=1120, out_features=2806, bias=False)
  (epx_head): Linear(in_features=1120, out_features=2806, bias=False)
  (epx_regressor): Sequ

In [26]:
import wandb

batch_size = 4
max_len = 500
no_shuffle = True
num_workers = 1
epochs = 1
lr = 0.0005
lambda_val = 1.0

adata = adata_sub
loader = get_generation_dataloader(
    adata      = adata,
    meta_info  = meta_info,
    batch_size = batch_size,
    max_len    = max_len,
    shuffle    = not no_shuffle,
    num_workers= num_workers,
    include_0s=False,
    
)
for step, batch in enumerate(tqdm(loader, desc=f"Epoch {epochs}")):
    input_ids      = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    # If using labels and expression targets
    labels         = batch.get('labels')
    x_expr    = batch.get('input_vals').to(device)
    # print(x_expr, x_expr.max(), x_expr.min())
    expr_target    = batch.get('target_vals')
    print(input_ids.shape, labels.shape, attention_mask.shape, x_expr.shape, expr_target.shape)
    print(input_ids.max(), labels.max(), attention_mask.max(), x_expr.max(), expr_target.max())
    print(input_ids.min(), labels.min(), attention_mask.min(), x_expr.min(), expr_target.min())
    break

Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.


‚úÖ adata passed check
üë∏ scMulan is ready


Epoch 1:   0%|          | 0/6221 [00:00<?, ?it/s]

torch.Size([4, 500]) torch.Size([4, 500]) torch.Size([4, 500]) torch.Size([4, 500]) torch.Size([4, 500])
tensor(2805, device='cuda:0') tensor(2801) tensor(1, device='cuda:0') tensor(66, device='cuda:0') tensor(66.)
tensor(0, device='cuda:0') tensor(-100) tensor(0, device='cuda:0') tensor(0, device='cuda:0') tensor(0.)


Epoch 1:   0%|          | 0/6221 [00:05<?, ?it/s]


In [29]:
# 4) Load AnnData and prepare DataLoader

import wandb

batch_size = 4
max_len = 500
no_shuffle = True
num_workers = 1
epochs = 1
lr = 0.0005
lambda_val = 1.0

# 0) Before training, initialize a run
# wandb.init(
#     project="scMulan-finetune",
#     name="conditional-gen-with-coords",
#     config={
#         "epochs":     epochs,
#         "batch_size": batch_size,
#         "lr":         lr,
#         "lambda_val": lambda_val,
#     },
#     dir='/scratch/skrieger',
# )



adata = adata_sub
loader = get_generation_dataloader(
    adata      = adata,
    meta_info  = meta_info,
    batch_size = batch_size,
    max_len    = max_len,
    shuffle    = not no_shuffle,
    num_workers= num_workers,
    include_0s = False,
)

# 5) Optimizer
optimizer = AdamW(model.parameters(), lr=lr)

# 6) Training loop
for epoch in range(1, epochs+1):
    model.train()
    total_loss, total_cls, total_exp = 0.0, 0.0, 0.0
    for step, batch in enumerate(tqdm(loader, desc=f"Epoch {epoch}")):
        input_ids      = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        # If using labels and expression targets
        labels         = batch.get('labels')
        x_expr    = batch.get('input_vals').to(device)
        # print(x_expr, x_expr.max(), x_expr.min())
        expr_target    = batch.get('target_vals')
        # print(x_expr.shape, input_ids.shape)
        if labels is not None:
            labels = labels.to(device)
        if expr_target is not None:
            expr_target = expr_target.to(device)

        # Forward: returns logits and losses
        logits_cls, logits_exp, loss, loss_cls, loss_exp = model(
            idx=input_ids,
            x_expr=x_expr,
            targets=labels,
            y_expr=expr_target,
            lambda_val=lambda_val,
            return_hidden=False,
        )

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss   += loss.item() if loss is not None else 0.0
        total_cls    += loss_cls.item() if loss_cls is not None else 0.0
        total_exp    += loss_exp.item() if loss_exp is not None else 0.0

        # wandb.log({
        #     "train/batch_loss":       loss.item(),
        #     "train/batch_loss_cls":   loss_cls.item(),
        #     "train/batch_loss_exp":   loss_exp.item(),
        #     "train/learning_rate":    optimizer.param_groups[0]["lr"],
        #     "train/epoch":            epoch,
        #     "train/step":             epoch * len(loader) + step,
        # })

    avg_loss = total_loss / len(loader)
    avg_cls  = total_cls  / len(loader)
    avg_exp  = total_exp  / len(loader)
    print(f"Epoch {epoch} ‚Äî total {avg_loss:.4f}, cls {avg_cls:.4f}, exp {avg_exp:.4f}")

    # wandb.log({
    #     "train/epoch_loss":     avg_loss,
    #     "train/epoch_loss_cls": avg_cls,
    #     "train/epoch_loss_exp": avg_exp,
    #     "train/epoch":          epoch,
    # })




    # Save epoch checkpoint
    # ckpt_file = os.path.join(args.output_dir, f"epoch{epoch}_model.pt")
    # torch.save({'model': model.state_dict(), 'model_args': ckp['model_args']}, ckpt_file)

# 7) Save final model, config, and tokenizer
# model_save_dir = args.output_dir
# # Save model weights and config
# model.save_pretrained(model_save_dir)
# MulanConfig(**ckp['model_args']).save_pretrained(model_save_dir)
# # Save tokenizer
# tokenizer.save_pretrained(model_save_dir)
# print(f"Finetuned artifacts written to {model_save_dir}")
wandb.finish()

Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.


‚úÖ adata passed check
üë∏ scMulan is ready


Epoch 1:   1%|          | 61/6221 [00:29<49:31,  2.07it/s] 


KeyboardInterrupt: 

# Testing inference

In [4]:
from data_util import get_generation_dataloader
import torch

# load your AnnData and meta_info
# adata     = sc.read_h5ad('my_mouse_data.h5ad')
adata_sub = sc.read_h5ad('../../tissue_generator/MERFISH_aging/sc_ref_xyz_scMulan_processed_4hierarchy.h5ad')
# meta_info = torch.load('scMulan/utils/meta_info.pt')
meta_info = torch.load('../../tissue_generator/MERFISH_aging/4hierarchy_metainfo.pt')
new_tokens = ["<x>", "<y>", "<z>"]
meta_info["token_set"].extend(new_tokens)
tok = scMulanTokenizer(meta_info['token_set'])



In [24]:
from scMulan import model_generate
ckp_path = '/compute/oven-0-13/skrieger/scMulan-output-4hierarchy-xyznoise/epoch30_model.pt'
scm = model_generate(ckp_path=ckp_path,
                    adata=adata_sub,
                    meta_info=meta_info,
                    )
run_and_metrics(scm)

MulanConfig(block_size=1550, vocab_size=2222, n_layer=24, n_head=20, n_embd=1120, dropout=0.0, bias=False, train_mode='pretrain', expression_level=100, ele=1, bin_edges=None)


[32m2025-07-01 12:27:06.960[0m | [1mINFO    [0m | [36mscMulan.model.model[0m:[36m__init__[0m:[36m129[0m - [1mnumber of parameters: 374.01M[0m


‚úÖ adata passed check
üë∏ scMulan is ready


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 100/100 [31:30<00:00, 18.90s/it]


Average Pearson r over 100 pairs: 0.3350
Binary reconstruction metrics: {'precision': 0.31939218855195395, 'recall': 0.6283355927067923, 'f1': 0.40855191910374794}
Mean cell-type likelihood (kNN): 0.1320


In [9]:
from scMulan import model_generate
ckp_path = '/compute/oven-0-13/skrieger/scMulan-output-4hierarchy-xyznoise/epoch70_model.pt'
scm = model_generate(ckp_path=ckp_path,
                    adata=adata_sub,
                    meta_info=meta_info,
                    use_kv_cache=True,
                    )
run_and_metrics(scm)

MulanConfig(block_size=1550, vocab_size=2222, n_layer=24, n_head=20, n_embd=1120, dropout=0.0, bias=False, train_mode='pretrain', expression_level=100, ele=1, bin_edges=None)


[32m2025-07-02 11:30:39.491[0m | [1mINFO    [0m | [36mscMulan.model.model_kvcache[0m:[36m__init__[0m:[36m132[0m - [1mnumber of parameters: 374.01M[0m


‚úÖ adata passed check
üë∏ scMulan is ready


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [02:46<00:00, 18.47s/it]


Average Pearson r over 100 pairs: 0.0993
Binary reconstruction metrics: {'precision': 0.2726604871166232, 'recall': 0.41473708880522686, 'f1': 0.3226855356741465}
Mean cell-type likelihood (kNN): 0.0014


In [26]:
from scMulan import model_generate
ckp_path = '/compute/oven-0-13/skrieger/scMulan-output-4hierarchy-xyznoise/epoch40_model.pt'
scm = model_generate(ckp_path=ckp_path,
                    adata=adata_sub,
                    meta_info=meta_info,
                    )
run_and_metrics(scm)

MulanConfig(block_size=1550, vocab_size=2222, n_layer=24, n_head=20, n_embd=1120, dropout=0.0, bias=False, train_mode='pretrain', expression_level=100, ele=1, bin_edges=None)


[32m2025-07-01 15:56:24.916[0m | [1mINFO    [0m | [36mscMulan.model.model[0m:[36m__init__[0m:[36m129[0m - [1mnumber of parameters: 374.01M[0m


‚úÖ adata passed check
üë∏ scMulan is ready


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 100/100 [22:58<00:00, 13.79s/it]


Average Pearson r over 100 pairs: 0.3157
Binary reconstruction metrics: {'precision': 0.299497379549725, 'recall': 0.6105880911972735, 'f1': 0.39075067444572753}
Mean cell-type likelihood (kNN): 0.0916


In [25]:
from scMulan import model_generate
ckp_path = '/compute/oven-0-13/skrieger/scMulan-output-4hierarchy/epoch30_model.pt'
scm = model_generate(ckp_path=ckp_path,
                    adata=adata_sub,
                    meta_info=meta_info,
                    )
run_and_metrics(scm)

MulanConfig(block_size=1550, vocab_size=2222, n_layer=24, n_head=20, n_embd=1120, dropout=0.0, bias=False, train_mode='pretrain', expression_level=100, ele=1, bin_edges=None)


[32m2025-07-01 12:58:42.980[0m | [1mINFO    [0m | [36mscMulan.model.model[0m:[36m__init__[0m:[36m129[0m - [1mnumber of parameters: 374.01M[0m


‚úÖ adata passed check
üë∏ scMulan is ready


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 100/100 [23:37<00:00, 14.17s/it]


Average Pearson r over 100 pairs: 0.1523
Binary reconstruction metrics: {'precision': 0.3155481210949526, 'recall': 0.6938560080330667, 'f1': 0.42575708570918896}
Mean cell-type likelihood (kNN): 0.0876


In [14]:
from scMulan import model_generate
ckp_path = '/compute/oven-0-13/skrieger/scMulan-output-4hierarchy/epoch70_model.pt'
scm = model_generate(ckp_path=ckp_path,
                    adata=adata_sub,
                    meta_info=meta_info,
                    use_kv_cache=True,
                    )
results = run_and_metrics(scm)

MulanConfig(block_size=1550, vocab_size=2222, n_layer=24, n_head=20, n_embd=1120, dropout=0.0, bias=False, train_mode='pretrain', expression_level=100, ele=1, bin_edges=None)


[32m2025-07-02 12:49:23.460[0m | [1mINFO    [0m | [36mscMulan.model.model_kvcache[0m:[36m__init__[0m:[36m132[0m - [1mnumber of parameters: 374.01M[0m


‚úÖ adata passed check
üë∏ scMulan is ready


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [01:21<00:00,  9.07s/it]


Average Pearson r over 100 pairs: 0.0631
Binary reconstruction metrics: {'precision': 0.27390757064508864, 'recall': 0.3099241418217609, 'f1': 0.2845713290154339}
Mean cell-type likelihood (kNN): 0.0012


In [15]:
results[0][6]

tensor([[ 606, 1285,   34, 2300, 1452, 1668, 2803, 2804, 2805, 1279, 2217, 2646,
         2676,  119, 2601, 1316, 2694, 2525, 1331,  958, 1856, 2132,  686, 1305,
         1540, 1568,  216, 2153, 1781,  930, 2480,  279,  563,  840,  529,  370,
         2178, 2278, 1269,  993,  105,  380, 1927,  479, 2092, 2402, 1415, 1011,
         1146, 1454,  387,   42, 1981, 1941, 2337, 2622, 1878, 2495, 2200, 2263,
          165,  219, 1706, 1110, 1032,  753, 2788, 2560,  368, 1705,  858, 2305,
         1601, 2216, 1127, 1239, 1259,  837, 2398, 2787,  230, 1162, 2054,  486,
         2111, 1972, 2759, 2520, 1916, 2705,  218,  872, 1821, 1628, 1852,  739,
         2399, 1751, 2164, 1227,  662,   76, 2291, 2118, 1031,   31, 2618, 1086,
         1936, 2083,  310, 1776, 2588,   30,  335, 1409, 1979,  577, 2106, 2208,
         1507, 1009, 2250,  420, 2265,  618, 2471,  822, 2264,  284,  383, 2613,
         2426, 1676, 2316, 1004,  312,  784, 2446, 1041, 2666, 1862, 1196, 1426,
         2312,  589, 1709,  

In [8]:
i = 3


row, gt, nv, gen_seq, gen_vals_binned, gen_vals, target_tokens, target_vals, target_real_vals = scm.generate_cell_genesis(
        idx=i,
        max_new_tokens= 500,
        top_k= 2,
        return_gt=True,
        # cheat_with_tokens=True,
        # cheat_with_expr=True,
    )


In [9]:
target_tokens

tensor([[ 606,  233, 2410, 1630,  870, 1115, 2803, 2804, 2805, 1279, 2217, 2646,
          894,  743,  711, 1884,  119, 2601, 1813, 2694, 2525,  958, 2059, 1856,
          532, 2643,  686, 1540, 2153,  930,  279, 1660,  529, 2178, 2278, 2011,
         1523, 1394,  380, 1894, 1095, 2623,  418, 2402, 1011, 1146, 1454,  387,
         1981, 1382,  178, 2495, 2392,  165, 2665,  316, 1627, 1110, 1032,  620,
         1765,  753, 2659, 2560,  368, 1780, 1985, 1961,  858, 2305,  605, 1601,
          407,  837, 2398, 1457, 2054, 1405, 2111,  192, 1916, 2705, 1104,  872,
         1345, 1628, 1214, 2399,   59, 1366, 1751, 2164,  633, 1771,   76, 2118,
         1031,   31, 2618, 1936,  310, 2736, 2588,   30,  335, 1979,  577, 1672,
         1507, 2250, 2265,  929,  618, 2791,  822,  211, 1831,  383, 1075,  455,
          795,   58, 1676,  431,  312,   15, 1171, 2446, 1041, 1862, 1196, 1426,
         2312,  589, 1964, 1709, 2492,  807, 2370,  900,  952,  298,  409, 1671,
         2156, 1677,  979,  

In [10]:
torch.tensor(gen_seq)

tensor([2217, 2217, 2269, 2269, 2646,  894,  894,  711, 2059, 2132,  686, 1008,
        2018, 1540, 2153, 2106, 2099,  387,  387,  529,  719, 1381, 2032,  563,
         101, 1568,  228, 2137,  286, 2053,  845,  383, 1011, 2152, 1961, 1117,
         165,  165, 1730,  618, 1315, 1991, 1991, 1309, 1394, 1394,  216, 2164,
         288, 1748,  380, 1226,   31, 1231, 2000, 1192,  418, 2643, 1211, 2007,
         607, 2278, 1786, 1981,  851, 2291,  807,  901,  395, 1207, 1750,  966,
        1190, 1693, 1136,  910, 2146, 1657,   92, 2366, 1890, 2160, 1165,  445,
        2402, 1476, 2118,  412, 1304, 1635, 1458,  472, 1517, 2713, 1053,  168,
        1788, 1729, 2061,  508, 1388, 2012, 2473,   76, 1563,  516, 1029, 2062,
        2580, 1523, 2201, 1709, 1567, 1717, 2017,  830, 1670, 1670,  532, 1617,
        1405,  255, 1241,   42, 1921, 1921, 1228,  753, 1557, 1261, 1261,   46,
        1964, 1688,  880, 1797,  950, 1327,  160, 1147, 1236, 1899, 2665,  621,
         202, 1269, 1993, 1916, 1800,  8

In [11]:
target_vals

tensor([[ 0,  0,  0,  0,  0,  0, 66, 57, 25,  0, 60, 24, 16, 10, 24, 16, 10, 10,
         10, 10, 54, 10, 21, 10, 24, 10, 10, 16, 39, 21, 10, 16, 34, 27, 10, 10,
         10, 39, 10, 10, 10, 10, 10, 16, 16, 16, 10, 16, 27, 10, 10, 10, 10, 10,
         10, 10, 10, 16, 21, 10, 10, 10, 10, 10, 10, 10, 16, 10, 37, 21, 16, 10,
         10, 10, 10, 16, 21, 10, 36, 10, 38, 16, 27, 10, 10, 16, 10, 37, 10, 10,
         10, 10, 10, 10, 10, 16, 16, 21, 31, 24, 10, 16, 21, 24, 10, 16, 24, 10,
         16, 10, 21, 10, 16, 21, 16, 21, 10, 10, 10, 10, 10, 10, 16, 10, 21, 10,
         24, 24, 10, 33, 10, 10, 16, 10, 10, 39, 10, 38, 10, 10, 24, 16, 10, 10,
         16, 10, 29, 10, 10, 16, 10, 16, 10, 10, 10, 10, 16, 33, 21, 21, 10, 10,
         42, 16, 10, 16, 21, 10, 10, 10, 10, 10, 21, 16, 16, 16, 10, 10, 10, 10,
         10, 10, 10, 10, 10, 16, 10, 21, 16, 10, 16, 16, 10, 16, 10, 29, 10, 10,
         10, 16, 10, 10, 10, 10, 10, 10, 21, 10, 21, 10, 10, 10, 16, 21, 10, 10,
         37, 10, 10, 21, 24,

In [12]:
torch.tensor(gen_vals_binned)

tensor([ 0,  0,  0,  0,  1,  8, 67, 25, 21, 34, 24, 23, 22, 36, 19, 23, 22, 22,
        20, 17, 14, 28, 30, 23, 18, 16, 20, 17, 19, 21, 34, 22, 21, 32, 24, 19,
        19, 20, 17, 16, 18, 17, 17, 17, 17, 16, 15, 22, 21, 23, 20,  1, 11, 19,
        21, 26, 19, 13, 19, 18, 20,  5, 21, 26, 19,  7, 22, 20,  4, 28, 21, 24,
        21, 23,  0,  3, 21, 25, 20,  4, 27, 23,  0,  2, 17, 19,  6, 31, 22, 22,
        22, 22, 23,  8, 25, 19,  0,  3, 22, 23, 21, 26, 23, 21, 26, 21, 27, 23,
        19, 23, 13, 21, 27, 23, 15, 22, 22, 23, 22, 23, 21, 26, 22, 22, 22, 22,
        22, 22, 22, 22, 22, 22, 22, 22, 22, 21, 26, 21, 27, 21, 25, 19, 22, 20,
         5, 39, 22, 22, 21, 27, 21, 26, 22, 21, 25, 20,  7, 32, 22, 21, 26, 21,
        26, 20, 18, 22, 23, 15, 22, 22, 22, 22, 21, 25, 19, 21, 22, 22, 21, 24,
        25, 18, 21, 24, 19, 19, 17, 18, 22, 20,  4, 32, 21, 24, 22, 21, 23, 20,
        19,  8, 21, 24, 22, 21, 21, 23, 15, 21, 24, 24, 22, 21, 23, 18, 19, 10,
        32, 21, 24, 25, 19, 20,  4, 29, 

In [13]:
target_real_vals

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 66.0000, 57.0000,
         25.0000,  0.0000,  4.5899,  2.0147,  1.4468,  0.9650,  2.0147,  1.4468,
          0.9650,  0.9650,  0.9650,  0.9650,  4.1895,  0.9650,  1.7706,  0.9650,
          2.0147,  0.9650,  0.9650,  1.4468,  3.0965,  1.7706,  0.9650,  1.4468,
          2.7487,  2.2109,  0.9650,  0.9650,  0.9650,  3.0965,  0.9650,  0.9650,
          0.9650,  0.9650,  0.9650,  1.4468,  1.4468,  1.4468,  0.9650,  1.4468,
          2.2109,  0.9650,  0.9650,  0.9650,  0.9650,  0.9650,  0.9650,  0.9650,
          0.9650,  1.4468,  1.7706,  0.9650,  0.9650,  0.9650,  0.9650,  0.9650,
          0.9650,  0.9650,  1.4468,  0.9650,  2.9377,  1.7706,  1.4468,  0.9650,
          0.9650,  0.9650,  0.9650,  1.4468,  1.7706,  0.9650,  2.8476,  0.9650,
          3.0202,  1.4468,  2.2109,  0.9650,  0.9650,  1.4468,  0.9650,  2.9377,
          0.9650,  0.9650,  0.9650,  0.9650,  0.9650,  0.9650,  0.9650,  1.4468,
          1.4468,  1.7706,  

In [14]:
torch.tensor(gen_vals)

tensor([0.0278, 0.0287, 0.0296, 0.2753, 0.3257, 0.8608, 5.1403, 2.0754, 1.8120,
        2.7270, 2.0285, 1.9322, 1.8735, 2.9021, 1.6294, 1.9277, 1.8401, 1.8274,
        1.7411, 1.4568, 1.2930, 2.2715, 2.4290, 1.9230, 1.5441, 1.4066, 1.6924,
        1.5070, 1.6043, 1.8131, 2.6946, 1.8596, 1.7890, 2.5617, 1.9736, 1.6115,
        1.6720, 1.6882, 1.5138, 1.4392, 1.5466, 1.4761, 1.4765, 1.4764, 1.4565,
        1.4177, 1.3193, 1.8396, 1.8095, 1.9335, 1.6770, 0.3269, 1.0581, 1.6136,
        1.7766, 2.1409, 1.6611, 1.1873, 1.6144, 1.5638, 1.7437, 0.6027, 1.7635,
        2.1379, 1.6526, 0.7691, 1.8888, 1.6739, 0.5137, 2.2705, 1.7804, 1.9960,
        1.7549, 1.8975, 0.1230, 0.4429, 1.7641, 2.0987, 1.6908, 0.5446, 2.2329,
        1.8923, 0.1563, 0.3896, 1.5213, 1.6275, 0.6726, 2.5340, 1.8350, 1.8465,
        1.8517, 1.8496, 1.8998, 0.8065, 2.0545, 1.6605, 0.1048, 0.4493, 1.8557,
        1.9186, 1.7897, 2.1251, 1.9385, 1.8083, 2.1503, 1.8045, 2.1825, 1.9015,
        1.6136, 1.9219, 1.1772, 1.8068, 

In [22]:
from tqdm import tqdm
rows = []
k=100
for i in tqdm(range(k)):
    row, gt, nv, gen_seq, gen_vals_binned, gen_vals = scm.generate_cell_genesis(
            idx=i,
            max_new_tokens= 500,
            top_k= 5,
        )
    rows.append(row)
get_metrics(scm, rows, k=k)

  7%|‚ñã         | 7/100 [02:50<37:50, 24.41s/it]


KeyboardInterrupt: 

In [15]:
# Model with no overfitting
from scipy.stats import pearsonr
import numpy as np

ground_truth = scm.adata.X.toarray()[:100].T

# rows: list of 1D arrays, each of shape (2000,)
# ground_truth: array shape (2000, 100)
assert len(rows) == ground_truth.shape[1]

rs = []
for j, pred_vec in enumerate(rows):
    true_vec = ground_truth[:, j]
    r, _ = pearsonr(pred_vec, true_vec)
    rs.append(r)

mean_r = np.mean(rs)
print(f"Average Pearson r over {len(rs)} pairs: {mean_r:.4f}")



Average Pearson r over 100 pairs: 0.4456


In [13]:
(pred_vec > 0).sum()

142

In [12]:
(true_vec > 0).sum()

153

In [23]:
adata_sub[:100].obs['subclass'].value_counts()

subclass
030 L6 CT CTX Glut           31
032 L5 NP CTX Glut           16
006 L4/5 IT CTX Glut          9
053 Sst Gaba                  7
007 L2/3 IT CTX Glut          6
046 Vip Gaba                  6
022 L5 ET CTX Glut            4
029 L6b CTX Glut              4
004 L6 IT CTX Glut            3
020 L2/3 IT RSP Glut          3
052 Pvalb Gaba                3
001 CLA-EPd-CTX Car3 Glut     1
050 Lamp5 Lhx6 Gaba           1
327 Oligo NN                  1
028 L6b/CT ENT Glut           1
003 L5/6 IT TPE-ENT Glut      1
021 L4 RSP-ACA Glut           1
005 L5 IT CTX Glut            1
333 Endo NN                   1
Name: count, dtype: int64

In [25]:
gt_cells = adata_sub[:100].copy()
glut_gt_cells = gt_cells[gt_cells.obs['subclass'] == '030 L6 CT CTX Glut'].copy()

In [27]:
pred_glut = pred[gt_cells.obs['subclass'] == '030 L6 CT CTX Glut']
pred_glut.shape

(31, 2000)