In [11]:
import json
import logging
import os
import re
from pathlib import Path

import anndata
import numpy as np
import pandas as pd
import scanpy as sc
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.svm import SVC, LinearSVC


import pollock.utils as utils
from pollock.dataloaders import normalize, get_train_dataloaders, get_prediction_dataloader
from pollock.model import fit_model

In [9]:
import torch

In [32]:
fps = sorted(utils.listfiles('/home/estorrs/pollock/benchmarking/results/01272021_harmonized_v2/', regex=r'SingleCellNet.tsv'))
fps

['/home/estorrs/pollock/benchmarking/results/01272021_harmonized_v2/scRNAseq/brca/SingleCellNet.tsv',
 '/home/estorrs/pollock/benchmarking/results/01272021_harmonized_v2/scRNAseq/cesc/SingleCellNet.tsv',
 '/home/estorrs/pollock/benchmarking/results/01272021_harmonized_v2/scRNAseq/hnscc/SingleCellNet.tsv',
 '/home/estorrs/pollock/benchmarking/results/01272021_harmonized_v2/scRNAseq/melanoma/SingleCellNet.tsv',
 '/home/estorrs/pollock/benchmarking/results/01272021_harmonized_v2/scRNAseq/myeloma/SingleCellNet.tsv',
 '/home/estorrs/pollock/benchmarking/results/01272021_harmonized_v2/scRNAseq/pdac/SingleCellNet.tsv',
 '/home/estorrs/pollock/benchmarking/results/01272021_harmonized_v2/snATACseq/brca_gene_activity/SingleCellNet.tsv',
 '/home/estorrs/pollock/benchmarking/results/01272021_harmonized_v2/snATACseq/ccrcc_gene_activity/SingleCellNet.tsv',
 '/home/estorrs/pollock/benchmarking/results/01272021_harmonized_v2/snATACseq/gbm_gene_activity/SingleCellNet.tsv',
 '/home/estorrs/pollock/bench

In [None]:
for fp in fps:
    disease = fp.split('/')[-2]
    dtype = fp.split('/')[-3]
    df = pd.read_csv(fp, sep='\t')
    y_pred = df['predicted'].to_list()
    y_true = df['groundtruth'].to_list()

In [34]:
df = pd.read_csv('/home/estorrs/pollock/benchmarking/results/01272021_harmonized_v2/scRNAseq/brca/SingleCellNet.tsv',
                sep='\t')
df

Unnamed: 0,cell_id,groundtruth,predicted,probability
0,_HT062B1_S1PA_AACCAACTCCACTAGA-1,Endothelial,Endothelial,
1,_HT062B1_S1PA_AACCCAAAGACGAGCT-1,NK,NK,
2,_HT062B1_S1PA_AACGTCAGTTTACACG-1,Endothelial,Endothelial,
3,_HT062B1_S1PA_AACTTCTGTCCAGAAG-1,CD8 T cell,CD8 T cell,
4,_HT062B1_S1PA_AAGCCATAGTGATTCC-1,Endothelial,Endothelial,
...,...,...,...,...
5743,_HT171B1_BC2_TTTATGCCAGCCCAGT-1,Fibroblast,Fibroblast,
5744,_HT171B1_BC2_TTTCATGGTCGACTTA-1,Treg,Treg,
5745,_HT171B1_BC2_TTTCCTCGTTAGCGGA-1,Treg,Treg,
5746,_HT171B1_BC2_TTTGACTCATGGGATG-1,Plasma,Plasma,


In [35]:
set(df['predicted'])

{'B cell',
 'CD4 T cell',
 'CD8 T cell',
 'Dendritic',
 'Endothelial',
 'Erythrocyte',
 'Fibroblast',
 'Malignant',
 'Mast',
 'Monocyte',
 'NK',
 'Plasma',
 'Treg',
 'rand'}

In [36]:
from collections import Counter
Counter(df['predicted'])

Counter({'Endothelial': 509,
         'NK': 475,
         'CD8 T cell': 514,
         'CD4 T cell': 503,
         'Treg': 527,
         'Fibroblast': 524,
         'Monocyte': 505,
         'Malignant': 455,
         'Mast': 253,
         'Erythrocyte': 195,
         'Plasma': 518,
         'B cell': 507,
         'Dendritic': 260,
         'rand': 3})

In [31]:
a = sc.read_h5ad('/data/pollock/benchmarking/results/seurat/scRNAseq_brca.h5ad')
a.obs

Unnamed: 0,orig.ident,nCount_RNA,nFeature_RNA,cell_type,barcode,sample,predicted.id,prediction.score.CD8.T.cell,prediction.score.Fibroblast,prediction.score.NK,...,prediction.score.Malignant,prediction.score.Endothelial,prediction.score.Monocyte,prediction.score.Treg,prediction.score.B.cell,prediction.score.Mast,prediction.score.Plasma,prediction.score.Dendritic,prediction.score.Erythrocyte,prediction.score.max
0_AACCAACTCCACTAGA-1,0,4712,2224,Endothelial,AACCAACTCCACTAGA-1,0,Endothelial,0,0,0,...,0.0958545238202942,0.904145476179706,0,0,0,0,0,0,0,0.904145476179706
0_AACCCAAAGACGAGCT-1,0,4732,2612,NK,AACCCAAAGACGAGCT-1,0,NK,0,0,1,...,0,0,0,0,0,0,0,0,0,1
0_AACGTCAGTTTACACG-1,0,3296,1052,Endothelial,AACGTCAGTTTACACG-1,0,Endothelial,0,0,0,...,0,1,0,0,0,0,0,0,0,1
0_AACTTCTGTCCAGAAG-1,0,3202,647,CD8 T cell,AACTTCTGTCCAGAAG-1,0,CD8 T cell,0.940970416529484,0,0.0590295834705152,...,0,0,0,0,0,0,0,0,0,0.940970416529484
0_AAGCCATAGTGATTCC-1,0,4088,1757,Endothelial,AAGCCATAGTGATTCC-1,0,Endothelial,0,0,0,...,0,1,0,0,0,0,0,0,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
29_TTTATGCCAGCCCAGT-1,29,4480,2112,Fibroblast,TTTATGCCAGCCCAGT-1,29,Fibroblast,0,0.928496134305296,0,...,0,0,0,0,0,0,0.0715038656947036,0,0,0.928496134305296
29_TTTCATGGTCGACTTA-1,29,4580,2148,Treg,TTTCATGGTCGACTTA-1,29,Treg,0.0247594920019057,0,0,...,0,0,0,0.975240507998094,0,0,0,0,0,0.975240507998094
29_TTTCCTCGTTAGCGGA-1,29,4526,2173,Treg,TTTCCTCGTTAGCGGA-1,29,Treg,0,0,0,...,0,0,0,0.999999999999999,0,0,0,0,0,0.999999999999999
29_TTTGACTCATGGGATG-1,29,4452,491,Plasma,TTTGACTCATGGGATG-1,29,Plasma,0,0,0,...,0.0159501515904298,0,0,0,0,0,0.98404984840957,0,0,0.98404984840957


In [5]:
data_dir = '/data/pollock/benchmarking/pollock_datasets/'
fps = utils.listfiles(data_dir, regex=r'.h5ad')
fps = [fp for fp in fps if '_train' in fp or '_val' in fp]

fmap = {}
for fp in fps:
    dtype = fp.split('/')[-2]
    disease = fp.split('/')[-1].split('_')[0]
    partition = 'train' if '_train' in fp else 'val'
    
    if dtype not in fmap:
        fmap[dtype] = {}
    if disease not in fmap[dtype]:
        fmap[dtype][disease] = {}
    
    if not (dtype == 'snATACseq' and ('peaks' in fp or 'motif' in fp)):
        fmap[dtype][disease][partition] = fp
fmap

{'snRNAseq': {'ccrcc': {'train': '/data/pollock/benchmarking/pollock_datasets/snRNAseq/ccrcc_train.h5ad',
   'val': '/data/pollock/benchmarking/pollock_datasets/snRNAseq/ccrcc_val.h5ad'},
  'brca': {'train': '/data/pollock/benchmarking/pollock_datasets/snRNAseq/brca_train.h5ad',
   'val': '/data/pollock/benchmarking/pollock_datasets/snRNAseq/brca_val.h5ad'},
  'gbm': {'val': '/data/pollock/benchmarking/pollock_datasets/snRNAseq/gbm_val.h5ad',
   'train': '/data/pollock/benchmarking/pollock_datasets/snRNAseq/gbm_train.h5ad'}},
 'snATACseq': {'brca': {'train': '/data/pollock/benchmarking/pollock_datasets/snATACseq/brca_gene_activity_train.h5ad',
   'val': '/data/pollock/benchmarking/pollock_datasets/snATACseq/brca_gene_activity_val.h5ad'},
  'ccrcc': {'train': '/data/pollock/benchmarking/pollock_datasets/snATACseq/ccrcc_gene_activity_train.h5ad',
   'val': '/data/pollock/benchmarking/pollock_datasets/snATACseq/ccrcc_gene_activity_val.h5ad'},
  'gbm': {'val': '/data/pollock/benchmarking/p

In [6]:
for dtype, d1 in fmap.items():
    for disease, d2 in d1.items():
        print(dtype, disease, sc.read_h5ad(d2['train']).shape, sc.read_h5ad(d2['val']).shape)

snRNAseq ccrcc (4754, 33538) (4518, 33538)
snRNAseq brca (5252, 29175) (4893, 29175)
snRNAseq gbm (3722, 29748) (3577, 29748)
snATACseq brca (3576, 19891) (3519, 19891)
snATACseq ccrcc (3000, 19843) (3000, 19843)
snATACseq gbm (3389, 19891) (2876, 19891)
scRNAseq cesc (4661, 22928) (4276, 22928)
scRNAseq myeloma (3617, 24020) (3312, 24020)
scRNAseq brca (6105, 27131) (5748, 27131)
scRNAseq hnscc (5287, 26929) (5201, 26929)
scRNAseq pdac (7940, 28756) (7823, 28756)
scRNAseq melanoma (4218, 23452) (3517, 23452)


In [None]:
for dtype, d1 in fmap.items():
    for disease, d2 in d1.items():
#         if dtype == 'scRNAseq':
        print(dtype, disease)
        train = sc.read_h5ad(d2['train'])
        val = sc.read_h5ad(d2['val'])

In [23]:
train, val = sc.read_h5ad(fmap['scRNAseq']['brca']['train']), sc.read_h5ad(fmap['scRNAseq']['brca']['val'])

In [24]:
train_dl, val_dl = get_train_dataloaders(train, val)

2022-01-18 10:31:09,562 22285 genes overlap with model after filtering
2022-01-18 10:31:09,564 1268 genes missing from dataset after filtering


For MLP, we use Dense layers with rectified linear unit (ReLU) activation function, and we apply dropout after each layer with a dropout rate of 0.1. Cross-entropy loss is used for training the model

The network structures of MLP, GEDFN, and ItClust are dependent on the sample size. When there are over 5000 cells, MLP has the structure of [input_dim, 128, 64, 32, 16, 8, n_classes] (input_dim refers to the length of feature space, n_classes refers to the number of cell types). When there are less than 5000 cells, MLP has [input_dim, 64, 16, n_classes]

In [25]:
class MLP(torch.nn.Module):
    def __init__(self, genes, classes, method='small'):
        """
        MLP
        """
        super(MLP, self).__init__()
        self.genes = genes
        self.n_genes = len(genes)
        self.classes = classes
        self.n_classes = len(classes)

        if method == 'small':
            self.layers = torch.nn.Sequential(
                torch.nn.Linear(self.n_genes, 128),
                torch.nn.ReLU(),
                torch.nn.Dropout(.1),
                torch.nn.Linear(128, 32),
                torch.nn.ReLU(),
                torch.nn.Dropout(.1),
                torch.nn.Linear(32, 16),
                torch.nn.ReLU(),
                torch.nn.Dropout(.1),
                torch.nn.Linear(16, 8),
                torch.nn.ReLU(),
                torch.nn.Dropout(.1),
                torch.nn.Linear(8, self.n_classes),
            )
        else:
            self.layers = torch.nn.Sequential(
                torch.nn.Linear(self.n_genes, 64),
                torch.nn.ReLU(),
                torch.nn.Dropout(.1),
                torch.nn.Linear(64, 16),
                torch.nn.ReLU(),
                torch.nn.Dropout(.1),
                torch.nn.Linear(16, self.n_classes),
            )
            
    def forward(self, x):
        x = self.layers(x)
        return x

In [26]:
model = MLP(train_dl.dataset.adata.var.index.to_list(), train_dl.dataset.cell_types, method='small')
model = model.cuda()

lr = 1e-4
epochs = 20
criteria = torch.nn.CrossEntropyLoss()

opt = torch.optim.Adam(model.parameters(), lr=lr)


In [27]:
import time
for epoch in range(epochs):
    train_loss, val_loss = 0., 0.
    start = time.time()
    model.train()
    for i, b in enumerate(train_dl):
        x, y = b['x'], b['y']
        x, y = x.cuda(), y.cuda()
        opt.zero_grad()
        out = model(x)
        loss = criteria(out, y)
        loss.backward()
        opt.step()

        train_loss += float(loss.detach().cpu())
    train_loss = train_loss / len(train_dl)

    time_delta = time.time() - start
    model.eval()
    with torch.no_grad():
        for i, b in enumerate(val_dl):
            x, y = b['x'], b['y']
            x, y = x.cuda(), y.cuda()

            out = model(x)
            loss = criteria(out, y)
            val_loss += float(loss.detach().cpu())


    val_loss = val_loss / len(val_dl)

    logging.info(f'epoch: {epoch}, train loss: {train_loss:.3f}, val loss: {val_loss:.3f}, time: {time_delta:.2f}')

2022-01-18 10:31:19,224 epoch: 0, train loss: 2.479, val loss: 2.279, time: 1.62
2022-01-18 10:31:22,003 epoch: 1, train loss: 2.072, val loss: 1.887, time: 1.69
2022-01-18 10:31:24,647 epoch: 2, train loss: 1.758, val loss: 1.647, time: 1.57
2022-01-18 10:31:27,386 epoch: 3, train loss: 1.548, val loss: 1.475, time: 1.58
2022-01-18 10:31:30,084 epoch: 4, train loss: 1.364, val loss: 1.297, time: 1.63
2022-01-18 10:31:32,693 epoch: 5, train loss: 1.184, val loss: 1.121, time: 1.54
2022-01-18 10:31:35,301 epoch: 6, train loss: 1.019, val loss: 0.980, time: 1.54
2022-01-18 10:31:37,918 epoch: 7, train loss: 0.869, val loss: 0.869, time: 1.54
2022-01-18 10:31:40,606 epoch: 8, train loss: 0.760, val loss: 0.790, time: 1.53
2022-01-18 10:31:43,319 epoch: 9, train loss: 0.666, val loss: 0.735, time: 1.65
2022-01-18 10:31:45,931 epoch: 10, train loss: 0.614, val loss: 0.705, time: 1.55
2022-01-18 10:31:48,529 epoch: 11, train loss: 0.552, val loss: 0.688, time: 1.53
2022-01-18 10:31:51,130 ep

In [28]:
def mlp_predict_dl(dl, model):
    y_prob = None
    model.eval()
    with torch.no_grad():
        for i, b in enumerate(dl):
            x, y = b['x'], b['y']
            x, y = x.cuda(), y.cuda()

            out = model(x).detach().cpu().numpy()

            if y_prob is None:
                y_prob = out
            else:
                y_prob = np.concatenate((y_prob, out), axis=0)
                
    labels = [model.classes[i] for i in np.argmax(y_prob, axis=1)]
    
    return labels, y_prob


In [29]:
labels, probs = mlp_predict_dl(val_dl, model)
val.obs['predicted_cell_type'] = labels

prob_df = pd.DataFrame(data=probs, index=val.obs.index.to_list(),
                      columns=[f'probability {ct}' for ct in model.classes])

val.obs = pd.concat((val.obs, prob_df), axis=1)
val.obs

Unnamed: 0,cell_type,barcode,sample,n_counts,predicted_cell_type,probability B cell,probability CD4 T cell,probability CD8 T cell,probability Dendritic,probability Endothelial,probability Erythrocyte,probability Fibroblast,probability Malignant,probability Mast,probability Monocyte,probability NK,probability Plasma,probability Treg
0_AACCAACTCCACTAGA-1,Endothelial,AACCAACTCCACTAGA-1,0,4712.0,Endothelial,1.362963,-1.287840,-0.119352,1.115360,9.295658,5.175925,-11.274056,-9.268015,-10.175484,-8.702056,4.746614,-3.158478,-7.635473
0_AACCCAAAGACGAGCT-1,NK,AACCCAAAGACGAGCT-1,0,4732.0,NK,1.202518,-14.874049,-0.706506,1.847630,1.782679,0.247695,-10.714002,-3.860619,-3.300359,-15.253161,9.986592,-2.909322,-9.984134
0_AACGTCAGTTTACACG-1,Endothelial,AACGTCAGTTTACACG-1,0,3294.0,Endothelial,0.546287,-0.289826,-0.287198,1.185426,6.852063,3.503712,-7.477614,-6.962991,-7.320167,-5.421151,2.972970,-2.293766,-5.882442
0_AACTTCTGTCCAGAAG-1,CD8 T cell,AACTTCTGTCCAGAAG-1,0,3202.0,CD8 T cell,-3.959328,-0.365535,5.278174,-5.298337,-3.252025,-0.839206,-3.099293,0.585630,-6.207127,0.511336,-3.251617,-4.518991,-2.337372
0_AAGCCATAGTGATTCC-1,Endothelial,AAGCCATAGTGATTCC-1,0,4088.0,Endothelial,0.573431,-0.395243,-0.181737,1.578007,10.324621,5.100435,-11.602044,-10.812377,-11.782465,-8.537018,4.670155,-4.002245,-9.061112
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
29_TTTATGCCAGCCCAGT-1,Fibroblast,TTTATGCCAGCCCAGT-1,29,4480.0,Fibroblast,-8.283934,-8.637789,-1.662289,1.914663,-9.844347,-11.607601,7.631142,2.028147,1.299929,-1.830342,0.639902,-7.567197,-7.095963
29_TTTCATGGTCGACTTA-1,Treg,TTTCATGGTCGACTTA-1,29,4580.0,Treg,-0.329183,6.125992,5.443079,-12.981363,-4.399100,3.269040,-3.292470,0.603797,-14.448524,-8.862652,-7.142284,-11.068240,10.028825
29_TTTCCTCGTTAGCGGA-1,Treg,TTTCCTCGTTAGCGGA-1,29,4525.0,Treg,-0.088276,5.581955,5.489130,-13.455153,-4.280552,3.387279,-4.071203,0.698752,-14.996204,-9.669154,-6.786827,-11.557008,10.196788
29_TTTGACTCATGGGATG-1,Plasma,TTTGACTCATGGGATG-1,29,4452.0,Plasma,4.660043,-8.555573,0.248813,9.933461,8.510962,7.783207,-8.565523,-6.077346,8.577444,-4.344531,9.106541,15.374352,-11.468411
