In [62]:
import numpy as np
import pandas as pd

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import os,csv,math,sys, joblib
import networkx as nx

import matplotlib.pyplot as plt
%matplotlib inline
import collections
from joblib import Parallel, delayed
import itertools
from sklearn.neural_network import MLPRegressor, MLPClassifier
import random
import copy
from sklearn import preprocessing
from sklearn.preprocessing import MinMaxScaler
import sklearn.model_selection, sklearn.preprocessing
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LogisticRegression
from sklearn import linear_model
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
#import pydot
from similaritymeasures import frechet_dist
from captum.attr import LayerConductance, LayerActivation, LayerIntegratedGradients
from captum.attr import IntegratedGradients, DeepLift, GradientShap, NoiseTunnel, FeatureAblation
import json
import tqdm
import matplotlib
seed = 99 # To reproduce the results
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)

In [63]:
data = pd.read_csv("sachs_10000.csv")
data.drop(data.columns[[0]], axis=1, inplace=True)
cols = data.columns
for i in cols:
    data[i] = data[i].map({'LOW': 0, 'AVG': 1, 'HIGH': 2})
mms = MinMaxScaler()
data[['PKC', 'PKA', 'Raf', 'Mek', 'Erk', 'Jnk', 'P38']] = mms.fit_transform(data[['PKC', 'PKA', 'Raf', 'Mek', 'Erk', 'Jnk', 'P38']])
data = data[['PKC', 'PKA', 'Raf', 'Mek', 'Erk', 'Jnk', 'P38', 'Akt']]

In [3]:
data.head()

Unnamed: 0,PKC,PKA,Raf,Mek,Erk,Jnk,P38,Akt
0,0.0,0.5,0.5,0.5,1.0,0.0,0.0,1
1,0.0,0.5,0.5,0.0,0.0,0.5,0.0,0
2,0.5,0.5,0.5,0.5,1.0,0.5,0.5,1
3,0.5,0.5,0.5,0.0,0.0,0.5,0.0,1
4,0.0,0.5,1.0,0.0,1.0,0.5,0.0,1


In [4]:
pkc_interventions = np.linspace(min(data['PKC']), max(data['PKC']), 3)
reg = MLPClassifier().fit(data['PKC'].values.reshape(-1,1), data['Akt'].values)
gt_pkc = reg.predict(pkc_interventions.reshape(-1,1))


pka_interventions = np.linspace(min(data['PKA']), max(data['PKA']), 3)
reg = MLPClassifier().fit(data[['PKA','PKC']].values, data['Akt'].values)
gt_pka = []
for alpha in np.linspace(min(data['PKA']), max(data['PKA']), 3):
    df1 = pd.DataFrame.copy(data[['PKA','PKC']])
    df1['PKA'] = alpha
    gt_pka.append(np.mean(reg.predict(df1.values)))
    

raf_interventions = np.linspace(min(data['Raf']), max(data['Raf']), 3)
reg = MLPClassifier().fit(data[['Raf','PKC', 'PKA']].values, data['Akt'].values)
gt_raf = []
for alpha in np.linspace(min(data['Raf']), max(data['Raf']), 3):
    df1 = pd.DataFrame.copy(data[['Raf','PKC', 'PKA']])
    df1['Raf'] = alpha
    gt_raf.append(np.mean(reg.predict(df1.values)))
    

reg = MLPClassifier().fit(data[['Mek', 'PKA']].values, data['Akt'].values)
gt_mek = []
for alpha in np.linspace(min(data['Mek']), max(data['Mek']), 3):
    df1 = pd.DataFrame.copy(data[['Mek', 'PKA']])
    df1['Mek'] = alpha
    gt_mek.append(np.mean(reg.predict(df1.values)))
    

reg = MLPClassifier().fit(data[['Erk', 'PKA']].values, data['Akt'].values)
gt_erk = []
for alpha in np.linspace(min(data['Erk']), max(data['Erk']), 3):
    df1 = pd.DataFrame.copy(data[['Erk', 'PKA']])
    df1['Erk'] = alpha
    gt_erk.append(np.mean(reg.predict(df1.values)))
    
gt_jnk = np.zeros_like(gt_erk)    
gt_p38 = np.zeros_like(gt_erk)
    
aces_gt=[]
aces_gt.append(gt_pkc-np.mean(gt_pkc))
aces_gt.append(gt_pka-np.mean(gt_pka))
aces_gt.append(gt_raf-np.mean(gt_raf))
aces_gt.append(gt_mek-np.mean(gt_mek))
aces_gt.append(gt_erk-np.mean(gt_erk))
aces_gt.append(gt_jnk-np.mean(gt_jnk))
aces_gt.append(gt_p38-np.mean(gt_p38))
np.save('./aces/aces_gt.npy',aces_gt,allow_pickle=True)

In [5]:
aces_gt

[array([0., 0., 0.]),
 array([ 0.5624, -0.2812, -0.2812]),
 array([-0.2254,  0.0988,  0.1266]),
 array([-0.1915,  0.    ,  0.1915]),
 array([-0.461 , -0.2695,  0.7305]),
 array([0., 0., 0.]),
 array([0., 0., 0.])]

In [64]:
def multi_acc(y_pred, y_test):
    y_pred_softmax = torch.log_softmax(y_pred, dim = 1)
    _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)
    correct_pred = (y_pred_tags == y_test).float()
    acc = correct_pred.sum() / len(correct_pred)
    acc = torch.round(acc * 100)
    return acc

In [65]:
ensemble_size=5 # to get mean and std values

values = list(data.columns.values)
y = data[values[-1:]]
y = np.array(y, dtype='long')
X = data[values[:-1]]
X = np.array(X)

indices = np.random.choice(len(X), len(X), replace=False)
X_values = X[indices]
y_values = y[indices]

# Creating a Train and a Test Dataset
test_size = 1000
val_size = 1000

X_test = X_values[-test_size:]
X_trainval = X_values[:-test_size]
X_val = X_trainval[-val_size:]
X_train = X_trainval[:-val_size]

y_test = y_values[-test_size:]
y_trainval = y_values[:-test_size]
y_val = y_trainval[-val_size:]
y_train = y_trainval[:-val_size]

def rmse(predictions, targets):
    return np.sqrt(((predictions - targets) ** 2).mean())

# Interval / Epochs
interval = 5
epoch = 50
batch_size=64

