In [None]:
import os
import re
from collections import Counter

import pandas as pd
import scanpy as sc
import seaborn as sns

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
import pollock

In [None]:
DATA_DIR = '/data/single_cell_classification'
MODEL_DIR = '/models'

In [None]:
run_name = 'HTAN_breast_v9'
adata = sc.read_h5ad('/data/single_cell_classification/tumor/BR/raw/houxiang_brca/merged.h5ad')

malignant_cell_type = 'BR_Malignant'
cell_type_key = 'cell_type'
model_save_dir = os.path.join(MODEL_DIR, run_name)

In [None]:
run_name = 'pdac_caf_states_v1'
adata = sc.read_h5ad('/data/single_cell_classification/tumor/PDAC/pdac_caf_subtypes.h5ad')

cell_type_key = 'fibroblast_subtype'
model_save_dir = os.path.join(MODEL_DIR, run_name)

In [None]:
run_name = 'melanoma_v1'
adata = sc.read_h5ad('/data/single_cell_classification/tumor/melanoma/melanoma.h5ad')

cell_type_key = 'cell_type'
model_save_dir = os.path.join(MODEL_DIR, run_name)

In [None]:
sorted(set(adata.obs['sample_id']))

In [None]:
adata

In [None]:
counts = Counter(adata.obs[cell_type_key])
counts.most_common()

In [None]:
pds = pollock.PollockDataset(adata.copy(), cell_type_key=cell_type_key, n_per_cell_type=500, batch_size=64,
                    dataset_type='training', min_genes=10, min_cells=3, mito_threshold=None,
                    max_n_genes=None, log=True, cpm=False, min_disp=None)

In [None]:
pm = pollock.PollockModel(pds.cell_types, pds.train_adata.shape[1], alpha=.001, latent_dim=100)

In [None]:
pm.fit(pds, epochs=40)

In [None]:
pm.save(pds, model_save_dir)

In [None]:
cdf = pd.DataFrame(data=pm.summary['validation']['confusion_matrix'], columns=pm.class_names,
                  index=pm.class_names)
print(pm.summary['validation']['metrics']['accuracy'])
sns.heatmap(cdf, cmap='Blues')

In [None]:
pm.summary['validation']['metrics']['Fibroblast']

In [None]:
0.06629834254143646

In [None]:
l_pds, l_pm = pollock.load_from_directory(adata, model_save_dir)

In [None]:
labels, label_prob, all_probs = l_pm.predict_pollock_dataset(l_pds, labels=True, )

In [None]:
X_umap = l_pm.get_umap_cell_embeddings(l_pds.prediction_ds)
X_umap

In [None]:
X_umap.shape

In [None]:
l_pds.prediction_adata.obsm['X_umap'] = X_umap
l_pds.prediction_adata.obs['predicted_cell_type'] = labels

In [None]:
l_pm.summary['training'].keys()

In [None]:
sc.pl.umap(l_pds.prediction_adata, color=['cell_type', cell_type_key, 'predicted_cell_type', 'sample_id'], ncols=1)

In [None]:
sc.pl.umap(l_pds.prediction_adata, color=['ACTA2', 'PF4', 'PROM1', 'HLA-DRA'], ncols=1, color_map='Reds',)

In [None]:
pdac_adata = sc.read_h5ad('/data/single_cell_classification/tumor/PDAC/pdac.h5ad')
l_pds, l_pm = pollock.load_from_directory(pdac_adata, model_save_dir)
pdac_adata = pdac_adata[l_pds.prediction_adata.obs.index]
labels, label_prob, all_probs = l_pm.predict_pollock_dataset(l_pds, labels=True, )
X_umap = l_pm.get_umap_cell_embeddings(l_pds.prediction_ds)
pdac_adata.obsm['X_umap'] = X_umap
pdac_adata.obs['predicted_cell_type'] = labels

In [None]:
sc.pl.umap(pdac_adata[pdac_adata.obs['cell_type']=='Fibroblast'],
           color=['cell_type', 'predicted_cell_type', 'sample'], ncols=1)

In [None]:
ccrcc_adata = sc.read_h5ad('/data/single_cell_classification/tumor/CCRCC/yige/adata.h5ad')
l_pds, l_pm = pollock.load_from_directory(ccrcc_adata, model_save_dir)
ccrcc_adata = ccrcc_adata[l_pds.prediction_adata.obs.index]
labels, label_prob, all_probs = l_pm.predict_pollock_dataset(l_pds, labels=True, )
X_umap = l_pm.get_umap_cell_embeddings(l_pds.prediction_ds)
ccrcc_adata.obsm['X_umap'] = X_umap
ccrcc_adata.obs['predicted_cell_type'] = labels

In [None]:
sc.pl.umap(ccrcc_adata[ccrcc_adata.obs['cell_type']=='Fibroblasts'],
           color=['cell_type', 'predicted_cell_type', 'sample_id'], ncols=1)

In [None]:
sc.pl.umap(ccrcc_adata, color=['ACTA2', 'PF4', 'PROM1', 'HLA-DRA'], ncols=1, color_map='Reds',)

In [None]:
pm.summary['validation']['metrics']['accuracy']