In [1]:
import sys
sys.path.insert(0, '/tf/pollock')

In [2]:
%load_ext autoreload

In [3]:
import logging
import os
import random
from collections import Counter
from importlib import reload
import time

import anndata
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
# import anndata2ri


import pollock
from pollock import PollockDataset, PollockModel, load_from_directory
# import pollock.models.analysis as pollock_analysis

  from pandas.core.index import RangeIndex


In [4]:
import tensorflow as tf
from tensorflow.keras import layers

tf.keras.backend.clear_session()  # For easy reset of notebook state.

In [5]:
%autoreload 2

In [6]:
DATA_DIR = '/data/single_cell_classification'
MODEL_DIR = '/models'

## expression tables

In [None]:
run_name = 'sc_brca'

expression_fp = os.path.join(DATA_DIR, 'tumor', 'BR', 'raw', 'houxiang_brca',
                            'breast_counts_matrix.tsv')
label_fp = os.path.join(DATA_DIR, 'tumor', 'BR', 'raw', 'houxiang_brca',
                            'breast_metadata.tsv')

model_save_dir = os.path.join(MODEL_DIR, run_name)

sample_column = 'Genes'
sep='\t'
cell_type_key = 'cell_type'

In [None]:
run_name = 'sc_hnsc'

expression_fp = os.path.join(DATA_DIR, 'tumor', 'HNSC', 'raw', 'hnsc_yize',
                            'Assigned_WUHN_15_processed_cluster_review_gene_expression_format_2.tsv')
label_fp = os.path.join(DATA_DIR, 'tumor', 'HNSC', 'raw', 'hnsc_yize',
                            'Assigned_WUHN_15_processed_cluster_review_cell_metadata_format_2.tsv')

model_save_dir = os.path.join(MODEL_DIR, run_name)

sample_column = 'Genes'
sep='\t'
cell_type_key = 'cell_type'

In [None]:
run_name = 'sc_cesc'

expression_fp = os.path.join(DATA_DIR, 'tumor', 'CESC', 'raw', 'cesc_yize_v2',
                            'Assigned_CESC_9_processed_cluster_review_final_gene_expression_format.tsv')
label_fp = os.path.join(DATA_DIR, 'tumor', 'CESC', 'raw', 'cesc_yize_v2',
                            'Assigned_CESC_9_processed_cluster_review_final_cell_metadata_format.tsv')

model_save_dir = os.path.join(MODEL_DIR, run_name)

sample_column = 'Genes'
sep='\t'
cell_type_key = 'cell_type'

## H5 object

In [None]:
run_name = 'sc_pdac'

expression_fp = os.path.join(DATA_DIR, 'tumor', 'PDAC', 'pdac.h5ad')

model_save_dir = os.path.join(MODEL_DIR, run_name)

sample_column = 'Genes'
sep='\t'
cell_type_key = 'cell_type'

In [None]:
run_name = 'sn_ccrcc'

expression_fp = os.path.join(DATA_DIR, 'tumor', 'CCRCC', 'yige',
                            'adata.h5')
label_fp = os.path.join(DATA_DIR, 'tumor', 'CCRCC', 'yige',
                            'metadata.tsv')

model_save_dir = os.path.join(MODEL_DIR, run_name)

sample_column = 'Genes'
sep='\t'
cell_type_key = 'cell_type'

In [7]:
run_name = 'sc_myeloma'

expression_fp = os.path.join(DATA_DIR, 'tumor', 'melanoma', 'merged.h5ad')

model_save_dir = os.path.join(MODEL_DIR, run_name)

cell_type_key = 'ident'

In [None]:
expression_df = pd.read_csv(expression_fp, sep=sep)
expression_df

In [None]:
expression_df = expression_df.set_index('Genes')
expression_df = expression_df.transpose()
expression_df

In [None]:
expression_df.to_hdf(expression_fp.replace('.tsv', '.h5'), 'df')

In [None]:
expression_df = pd.read_hdf(expression_fp.replace('.tsv', '.h5'), 'df')
expression_df

In [8]:
adata = anndata.read_h5ad(expression_fp)
adata

AnnData object with n_obs × n_vars = 108187 × 24020 
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'percent.mito', 'nCount_SCT', 'nFeature_SCT', 'SCT_snn_res.0.5', 'seurat_clusters', 'cell_type', 'sample', 'predicted_cell_type', 'probability', 'new.ident', 'ident'
    layers: 'logcounts'

