In [1]:
from pca_model import PCAModel
from gcn_model import GCNModel
import utilities
import pandas as pd
import numpy as np
from sklearn.decomposition import PCA
import torch
from captum.attr import IntegratedGradients, DeepLift, DeepLiftShap, FeaturePermutation

In [2]:
data_folder = "simulations/splat_0.7_de_rq/"
data_path = data_folder + "query_counts.csv"
marker_path = data_folder + "markers.txt"
# read in dataset
counts = pd.read_csv(data_path, index_col=0)

In [3]:
X,_,_,_ = utilities.preprocess(np.array(counts), scale=False, run_pca=False)
pca = PCA(n_components=500, random_state=8)
new_data = pca.fit(X)
pca_mod = PCAModel(pca.components_, pca.mean_)
pca_mod(X)

tensor([[-4.1202,  4.9487,  0.0848,  ..., -0.4235,  0.8613, -1.8527],
        [-0.2812, -1.5428,  3.2285,  ..., -0.2027, -0.5386, -2.4170],
        [-2.1465, -0.3854, -1.0680,  ..., -0.8622,  1.8184,  0.8735],
        ...,
        [ 6.4777, -1.0975,  0.5996,  ..., -0.1975,  0.5876,  0.2261],
        [-6.2502, -2.2448,  2.3801,  ...,  0.5046, -0.0493,  0.1208],
        [-5.9382,  4.7989, -0.2715,  ..., -0.6843, -0.5246,  1.3884]],
       grad_fn=<MmBackward0>)

In [4]:
X, keep_cells, keep_genes, rpca = utilities.preprocess(np.array(counts), scale=False, comps=500)
X

array([[-4.12022   ,  4.948713  ,  0.08479678, ..., -0.42349264,
         0.861268  , -1.8526838 ],
       [-0.28116703, -1.5429112 ,  3.2285054 , ..., -0.2027156 ,
        -0.5385919 , -2.4170127 ],
       [-2.1465366 , -0.3853437 , -1.0680304 , ..., -0.8622164 ,
         1.8183985 ,  0.87353873],
       ...,
       [ 6.477684  , -1.0975296 ,  0.5995855 , ..., -0.19748206,
         0.587635  ,  0.22611101],
       [-6.2502413 , -2.2447553 ,  2.380058  , ...,  0.50463355,
        -0.04931175,  0.12082425],
       [-5.938177  ,  4.798928  , -0.2715307 , ..., -0.6843324 ,
        -0.5246253 ,  1.3883908 ]], dtype=float32)

In [5]:
gene_names = counts.columns.to_numpy()[keep_genes]

In [6]:
all_labels = pd.read_csv(data_folder + "preds.csv", index_col=0)
all_labels = all_labels.loc[keep_cells,:]

_,marker_names = utilities.read_marker_file(marker_path)

all_labels_factored = utilities.factorize_df(all_labels, marker_names)
encoded_labels = utilities.encode_predictions(all_labels_factored)

meta_path = data_folder + "query_meta.csv"
metadata = pd.read_csv(meta_path, index_col=0)
real_y = pd.factorize(metadata['Group'], sort=True)[0]
real_y = real_y[keep_cells]

confident_labels = utilities.get_consensus_labels(encoded_labels, necessary_vote = 3)

train_nodes = np.where(confident_labels != -1)[0]
test_nodes = np.where(confident_labels == -1)[0]

dataset  = torch.utils.data.TensorDataset(torch.tensor(X), torch.tensor(confident_labels))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=50, shuffle=True)

test_dataset  = torch.utils.data.TensorDataset(torch.tensor(X), torch.tensor(real_y))
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=50, shuffle=False)

In [7]:
m = GCNModel("configs/2_40.txt", 2, target_types=4)
m.train(dataloader, 150)

Loss in epoch 0 = 26.021423
Loss in epoch 10 = 0.069141
Loss in epoch 20 = 0.014639
Loss in epoch 30 = 0.008954
Loss in epoch 40 = 0.004932
Loss in epoch 50 = 0.002885
Loss in epoch 60 = 0.001884
Loss in epoch 70 = 0.001322
Loss in epoch 80 = 0.001180
Loss in epoch 90 = 0.001211
Loss in epoch 100 = 0.000603
Loss in epoch 110 = 0.000499
Loss in epoch 120 = 0.000556
Loss in epoch 130 = 0.000350
Loss in epoch 140 = 0.000364


In [8]:
preds,_ = m.predict(test_dataloader)
final_preds = preds.max(dim=1)[1]
utilities.validation_metrics(torch.tensor(real_y), final_preds, train_nodes, test_nodes)

