In [1]:
import sys
sys.path.insert(0, '/tf/pollock')

In [2]:
%load_ext autoreload

In [3]:
import logging
import os
import random
from collections import Counter
from importlib import reload
import time

import anndata
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
# import anndata2ri


import pollock
from pollock import PollockDataset, PollockModel, load_from_directory
# import pollock.models.analysis as pollock_analysis

  from pandas.core.index import RangeIndex


In [4]:
import tensorflow as tf
from tensorflow.keras import layers

tf.keras.backend.clear_session()  # For easy reset of notebook state.

In [5]:
%autoreload 2

In [6]:
DATA_DIR = '/data/single_cell_classification'
MODEL_DIR = '/models'

## expression tables

In [None]:
run_name = 'sc_brca'

expression_fp = os.path.join(DATA_DIR, 'tumor', 'BR', 'raw', 'houxiang_brca',
                            'breast_counts_matrix.tsv')
label_fp = os.path.join(DATA_DIR, 'tumor', 'BR', 'raw', 'houxiang_brca',
                            'breast_metadata.tsv')

model_save_dir = os.path.join(MODEL_DIR, run_name)

sample_column = 'Genes'
sep='\t'
cell_type_key = 'cell_type'

In [None]:
run_name = 'sc_hnsc'

expression_fp = os.path.join(DATA_DIR, 'tumor', 'HNSC', 'raw', 'hnsc_yize',
                            'Assigned_WUHN_15_processed_cluster_review_gene_expression_format_2.tsv')
label_fp = os.path.join(DATA_DIR, 'tumor', 'HNSC', 'raw', 'hnsc_yize',
                            'Assigned_WUHN_15_processed_cluster_review_cell_metadata_format_2.tsv')

model_save_dir = os.path.join(MODEL_DIR, run_name)

sample_column = 'Genes'
sep='\t'
cell_type_key = 'cell_type'

In [None]:
run_name = 'sc_cesc'

expression_fp = os.path.join(DATA_DIR, 'tumor', 'CESC', 'raw', 'cesc_yize_v2',
                            'Assigned_CESC_9_processed_cluster_review_final_gene_expression_format.tsv')
label_fp = os.path.join(DATA_DIR, 'tumor', 'CESC', 'raw', 'cesc_yize_v2',
                            'Assigned_CESC_9_processed_cluster_review_final_cell_metadata_format.tsv')

model_save_dir = os.path.join(MODEL_DIR, run_name)

sample_column = 'Genes'
sep='\t'
cell_type_key = 'cell_type'

## H5 object

In [None]:
run_name = 'sc_pdac'

expression_fp = os.path.join(DATA_DIR, 'tumor', 'PDAC', 'pdac.h5ad')

model_save_dir = os.path.join(MODEL_DIR, run_name)

sample_column = 'Genes'
sep='\t'
cell_type_key = 'cell_type'

In [None]:
run_name = 'sn_ccrcc'

expression_fp = os.path.join(DATA_DIR, 'tumor', 'CCRCC', 'yige',
                            'adata.h5')
label_fp = os.path.join(DATA_DIR, 'tumor', 'CCRCC', 'yige',
                            'metadata.tsv')

model_save_dir = os.path.join(MODEL_DIR, run_name)

sample_column = 'Genes'
sep='\t'
cell_type_key = 'cell_type'

In [None]:
expression_df = pd.read_csv(expression_fp, sep=sep)
expression_df

In [None]:
expression_df = expression_df.set_index('Genes')
expression_df = expression_df.transpose()
expression_df

In [None]:
expression_df.to_hdf(expression_fp.replace('.tsv', '.h5'), 'df')

In [None]:
expression_df = pd.read_hdf(expression_fp.replace('.tsv', '.h5'), 'df')
expression_df

In [None]:
adata = anndata.read_h5ad(expression_fp)
adata

In [None]:
label_df = pd.read_csv(
    label_fp,
    sep=sep
    )
label_df = label_df.set_index('cell_id')
label_df = label_df.loc[expression_df.index]
label_df

In [None]:
adata = anndata.AnnData(X=expression_df.values, obs=label_df)
adata.obs.index = expression_df.index
adata.var.index = expression_df.columns
adata

In [7]:
run_name = 'sc_master'