In [66]:
class samp_network(nn.Module):
    def __init__(self, input_size=1):
        super().__init__()
        self.input_size=input_size
        self.fc1 = nn.Linear(self.input_size, 2)
        self.fc2 = nn.Linear(2, 2)
        self.fc3 = nn.Linear(2, 1)

    def forward(self,x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [67]:
class Model(nn.Module):
    def __init__(self, feature_dim, batch_size=64, device='cpu',sample_size=50000):
        super(Model, self).__init__()
        self.batchsize=batch_size

        self.causal_link_pkc_pka = samp_network()
        self.causal_link_pkc_pka_raf = samp_network(input_size=2)
        self.causal_link_pkc_pka_raf_mek = samp_network(input_size=3)
        self.causal_link_mek_pka_erk = samp_network(input_size=2)
        self.causal_link_pkc_pka_jnk = samp_network(input_size=2)
        self.causal_link_pkc_pka_p38 = samp_network(input_size=2)

        self.batchsize=batch_size
        self.first_layer = nn.Linear(7,6)
        self.second_layer = nn.Linear(6,5)
        self.third_layer = nn.Linear(5,5)
        self.fourth_layer = nn.Linear(5,4)
        self.regression_layer = nn.Linear(4, 3)
        self.sample_size = sample_size

    def forward(self, inp, phase='freeze', inde=0, alpha=0):
        if phase=='freeze':
            x = F.relu(self.first_layer(inp))
            x = F.relu(self.second_layer(x))
            x = F.relu(self.third_layer(x))
            x = F.relu(self.fourth_layer(x))
            prediction = self.regression_layer(x)
            return prediction

        elif phase=='train_dag':

            batch_size = inp.shape[0]
            pkc_sample = inp[:,0].reshape(batch_size, -1)

            pka_sample = self.causal_link_pkc_pka(pkc_sample)
            raf_sample = self.causal_link_pkc_pka_raf(torch.cat((pkc_sample, pka_sample), dim=1))
            mek_sample = self.causal_link_pkc_pka_raf_mek(torch.cat((pkc_sample, pka_sample, raf_sample), dim=1))
            erk_sample = self.causal_link_mek_pka_erk(torch.cat((mek_sample, pka_sample), dim=1))
            jnk_sample = self.causal_link_pkc_pka_jnk(torch.cat((pkc_sample, pka_sample), dim=1))
            p38_sample = self.causal_link_pkc_pka_p38(torch.cat((pkc_sample, pka_sample), dim=1))
            inp = torch.cat((pkc_sample, pka_sample, raf_sample, mek_sample, erk_sample, jnk_sample, p38_sample),dim=1)

            x = F.relu(self.first_layer(inp))
            x = F.relu(self.second_layer(x))
            x = F.relu(self.third_layer(x))
            x = F.relu(self.fourth_layer(x))
            prediction = self.regression_layer(x)
            
            return prediction, pka_sample, raf_sample, mek_sample, erk_sample, jnk_sample, p38_sample

        elif phase=='sample':
            if inde == 0:
                pkc_sample = torch.tensor([alpha]*self.sample_size, dtype=torch.float).reshape(self.sample_size,-1)
                pka_sample = self.causal_link_pkc_pka(torch.tensor(pkc_sample, dtype=torch.float))
                raf_sample = self.causal_link_pkc_pka_raf(torch.cat((torch.tensor(pkc_sample, dtype=torch.float), pka_sample), dim=1))
                mek_sample = self.causal_link_pkc_pka_raf_mek(torch.cat((torch.tensor(pkc_sample, dtype=torch.float), torch.tensor(pka_sample, dtype=torch.float), torch.tensor(raf_sample, dtype=torch.float)), dim=1))
                erk_sample = self.causal_link_mek_pka_erk(torch.cat((torch.tensor(mek_sample, dtype=torch.float), torch.tensor(pka_sample, dtype=torch.float)), dim=1))
                jnk_sample = self.causal_link_pkc_pka_jnk(torch.cat((torch.tensor(pkc_sample, dtype=torch.float), torch.tensor(pka_sample, dtype=torch.float)), dim=1))
                p38_sample = self.causal_link_pkc_pka_p38(torch.cat((torch.tensor(pkc_sample, dtype=torch.float), torch.tensor(pka_sample, dtype=torch.float)), dim=1))
                inp = torch.cat((pkc_sample, pka_sample, raf_sample, mek_sample, erk_sample, jnk_sample, p38_sample),dim=1)
                return inp

            elif inde == 1:
                pkc_sample = torch.tensor(inp[:,0].reshape(self.sample_size,-1), dtype=torch.float)
                pka_sample = torch.tensor([alpha]*self.sample_size, dtype=torch.float).reshape(self.sample_size,-1)
                raf_sample = self.causal_link_pkc_pka_raf(torch.cat((torch.tensor(pkc_sample, dtype=torch.float), pka_sample), dim=1))
                mek_sample = self.causal_link_pkc_pka_raf_mek(torch.cat((torch.tensor(pkc_sample, dtype=torch.float), torch.tensor(pka_sample, dtype=torch.float), torch.tensor(raf_sample, dtype=torch.float)), dim=1))
                erk_sample = self.causal_link_mek_pka_erk(torch.cat((torch.tensor(mek_sample, dtype=torch.float), torch.tensor(pka_sample, dtype=torch.float)), dim=1))
                jnk_sample = self.causal_link_pkc_pka_jnk(torch.cat((torch.tensor(pkc_sample, dtype=torch.float), torch.tensor(pka_sample, dtype=torch.float)), dim=1))
                p38_sample = self.causal_link_pkc_pka_p38(torch.cat((torch.tensor(pkc_sample, dtype=torch.float), torch.tensor(pka_sample, dtype=torch.float)), dim=1))
                inp = torch.cat((pkc_sample, pka_sample, raf_sample, mek_sample, erk_sample, jnk_sample, p38_sample),dim=1)
                return inp

            elif inde == 2:
                pkc_sample = torch.tensor(inp[:,0].reshape(self.sample_size,-1), dtype=torch.float)
                pka_sample = torch.tensor(inp[:,1].reshape(self.sample_size,-1), dtype=torch.float)
                raf_sample = torch.tensor([alpha]*self.sample_size, dtype=torch.float).reshape(self.sample_size,-1)
                mek_sample = self.causal_link_pkc_pka_raf_mek(torch.cat((torch.tensor(pkc_sample, dtype=torch.float), torch.tensor(pka_sample, dtype=torch.float), torch.tensor(raf_sample, dtype=torch.float)), dim=1))
                erk_sample = self.causal_link_mek_pka_erk(torch.cat((torch.tensor(mek_sample, dtype=torch.float), torch.tensor(pka_sample, dtype=torch.float)), dim=1))
                jnk_sample = self.causal_link_pkc_pka_jnk(torch.cat((torch.tensor(pkc_sample, dtype=torch.float), torch.tensor(pka_sample, dtype=torch.float)), dim=1))
                p38_sample = self.causal_link_pkc_pka_p38(torch.cat((torch.tensor(pkc_sample, dtype=torch.float), torch.tensor(pka_sample, dtype=torch.float)), dim=1))
                inp = torch.cat((pkc_sample, pka_sample, raf_sample, mek_sample, erk_sample, jnk_sample, p38_sample),dim=1)
                return inp

            elif inde == 3:
                pkc_sample = torch.tensor(inp[:,0].reshape(self.sample_size,-1), dtype=torch.float)
                pka_sample = torch.tensor(inp[:,1].reshape(self.sample_size,-1), dtype=torch.float)
                raf_sample = torch.tensor(inp[:,2].reshape(self.sample_size,-1), dtype=torch.float)
                mek_sample = torch.tensor([alpha]*self.sample_size, dtype=torch.float).reshape(self.sample_size,-1)
                erk_sample = self.causal_link_mek_pka_erk(torch.cat((torch.tensor(mek_sample, dtype=torch.float), torch.tensor(pka_sample, dtype=torch.float)), dim=1))
                jnk_sample = self.causal_link_pkc_pka_jnk(torch.cat((torch.tensor(pkc_sample, dtype=torch.float), torch.tensor(pka_sample, dtype=torch.float)), dim=1))
                p38_sample = self.causal_link_pkc_pka_p38(torch.cat((torch.tensor(pkc_sample, dtype=torch.float), torch.tensor(pka_sample, dtype=torch.float)), dim=1))
                inp = torch.cat((pkc_sample, pka_sample, raf_sample, mek_sample, erk_sample, jnk_sample, p38_sample),dim=1)
                return inp

            elif inde == 4:
                pkc_sample = torch.tensor(inp[:,0].reshape(self.sample_size,-1), dtype=torch.float)
                pka_sample = torch.tensor(inp[:,1].reshape(self.sample_size,-1), dtype=torch.float)
                raf_sample = torch.tensor(inp[:,2].reshape(self.sample_size,-1), dtype=torch.float)
                mek_sample = torch.tensor(inp[:,3].reshape(self.sample_size,-1), dtype=torch.float)
                erk_sample = torch.tensor([alpha]*self.sample_size, dtype=torch.float).reshape(self.sample_size,-1)
                jnk_sample = self.causal_link_pkc_pka_jnk(torch.cat((torch.tensor(pkc_sample, dtype=torch.float), torch.tensor(pka_sample, dtype=torch.float)), dim=1))
                p38_sample = self.causal_link_pkc_pka_p38(torch.cat((torch.tensor(pkc_sample, dtype=torch.float), torch.tensor(pka_sample, dtype=torch.float)), dim=1))
                inp = torch.cat((pkc_sample, pka_sample, raf_sample, mek_sample, erk_sample, jnk_sample, p38_sample),dim=1)
                return inp


            elif inde == 5:
                pkc_sample = torch.tensor(inp[:,0].reshape(self.sample_size,-1), dtype=torch.float)
                pka_sample = torch.tensor(inp[:,1].reshape(self.sample_size,-1), dtype=torch.float)
                raf_sample = torch.tensor(inp[:,2].reshape(self.sample_size,-1), dtype=torch.float)
                mek_sample = torch.tensor(inp[:,3].reshape(self.sample_size,-1), dtype=torch.float)
                erk_sample = torch.tensor(inp[:,4].reshape(self.sample_size,-1), dtype=torch.float)
                jnk_sample = torch.tensor([alpha]*self.sample_size, dtype=torch.float).reshape(self.sample_size,-1)
                p38_sample = self.causal_link_pkc_pka_p38(torch.cat((torch.tensor(pkc_sample, dtype=torch.float), torch.tensor(pka_sample, dtype=torch.float)), dim=1))
                inp = torch.cat((pkc_sample, pka_sample, raf_sample, mek_sample, erk_sample, jnk_sample, p38_sample),dim=1)
                return inp

            elif inde == 6:
                pkc_sample = torch.tensor(inp[:,0].reshape(self.sample_size,-1), dtype=torch.float)
                pka_sample = torch.tensor(inp[:,1].reshape(self.sample_size,-1), dtype=torch.float)
                raf_sample = torch.tensor(inp[:,2].reshape(self.sample_size,-1), dtype=torch.float)
                mek_sample = torch.tensor(inp[:,3].reshape(self.sample_size,-1), dtype=torch.float)
                erk_sample = torch.tensor(inp[:,4].reshape(self.sample_size,-1), dtype=torch.float)
                jnk_sample = torch.tensor(inp[:,5].reshape(self.sample_size,-1))
                p38_sample = torch.tensor([alpha]*self.sample_size, dtype=torch.float).reshape(self.sample_size,-1)
                inp = torch.cat((pkc_sample, pka_sample, raf_sample, mek_sample, erk_sample, jnk_sample, p38_sample),dim=1)
                return inp

In [68]:
class DataSet(Dataset):
    """
    Custom dataset to load data from csv file
    """
    def __init__(self, dataframe1, dataframe2):
        self.data_points = dataframe1
        self.targets = dataframe2

    def __len__(self):
        return len(self.data_points)

    def __getitem__(self, idx):
        input_point = torch.tensor(np.array(self.data_points[idx]), dtype=torch.float)
        target_point = torch.tensor(np.array(self.targets[idx]), dtype=torch.long)
        return input_point, target_point

# ERM

In [11]:
for ensemble in range(ensemble_size):
    loss_func = nn.CrossEntropyLoss()
    erm_model = Model(7,sample_size=len(data))
    optimizer = optim.Adam([{'params': erm_model.parameters()}], lr = 0.001, weight_decay=1e-4)

    for ep in range(0,epoch): 
        trainval = DataSet(X_train,y_train)
        train_loader = DataLoader(trainval, batch_size=batch_size)
        for input_data, target in train_loader:
            erm_model.zero_grad()
            output = erm_model(input_data)
            loss = loss_func(output,target.squeeze())
            loss.backward()
            optimizer.step()
        if ep%interval == 0:
            print(ep, interval)
            val = DataSet(X_val,y_val)
            val_loader = DataLoader(val, batch_size=1)
            acc_val = 0
            acc_test = 0
            
            for input_data, target in val_loader:
                output = erm_model(input_data)
                acc = multi_acc(output, target.unsqueeze(1))
                acc_val += acc

            print ('validation accuracy:', float(acc_val/len(val_loader)))
            testval = DataSet(X_test,y_test)
            test_loader = DataLoader(testval, batch_size=1)
            
            for input_data, target in test_loader:
                output = erm_model(input_data)
                acc = multi_acc(output, target.unsqueeze(1))
                acc_test += acc
                
            print('test accuracy:', float(acc_test/len(test_loader)))
            print()
    print("************")
    torch.save(erm_model, "models/erm_sachs_"+str(ensemble+1))

0 5
validation accuracy: 60.79999923706055
test accuracy: 59.599998474121094

5 5
validation accuracy: 78.5999984741211
test accuracy: 80.4000015258789

10 5
validation accuracy: 78.5
test accuracy: 80.5

15 5
validation accuracy: 78.69999694824219
test accuracy: 80.5999984741211

20 5
validation accuracy: 78.9000015258789
test accuracy: 81.0999984741211

25 5
validation accuracy: 79.30000305175781
test accuracy: 81.69999694824219

30 5
validation accuracy: 79.30000305175781
test accuracy: 81.69999694824219

35 5
validation accuracy: 79.30000305175781
test accuracy: 81.69999694824219

40 5
validation accuracy: 79.30000305175781
test accuracy: 81.69999694824219

45 5
validation accuracy: 79.30000305175781
test accuracy: 81.69999694824219

************
0 5
validation accuracy: 60.79999923706055
test accuracy: 59.599998474121094

5 5
validation accuracy: 60.79999923706055
test accuracy: 59.599998474121094

10 5
validation accuracy: 73.69999694824219
test accuracy: 73.30000305175781

15 5


# Integrated Gradients

In [12]:

frechet_results = []
rmse_results = []

for ensemble in range(ensemble_size):
    print(ensemble)
    model = torch.load("models/erm_sachs_"+str(ensemble+1))

    data_pkc = data['PKC'].values[:3]
    data_pka = data['PKA'].values[:3]
    data_raf = data['Raf'].values[:3]
    data_mek = data['Mek'].values[:3]
    data_erk = data['Erk'].values[:3]
    data_jnk = data['Jnk'].values[:3]
    data_p38 = data['P38'].values[:3]

    do_pkc = np.linspace(min(data['PKC']), max(data['PKC']), 3)
    do_pka = np.linspace(min(data['PKA']), max(data['PKA']), 3)
    do_raf = np.linspace(min(data['Raf']), max(data['Raf']), 3)
    do_mek = np.linspace(min(data['Mek']), max(data['Mek']), 3)
    do_erk = np.linspace(min(data['Erk']), max(data['Erk']), 3)
    do_jnk = np.linspace(min(data['Jnk']), max(data['Jnk']), 3)
    do_p38 = np.linspace(min(data['P38']), max(data['P38']), 3)


    test_array_pkc = np.stack((data_pkc, data_pka, data_raf, data_mek, data_erk, data_jnk, data_p38), axis=1)
    test_array_pkc = torch.from_numpy(test_array_pkc).type(torch.FloatTensor)

    test_array_pka = np.stack((data_pkc, data_pka, data_raf, data_mek, data_erk, data_jnk, data_p38), axis=1)
    test_array_pka = torch.from_numpy(test_array_pka).type(torch.FloatTensor)

    test_array_raf = np.stack((data_pkc, data_pka, data_raf, data_mek, data_erk, data_jnk, data_p38), axis=1)
    test_array_raf = torch.from_numpy(test_array_raf).type(torch.FloatTensor)

    test_array_mek = np.stack((data_pkc, data_pka, data_raf, data_mek, data_erk, data_jnk, data_p38), axis=1)
    test_array_mek = torch.from_numpy(test_array_mek).type(torch.FloatTensor)

    
    test_array_erk = np.stack((data_pkc, data_pka, data_raf, data_mek, data_erk, data_jnk, data_p38), axis=1)
    test_array_erk = torch.from_numpy(test_array_erk).type(torch.FloatTensor)

    test_array_jnk = np.stack((data_pkc, data_pka, data_raf, data_mek, data_erk, data_jnk, data_p38), axis=1)
    test_array_jnk = torch.from_numpy(test_array_jnk).type(torch.FloatTensor)

    test_array_p38 = np.stack((data_pkc, data_pka, data_raf, data_mek, data_erk, data_jnk, do_p38), axis=1)
    test_array_p38 = torch.from_numpy(test_array_p38).type(torch.FloatTensor)


    ig = IntegratedGradients(model)
    # print(gt_aces[0].shape)

    ig_attr_test_pkc, delta = ig.attribute(test_array_pkc, n_steps=50, return_convergence_delta=True, target=2)
    ig_attr_test_pka, delta = ig.attribute(test_array_pka, n_steps=50, return_convergence_delta=True, target=2)
    ig_attr_test_raf, delta = ig.attribute(test_array_raf, n_steps=50, return_convergence_delta=True, target=2)
    ig_attr_test_mek, delta = ig.attribute(test_array_mek, n_steps=50, return_convergence_delta=True, target=2)
    ig_attr_test_erk, delta = ig.attribute(test_array_erk, n_steps=50, return_convergence_delta=True, target=2)
    ig_attr_test_jnk, delta = ig.attribute(test_array_jnk, n_steps=50, return_convergence_delta=True, target=2)
    ig_attr_test_p38, delta = ig.attribute(test_array_p38, n_steps=50, return_convergence_delta=True, target=2)

    rmse_results.append([rmse(aces_gt[0], np.array(ig_attr_test_pkc[:,0])),
                         rmse(aces_gt[1], np.array(ig_attr_test_pka[:,1])),
                         rmse(aces_gt[2], np.array(ig_attr_test_raf[:,2])),
                         rmse(aces_gt[3], np.array(ig_attr_test_mek[:,3])),
                         rmse(aces_gt[4], np.array(ig_attr_test_erk[:,4])),
                         rmse(aces_gt[5], np.array(ig_attr_test_jnk[:,5])),
                         rmse(aces_gt[6], np.array(ig_attr_test_p38[:,6]))])
    frechet_results.append([frechet_dist(aces_gt[0].reshape(-1, 1), np.array(ig_attr_test_pkc[:,0]).reshape(-1, 1)),
                         frechet_dist(aces_gt[1].reshape(-1, 1), np.array(ig_attr_test_pka[:,1]).reshape(-1, 1)),
                         frechet_dist(aces_gt[2].reshape(-1, 1), np.array(ig_attr_test_raf[:,2]).reshape(-1, 1)),
                         frechet_dist(aces_gt[3].reshape(-1, 1), np.array(ig_attr_test_mek[:,3]).reshape(-1, 1)),
                         frechet_dist(aces_gt[4].reshape(-1, 1), np.array(ig_attr_test_erk[:,4]).reshape(-1, 1)),
                         frechet_dist(aces_gt[5].reshape(-1, 1), np.array(ig_attr_test_jnk[:,5]).reshape(-1, 1)),
                         frechet_dist(aces_gt[6].reshape(-1, 1), np.array(ig_attr_test_p38[:,6]).reshape(-1, 1))])

0
1
2
3
4


In [13]:
rmse_results = np.array(rmse_results)
print("rmse mean: ", np.mean(rmse_results, axis=0))
print("rmse std: ", np.std(rmse_results, axis=0))
print("rmse all features mean: ", np.mean(np.mean(rmse_results, axis=0)))
print("rmse all features std: ", np.mean(np.std(rmse_results, axis=0)))

print("frechet mean: ", np.mean(frechet_results, axis=0))
print("frechet std: ", np.std(frechet_results, axis=0))
print("frechet all features mean: ", np.mean(np.mean(frechet_results, axis=0)))
print("frechet all features std: ", np.mean(np.std(frechet_results, axis=0)))

rmse mean:  [0.08236072 2.29456244 0.15260043 0.20471391 4.33897528 0.08225392
 0.26346759]
rmse std:  [0.072581   1.40182269 0.03694011 0.04234252 3.25117804 0.04907327
 0.18473028]
rmse all features mean:  1.059847755435633
rmse all features std:  0.7198097026594029
frechet mean:  [0.14265295 2.89081229 0.2187971  0.33942284 5.63291074 0.12051722
 0.41775694]
frechet std:  [0.12571399 1.6272485  0.05619806 0.08208235 4.04983256 0.06859059
 0.30959137]
frechet all features mean:  1.3946957246431217
frechet all features std:  0.9027510582110259


# Causal Attributions

In [14]:
n_classes=3
num_feat=7
num_alpha=3

cov = np.cov(X_values, rowvar=False)
mean_vector = np.mean(X_values, axis=0)

aces_ca_total = []

for ensemble in range(ensemble_size):
    ace_ca_total = []
    model = torch.load("models/erm_sachs_"+str(ensemble+1))
    for output_index in range(0,1):
        for t in range(0,num_feat):
            expectation_do_x = []
            inp=copy.deepcopy(mean_vector)
            for x in np.linspace(0, 1, num_alpha):                
                inp[t] = x
                input_torchvar = autograd.Variable(torch.FloatTensor(inp), requires_grad=True)
                output=model(input_torchvar)
                o1=output.data.cpu()
                val=o1.numpy()[output_index]#first term in interventional expectation                                       

                grad_mask_gradient = torch.zeros(n_classes)
                grad_mask_gradient[output_index] = 1.0
                #calculating the hessian
                first_grads = torch.autograd.grad(output.cpu(), input_torchvar.cpu(), grad_outputs=grad_mask_gradient, retain_graph=True, create_graph=True)

                for dimension in range(0,num_feat):#Tr(Hessian*Covariance)
                    if dimension == t:
                        continue
                    temp_cov = copy.deepcopy(cov)
                    temp_cov[dimension][t] = 0.0#row,col in covariance corresponding to the intervened one made 0
                    grad_mask_hessian = torch.zeros(num_feat)
                    grad_mask_hessian[dimension] = 1.0

                    #calculating the hessian
                    hessian = torch.autograd.grad(first_grads, input_torchvar, grad_outputs=grad_mask_hessian, retain_graph=True, create_graph=False)
                    val += np.sum(0.5*hessian[0].data.numpy()*temp_cov[dimension])#adding second term in interventional expectation
                expectation_do_x.append(val)#append interventional expectation for given interventional value
            ace_ca_total.append(np.array(expectation_do_x) - np.mean(np.array(expectation_do_x)))
    aces_ca_total.append(ace_ca_total)
np.save('./aces/sachs_ca_total.npy',aces_ca_total,allow_pickle=True)

In [15]:
rmse_results = []
frechet_results = []

for ensemble in range(ensemble_size):
    rmse_results.append([rmse(aces_gt[0], aces_ca_total[ensemble][0]),
                         rmse(aces_gt[1], aces_ca_total[ensemble][1]),
                         rmse(aces_gt[2], aces_ca_total[ensemble][2]),
                         rmse(aces_gt[3], aces_ca_total[ensemble][3]),
                         rmse(aces_gt[4], aces_ca_total[ensemble][4]),
                         rmse(aces_gt[5], aces_ca_total[ensemble][5]),
                         rmse(aces_gt[6], aces_ca_total[ensemble][6])])
    
    frechet_results.append([frechet_dist(aces_gt[0].reshape(3,1), aces_ca_total[ensemble][0].reshape(3,1)),
                            frechet_dist(aces_gt[1].reshape(3,1), aces_ca_total[ensemble][1].reshape(3,1)),
                            frechet_dist(aces_gt[2].reshape(3,1), aces_ca_total[ensemble][2].reshape(3,1)),
                            frechet_dist(aces_gt[3].reshape(3,1), aces_ca_total[ensemble][3].reshape(3,1)),
                            frechet_dist(aces_gt[4].reshape(3,1), aces_ca_total[ensemble][4].reshape(3,1)),
                            frechet_dist(aces_gt[5].reshape(3,1), aces_ca_total[ensemble][5].reshape(3,1)),
                            frechet_dist(aces_gt[6].reshape(3,1), aces_ca_total[ensemble][6].reshape(3,1))])
    
rmse_results = np.array(rmse_results)
print("rmse mean: ", np.mean(rmse_results, axis=0))
print("rmse std: ", np.std(rmse_results, axis=0))
print("rmse all features mean: ", np.mean(np.mean(rmse_results, axis=0)))
print("rmse all features std: ", np.mean(np.std(rmse_results, axis=0)))

print("frechet mean: ", np.mean(frechet_results, axis=0))
print("frechet std: ", np.std(frechet_results, axis=0))
print("frechet all features mean: ", np.mean(np.mean(frechet_results, axis=0)))
print("frechet all features std: ", np.mean(np.std(frechet_results, axis=0)))

rmse mean:  [0.10176131 2.19258358 0.11545693 0.21651297 2.23200216 0.0742025
 0.09842272]
rmse std:  [0.09715836 0.9016019  0.05735296 0.13859572 0.63006274 0.0440616
 0.06976197]
rmse all features mean:  0.7187060241403819
rmse all features std:  0.27694218134075055
frechet mean:  [0.13223125 2.97188839 0.16011002 0.27970836 3.12497292 0.09462399
 0.12896724]
frechet std:  [0.12704554 1.14069842 0.08258818 0.18027261 0.90048447 0.05460573
 0.09408247]
frechet all features mean:  0.9846431686074393
frechet all features std:  0.36853963237833876


# Causal Shapley Values

In [16]:
# Probability is taken over indices of baseline only
def get_probabiity(unique_count, x_hat, indices_baseline, n):
    if len(indices_baseline) > 0:
        count = 0
        for i in unique_count:
            check = True
            key = np.asarray(i)
            for j in indices_baseline:
                check = check and key[j] == x_hat[j]
            if check:
                count += unique_count[i]
        return count / n
    else:
        return 1


def conditional_prob(unique_count, x_hat, indices, indices_baseline, n):
    numerator_indices = indices + indices_baseline
    numerator = get_probabiity(unique_count, x_hat, numerator_indices, n)
    denominator = get_probabiity(unique_count, x_hat, indices, n)
    try:
        kk = numerator / denominator
    except ZeroDivisionError:
        denominator = 1e-7
        # pass
    return numerator / denominator


def causal_prob(unique_count, x_hat, indices, indices_baseline, causal_struc, n):
    p = 1
    for i in indices_baseline:
        intersect_s, intersect_s_hat = [], []
        intersect_s_hat.append(i)
        if len(causal_struc[str(i)]) > 0:
            for index in causal_struc[str(i)]:
                if index in indices or index in indices_baseline:
                    intersect_s.append(index)
            p *= conditional_prob(unique_count, x_hat, intersect_s, intersect_s_hat, n)
        else:
            p *= get_probabiity(unique_count, x_hat, intersect_s_hat, n)
    return p


def get_baseline(X, model):
    fx = 0
    n_features = X.shape[1]
    X = np.reshape(X, (len(X), 1, n_features))
    for i in X:
        fx += torch.softmax(model(torch.tensor(i, dtype=torch.float)), 0)[0]
    return fx / len(X)

In [17]:
# Returns value from using function for different versions
def get_value(version, permutation, X, x, unique_count, causal_struct, model, N, is_classification, xi):
    # intializing returns
    absolute_diff, f1, f2 = 0, 0, 0
    xi_index = permutation.index(xi)
    indices = permutation[:xi_index + 1]
    indices_baseline = permutation[xi_index + 1:]
    x_hat = np.zeros(N)
    x_hat_2 = np.zeros(N)
    len_X = len(X)
    for j in indices:
        x_hat[j] = x[j]
        x_hat_2[j] = x[j]
    if version == '2' or version == '3' or version == '4':
        proba1, proba2 = 0, 0
        baseline_check_1, baseline_check_2 = [], []
        f1, f2 = 0, 0
        indices_baseline_2 = indices_baseline[:]
        for i in unique_count:
            X = np.asarray(i)
            for j in indices_baseline:
                x_hat[j] = X[j]
                x_hat_2[j] = X[j]

            # No repetition
            # Eg if baseline_indices is null, it'll only run once as x_hat will stay the same over each iteration
            if x_hat.tolist() not in baseline_check_1:
                baseline_check_1.append(x_hat.tolist())
                if version == '2':
                    prob_x_hat = get_probabiity(unique_count, x_hat, indices_baseline, len_X)
                elif version == '3':
                    prob_x_hat = conditional_prob(unique_count, x_hat, indices, indices_baseline, len_X)
                else:
                    prob_x_hat = causal_prob(unique_count, x_hat, indices, indices_baseline, causal_struct, len_X)
                proba1 += prob_x_hat
                x_hat = np.reshape(x_hat, (1, N))
                f1 = f1 + (torch.softmax(model(torch.tensor(x_hat, dtype=torch.float)), 0)[0] * prob_x_hat if is_classification else model.predict(
                    x_hat) * prob_x_hat)

            # xi index will be given to baseline for f2
            x_hat_2[xi] = X[xi]
            if xi not in indices_baseline_2:
                indices_baseline_2.append(xi)

            # No repetition
            indices_2 = indices[:]
            indices_2.remove(xi)
            if x_hat_2.tolist() not in baseline_check_2:
                baseline_check_2.append(x_hat_2.tolist())
                if version == '2':
                    prob_x_hat_2 = get_probabiity(unique_count, x_hat_2, indices_baseline_2, len_X)
                elif version == '3':
                    prob_x_hat_2 = conditional_prob(unique_count, x_hat_2, indices_2, indices_baseline_2, len_X)
                else:
                    prob_x_hat_2 = causal_prob(unique_count, x_hat_2, indices_2, indices_baseline_2, causal_struct,
                                               len_X)
                proba2 += prob_x_hat_2
                x_hat_2 = np.reshape(x_hat_2, (1, N))
                f2 = f2 + (torch.softmax(model(torch.tensor(x_hat, dtype=torch.float)), 0)[0] * prob_x_hat_2 if is_classification else model.predict(
                    x_hat_2) * prob_x_hat_2)
            x_hat = np.squeeze(x_hat)
            x_hat_2 = np.squeeze(x_hat_2)
        absolute_diff = abs(f1 - f2)
    elif version == '1':
        f1, f2 = 0, 0
        for i in range(len(X)):
            for j in indices_baseline:
                x_hat[j] = X[i][j]
                x_hat_2[j] = X[i][j]
            x_hat = np.reshape(x_hat, (1, N))
            f1 += torch.softmax(model(torch.tensor(x_hat, dtype=torch.float)), 0)[0] if is_classification else model.predict(x_hat)
            x_hat_2[xi] = X[i][xi]
            x_hat_2 = np.reshape(x_hat_2, (1, N))
            f2 += torch.softmax(model(torch.tensor(x_hat, dtype=torch.float)), 0)[0] if is_classification else model.predict(x_hat_2)
            x_hat = np.squeeze(x_hat)
            x_hat_2 = np.squeeze(x_hat_2)
        absolute_diff = abs(f1 - f2) / len_X
        f1 = f1 / len_X
        f2 = f2 / len_X
    return absolute_diff, f1, f2

In [18]:
def approximate_shapley(version, xi, N, X, x, m, model, unique_count, causal_struct, is_classification,
                        global_shap=False):
    R = list(itertools.permutations(range(N)))
    random.shuffle(R)
    score = 0
    count_negative = 0
    vf1, vf2 = 0, 0
    for i in range(m):
        abs_diff, f1, f2 = get_value(version, list(R[i]), X, x, unique_count, causal_struct, model, N,
                                     is_classification, xi)
        vf1 += f1
        vf2 += f2
        score += abs_diff
        if not global_shap:
            if vf2 > vf1:
                count_negative -= 1
            else:
                count_negative += 1
    if count_negative < 0 and not global_shap:
        score = -1 * score
    return score / m

In [19]:
def shapley(model, version, local_shap=0):
    sigma_phi = 0
    global_shap=True
    causal_struct = None
    try:
        causal_struct = json.load(open('sachs.json', 'rb'))
    except FileNotFoundError:
        pass
    n_features = 7
    unique_count = collections.Counter(map(tuple, X_train[50:75]))
    ##### f(x) with baseline
    f_o = get_baseline(X_train[50:75], model)
    rmse_shapley_values = []
    frechet_shapley_values = []
    shapley_vals = []
    for feature in range(n_features):
        global_shap_score = Parallel(n_jobs=-1)(
            delayed(approximate_shapley)(version, feature, n_features, X_train[50:75], x, math.factorial(n_features), model,
                                         unique_count, causal_struct, True, global_shap) for i, x in
            enumerate(X_train[50:75]))
        
        shapley_vals.append(np.array([i.detach().numpy() for i in global_shap_score]))
    return shapley_vals

In [None]:
shapley_values = []
for ensemble in range(ensemble_size):
    print(ensemble)
    model = torch.load("models/erm_sachs_"+str(ensemble+1))
    s = shapley(model, version='4', local_shap=12)
    shapley_values.append(s)

0
1


In [22]:
rmse_results = []
frechet_results = []
shapley_values = np.array(shapley_values)
for ensemble in range(ensemble_size-1):
    rmses = []
    frechets = []
    for feature in range(7):
        shaps = []
        for inte in [0,0.5,1]:
            indices = X_train[50:75][:, feature] == inte
            shaps.append(np.mean(shapley_values[ensemble][feature][indices,0]))

        rmses.append(rmse(shaps, aces_gt[feature]))
        frechets.append(frechet_dist(np.array(shaps).reshape(-1,1),
        np.mean(aces_gt[feature]).reshape(-1,1)))
    rmse_results.append(rmses)
    frechet_results.append(frechets)
    
print("rmse mean: ", np.mean(rmse_results, axis=0))
print("rmse std: ", np.std(rmse_results, axis=0))
print("rmse all features mean: ", np.mean(np.mean(rmse_results, axis=0)))
print("rmse all features std: ", np.mean(np.std(rmse_results, axis=0)))

print("frechet mean: ", np.mean(frechet_results, axis=0))
print("frechet std: ", np.std(frechet_results, axis=0))
print("frechet all features mean: ", np.mean(np.mean(frechet_results, axis=0)))
print("frechet all features std: ", np.mean(np.std(frechet_results, axis=0)))

rmse mean:  [0.1999237  0.46053444 0.24783358 0.23926059 0.53359662 0.25573431
 0.31379661]
rmse std:  [0.00137769 0.00446622 0.00087216 0.00107753 0.0040468  0.00303377
 0.00321371]
rmse all features mean:  0.3215256935169771
rmse all features std:  0.002583983476831131
frechet mean:  [0.3093913  0.29537311 0.27856761 0.37394828 0.3639802  0.36520084
 0.41422786]
frechet std:  [0.00254364 0.00022662 0.00390349 0.00066014 0.00280709 0.00546883
 0.0025095 ]
frechet all features mean:  0.34295560100248884
frechet all features std:  0.002588472231213391


# CREDO

In [23]:
prior = {0:(lambda ii:ii), 1:(lambda ii:ii), 2:(lambda ii:ii),
        3:(lambda ii:ii), 4:(lambda ii:ii), 5:(lambda ii:0*ii),
        6:(lambda ii:0*ii)}

def get_grad(x, prior):
    a = x.clone().detach().requires_grad_(True)
    for f in prior.keys():
        z = prior[f]
        z = torch.sum(z(a[0][f]), dim=0)
        z.backward()
    return a.grad

def get_grads_to_match(ip, prior):
    return get_grad(ip, prior)

In [24]:
for ensemble in range(ensemble_size):
    loss_func = nn.CrossEntropyLoss()
    credo_model = Model(7,sample_size=len(data))
    optimizer = optim.Adam([{'params': credo_model.parameters()}], lr = 0.001, weight_decay=1e-4)

    for ep in range(0,epoch): 
        trainval = DataSet(X_train,y_train)
        train_loader = DataLoader(trainval, batch_size=batch_size)
        for input_data, target in train_loader:
            credo_model.zero_grad()
            input_data.requires_grad=True
            output = credo_model(input_data)
            
            calc_grads = (autograd.grad(torch.sum(output[0], dim=0), input_data, retain_graph=True, create_graph=True)[0])
            grads_to_match = get_grads_to_match(input_data, prior) 
            hinge_input = torch.abs(grads_to_match - calc_grads)
            loss = loss_func(output,target.squeeze()) + 0.01 * torch.norm(torch.clamp(hinge_input, min=0), p=1)
            
            loss.backward()
            optimizer.step()
        if ep%interval == 0:
            print(ep, interval)
            val = DataSet(X_val,y_val)
            val_loader = DataLoader(val, batch_size=1)
            acc_val = 0
            acc_test = 0
            
            for input_data, target in val_loader:
                output = credo_model(input_data)
                acc = multi_acc(output, target.unsqueeze(1))
                acc_val += acc

            print ('validation accuracy:', float(acc_val/len(val_loader)))
            testval = DataSet(X_test,y_test)
            test_loader = DataLoader(testval, batch_size=1)
            
            for input_data, target in test_loader:
                output = credo_model(input_data)
                acc = multi_acc(output, target.unsqueeze(1))
                acc_test += acc
                
            print('test accuracy:', float(acc_test/len(test_loader)))
            print()
    print("************")
    torch.save(credo_model, "models/credo_sachs_"+str(ensemble+1))

0 5
validation accuracy: 60.79999923706055
test accuracy: 59.599998474121094

5 5
validation accuracy: 78.5
test accuracy: 80.0

10 5
validation accuracy: 78.5
test accuracy: 80.9000015258789

15 5
validation accuracy: 79.19999694824219
test accuracy: 81.4000015258789

20 5
validation accuracy: 79.30000305175781
test accuracy: 81.69999694824219

25 5
validation accuracy: 79.30000305175781
test accuracy: 81.69999694824219

30 5
validation accuracy: 79.30000305175781
test accuracy: 81.69999694824219

35 5
validation accuracy: 79.30000305175781
test accuracy: 81.69999694824219

40 5
validation accuracy: 79.30000305175781
test accuracy: 81.69999694824219

45 5
validation accuracy: 79.30000305175781
test accuracy: 81.69999694824219

************
0 5
validation accuracy: 60.79999923706055
test accuracy: 59.599998474121094

5 5
validation accuracy: 60.79999923706055
test accuracy: 59.599998474121094

10 5
validation accuracy: 78.5
test accuracy: 81.19999694824219

15 5
validation accuracy: 79

In [25]:
n_classes=3
num_feat=7
num_alpha=3

cov = np.cov(X_values, rowvar=False)
mean_vector = np.mean(X_values, axis=0)

aces_credo_total = []

for ensemble in range(ensemble_size):
    ace_credo_total = []
    model = torch.load("models/credo_sachs_"+str(ensemble+1))
    for output_index in range(0,1):
        for t in range(0,num_feat):
            expectation_do_x = []
            inp=copy.deepcopy(mean_vector)
            for x in np.linspace(0, 1, num_alpha):                
                inp[t] = x
                input_torchvar = autograd.Variable(torch.FloatTensor(inp), requires_grad=True)
                output=model(input_torchvar)
                o1=output.data.cpu()
                val=o1.numpy()[output_index]#first term in interventional expectation                                       

                grad_mask_gradient = torch.zeros(n_classes)
                grad_mask_gradient[output_index] = 1.0
                #calculating the hessian
                first_grads = torch.autograd.grad(output.cpu(), input_torchvar.cpu(), grad_outputs=grad_mask_gradient, retain_graph=True, create_graph=True)

                for dimension in range(0,num_feat):#Tr(Hessian*Covariance)
                    if dimension == t:
                        continue
                    temp_cov = copy.deepcopy(cov)
                    temp_cov[dimension][t] = 0.0#row,col in covariance corresponding to the intervened one made 0
                    grad_mask_hessian = torch.zeros(num_feat)
                    grad_mask_hessian[dimension] = 1.0

                    #calculating the hessian
                    hessian = torch.autograd.grad(first_grads, input_torchvar, grad_outputs=grad_mask_hessian, retain_graph=True, create_graph=False)
                    val += np.sum(0.5*hessian[0].data.numpy()*temp_cov[dimension])#adding second term in interventional expectation
                expectation_do_x.append(val)#append interventional expectation for given interventional value
            ace_credo_total.append(np.array(expectation_do_x) - np.mean(np.array(expectation_do_x)))
    aces_credo_total.append(ace_ca_total)
np.save('./aces/sachs_credo_total.npy',aces_credo_total,allow_pickle=True)

In [26]:
rmse_results = []
frechet_results = []

for ensemble in range(ensemble_size):
    rmse_results.append([rmse(aces_gt[0], aces_credo_total[ensemble][0]),
                         rmse(aces_gt[1], aces_credo_total[ensemble][1]),
                         rmse(aces_gt[2], aces_credo_total[ensemble][2]),
                         rmse(aces_gt[3], aces_credo_total[ensemble][3]),
                         rmse(aces_gt[4], aces_credo_total[ensemble][4]),
                         rmse(aces_gt[5], aces_credo_total[ensemble][5]),
                         rmse(aces_gt[6], aces_credo_total[ensemble][6])])
    
    frechet_results.append([frechet_dist(aces_gt[0].reshape(3,1), aces_credo_total[ensemble][0].reshape(3,1)),
                            frechet_dist(aces_gt[1].reshape(3,1), aces_credo_total[ensemble][1].reshape(3,1)),
                            frechet_dist(aces_gt[2].reshape(3,1), aces_credo_total[ensemble][2].reshape(3,1)),
                            frechet_dist(aces_gt[3].reshape(3,1), aces_credo_total[ensemble][3].reshape(3,1)),
                            frechet_dist(aces_gt[4].reshape(3,1), aces_credo_total[ensemble][4].reshape(3,1)),
                            frechet_dist(aces_gt[5].reshape(3,1), aces_credo_total[ensemble][5].reshape(3,1)),
                            frechet_dist(aces_gt[6].reshape(3,1), aces_credo_total[ensemble][6].reshape(3,1))])
    
rmse_results = np.array(rmse_results)
print("rmse mean: ", np.mean(rmse_results, axis=0))
print("rmse std: ", np.std(rmse_results, axis=0))
print("rmse all features mean: ", np.mean(np.mean(rmse_results, axis=0)))
print("rmse all features std: ", np.mean(np.std(rmse_results, axis=0)))

print("frechet mean: ", np.mean(frechet_results, axis=0))
print("frechet std: ", np.std(frechet_results, axis=0))
print("frechet all features mean: ", np.mean(np.mean(frechet_results, axis=0)))
print("frechet all features std: ", np.mean(np.std(frechet_results, axis=0)))

rmse mean:  [0.08285197 3.81488683 0.02983257 0.42860937 2.87066199 0.13646983
 0.04895687]
rmse std:  [0.00000000e+00 4.44089210e-16 0.00000000e+00 5.55111512e-17
 0.00000000e+00 0.00000000e+00 0.00000000e+00]
rmse all features mean:  1.0588956327825614
rmse all features std:  7.137148015447435e-17
frechet mean:  [0.11701838 5.02455987 0.03905235 0.56964442 4.04197361 0.1671408
 0.0688235 ]
frechet std:  [0. 0. 0. 0. 0. 0. 0.]
frechet all features mean:  1.4326018477303641
frechet all features std:  0.0


# AHCE

In [58]:
for ensemble in range(ensemble_size):
    # Interval / Epochs
    
    mse_loss_func = nn.MSELoss()
    loss_func = nn.CrossEntropyLoss()
    ahce_model = Model(7,sample_size=len(data))

    optimizer = optim.Adam([{'params': ahce_model.parameters()}], lr = 0.001, weight_decay=1e-4)

    for ep in range(0,100):
        trainval = DataSet(X_train,y_train)
        train_loader = DataLoader(trainval, batch_size=64)
        for input_data, target in train_loader: 
            for phase in ['train_dag', 'freeze']:
                ahce_model.zero_grad()
                if phase == 'freeze':
                    output = ahce_model(input_data)
                    loss = loss_func(output,target.squeeze())
                    loss.backward(retain_graph=True)

                else:
                    output, pka_sample, raf_sample, mek_sample, erk_sample, jnk_sample, p38_sample  = ahce_model(input_data, phase='train_dag')             
                    
                    loss = 0.01*loss_func(output,target.squeeze()) 
                    loss.backward(retain_graph=True)
                optimizer.step()

        if ep%interval == 0:
            print(ep, interval)
            val = DataSet(X_val,y_val)
            val_loader = DataLoader(val, batch_size=1)
            acc_val = 0
            acc_test = 0

            for input_data, target in val_loader:
                output = ahce_model(input_data)
                acc = multi_acc(output, target.unsqueeze(1))
                acc_val += acc

            print ('validation accuracy:', float(acc_val/len(val_loader)))
            testval = DataSet(X_test,y_test)
            test_loader = DataLoader(testval, batch_size=1)

            for input_data, target in test_loader:
                output = ahce_model(input_data)
                acc = multi_acc(output, target.unsqueeze(1))
                acc_test += acc

            print('test accuracy:', float(acc_test/len(test_loader)))
            print()

    torch.save(ahce_model, "./models/ahce_sachs_"+str(ensemble+1))

0 5
validation accuracy: 60.20000076293945
test accuracy: 61.599998474121094

5 5
validation accuracy: 81.5999984741211
test accuracy: 80.3499984741211

10 5
validation accuracy: 82.19999694824219
test accuracy: 81.25

15 5
validation accuracy: 82.5999984741211
test accuracy: 81.44999694824219

20 5
validation accuracy: 82.5999984741211
test accuracy: 81.44999694824219

25 5
validation accuracy: 82.5999984741211
test accuracy: 81.44999694824219

30 5
validation accuracy: 82.5999984741211
test accuracy: 81.44999694824219

35 5
validation accuracy: 82.5999984741211
test accuracy: 81.44999694824219

40 5
validation accuracy: 82.5999984741211
test accuracy: 81.44999694824219

45 5
validation accuracy: 82.5999984741211
test accuracy: 81.44999694824219

50 5
validation accuracy: 82.5999984741211
test accuracy: 81.44999694824219

55 5
validation accuracy: 82.5999984741211
test accuracy: 81.44999694824219

60 5
validation accuracy: 82.5999984741211
test accuracy: 81.44999694824219

65 5
valida

In [59]:
n_classes=3
num_c=7#no. of features
num_alpha=3

aces_ahce_total = []
for ensemble in range(5):
    ace_ahce_total = []
    model =  torch.load("./models/ahce_sachs_"+str(ensemble+1))
    for output_index in range(1,2):#For every class
        #plt.figure()
        for t in range(0,num_c):#For every feature
            expectation_do_x = []
            for x in [0,0.5,1]:
                X_values[:,t] = x
                sample_data = model(X_values, phase='sample', inde=t, alpha=x).detach().numpy()
                cov = np.cov(sample_data, rowvar=False)
                means = np.mean(sample_data, axis=0)
                cov=np.array(cov)
                mean_vector = np.array(means)
                inp=copy.deepcopy(mean_vector)
                inp[t] = x
                input_torchvar = autograd.Variable(torch.FloatTensor(inp), requires_grad=True)

                output=model(input_torchvar)

                o1=output.data.cpu()
                val=o1.numpy()[output_index]#first term in interventional expectation                                       

                grad_mask_gradient = torch.zeros(n_classes)
                grad_mask_gradient[output_index] = 1.0
                #calculating the hessian
                first_grads = torch.autograd.grad(output.cpu(), input_torchvar.cpu(), grad_outputs=grad_mask_gradient, retain_graph=True, create_graph=True)

                for dimension in range(0,num_c):#Tr(Hessian*Covariance)
                    if dimension == t:
                        continue
                    temp_cov = copy.deepcopy(cov)
                    temp_cov[dimension][t] = 0.0#row,col in covariance corresponding to the intervened one made 0
                    grad_mask_hessian = torch.zeros(num_c)
                    grad_mask_hessian[dimension] = 1.0

                    #calculating the hessian
                    hessian = torch.autograd.grad(first_grads, input_torchvar, grad_outputs=grad_mask_hessian, retain_graph=True, create_graph=False)
                    val += np.sum(0.5*hessian[0].data.numpy()*temp_cov[dimension])#adding second term in interventional expectation
                expectation_do_x.append(val)#append interventional expectation for given interventional value

            ace_ahce_total.append(np.array(expectation_do_x) - np.mean(np.array(expectation_do_x)))

    aces_ahce_total.append(ace_ahce_total)
np.save('./aces/sachs_ahce_total.npy',aces_ahce_total,allow_pickle=True)

  pka_sample = self.causal_link_pkc_pka(torch.tensor(pkc_sample, dtype=torch.float))
  raf_sample = self.causal_link_pkc_pka_raf(torch.cat((torch.tensor(pkc_sample, dtype=torch.float), pka_sample), dim=1))
  mek_sample = self.causal_link_pkc_pka_raf_mek(torch.cat((torch.tensor(pkc_sample, dtype=torch.float), torch.tensor(pka_sample, dtype=torch.float), torch.tensor(raf_sample, dtype=torch.float)), dim=1))
  erk_sample = self.causal_link_mek_pka_erk(torch.cat((torch.tensor(mek_sample, dtype=torch.float), torch.tensor(pka_sample, dtype=torch.float)), dim=1))
  jnk_sample = self.causal_link_pkc_pka_jnk(torch.cat((torch.tensor(pkc_sample, dtype=torch.float), torch.tensor(pka_sample, dtype=torch.float)), dim=1))
  p38_sample = self.causal_link_pkc_pka_p38(torch.cat((torch.tensor(pkc_sample, dtype=torch.float), torch.tensor(pka_sample, dtype=torch.float)), dim=1))
  raf_sample = self.causal_link_pkc_pka_raf(torch.cat((torch.tensor(pkc_sample, dtype=torch.float), pka_sample), dim=1))
  mek_sa

In [60]:
rmse_results = []
frechet_results = []

for ensemble in [0,1,2,3,4]:
    rmse_results.append([rmse(aces_gt[0], aces_ahce_total[ensemble][0]),
                         rmse(aces_gt[1], aces_ahce_total[ensemble][1]),
                         rmse(aces_gt[2], aces_ahce_total[ensemble][2]),
                         rmse(aces_gt[3], aces_ahce_total[ensemble][3]),
                         rmse(aces_gt[4], aces_ahce_total[ensemble][4]),
                         rmse(aces_gt[5], aces_ahce_total[ensemble][5]),
                         rmse(aces_gt[6], aces_ahce_total[ensemble][6])])
    
    frechet_results.append([frechet_dist(aces_gt[0].reshape(3,1), aces_ahce_total[ensemble][0].reshape(3,1)),
                            frechet_dist(aces_gt[1].reshape(3,1), aces_ahce_total[ensemble][1].reshape(3,1)),
                            frechet_dist(aces_gt[2].reshape(3,1), aces_ahce_total[ensemble][2].reshape(3,1)),
                            frechet_dist(aces_gt[3].reshape(3,1), aces_ahce_total[ensemble][3].reshape(3,1)),
                            frechet_dist(aces_gt[4].reshape(3,1), aces_ahce_total[ensemble][4].reshape(3,1)),
                            frechet_dist(aces_gt[5].reshape(3,1), aces_ahce_total[ensemble][5].reshape(3,1)),
                            frechet_dist(aces_gt[6].reshape(3,1), aces_ahce_total[ensemble][6].reshape(3,1))])
    
rmse_results = np.array(rmse_results)
print("rmse mean: ", np.mean(rmse_results, axis=0))
print("rmse std: ", np.std(rmse_results, axis=0))
print("rmse all features mean: ", np.mean(np.mean(rmse_results, axis=0)))
print("rmse all features std: ", np.mean(np.std(rmse_results, axis=0)))

print("frechet mean: ", np.mean(frechet_results, axis=0))
print("frechet std: ", np.std(frechet_results, axis=0))
print("frechet all features mean: ", np.mean(np.mean(frechet_results, axis=0)))
print("frechet all features std: ", np.mean(np.std(frechet_results, axis=0)))

rmse mean:  [0.12475822 0.65547383 0.12277304 0.14353915 0.51898169 0.0186962
 0.02440762]
rmse std:  [0.06767317 0.1756164  0.03766031 0.01556702 0.34269087 0.01361813
 0.01961943]
rmse all features mean:  0.2298042509487662
rmse all features std:  0.0960636188516346
frechet mean:  [0.171739   0.91099781 0.1714031  0.17648882 0.70143739 0.02320759
 0.02991145]
frechet std:  [0.09771076 0.23724553 0.05571379 0.01884757 0.45834208 0.01672321
 0.0240606 ]
frechet all features mean:  0.3121693101549147
frechet all features std:  0.12980622024131902
