# Allen Brain Dataset


In [3]:
import pandas as pd
import numpy as np
import torch
import torch.onnx
import anndata
import onnx
import onnxruntime as ort
from scsims import SIMS
import scanpy as sc

## Validate model and training dataset dimensions


In [4]:
sims = SIMS(
    weights_path="checkpoints/allen-celltypes+human-cortex+various-cortical-areas.ckpt",
    map_location=torch.device("cpu"),
)

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


Initializing network
Initializing explain matrix


In [5]:
adata_raw = anndata.read_h5ad(
    "checkpoints/allen-celltypes+human-cortex+various-cortical-areas.h5ad"
)

adata = adata_raw.copy()
sc.pp.filter_cells(adata, min_genes=100)
sc.pp.filter_genes(adata, min_cells=3)
sc.pp.normalize_total(adata)
### Logarithmizing the data
sc.pp.log1p(adata)
sc.pp.scale(adata)

In [6]:
print(f"Model input shape: {sims.model.input_dim}")
print(f"Raw training h5ad num genes: {adata_raw.n_vars}")
print(f"Processed training h5ad num genes: {adata.n_vars}")

Model input shape: 48119
Raw training h5ad num genes: 50281
Processed training h5ad num genes: 48119


## Compare python vs. onnx vs. ground truth labels

In [7]:
adata_subset = adata[:100, :]
predictions = sims.predict(adata_subset)

Parsing inference data...


100%|██████████| 4/4 [00:28<00:00,  7.02s/it]


Predictions: ['Exclude' 'VIP' 'LAMP5' 'LAMP5' 'VIP' 'VIP' 'LAMP5' 'IT' 'IT' 'VIP']
Ground Truth: ['Exclude', 'VIP', 'LAMP5', 'LAMP5', 'VIP', 'VIP', 'LAMP5', 'IT', 'IT', 'VIP']
Categories (14, object): ['Astrocyte', 'Exclude', 'IT', 'L5/6 IT Car3', ..., 'PAX6', 'PVALB', 'SST', 'VIP']


In [16]:
session = ort.InferenceSession(
    "public/models/allen-celltypes+human-cortex+various-cortical-areas.onnx"
)

onnx_predictions, onnx_encodings = session.run(
    ["topk_indices", "encoding"], {"input": adata_subset.X}
)

In [17]:
with open(
    "public/models/allen-celltypes+human-cortex+various-cortical-areas.classes", "r"
) as f:
    classes = [line.strip() for line in f]

In [29]:
print(f"Python: {predictions.pred_0.values[0:10]}")
print(f"ONNX: {[classes[p[0]] for p in onnx_predictions[0:10]]}")
print(f"Ground: {adata_subset.obs.subclass_label.values[0:10]}")

Python: ['Exclude' 'VIP' 'LAMP5' 'LAMP5' 'VIP' 'VIP' 'LAMP5' 'IT' 'IT' 'VIP']
ONNX: ['Exclude', 'VIP', 'LAMP5', 'LAMP5', 'VIP', 'VIP', 'LAMP5', 'IT', 'IT', 'VIP']
Ground: ['Exclude', 'VIP', 'LAMP5', 'LAMP5', 'VIP', 'VIP', 'LAMP5', 'IT', 'IT', 'VIP']
Categories (14, object): ['Astrocyte', 'Exclude', 'IT', 'L5/6 IT Car3', ..., 'PAX6', 'PVALB', 'SST', 'VIP']