expression_table_fps = [
    (os.path.join(DATA_DIR, 'tumor', 'BR', 'raw', 'houxiang_brca',
                            'breast_counts_matrix.h5'),
     os.path.join(DATA_DIR, 'tumor', 'BR', 'raw', 'houxiang_brca',
                            'breast_metadata.tsv')),
    (os.path.join(DATA_DIR, 'tumor', 'HNSC', 'raw', 'hnsc_yize',
                            'Assigned_WUHN_15_processed_cluster_review_gene_expression_format_2.h5'),
     os.path.join(DATA_DIR, 'tumor', 'HNSC', 'raw', 'hnsc_yize',
                            'Assigned_WUHN_15_processed_cluster_review_cell_metadata_format_2.tsv')),
    (os.path.join(DATA_DIR, 'tumor', 'CESC', 'raw', 'cesc_yize_v2',
                            'Assigned_CESC_9_processed_cluster_review_final_gene_expression_format.h5'),
    os.path.join(DATA_DIR, 'tumor', 'CESC', 'raw', 'cesc_yize_v2',
                            'Assigned_CESC_9_processed_cluster_review_final_cell_metadata_format.tsv')),
#     (os.path.join(DATA_DIR, 'tumor', 'CCRCC', 'yige',
#                             'adata.h5'),
#     os.path.join(DATA_DIR, 'tumor', 'CCRCC', 'yige',
#                             'metadata.tsv'))
]

anndata_fps = [
    os.path.join(DATA_DIR, 'tumor', 'PDAC', 'pdac.h5ad'),
]

model_save_dir = os.path.join(MODEL_DIR, run_name)
cell_type_key = 'cell_type'
sep='\t'

In [None]:
adata = None
for expression_fp, metadata_fp in expression_table_fps:
    expression_df = pd.read_hdf(expression_fp, 'df')
    label_df = pd.read_csv(metadata_fp, sep=sep)
    label_df = label_df.set_index('cell_id')
    label_df = label_df.loc[expression_df.index]
    
    temp = anndata.AnnData(X=expression_df.values, obs=label_df)
    temp.obs.index = expression_df.index
    temp.var.index = expression_df.columns

    if adata is None:
        adata = temp.copy()
    else:
        adata = adata.concatenate(temp)
adata

In [None]:
adata

In [None]:
to_add = anndata.read_h5ad(anndata_fps[0])
to_add

In [None]:
adata = adata.concatenate(to_add)
adata

In [None]:
Counter(adata.obs[cell_type_key]).most_common()

In [None]:
cell_type_map = {
    'B': 'B-cells',
    'BR_Malignant': 'Malignant',
    'CD4+T': 'CD4+ T-cells',
    'CD4_T': 'CD4+ T-cells',
    'CD8+T': 'CD8+ T-cells',
    'CD8_T': 'CD8+ T-cells',
    'CESC_Malignant/Epithelial_1': 'Malignant',
    'CESC_Malignant/Epithelial_2': 'Malignant',
    'CESC_Malignant/Epithelial_3': 'Malignant',
    'CESC_Malignant/Epithelial_4': 'Malignant',
    'Endothelial': 'Endothelial cells',
    'Fibroblast': 'Fibroblasts',
    'HNSC_Malignant/Epithelial': 'Malignant',
    'Macrophage': 'Macrophage/Monocyte',
    'Macrophages': 'Macrophage/Monocyte',
    'Monocyte': 'Macrophage/Monocyte',
    'Mast': 'Mast cells',
    'NK': 'NK cells',
    'Plasma': 'Plasma cells',
    'Treg': 'Tregs'
}

In [None]:
sorted(set(adata.obs[cell_type_key]))

In [None]:
adata.obs[cell_type_key] = [cell_type_map.get(x, x) for x in adata.obs[cell_type_key]]
sorted(set(adata.obs[cell_type_key]))

In [None]:
Counter(adata.obs[cell_type_key]).most_common()

In [None]:
adata = adata[adata.obs[cell_type_key]!='Unknown']
adata = adata[adata.obs[cell_type_key]!='Tnaive']
adata = adata[adata.obs[cell_type_key]!='Epithelial']
adata