(0.9579579830169678,
 array([[212,  10,   3,   7],
        [  0, 265,   2,   1],
        [  1,   3, 227,   3],
        [  1,   4,   7, 253]]),
 0.9722838401794434,
 array([[172,   7,   2,   1],
        [  0, 264,   2,   1],
        [  0,   2, 206,   1],
        [  0,   3,   6, 235]]),
 0.8247422575950623,
 array([[40,  3,  1,  6],
        [ 0,  1,  0,  0],
        [ 1,  1, 21,  2],
        [ 1,  1,  1, 18]]))

In [9]:
old_int_df = utilities.run_interpretation(m, X, rpca, final_preds, gene_names, 50)

               activations. The hooks and attributes will be removed
            after the attribution is finished


In [19]:
X,_,_,_ = utilities.preprocess(np.array(counts), scale=False, run_pca=False)
pca = PCA(n_components=500, random_state=8)
new_data = pca.fit(X)
pca_mod = PCAModel(pca.components_, pca.mean_)
seq = torch.nn.Sequential(pca_mod, m)

In [11]:
test_dataset  = torch.utils.data.TensorDataset(torch.tensor(X), torch.tensor(final_preds))
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=50, shuffle=False)

  test_dataset  = torch.utils.data.TensorDataset(torch.tensor(X), torch.tensor(final_preds))


In [12]:
preds[4]

tensor([1.0000e+00, 2.0002e-07, 9.4203e-09, 3.0960e-07],
       grad_fn=<SelectBackward0>)

In [13]:
for X, y in test_dataloader:
    temp_preds = seq(X)
    softmax = torch.nn.Softmax(dim=1)
    temp_preds = softmax(temp_preds)
    print(temp_preds)
    print(temp_preds.max(dim=1)[1])
    break

