In [None]:
import os
from pathlib import Path
from collections import Counter

import anndata
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import pollock
from pollock.models.model import PollockDataset, PollockModel, load_from_directory

In [None]:
# !pip install -e /pollock

In [None]:
%load_ext autoreload
%autoreload 2

## collect data for module training

In [None]:
data_map = {
    'aml': '/data/single_cell_classification/tumor/aml',
    'br': '/data/single_cell_classification/tumor/BR/raw/houxiang_brca/merged.h5ad',
    'ccrcc': '/data/single_cell_classification/tumor/CCRCC/yige/adata.h5ad',
    'cesc': '/data/single_cell_classification/tumor/CESC/cesc.h5ad',
    'gbm': '/data/single_cell_classification/tumor/gbm/gbm.h5ad',
    'hnscc': '/data/single_cell_classification/tumor/HNSC/processed.h5ad',
    'melanoma': '/data/single_cell_classification/tumor/melanoma/melanoma.h5ad',
    'myeloma': '/data/single_cell_classification/tumor/myeloma/processed.h5ad',
    'pdac': '/data/single_cell_classification/tumor/PDAC/pdac.h5ad',
    'pdac_caf': '/data/single_cell_classification/tumor/PDAC/pdac_caf_subtypes.h5ad',
    'zheng_sorted': '/data/single_cell_classification/scRNAseq_benchmark/Intra-dataset/\
Zheng sorted/zheng_sorted.h5ad',
    'snCCRCC_alla': '/data/single_cell_classification/immune/snCCRCC/merged.h5ad'
}

In [None]:
# adata_map = {}
# for k, fp in data_map.items():
#     print(k)
#     try:
#         adata_map[k] = sc.read_h5ad(fp)
#     except OSError:
#         print(f'{k} failed')

In [None]:
# ## check for cell_type key
# for k, adata in adata_map.items(): print(k, 'cell_type' in adata.obs.columns)

## iterate through and train modules

## 3. train pollock module

specify a location to save the trained pollock module

In [None]:
module_type = 'snCCRCC_alla'
run_name = f'{module_type}_v0.1.0'

In [None]:
module_save_filepath = f'/models/modules/{run_name}'
Path(module_save_filepath).mkdir(parents=True, exist_ok=True)

here we create a new anndata object from our processed anndata object

scanpy expects the raw counts data to be in the X attribute

In [None]:
train_adata = sc.read_h5ad(data_map[module_type])
train_adata

In [None]:
# Counter(train_adata.obs['fibroblast_subtype']).most_common()

In [None]:
# train_adata.obs['cell_type'] = [x if 'iCAF' not in x else 'iCAF' for x in train_adata.obs['fibroblast_subtype']]

In [None]:
# sc.settings.figdir = '/models/figures/'
sc.pl.umap(train_adata, color='cell_type', save='_snCCRCC_umap_all_cell_types.pdf')

In [None]:
# train_adata = train_adata[train_adata.obs['cell_type']!= 'Unknown']
# train_adata = train_adata[train_adata.obs['cell_type']!= 'CD34+CYTL1+']
# train_adata = train_adata[train_adata.obs['cell_type']!= 'Plasma_BM']
# train_adata

take a look at our cell counts

In [None]:
Counter(train_adata.obs['cell_type']).most_common()

initialize PollockDataset

In [None]:
pds = PollockDataset(train_adata, cell_type_key='cell_type', n_per_cell_type=1000,
                    dataset_type='training')

In [None]:
pm = PollockModel(pds.cell_types, pds.train_adata.shape[1], alpha=.0001, latent_dim=25)

In [None]:
pm.fit(pds, epochs=50, max_metric_batches=2, metric_epoch_interval=1,
      metric_n_per_cell_type=50)

In [None]:
pm.save(pds, module_save_filepath)

## 4. module performance

visualize the overlap between groundtruth vs predicted cell types

In [None]:
fig, axs = plt.subplots(1, 1, figsize=(15, 10))
confusion_matrix = pd.DataFrame(data=pm.summary['validation']['confusion_matrix'], columns=pm.class_names,
                  index=pm.class_names)
sns.heatmap(confusion_matrix, cmap='Blues', ax=axs)
plt.xlabel('Predicted')
plt.ylabel('Groundtruth')
plt.tight_layout()
plt.savefig('/models/figures/snCCRCC_model_training_accuracy.pdf')

plot loss and accuracy for training and validation sets during training

In [None]:
loss, label, epoch = [], [], []
for k in ['train_loss', 'validation_loss']:
    loss += pm.summary['history'][k]
    label += [k] * len(pm.summary['history'][k])
    epoch += list(range(len(pm.summary['history'][k])))
df = pd.DataFrame.from_dict({'label': label, 'loss': loss, 'epoch': epoch})

sns.lineplot(x='epoch', y='loss', data=df, hue='label')
plt.legend(bbox_to_anchor=(1.5, 1))

accuracy, label = [], []
for k in ['train_accuracy', 'validation_accuracy']:
    accuracy += pm.summary['history'][k]
    label += [k] * len(pm.summary['history'][k])
df = pd.DataFrame.from_dict({'label': label, 'accuracy': accuracy, 'epoch': epoch})

ax2 = plt.twinx()
sns.lineplot(x='epoch', y='accuracy', data=df, hue='label', ax=ax2, style='label',
                 dashes=[(3,2), (3,2)])
plt.legend(bbox_to_anchor=(2., 1))
# plt.tight_layout()
plt.savefig('/models/figures/training_history.pdf')

plot validation loss broken down by cell type