In [None]:
adata.write_h5ad('/data/single_cell_classification/tumor/master/master.h5ad')

In [8]:
adata = anndata.read_h5ad('/data/single_cell_classification/tumor/master/master.h5ad')

In [None]:
counts = Counter(adata.obs[cell_type_key])
counts.most_common()

In [None]:
## get rid of unknowns
adata = adata[adata.obs[cell_type_key]!='Unknown']
adata

In [None]:
# adata.obs[cell_type_key] = ['CESC_Malignant' if 'Malignant' in x else x for x in adata.obs[cell_type_key]]
# counts = Counter(adata.obs[cell_type_key])
# counts.most_common()

In [9]:
pds = PollockDataset(adata.copy(), cell_type_key=cell_type_key, n_per_cell_type=1000, batch_size=128,
                    dataset_type='training', min_genes=200, min_cells=3, mito_threshold=None,
                    max_n_genes=None, log=True, cpm=False, min_disp=.2)

2020-03-23 23:36:37,975 normalizing counts for model training
2020-03-23 23:36:37,976 filtering by min genes: 200
2020-03-23 23:36:43,303 filtering by min cells: 3
2020-03-23 23:36:51,731 loging data
2020-03-23 23:36:57,244 filtering with dispersion 0.2
2020-03-23 23:37:08,683 remaining after min disp: 4894
2020-03-23 23:37:10,013 scaling data
2020-03-23 23:40:26,861 scaling to between 0-1
2020-03-23 23:42:00,874 creating datasets


In [None]:
Counter(pds.val_adata.obs[cell_type_key]).most_common()

In [10]:
pm = PollockModel(pds.cell_types, pds.train_adata.shape[1], alpha=.00001)

In [11]:
pm.fit(pds, epochs=25)

2020-03-23 23:44:23,800 epoch: 1, val loss: 2.7712647914886475
2020-03-23 23:44:35,577 epoch: 2, val loss: 2.67366361618042
2020-03-23 23:44:45,887 epoch: 3, val loss: 2.6386489868164062
2020-03-23 23:44:56,393 epoch: 4, val loss: 2.625145196914673
2020-03-23 23:45:08,075 epoch: 5, val loss: 2.6201226711273193
2020-03-23 23:45:18,480 epoch: 6, val loss: 2.618171215057373
2020-03-23 23:45:27,595 epoch: 7, val loss: 2.6181745529174805
2020-03-23 23:45:37,502 epoch: 8, val loss: 2.6185436248779297
2020-03-23 23:45:47,268 epoch: 9, val loss: 2.618565797805786
2020-03-23 23:45:56,922 epoch: 10, val loss: 2.618558645248413
2020-03-23 23:46:07,426 epoch: 11, val loss: 2.619157075881958
2020-03-23 23:46:17,278 epoch: 12, val loss: 2.619549512863159
2020-03-23 23:46:27,579 epoch: 13, val loss: 2.619530439376831
2020-03-23 23:46:37,784 epoch: 14, val loss: 2.6197454929351807
2020-03-23 23:46:47,803 epoch: 15, val loss: 2.6200525760650635
2020-03-23 23:46:57,986 epoch: 16, val loss: 2.62011456489

1.0
0.09127391067766104


In [12]:
pm.save(pds, model_save_dir)

  d['descr'] = dtype_to_descr(array.dtype)
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [13]:
pm.summary['validation']

