In [1]:
import logging
import os
import re

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc
import scipy
import anndata
from sklearn.preprocessing import StandardScaler

In [2]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

In [3]:
from captum.attr import DeepLiftShap, LayerDeepLiftShap

In [4]:
%load_ext autoreload

In [5]:
%autoreload 2

In [6]:
from pollock.utils import train_and_save_model, load_model, get_prediction_dataloader, predict_adata, get_splits
from pollock.explain import explain_adata, explain_predictions

In [35]:
a = sc.read_h5ad('/data/pollock/benchmarking/pollock_datasets/scRNAseq/brca.h5ad')
a

AnnData object with n_obs × n_vars = 98564 × 27131
    obs: 'cell_type', 'barcode', 'sample'

In [64]:
train_ids, val_ids = get_splits(a, 'cell_type', 500, oversample=True)

In [65]:
len(train_ids), len(val_ids)

(6500, 92757)

In [66]:
v2, _ = get_splits(a[val_ids], 'cell_type', 500, oversample=True, split=1.)

In [67]:
len(v2)

6500

In [68]:
Counter(a.obs['cell_type'])

Counter({'CD8 T cell': 11649,
         'Endothelial': 6188,
         'Fibroblast': 14769,
         'Malignant': 27884,
         'NK': 2958,
         'Monocyte': 9824,
         'Treg': 4404,
         'CD4 T cell': 11184,
         'B cell': 2972,
         'Plasma': 4879,
         'Mast': 746,
         'Dendritic': 592,
         'Erythrocyte': 515})

In [69]:
from collections import Counter
Counter(a[train_ids].obs['cell_type'])

Counter({'CD8 T cell': 500,
         'Endothelial': 500,
         'Fibroblast': 500,
         'Malignant': 500,
         'NK': 500,
         'Monocyte': 500,
         'Treg': 500,
         'CD4 T cell': 500,
         'B cell': 500,
         'Plasma': 500,
         'Mast': 500,
         'Dendritic': 500,
         'Erythrocyte': 500})

In [70]:
from collections import Counter
Counter(a[val_ids].obs['cell_type'])

Counter({'CD8 T cell': 11160,
         'Endothelial': 5702,
         'Fibroblast': 14273,
         'Malignant': 27389,
         'NK': 2504,
         'Monocyte': 9335,
         'Treg': 3938,
         'CD4 T cell': 10693,
         'B cell': 2511,
         'Plasma': 4414,
         'Mast': 380,
         'Erythrocyte': 205,
         'Dendritic': 253})

In [71]:
Counter(a[v2].obs['cell_type'])

Counter({'CD8 T cell': 500,
         'Endothelial': 500,
         'Fibroblast': 500,
         'Malignant': 500,
         'NK': 500,
         'Monocyte': 500,
         'Treg': 500,
         'CD4 T cell': 500,
         'B cell': 500,
         'Plasma': 500,
         'Mast': 500,
         'Erythrocyte': 500,
         'Dendritic': 500})

In [7]:
train, val = (sc.read_h5ad('/data/pollock/benchmarking/pollock_datasets/scRNAseq/brca_train.h5ad'),
              sc.read_h5ad('/data/pollock/benchmarking/pollock_datasets/scRNAseq/brca_val.h5ad'))
train.shape, val.shape

((6105, 27131), (5748, 27131))

In [8]:
args = {
    'use_cuda': True,
    'epochs': 2,
    'cell_type_key': 'cell_type',
    'module_filepath': '/data/pollock/modules/sandbox/brca_v1'
}

In [9]:
train_and_save_model(train, val, args)

2022-01-12 09:46:36,498 beginning training
2022-01-12 09:46:36,499 creating dataloaders
2022-01-12 09:46:38,211 22285 genes overlap with model after filtering
2022-01-12 09:46:38,213 1268 genes missing from dataset after filtering
2022-01-12 09:46:40,257 creating model


['B cell', 'CD4 T cell', 'CD8 T cell', 'Dendritic', 'Endothelial', 'Erythrocyte', 'Fibroblast', 'Malignant', 'Mast', 'Monocyte', 'NK', 'Plasma', 'Treg']


2022-01-12 09:46:43,584 training dataset size: 6105, validation dataset size: 5748, cell types: ['B cell', 'CD4 T cell', 'CD8 T cell', 'Dendritic', 'Endothelial', 'Erythrocyte', 'Fibroblast', 'Malignant', 'Mast', 'Monocyte', 'NK', 'Plasma', 'Treg']
2022-01-12 09:46:43,585 fitting model