In [None]:
label_df = pd.read_csv(
    label_fp,
    sep=sep
    )
label_df = label_df.set_index('cell_id')
label_df = label_df.loc[expression_df.index]
label_df

In [None]:
adata = anndata.AnnData(X=expression_df.values, obs=label_df)
adata.obs.index = expression_df.index
adata.var.index = expression_df.columns
adata

In [None]:
adata = sc.read_h5ad(expression_fp)
adata

In [None]:
run_name = 'sc_master'

expression_table_fps = [
    (os.path.join(DATA_DIR, 'tumor', 'BR', 'raw', 'houxiang_brca',
                            'breast_counts_matrix.h5'),
     os.path.join(DATA_DIR, 'tumor', 'BR', 'raw', 'houxiang_brca',
                            'breast_metadata.tsv')),
    (os.path.join(DATA_DIR, 'tumor', 'HNSC', 'raw', 'hnsc_yize',
                            'Assigned_WUHN_15_processed_cluster_review_gene_expression_format_2.h5'),
     os.path.join(DATA_DIR, 'tumor', 'HNSC', 'raw', 'hnsc_yize',
                            'Assigned_WUHN_15_processed_cluster_review_cell_metadata_format_2.tsv')),
    (os.path.join(DATA_DIR, 'tumor', 'CESC', 'raw', 'cesc_yize_v2',
                            'Assigned_CESC_9_processed_cluster_review_final_gene_expression_format.h5'),
    os.path.join(DATA_DIR, 'tumor', 'CESC', 'raw', 'cesc_yize_v2',
                            'Assigned_CESC_9_processed_cluster_review_final_cell_metadata_format.tsv')),
#     (os.path.join(DATA_DIR, 'tumor', 'CCRCC', 'yige',
#                             'adata.h5'),
#     os.path.join(DATA_DIR, 'tumor', 'CCRCC', 'yige',
#                             'metadata.tsv'))
]

anndata_fps = [
    os.path.join(DATA_DIR, 'tumor', 'PDAC', 'pdac.h5ad'),
]

model_save_dir = os.path.join(MODEL_DIR, run_name)
cell_type_key = 'cell_type'
sep='\t'

In [None]:
adata = None
for expression_fp, metadata_fp in expression_table_fps:
    expression_df = pd.read_hdf(expression_fp, 'df')
    label_df = pd.read_csv(metadata_fp, sep=sep)
    label_df = label_df.set_index('cell_id')
    label_df = label_df.loc[expression_df.index]
    
    temp = anndata.AnnData(X=expression_df.values, obs=label_df)
    temp.obs.index = expression_df.index
    temp.var.index = expression_df.columns

    if adata is None:
        adata = temp.copy()
    else:
        adata = adata.concatenate(temp)
adata

In [None]:
adata

In [None]:
to_add = anndata.read_h5ad(anndata_fps[0])
to_add

In [None]:
adata = adata.concatenate(to_add)
adata

In [None]:
Counter(adata.obs[cell_type_key]).most_common()

In [None]:
cell_type_map = {
    'B': 'B-cells',
    'B(1)': 'B-cells',
    'B(2)': 'B-cells',
    'BR_Malignant': 'Malignant',
    'CD4+T': 'CD4+ T-cells',
    'CD4_T': 'CD4+ T-cells',
    'CD8+T': 'CD8+ T-cells',
    'CD8_T': 'CD8+ T-cells',
    'CESC_Malignant/Epithelial_1': 'Malignant',
    'CESC_Malignant/Epithelial_2': 'Malignant',
    'CESC_Malignant/Epithelial_3': 'Malignant',
    'CESC_Malignant/Epithelial_4': 'Malignant',
    'Endothelial': 'Endothelial cells',
    'Fibroblast': 'Fibroblasts',
    ## macro/mono
    'HNSC_Malignant/Epithelial': 'Malignant',
    'Macrophage': 'Macrophage/Monocyte',
    'Macrophages': 'Macrophage/Monocyte',
    'Monocyte': 'Macrophage/Monocyte',
    'Mast': 'Mast cells',
    'NK': 'NK cells',
    'Plasma': 'Plasma cells',
    'Treg': 'Tregs'
}

In [None]:
sorted(set(adata.obs[cell_type_key]))