{'metrics': {'Acinar': {'precision': 0.0,
   'recall': 0.0,
   'f1-score': 0.0,
   'support': 51},
  'B-cells': {'precision': 0.02845429254722046,
   'recall': 0.1153770037601425,
   'f1-score': 0.04565030146425495,
   'support': 5053},
  'CD4+ T-cells': {'precision': 0.10895936463036501,
   'recall': 0.09957247961821436,
   'f1-score': 0.10405465077014989,
   'support': 20116},
  'CD8+ T-cells': {'precision': 0.0819542053956019,
   'recall': 0.09335054874112331,
   'f1-score': 0.08728194603730308,
   'support': 15490},
  'DC': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 22},
  'Ductal': {'precision': 0.03997187364777153,
   'recall': 0.1110443275732532,
   'f1-score': 0.05878375691047211,
   'support': 6655},
  'Endothelial cells': {'precision': 0.08284991034700981,
   'recall': 0.12041082241128229,
   'f1-score': 0.09815989253022586,
   'support': 13047},
  'Erythrocyte': {'precision': 0.0,
   'recall': 0.0,
   'f1-score': 0.0,
   'support': 30},
  'Fibroblasts': {'

In [14]:
pm.summary['validation']['metrics']['accuracy']

0.08950546230798209

In [15]:
pm.summary['training']['metrics']['accuracy']

0.08953432549207874

In [16]:
l_pds, l_pm = load_from_directory(adata, model_save_dir)

2020-03-23 23:51:17,748 normalizing counts for model training
2020-03-23 23:51:17,831 loging data
  view_to_actual(data)
2020-03-23 23:51:21,829 scaling data


In [17]:
labels, probs = l_pm.predict_pollock_dataset(l_pds, labels=True, )
labels

('Macrophage/Monocyte',
 'NK cells',
 'Tregs',
 'B-cells',
 'B-cells',
 'Endothelial cells',
 'Ductal',
 'CD4+ T-cells',
 'Fibroblasts',
 'Plasma cells',
 'Endothelial cells',
 'Tregs',
 'Ductal',
 'Macrophage/Monocyte',
 'Mast cells',
 'CD8+ T-cells',
 'Ductal',
 'Endothelial cells',
 'CD4+ T-cells',
 'Fibroblasts',
 'Mast cells',
 'CD8+ T-cells',
 'Fibroblasts',
 'B-cells',
 'Malignant',
 'CD8+ T-cells',
 'Fibroblasts',
 'B-cells',
 'Ductal',
 'Macrophage/Monocyte',
 'NK cells',
 'CD4+ T-cells',
 'Macrophage/Monocyte',
 'Plasma cells',
 'CD8+ T-cells',
 'CD4+ T-cells',
 'Malignant',
 'Malignant',
 'B-cells',
 'Plasma cells',
 'Ductal',
 'Ductal',
 'Macrophage/Monocyte',
 'Endothelial cells',
 'CD8+ T-cells',
 'Plasma cells',
 'NK cells',
 'Macrophage/Monocyte',
 'B-cells',
 'CD8+ T-cells',
 'Fibroblasts',
 'Fibroblasts',
 'CD4+ T-cells',
 'CD4+ T-cells',
 'Ductal',
 'NK cells',
 'CD4+ T-cells',
 'B-cells',
 'Malignant',
 'NK cells',
 'CD4+ T-cells',
 'Plasma cells',
 'Ductal',
 'Macr

In [18]:
list(l_pds.prediction_adata.obs[cell_type_key])

['CD8+ T-cells',
 'Endothelial cells',
 'Fibroblasts',
 'Malignant',
 'CD4+ T-cells',
 'CD8+ T-cells',
 'Malignant',
 'Malignant',
 'Malignant',
 'CD8+ T-cells',
 'CD8+ T-cells',
 'Malignant',
 'Endothelial cells',
 'NK cells',
 'CD8+ T-cells',
 'Malignant',
 'Macrophage/Monocyte',
 'Endothelial cells',
 'Fibroblasts',
 'Fibroblasts',
 'Malignant',
 'CD8+ T-cells',
 'Fibroblasts',
 'CD8+ T-cells',
 'Tregs',
 'CD8+ T-cells',
 'Endothelial cells',
 'CD4+ T-cells',
 'Endothelial cells',
 'CD4+ T-cells',
 'CD4+ T-cells',
 'Malignant',
 'Malignant',
 'CD8+ T-cells',
 'Malignant',
 'CD8+ T-cells',
 'CD8+ T-cells',
 'Fibroblasts',
 'Malignant',
 'Endothelial cells',
 'CD8+ T-cells',
 'Endothelial cells',
 'Endothelial cells',
 'Malignant',
 'Endothelial cells',
 'CD8+ T-cells',
 'Malignant',
 'Macrophage/Monocyte',
 'Malignant',
 'Endothelial cells',
 'CD8+ T-cells',
 'Malignant',
 'CD8+ T-cells',
 'Endothelial cells',
 'NK cells',
 'CD8+ T-cells',
 'CD8+ T-cells',
 'CD4+ T-cells',
 'CD4+ T-c