tensor([[8.9259e-08, 1.0000e+00, 2.9075e-06, 1.2897e-06],
        [2.0151e-07, 1.0000e+00, 4.2631e-08, 4.4288e-06],
        [3.6708e-01, 1.3469e-01, 1.2522e-03, 4.9698e-01],
        [1.6632e-06, 1.3857e-09, 1.4777e-06, 1.0000e+00],
        [1.0000e+00, 2.0002e-07, 9.4204e-09, 3.0961e-07],
        [7.5070e-08, 2.0073e-07, 7.4459e-09, 1.0000e+00],
        [1.2477e-06, 4.2179e-07, 2.3893e-07, 1.0000e+00],
        [1.6091e-04, 9.0509e-04, 7.2134e-05, 9.9886e-01],
        [2.4107e-06, 9.9999e-01, 7.7663e-07, 3.5167e-06],
        [1.0000e+00, 3.5011e-08, 5.9320e-07, 1.9135e-08],
        [1.4316e-06, 1.0137e-08, 1.0000e+00, 3.8611e-08],
        [4.3361e-08, 1.8669e-06, 1.8195e-06, 1.0000e+00],
        [5.0158e-07, 6.1181e-06, 1.1746e-07, 9.9999e-01],
        [9.9997e-01, 6.0984e-06, 1.8858e-05, 5.8216e-06],
        [2.0476e-06, 3.5836e-06, 2.2832e-07, 9.9999e-01],
        [4.2486e-02, 3.3917e-04, 1.7764e-05, 9.5716e-01],
        [1.0423e-06, 2.7161e-07, 4.3051e-08, 1.0000e+00],
        [9.999

In [14]:
dl = DeepLift(seq)
attributions = np.zeros((X.shape[0], X.shape[1]))
temp_atts = None
for data,preds in test_dataloader:
    #baseline = torch.FloatTensor(np.full(data.shape, X.min()))
    baseline = torch.FloatTensor(np.zeros(data.shape))
    #baseline = torch.FloatTensor(np.full(data.shape, X.max()))
    temp = dl.attribute(data.to(m.device), baseline.to(m.device), target=preds.to(m.device), return_convergence_delta=True)[0]
    #temp = dl.attribute(data.to(model.device), target=pred_name).cpu().detach()
    if temp_atts == None: temp_atts = temp
    else:
        temp_atts = torch.cat((temp_atts, temp), 0)
    
attributions = temp_atts

               activations. The hooks and attributes will be removed
            after the attribution is finished


In [15]:
int_df = pd.DataFrame(attributions.detach().numpy())
int_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,16186,16187,16188,16189,16190,16191,16192,16193,16194,16195
0,-0.00000,-0.0,-0.016126,-0.089052,0.0,0.000000,-0.0,-0.000000,0.0,-0.0,...,-0.0,0.0,0.0,0.0,0.000000,-0.0,0.011252,0.0,-0.0,0.0
1,-0.00000,-0.0,-0.000000,-0.000000,0.0,0.000000,-0.0,-0.000000,0.0,-0.0,...,-0.0,0.0,0.0,0.0,0.000000,-0.0,0.000000,0.0,-0.0,0.0
2,0.00000,-0.0,0.000000,0.000000,-0.0,-0.000000,-0.0,0.000000,0.0,0.0,...,0.0,-0.0,-0.0,0.0,-0.000000,0.0,-0.000000,-0.0,0.0,-0.0
3,0.00000,-0.0,-0.000000,0.000000,-0.0,-0.000000,-0.0,0.111724,0.0,0.0,...,0.0,0.0,-0.0,0.0,0.030965,0.0,0.000000,-0.0,0.0,-0.0
4,-0.00000,0.0,0.000000,-0.000000,0.0,0.000000,0.0,-0.000000,-0.0,0.0,...,-0.0,0.0,-0.0,-0.0,0.000000,-0.0,0.000000,0.0,-0.0,-0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
994,-0.00000,-0.0,-0.000000,-0.000000,0.0,0.000000,-0.0,-0.000000,0.0,-0.0,...,0.0,0.0,0.0,0.0,0.000000,-0.0,-0.000000,-0.0,-0.0,0.0
995,0.00000,0.0,0.000000,-0.000000,0.0,0.000000,0.0,-0.000000,-0.0,0.0,...,-0.0,0.0,-0.0,-0.0,0.000000,-0.0,0.000000,0.0,-0.0,-0.0
996,-0.00812,0.0,0.000000,-0.000000,0.0,0.000000,0.0,-0.000000,0.0,-0.0,...,-0.0,-0.0,0.0,0.0,-0.000000,-0.0,0.000000,0.0,-0.0,0.0
997,-0.00000,-0.0,0.000000,-0.040840,0.0,0.000000,-0.0,0.019261,0.0,0.0,...,0.0,0.0,-0.0,-0.0,-0.000000,-0.0,-0.000000,-0.0,0.0,-0.0


In [16]:
zero_class = np.where(final_preds==0)[0]
avg_zero = torch.mean(attributions[zero_class,:], 0)

In [17]:
pd.Series(data=avg_zero.detach().numpy(), index=gene_names).sort_values(ascending=False)

Gene3437     1.781654
Gene15390    1.502186
Gene32664    1.406374
Gene13270    1.229324
Gene30144    1.099090
               ...   
Gene3116    -0.443790
Gene3915    -0.464828
Gene26241   -0.566932
Gene29925   -0.801099
Gene22733   -0.909356
Length: 16196, dtype: float32

In [20]:
int_df = utilities.run_interpretation_new(seq, X, final_preds, gene_names, 50, m.device)

               activations. The hooks and attributes will be removed
            after the attribution is finished


In [22]:
int_df.sort_values(by=0,ascending=False)

Unnamed: 0,0,1,2,3
Gene3437,tensor(1.7817),tensor(-0.2752),tensor(-0.3625),tensor(-0.4078)
Gene15390,tensor(1.5022),tensor(-0.3361),tensor(-0.2582),tensor(-0.4191)
Gene32664,tensor(1.4064),tensor(-0.1714),tensor(-0.1545),tensor(-0.2735)
Gene13270,tensor(1.2293),tensor(-0.1493),tensor(-0.4300),tensor(-0.1779)
Gene30144,tensor(1.0991),tensor(-0.6028),tensor(-0.2078),tensor(0.0219)
...,...,...,...,...
Gene3116,tensor(-0.4438),tensor(2.4696),tensor(-0.1610),tensor(-0.6108)
Gene3915,tensor(-0.4648),tensor(0.8291),tensor(0.7450),tensor(-0.4540)
Gene26241,tensor(-0.5669),tensor(-0.3661),tensor(-0.2608),tensor(1.1565)
Gene29925,tensor(-0.8011),tensor(0.5427),tensor(0.2662),tensor(0.5778)


In [16]:
old_int_df.sort_values(by=0, ascending=False)

Unnamed: 0,0,1,2,3
Gene33251,4.290428,-1.682135,-1.897375,-1.908934
Gene32664,2.917032,-1.377862,-1.515536,-1.983209
Gene3437,2.906509,-0.978128,-1.546086,-1.582908
Gene13270,2.631295,-0.538067,-2.414461,-1.117767
Gene15390,2.622656,-0.850041,-0.923804,-1.285519
...,...,...,...,...
Gene26241,-1.991125,-1.709032,-1.298382,3.995141
Gene20573,-2.011701,1.140697,1.194331,0.440151
Gene3915,-2.050619,3.413482,3.420287,-2.445457
Gene22733,-2.073040,-2.156051,-1.995495,4.689560
