In [None]:
import os
import yaml
import copy

from random import choices

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import pandas as pd

import torch
import lightning as L

from sklearn.metrics import average_precision_score

import data
from train import NeighborSupervisedModule

import importlib
importlib.reload(data)

# Load model, make predictions, compute gene importance

In [37]:
model_base = '../../../cache/lightning_logs'

# model_dir = '2024_04_11_gut_in_enterocyte_out_goblet'
# input_type = 'Enterocyte'
# target_type = 'Goblet'

# model_dir = '2024_04_11_gut_in_goblet_out_enterocyte'
# input_type = 'Goblet'
# target_type = 'Enterocyte'

# model_dir = '2024_04_11_nsclc_in_fibroblast_out_tumor'
# input_type = 'fibroblast'
# target_type = 'tumor'

# model_dir = '2024_04_11_nsclc_in_tumor_out_fibroblast'
# input_type = 'tumor'
# target_type = 'fibroblast'

# model_dir = '2024_04_11_nsclc_in_macrophage_out_fibroblast'
# input_type = 'macrophage'
# target_type = 'fibroblast'

model_dir = '2024_04_11_nsclc_in_neutrophil_out_tumor'
input_type = 'neutrophil'
target_type = 'tumor'

with open(os.path.join(model_base, model_dir, 'hparams.yaml'), 'r') as stream:
    params_yaml = yaml.unsafe_load(stream) # Note: use safe_load instead if the yaml is not trusted.
params = params_yaml['params']
loaders, in_dim, out_dim, class_counts, class_names, gene_list = data.get_loaders(params)
model = NeighborSupervisedModule.load_from_checkpoint(os.path.join(model_base, model_dir, 'checkpoints', 'last.ckpt'), map_location='cpu')

# This notebook assumes we're working with logistic regression.
assert params['enc_depth'] == 0

# This notebook assumes we're using a single output class.
assert params['class_names_whitelist'] is not None
assert len(params['class_names_whitelist']) == 1

In [38]:
trainer = L.Trainer(accelerator='cpu', inference_mode=True)
preds = {}
targs = {}
ids = {}
for phase in ['train', 'val', 'test']:
    results = trainer.predict(model, loaders[phase])
    all_preds = []
    all_targs = []
    all_ids = []
    for batch in results:
        cur_preds, cur_targs, cur_ids = batch
        cur_preds = torch.sigmoid(cur_preds)
        all_preds.append(cur_preds.numpy())
        all_targs.append(cur_targs.numpy())
        try:
            # throws an error for nsclc val set, why?
            all_ids.append(cur_ids.numpy())
        except AttributeError as e:
            all_ids.append(cur_ids)
    preds[phase] = np.concatenate(all_preds, axis=0)
    targs[phase] = np.concatenate(all_targs, axis=0)
    ids[phase] = np.concatenate(all_ids, axis=0)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
  rank_zero_warn(


Predicting DataLoader 0: 100%|███████████████████████████████████████████████████████| 73/73 [00:00<00:00, 699.25it/s]
Predicting DataLoader 0: 100%|███████████████████████████████████████████████████████| 49/49 [00:00<00:00, 771.85it/s]
Predicting DataLoader 0: 100%|███████████████████████████████████████████████████████| 35/35 [00:00<00:00, 722.56it/s]


# Extract gene importance

In [39]:
gene_importance = trainer.model.model.head.weight.detach().numpy().ravel()
# Note: If we use raw values, then the most important genes will be those that are strongly enriched.
# We could also take the absolute value, which would rank genes as important if they're strongly enriched or depleted.

# Save results to disk

In [40]:
df_out = pd.DataFrame({'gene': gene_list, 'importance_score': gene_importance})
df_out.sort_values(by='importance_score', inplace=True, ascending=False)
df_out.to_csv(f'../../../cache/nb_pred_{model_dir}.csv', index=False)

# Visualize splits

In [None]:
def visualize_splits():
    all_coords, all_inputs, all_targs, all_ids, all_class_names, all_gene_list, all_focal_cell_type = data.parse_data(params)
    df = pd.DataFrame({'x': all_coords[:, 0], 'y': all_coords[:, 1], 'ids': all_ids})
    df['split'] = np.nan
    df.loc[df['ids'].isin(ids['train']), 'split'] = 'train'
    df.loc[df['ids'].isin(ids['val']), 'split'] = 'val'
    df.loc[df['ids'].isin(ids['test']), 'split'] = 'test'
    df = df[df['split'] != np.nan]
    sns.scatterplot(data=df, x='x', y='y', s=1.0, hue='split')
visualize_splits()

# Test if the model's test performance is significantly better than chance

In [None]:
# bootstrap null distribution:
n_bootstrap = int(1e4)
pop = np.ravel(targs['test'])
bootstrap_estimates = []
for i in range(n_bootstrap):
    cur_bootstrap_sample = choices(population=pop, k=len(pop))
    bootstrap_estimates.append(np.mean(cur_bootstrap_sample))
bootstrap_estimates = np.array(bootstrap_estimates)
# compute p value:
observed_value = average_precision_score(pop, np.ravel(preds['test']))
p = np.mean(bootstrap_estimates > observed_value)
plt.hist(bootstrap_estimates)
plt.axvline(observed_value, color='k', label='observed')
plt.axvline(np.mean(pop), color='r', label='null point estimate')
plt.legend()
print(f'p-value: {p}')