In [None]:
import os
import re
from collections import Counter

import pandas as pd
import scanpy as sc
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
import pollock
from pollock.models.model import PollockDataset, PollockModel

In [None]:
DATA_DIR = '/data/single_cell_classification'
MODEL_DIR = '/models'

In [None]:
run_name = 'snCCRCC_immune_annotated_v0.1.0_v2'
adata = sc.read_h5ad('/data/single_cell_classification/immune/snCCRCC/merged.h5ad')

malignant_cell_type = 'Tumor'
cell_type_key = 'cell_type'
model_save_dir = os.path.join(MODEL_DIR, run_name)

adata = adata[adata.obs['cell_type']!='Mixed myeloid/lymphoid']
adata

In [None]:
counts = Counter(adata.obs[cell_type_key])
counts.most_common()

In [None]:
# !mkdir /models/figures

In [None]:
sns.countplot(adata.obs['cell_type'], color=sns.color_palette()[0])
plt.yscale('log')
plt.xticks(rotation=90)
plt.tight_layout()
plt.savefig('/models/figures/snCCRCC_cell_counts.pdf')

In [None]:
sc.settings. = '/models/figures/'

In [None]:
sc.pl.umap(adata, color=['cell_type'], save='_annotated_cell_types_snCCRCC.pdf')

In [None]:
pds = PollockDataset(adata.copy(), cell_type_key=cell_type_key, n_per_cell_type=1000,
                    dataset_type='training')

In [None]:
pm = PollockModel(pds.cell_types, pds.train_adata.shape[1], alpha=.0001, latent_dim=25)

In [None]:
pm.fit(pds, epochs=13)

In [None]:
pm.save(pds, model_save_dir)

In [None]:
cdf = pd.DataFrame(data=pm.summary['validation']['confusion_matrix'], columns=pm.class_names,
                  index=pm.class_names)
print(pm.summary['validation']['metrics']['accuracy'])
sns.heatmap(cdf, cmap='Blues')

In [None]:
pm.summary['model_parameters']

In [None]:
pm.summary['history'].keys()

In [None]:
loss, label, epoch = [], [], []
for k in ['train_loss', 'validation_loss']:
    loss += pm.summary['history'][k]
    label += [k] * len(pm.summary['history'][k])
    epoch += list(range(len(pm.summary['history'][k])))
# pm.summary['history']['train_losses'] + pm.summary['history']['validation_losses']
df = pd.DataFrame.from_dict({
    'label': label,
    'loss': loss,
    'epoch': epoch
})

sns.lineplot(x='epoch', y='loss', data=df, hue='label')

In [None]:
loss, label, epoch = [], [], []
for k in ['cell_type_train_loss', 'cell_type_val_loss']:
    for cell_type, vals in pm.summary['history'][k].items():
        loss += vals
        label += [cell_type] * len(vals)
        epoch += list(range(len(vals)))
    
df = pd.DataFrame.from_dict({
    'label': label,
    'loss': loss,
    'epoch': epoch
})
palette = [c for ls in (sns.color_palette('tab20'), sns.color_palette('tab20b'), sns.color_palette('tab20c'))
           for c in ls]
print(len(palette))
sns.lineplot(x='epoch', y='loss', data=df, hue='label')
plt.legend(bbox_to_anchor=(1.05, 1))

In [None]:
loss, label, epoch = [], [], []
for k in ['cell_type_val_f1']:
    for cell_type, vals in pm.summary['history'][k].items():
        loss += vals
        label += [cell_type] * len(vals)
        epoch += list(range(len(vals)))
    
df = pd.DataFrame.from_dict({
    'label': label,
    'F1 score': loss,
    'epoch': epoch
})

sns.lineplot(x='epoch', y='F1 score', data=df, hue='label')
plt.legend(bbox_to_anchor=(1.05, 1))