In [None]:
loss, label, epoch = [], [], []
for k in ['cell_type_val_loss']:
    for cell_type, vals in pm.summary['history'][k].items():
        loss += vals
        label += [cell_type] * len(vals)
        epoch += list(range(len(vals)))
    
df = pd.DataFrame.from_dict({
    'cell_type': label,
    'loss': loss,
    'epoch': epoch
})

sns.lineplot(x='epoch', y='loss', data=df, hue='cell_type')
plt.legend(bbox_to_anchor=(1.05, 1))

plot validation accuracy for each cell type during training

In [None]:
loss, label, epoch = [], [], []
for k in ['cell_type_val_f1']:
    for cell_type, vals in pm.summary['history'][k].items():
        loss += vals
        label += [cell_type] * len(vals)
        epoch += list(range(len(vals)))
    
df = pd.DataFrame.from_dict({
    'cell_type': label,
    'accuracy': loss,
    'epoch': epoch
})

sns.lineplot(x='epoch', y='accuracy', data=df, hue='cell_type')
plt.legend(bbox_to_anchor=(1.05, 1))
plt.tight_layout()
plt.savefig('/models/figures/cell_accuracy_training_history.pdf')

## 5. retrain module for optimal number of epochs

from the above plots it apperas that the optimal training time is ~45 epochs for this dataset

In [None]:
pm = PollockModel(pds.cell_types, pds.train_adata.shape[1], alpha=.0001, latent_dim=25)

In [None]:
pm.fit(pds, epochs=13, max_metric_batches=5, metric_epoch_interval=1,
      metric_n_per_cell_type=50)

In [None]:
pm.save(pds, module_save_filepath)

In [None]:
fig, axs = plt.subplots(1, 1, figsize=(15, 10))
confusion_matrix = pd.DataFrame(data=pm.summary['validation']['confusion_matrix'], columns=pm.class_names,
                  index=pm.class_names)
sns.heatmap(confusion_matrix, cmap='Blues', ax=axs)
plt.xlabel('Predicted')
plt.ylabel('Groundtruth')
plt.tight_layout()
plt.savefig('/models/figures/snCCRCC_model_training_accuracy.pdf')

In [None]:
loss, label, epoch = [], [], []
for k in ['train_loss', 'validation_loss']:
    loss += pm.summary['history'][k]
    label += [k] * len(pm.summary['history'][k])
    epoch += list(range(len(pm.summary['history'][k])))
df = pd.DataFrame.from_dict({'label': label, 'loss': loss, 'epoch': epoch})

sns.lineplot(x='epoch', y='loss', data=df, hue='label')
plt.legend(bbox_to_anchor=(1.5, 1))

accuracy, label = [], []
for k in ['train_accuracy', 'validation_accuracy']:
    accuracy += pm.summary['history'][k]
    label += [k] * len(pm.summary['history'][k])
df = pd.DataFrame.from_dict({'label': label, 'accuracy': accuracy, 'epoch': epoch})

ax2 = plt.twinx()
sns.lineplot(x='epoch', y='accuracy', data=df, hue='label', ax=ax2, style='label',
                 dashes=[(3,2), (3,2)])
plt.legend(bbox_to_anchor=(2., 1))

In [None]:
loss, label, epoch = [], [], []
for k in ['cell_type_val_loss']:
    for cell_type, vals in pm.summary['history'][k].items():
        loss += vals
        label += [cell_type] * len(vals)
        epoch += list(range(len(vals)))
    
df = pd.DataFrame.from_dict({
    'cell_type': label,
    'loss': loss,
    'epoch': epoch
})

sns.lineplot(x='epoch', y='loss', data=df, hue='cell_type')
plt.legend(bbox_to_anchor=(1.05, 1))

In [None]:
loss, label, epoch = [], [], []
for k in ['cell_type_val_f1']:
    for cell_type, vals in pm.summary['history'][k].items():
        loss += vals
        label += [cell_type] * len(vals)
        epoch += list(range(len(vals)))
    
df = pd.DataFrame.from_dict({
    'cell_type': label,
    'accuracy': loss,
    'epoch': epoch
})

sns.lineplot(x='epoch', y='accuracy', data=df, hue='cell_type')
plt.legend(bbox_to_anchor=(1.05, 1))

In [None]:
train_adata = sc.read_h5ad(data_map[module_type])

In [None]:
loaded_pds, loaded_pm = load_from_directory(train_adata, module_save_filepath, )

In [None]:
labels, probs, cell_type_probs = loaded_pm.predict_pollock_dataset(loaded_pds, labels=True, )
labels[:5], probs[:5]

In [None]:
train_adata.obs['annotated_cell_type'] = train_adata.obs['cell_type'].to_list()
train_adata.obs['predicted_cell_type'] = list(labels)
train_adata.obs['predicted_probablility'] = list(probs)

In [None]:
cell_embeddings = loaded_pm.get_cell_embeddings(loaded_pds.prediction_ds)
cell_embeddings.shape, cell_embeddings

In [None]:
train_adata.obsm['cell_embeddings'] = cell_embeddings
sc.pp.neighbors(train_adata, use_rep='cell_embeddings')
sc.tl.umap(train_adata)

In [None]:
sc.pl.umap(train_adata, color=['annotated_cell_type', 'predicted_cell_type', 'predicted_probablility'],
          frameon=False, ncols=1, save='_ccrcc.pdf')

In [None]:
sc.pl.umap(train_adata, color=['CD4', 'CD3G', 'CD8A'], color_map='Reds')

In [None]:
pm.summary['validation']