In [None]:
adata.obs[cell_type_key] = [cell_type_map.get(x, x) for x in adata.obs[cell_type_key]]
sorted(set(adata.obs[cell_type_key]))

In [10]:
Counter(adata.obs['ident']).most_common()

[('CD4+T', 27182),
 ('CD14+Mono', 20274),
 ('Plasma', 14347),
 ('CD8+T', 11908),
 ('NK', 8060),
 ('B', 6189),
 ('CD16+Mono', 4310),
 ('Erythrocyte', 3402),
 ('DC', 781),
 ('pDC', 126),
 ('CD34+CYTL1+', 85),
 ('Plasma_BM', 16)]

In [9]:
adata = adata[adata.obs[cell_type_key]!='Unknown']
adata = adata[adata.obs[cell_type_key]!='unknown']
adata = adata[adata.obs[cell_type_key]!='NA']
adata = adata[adata.obs['orig.ident']!='27522_5']
adata = adata[adata.obs['orig.ident']!='27522_6']

adata

View of AnnData object with n_obs × n_vars = 96680 × 24020 
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'percent.mito', 'nCount_SCT', 'nFeature_SCT', 'SCT_snn_res.0.5', 'seurat_clusters', 'cell_type', 'sample', 'predicted_cell_type', 'probability', 'new.ident', 'ident'
    layers: 'logcounts'

In [None]:
# set(adata.obs['orig.ident'])

In [None]:
adata.write_h5ad('/data/single_cell_classification/tumor/master/master.h5ad')

In [None]:
adata = anndata.read_h5ad('/data/single_cell_classification/tumor/master/master.h5ad')

In [None]:
counts = Counter(adata.obs[cell_type_key])
counts.most_common()

In [None]:
## get rid of unknowns
adata = adata[adata.obs[cell_type_key]!='Unknown']
adata

In [None]:
# adata.obs[cell_type_key] = ['CESC_Malignant' if 'Malignant' in x else x for x in adata.obs[cell_type_key]]
# counts = Counter(adata.obs[cell_type_key])
# counts.most_common()

In [11]:
pds = PollockDataset(adata.copy(), cell_type_key=cell_type_key, n_per_cell_type=500, batch_size=64,
                    dataset_type='training', min_genes=200, min_cells=3, mito_threshold=None,
                    max_n_genes=None, log=True, cpm=False, min_disp=.01)

2020-04-15 23:25:01,776 normalizing counts for model training
2020-04-15 23:25:01,777 filtering by min genes: 200
2020-04-15 23:25:03,968 genes remaining after filter: 24020
2020-04-15 23:25:03,969 filtering by min cells: 3
2020-04-15 23:25:07,531 cells remaining after filter: 96679
2020-04-15 23:25:07,532 loging data
2020-04-15 23:25:08,786 filtering with dispersion 0.01
2020-04-15 23:25:12,324 remaining after min disp: 3030
2020-04-15 23:25:12,660 scaling data
2020-04-15 23:26:00,896 creating tf datasets


In [12]:
Counter(pds.train_adata.obs[cell_type_key]).most_common()

[('Plasma', 500),
 ('CD8+T', 500),
 ('B', 500),
 ('DC', 500),
 ('CD14+Mono', 500),
 ('CD4+T', 500),
 ('CD16+Mono', 500),
 ('NK', 500),
 ('Erythrocyte', 500),
 ('pDC', 100),
 ('CD34+CYTL1+', 68),
 ('Plasma_BM', 12)]

In [13]:
Counter(pds.val_adata.obs[cell_type_key]).most_common()

[('CD4+T', 26682),
 ('CD14+Mono', 19774),
 ('Plasma', 13847),
 ('CD8+T', 11408),
 ('NK', 7560),
 ('B', 5689),
 ('CD16+Mono', 3810),
 ('Erythrocyte', 2901),
 ('DC', 281),
 ('pDC', 26),
 ('CD34+CYTL1+', 17),
 ('Plasma_BM', 4)]

In [14]:
pm = PollockModel(pds.cell_types, pds.train_adata.shape[1], alpha=.00001)

In [15]:
pm.fit(pds, epochs=40)