['B cell', 'CD4 T cell', 'CD8 T cell', 'Dendritic', 'Endothelial', 'Erythrocyte', 'Fibroblast', 'Malignant', 'Mast', 'Monocyte', 'NK', 'Plasma', 'Treg']


2022-01-12 09:46:48,437 epoch: 0, train loss: 2.711, val loss: 2.489, zinb loss: 0.234, kl loss: 88.206, clf loss: 2.283, time: 3.46
2022-01-12 09:46:53,054 epoch: 1, train loss: 2.376, val loss: 2.373, zinb loss: 0.219, kl loss: 94.210, clf loss: 2.169, time: 3.23
2022-01-12 09:46:53,055 model fitting finished
2022-01-12 09:46:53,056 saving model to /data/pollock/modules/sandbox/brca_v1


In [10]:
model = load_model(args['module_filepath'])
model = model.cuda()

In [11]:
dl = get_prediction_dataloader(val, model.genes)

2022-01-10 12:07:44,709 22285 genes overlap with model after filtering
2022-01-10 12:07:44,711 1268 genes missing from dataset after filtering


In [12]:
model.eval()

PollockModel(
  (encoder): Sequential(
    (0): Linear(in_features=23553, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=128, bias=True)
    (3): ReLU()
  )
  (mu): Linear(in_features=128, out_features=64, bias=True)
  (var): Linear(in_features=128, out_features=64, bias=True)
  (decoder): Sequential(
    (0): Linear(in_features=64, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=512, bias=True)
    (3): ReLU()
  )
  (disp_decoder): Sequential(
    (0): Linear(in_features=512, out_features=23553, bias=True)
    (1): DispAct()
  )
  (mean_decoder): Sequential(
    (0): Linear(in_features=512, out_features=23553, bias=True)
    (1): MeanAct()
  )
  (drop_decoder): Sequential(
    (0): Linear(in_features=512, out_features=23553, bias=True)
    (1): Sigmoid()
  )
  (prediction_head): Sequential(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, ou

In [13]:
a = predict_adata(model, val)
a

2022-01-10 12:07:52,726 22285 genes overlap with model after filtering
2022-01-10 12:07:52,727 1268 genes missing from dataset after filtering


AnnData object with n_obs × n_vars = 5748 × 23553
    obs: 'cell_type', 'barcode', 'sample', 'n_counts', 'size_factors', 'y_pred', 'prediction_prob', 'predicted_cell_type'
    var: 'mean', 'std'
    uns: 'log1p'
    obsm: 'X_emb', 'X_umap', 'prediction_probs'

In [14]:
a.obs

Unnamed: 0,cell_type,barcode,sample,n_counts,size_factors,y_pred,prediction_prob,predicted_cell_type
0_AACCAACTCCACTAGA-1,Endothelial,AACCAACTCCACTAGA-1,0,4712.0,1.213182,4,0.979209,Endothelial
0_AACCCAAAGACGAGCT-1,NK,AACCCAAAGACGAGCT-1,0,4732.0,1.218332,10,0.937063,NK
0_AACGTCAGTTTACACG-1,Endothelial,AACGTCAGTTTACACG-1,0,3294.0,0.848095,4,0.826345,Endothelial
0_AACTTCTGTCCAGAAG-1,CD8 T cell,AACTTCTGTCCAGAAG-1,0,3202.0,0.824408,10,0.411445,NK
0_AAGCCATAGTGATTCC-1,Endothelial,AAGCCATAGTGATTCC-1,0,4088.0,1.052523,4,0.935567,Endothelial
...,...,...,...,...,...,...,...,...
29_TTTATGCCAGCCCAGT-1,Fibroblast,TTTATGCCAGCCCAGT-1,29,4480.0,1.153450,6,0.710459,Fibroblast
29_TTTCATGGTCGACTTA-1,Treg,TTTCATGGTCGACTTA-1,29,4580.0,1.179197,2,0.123620,CD8 T cell
29_TTTCCTCGTTAGCGGA-1,Treg,TTTCCTCGTTAGCGGA-1,29,4525.0,1.165036,2,0.138950,CD8 T cell
29_TTTGACTCATGGGATG-1,Plasma,TTTGACTCATGGGATG-1,29,4452.0,1.146241,11,0.875787,Plasma


In [15]:
inputs = torch.tensor(a.X[np.random.choice(np.arange(a.shape[0]), size=100, replace=False)])
baseline = torch.tensor(a.X[np.random.choice(np.arange(a.shape[0]), size=100, replace=False)])

In [16]:
class AttributionWrapper(torch.nn.Module):
    def __init__(self, model):
        super(AttributionWrapper, self).__init__()
        self.model = model
    
    def forward(self, x):
        result = self.model(x)
        return result['y']

In [19]:
model = model.to('cpu')
attr_model = AttributionWrapper(model)

In [22]:
model.classes

['B cell',
 'CD4 T cell',
 'CD8 T cell',
 'Dendritic',
 'Endothelial',
 'Erythrocyte',
 'Fibroblast',
 'Malignant',
 'Mast',
 'Monocyte',
 'NK',
 'Plasma',
 'Treg']

In [32]:
ig = DeepLiftShap(attr_model)
attributions, delta = ig.attribute(inputs, baseline, target=0, return_convergence_delta=True)

In [23]:
attributions.shape, delta.shape

(torch.Size([100, 23553]), torch.Size([10000]))

In [24]:
delta

tensor([ 0.0167, -0.0210,  0.0378,  ...,  0.0022,  0.0225,  0.0152])

In [29]:
e = explain_adata(model,
                  a[np.random.choice(np.arange(a.shape[0]), size=100, replace=False)],
                  a[np.random.choice(np.arange(a.shape[0]), size=100, replace=False)],
                  target='B cell')
e

               activations. The hooks and attributes will be removed
            after the attribution is finished


Unnamed: 0,AL627309.1,AL627309.3,AL669831.2,AL669831.5,FAM87B,LINC00115,FAM41C,AL645608.5,AL645608.1,SAMD11,...,AC145212.1,MAFIP,AC011043.1,AL592183.1,AC007325.4,AL354822.1,AC004556.1,AC233755.2,AC233755.1,AC240274.1
13_AGCGTCGAGCACTAAA-1,0.0,0.0,0.0,0.000006,0.0,2.005946e-05,0.0,0.0,-5.799095e-07,0.000004,...,0.0,-2.656010e-06,-2.887683e-07,-2.510423e-06,0.0,0.0,-0.000004,0.0,0.0,1.983775e-07
4_GATCAGTTCTGCCCTA-1,0.0,0.0,0.0,0.000002,0.0,5.630121e-07,0.0,0.0,-3.570036e-06,0.000002,...,0.0,-5.027580e-07,-1.551342e-07,-1.153892e-08,0.0,0.0,-0.000004,0.0,0.0,-1.624646e-04
29_TTACTGTGTAGCTAAA-1,0.0,0.0,0.0,0.000001,0.0,8.344631e-08,0.0,0.0,-2.345417e-05,0.000007,...,0.0,-2.363487e-06,-4.126609e-06,9.380979e-07,0.0,0.0,-0.000005,0.0,0.0,2.665176e-06
29_CTGTGGGAGAGCATCG-1,0.0,0.0,0.0,0.000001,0.0,3.254394e-07,0.0,0.0,-1.814887e-05,0.000005,...,0.0,-2.405773e-06,2.414603e-06,1.251127e-06,0.0,0.0,-0.000007,0.0,0.0,5.628926e-06
28_ACTGATGCACACCTAA-1,0.0,0.0,0.0,0.000001,0.0,3.054914e-08,0.0,0.0,2.477165e-06,0.000001,...,0.0,-1.987730e-06,-3.394456e-07,2.116943e-07,0.0,0.0,-0.000003,0.0,0.0,1.853686e-08
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9_TAAGCGTAGTAGCAAT-1,0.0,0.0,0.0,0.000002,0.0,1.153698e-06,0.0,0.0,-1.241570e-05,0.000015,...,0.0,-5.503713e-07,-1.283007e-06,-9.445744e-07,0.0,0.0,-0.000007,0.0,0.0,1.443942e-06
28_TTGGTTTCACTGGATT-1,0.0,0.0,0.0,0.000003,0.0,3.271004e-08,0.0,0.0,2.096097e-06,-0.000001,...,0.0,-4.119359e-07,4.645167e-07,3.326673e-07,0.0,0.0,-0.000003,0.0,0.0,2.494134e-06
28_TCAGTCCGTCCAGCCA-1,0.0,0.0,0.0,0.000003,0.0,6.616094e-07,0.0,0.0,-3.374186e-06,0.000003,...,0.0,-2.436832e-07,-5.802395e-07,-2.424132e-07,0.0,0.0,-0.000006,0.0,0.0,2.545161e-06
3_TACGGGCCATTCACCC-1,0.0,0.0,0.0,0.000001,0.0,7.128014e-07,0.0,0.0,-5.477696e-05,0.000019,...,0.0,-1.558956e-06,2.187244e-06,-8.163144e-08,0.0,0.0,0.000003,0.0,0.0,3.904497e-06


In [34]:
r = explain_adata(model,
                  a[np.random.choice(np.arange(a.shape[0]), size=10, replace=False)],
                  a[np.random.choice(np.arange(a.shape[0]), size=10, replace=False)],
                  target='all')
r.keys()

dict_keys(['B cell', 'CD4 T cell', 'CD8 T cell', 'Dendritic', 'Endothelial', 'Erythrocyte', 'Fibroblast', 'Malignant', 'Mast', 'Monocyte', 'NK', 'Plasma', 'Treg'])

In [36]:
r['CD8 T cell'].shape

(10, 23553)

In [41]:
from collections import Counter
Counter(a.obs['cell_type']).most_common()

[('Endothelial', 500),
 ('NK', 500),
 ('CD8 T cell', 500),
 ('CD4 T cell', 500),
 ('Treg', 500),
 ('Fibroblast', 500),
 ('Monocyte', 500),
 ('Malignant', 500),
 ('Plasma', 500),
 ('B cell', 500),
 ('Dendritic', 272),
 ('Mast', 246),
 ('Erythrocyte', 230)]

In [42]:
from collections import Counter
Counter(a.obs['predicted_cell_type']).most_common()

[('NK', 1687),
 ('Monocyte', 735),
 ('Malignant', 694),
 ('CD8 T cell', 631),
 ('Endothelial', 583),
 ('Fibroblast', 536),
 ('Plasma', 404),
 ('Mast', 313),
 ('Dendritic', 160),
 ('Treg', 5)]

In [43]:
len(model.classes)

13

In [44]:
df = explain_predictions(model,
                    a,
                    a[np.random.choice(np.arange(a.shape[0]), size=10, replace=False)],
                    label_key='cell_type', n_sample=10)

               activations. The hooks and attributes will be removed
            after the attribution is finished


In [45]:
df

Unnamed: 0,AL627309.1,AL627309.3,AL669831.2,AL669831.5,FAM87B,LINC00115,FAM41C,AL645608.5,AL645608.1,SAMD11,...,MAFIP,AC011043.1,AL592183.1,AC007325.4,AL354822.1,AC004556.1,AC233755.2,AC233755.1,AC240274.1,cell_type
10_ACTGATGGTCTTGCTC-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,-2.009493e-06,0.0,0.0,0.0,0.0,0.0,0.0,-0.000006,B cell
6_CTTGAGAAGGTTTGAA-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,-1.058356e-05,0.0,0.0,0.0,0.0,0.0,0.0,0.000043,B cell
9_AGCATCACATAAGCAA-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,5.068495e-06,0.0,0.0,0.0,0.0,0.0,0.0,-0.000017,B cell
10_AATGACCCACTTGGCG-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,-1.619562e-05,0.0,0.0,0.0,0.0,0.0,0.0,0.000008,B cell
9_GTTCCGTTCTGGCCTT-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,-7.203570e-06,0.0,0.0,0.0,0.0,0.0,0.0,0.000030,B cell
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
28_AACCACAAGCACCAGA-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,7.888536e-06,0.0,0.0,0.0,0.0,0.0,0.0,-0.000029,Treg
7_TTTCATGTCAGTCACA-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,8.345126e-06,0.0,0.0,0.0,0.0,0.0,0.0,-0.000013,Treg
28_GGCTTGGTCTTAATCC-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,4.564620e-06,0.0,0.0,0.0,0.0,0.0,0.0,-0.000014,Treg
29_CATGGATTCCTACGGG-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,1.609338e-07,0.0,0.0,0.0,0.0,0.0,0.0,-0.000028,Treg
