In [1]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"


In [2]:
import os
from pathlib import Path
from collections import Counter

import anndata
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix

import pollock.utils as utils
import pollock.explain as explain

In [3]:
model = utils.load_model('models/mouse_atlas_humanized_sharma')

  model.load_state_dict(torch.load(model_fp, map_location=torch.device('cpu')))


In [4]:
query = anndata.read_h5ad('data/U19_atlas.h5ad')

In [5]:
# note that if you are predicting a large dataset and don't need the pollock umap embeddings you can set make_umap to False to increase runtime speed
adata = utils.predict_adata(model, query, make_umap=True)
adata

2025-09-08 15:56:04,295 13882 genes overlap with model after filtering
2025-09-08 15:56:04,296 750 genes missing from dataset after filtering
  warn(
  sc.pp.normalize_per_cell(adata)
  normalize_per_cell(
2025-09-08 15:56:56,166 starting prediction of 53596 cells
  sf = self.adata.obs['size_factors'][idx]


AnnData object with n_obs × n_vars = 53596 × 14632
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'Study', 'Species', 'Cell_or_nuclei', 'Platform', 'percent.mt', 'integrated_snn_res.1', 'seurat_clusters', 'tSNE_1', 'tSNE_2', 'UMAP_1', 'UMAP_2', 'Level1', 'Donor', 'Batch', 'tech', 'integrated_snn_res.0.5', 'tissue', 'level1', 'integrated_snn_res.0.1', 'integrated_snn_res.3.4', 'cl.conserv', 'cl.HC.LC.split', 'cl.Ab.split.3', 'cl.Ab.split.4', 'diameter', 'drg_level', 'drg_location', 'batch', 'donor', 'cl.conserv_final', 'Strain', 'Sex', 'Age', 'Publication', 'Prep', 'Condition', 'Library', 'Level', 'sex', 'round', 'id', 'nn_score1', 'nn_score2', 'nn_score3', 'nn_score4', 'nn_score5', 'nn_score6', 'nn_score7', 'nn_score8', 'nn_score9', 'nn_score10', 'nn_score11', 'nn_score12', 'nn_score13', 'nn_score14', 'nn_score15', 'nn_score16', 'nn_score17', 'nn_score18', 'nn_score19', 'nn_score20', 'nn_score21', 'nn_score22', 'nn_score23', 'nn_score24', 'nn_score25', 'nn_score26', 'Level2_round

In [6]:
adata.obs

Unnamed: 0,orig.ident,nCount_RNA,nFeature_RNA,Study,Species,Cell_or_nuclei,Platform,percent.mt,integrated_snn_res.1,seurat_clusters,...,probability PEP1.3.b,probability PEP1.4,probability PEP2.1,probability PEP2.2,probability PEP3.1,probability PEP3.2,probability Proprioceptor,probability Rxfp1,probability TRPM8.1,probability TRPM8.2
TATATCCTCATCCACC-1_1_1,SeuratProject,271917.0,12362,U19_HMS,Human,Nuclei,10x,0.015446,9,11,...,0.055154,0.036455,0.018388,0.041872,0.041782,0.040649,0.034664,0.043205,0.051417,0.035218
CAAGAACCATAATCGT-1_1_1,SeuratProject,154655.0,11439,U19_HMS,Human,Nuclei,10x,0.086644,15,12,...,0.041955,0.033837,0.022907,0.058007,0.035952,0.036766,0.024779,0.068844,0.032704,0.038080
AATCAGGAGAATCGCT-1_1_1,SeuratProject,84579.0,8892,U19_HMS,Human,Nuclei,10x,0.001182,3,52,...,0.047191,0.053107,0.024810,0.036411,0.013827,0.030376,0.205478,0.033315,0.085612,0.024505
TTATTGCTCCGGAACC-1_1_1,SeuratProject,82656.0,9544,U19_HMS,Human,Nuclei,10x,0.008469,14,31,...,0.060390,0.040188,0.015362,0.045447,0.032136,0.028300,0.028608,0.079577,0.064570,0.044042
TTATAGCCAATTGAGA-1_1_1,SeuratProject,78058.0,8946,U19_HMS,Human,Nuclei,10x,0.111456,21,73,...,0.080214,0.045086,0.013247,0.054279,0.032135,0.033310,0.029986,0.050981,0.034096,0.039627
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
s8r2m.TTCTAACCGAAGCTTA.1_7_8,m4r2,11540.0,4309,Tavares-Ferriera_2022,Human,Cell,Visium,3.899480,0,10,...,0.022575,0.011615,0.086992,0.031041,0.070326,0.020169,0.012482,0.031562,0.056649,0.042578
s8r2m.TTCTGCTAGACTCCAA.1_7_8,m4r2,9430.0,3647,Tavares-Ferriera_2022,Human,Cell,Visium,3.467656,7,13,...,0.104490,0.071552,0.029654,0.006274,0.027147,0.031081,0.045845,0.038553,0.208063,0.026464
s8r2m.TTGAGCAGCCCACGGT.1_7_8,m4r2,9338.0,3691,Tavares-Ferriera_2022,Human,Cell,Visium,1.713429,4,4,...,0.031512,0.041420,0.016759,0.034662,0.133677,0.047249,0.009569,0.060645,0.036045,0.123704
s8r2m.TTGCTGATCATGTTCG.1_7_8,m4r2,11859.0,4116,Tavares-Ferriera_2022,Human,Cell,Visium,2.765832,6,16,...,0.068232,0.028896,0.027883,0.083718,0.073919,0.022316,0.021593,0.044299,0.078206,0.065156


In [7]:
# Save predictions as a .csv
adata.obs.to_csv('U19_predictions_full_model.csv')