2020-04-15 23:27:12,078 epoch: 1, val loss: 18.30880355834961
2020-04-15 23:27:18,399 epoch: 2, val loss: 16.825801849365234
2020-04-15 23:27:24,644 epoch: 3, val loss: 15.768560409545898
2020-04-15 23:27:30,905 epoch: 4, val loss: 15.191256523132324
2020-04-15 23:27:37,421 epoch: 5, val loss: 14.920063018798828
2020-04-15 23:27:43,660 epoch: 6, val loss: 14.734820365905762
2020-04-15 23:27:50,087 epoch: 7, val loss: 14.556163787841797
2020-04-15 23:27:56,360 epoch: 8, val loss: 14.42278003692627
2020-04-15 23:28:02,545 epoch: 9, val loss: 14.283868789672852
2020-04-15 23:28:09,049 epoch: 10, val loss: 14.17197322845459
2020-04-15 23:28:15,308 epoch: 11, val loss: 14.0711030960083
2020-04-15 23:28:21,567 epoch: 12, val loss: 13.985541343688965
2020-04-15 23:28:28,004 epoch: 13, val loss: 13.921600341796875
2020-04-15 23:28:34,281 epoch: 14, val loss: 13.861289024353027
2020-04-15 23:28:40,498 epoch: 15, val loss: 13.813666343688965
2020-04-15 23:28:46,819 epoch: 16, val loss: 13.775397

In [16]:
pm.save(pds, model_save_dir)

  _warn_prf(average, modifier, msg_start, len(result))


In [17]:
pm.summary['validation']

{'metrics': {'B': {'precision': 0.9751410437235543,
   'recall': 0.9722271049393566,
   'f1-score': 0.9736818941994543,
   'support': 5689},
  'CD14+Mono': {'precision': 0.9692738907398594,
   'recall': 0.9268736724992415,
   'f1-score': 0.94759972080759,
   'support': 19774},
  'CD16+Mono': {'precision': 0.8333333333333334,
   'recall': 0.963254593175853,
   'f1-score': 0.8935962990017043,
   'support': 3810},
  'CD34+CYTL1+': {'precision': 0.34,
   'recall': 1.0,
   'f1-score': 0.5074626865671642,
   'support': 17},
  'CD4+T': {'precision': 0.9565920110828908,
   'recall': 0.931639307398246,
   'f1-score': 0.9439507860560492,
   'support': 26682},
  'CD8+T': {'precision': 0.8202340443190306,
   'recall': 0.8663218793828892,
   'f1-score': 0.8426482499893423,
   'support': 11408},
  'DC': {'precision': 0.28683035714285715,
   'recall': 0.9145907473309609,
   'f1-score': 0.4367034834324554,
   'support': 281},
  'Erythrocyte': {'precision': 0.9681955138935386,
   'recall': 0.9968976215

In [18]:
pm.summary['validation']['confusion_matrix']

array([[0.97, 0.  , 0.  , 0.  , 0.  , 0.  , 0.01, 0.  , 0.  , 0.01, 0.  ,
        0.  ],
       [0.  , 0.93, 0.03, 0.  , 0.01, 0.  , 0.03, 0.  , 0.  , 0.  , 0.  ,
        0.  ],
       [0.  , 0.03, 0.96, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  ],
       [0.  , 0.  , 0.  , 1.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  ],
       [0.  , 0.01, 0.  , 0.  , 0.93, 0.05, 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.08, 0.87, 0.  , 0.  , 0.04, 0.  , 0.  ,
        0.  ],
       [0.  , 0.05, 0.  , 0.01, 0.  , 0.01, 0.91, 0.  , 0.  , 0.01, 0.  ,
        0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 1.  , 0.  , 0.  , 0.  ,
        0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.09, 0.  , 0.  , 0.9 , 0.  , 0.  ,
        0.  ],
       [0.  , 0.01, 0.  , 0.  , 0.01, 0.01, 0.  , 0.  , 0.  , 0.97, 0.  ,
        0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 1.  , 0.  ,
        0.  ],
       [0.  , 0.  , 0

In [19]:
pm.summary['validation']['metrics']['accuracy']

0.9312492527092686

In [20]:
pm.summary['training']['metrics']['accuracy']

0.9978632478632479

In [None]:
l_pds, l_pm = load_from_directory(adata, model_save_dir)

In [None]:
labels, probs = l_pm.predict_pollock_dataset(l_pds, labels=True, )
labels

In [None]:
list(l_pds.prediction_adata.obs[cell_type_key])