In [7]:
import numpy as np
import pandas as pd
import os
import json
from captum.attr import IntegratedGradients,Occlusion
from nilearn.image import load_img,resample_img 

import torch
from torch import nn 
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

from custom_dataset import CustomDataset
from network import Network
from utils import *
import torch.nn.functional as F

In [8]:
parent_directory = '/data/users2/pnadigapusuresh1/JobOutputs'
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

model = Network()
model.fc1 = nn.Sequential(nn.Linear(512,2))
model = nn.DataParallel(model)
model.to(device)

# Loading the model from Job 5436878
#loading model from 6066159
#load_path = os.path.join(parent_directory,'6066159','models','epoch_170')

#loading model from 1780660
load_path = os.path.join(parent_directory,'1818979','models_fold','5','epoch_38')
model.load_state_dict(torch.load(load_path))
model.eval()

torch.manual_seed(52)
np.random.seed(52)
# number of subprocesses to use for data loading
num_workers = 1
# how many samples per batch to load
batch_size = 1

valid_data = CustomDataset(train= False,valid=False)

# get filtered variables
vars = valid_data.vars.iloc[valid_data.test_idx]

valid_sampler = SubsetRandomSampler(valid_data.female_idx)

valid_loader = DataLoader(valid_data,batch_size=batch_size, 
                            sampler= valid_sampler, num_workers=num_workers)

X_all = np.zeros((121,145,121))
for X,y,age in valid_loader:
    X_all = np.add(X_all , X.squeeze())
X_all /= len(valid_loader)
X_all = np.expand_dims(np.expand_dims(X_all,axis =0),axis=0)
X_all = torch.tensor(X_all).float().to(device)

ig = Occlusion(model)

attr_0 = attr_1 = np.zeros((121,145,121),dtype = np.float64)

with open('region_labels.json','r') as f:
    l = json.load(f)

Using cuda device


In [None]:
imf = load_img('/trdapps/linux-x86_64/matlab/toolboxes/spm12/tpm/labels_Neuromorphometrics.nii')
aal = load_img('/data/users2/pnadigapusuresh1/Downloads/AAL3/AAL3v1.nii.gz')
aal_resampled = resample_img('/data/users2/pnadigapusuresh1/Downloads/AAL3/aal.nii.gz',target_affine=imf.affine,target_shape=imf.shape).get_fdata()
df = pd.read_csv('/data/users2/pnadigapusuresh1/Downloads/AAL3/aal.nii.txt',sep=' ',index_col=0,header=None,usecols=[0,1],names=['value','regions'])
l = df.to_dict()['regions']

In [None]:
labels = {v:0 for k,v in l.items()}
labels['age'] = 0
relevance_scores = []

In [None]:
for X,y,age in valid_loader:
    X,y = X.to(device),y.to(device)
    pred = torch.squeeze(model(torch.unsqueeze(X,1).float()))
    soft_max = F.softmax(pred,dim=0)
    if soft_max.argmax() != y:
        continue
    unoccluded_prob = soft_max[y].data
    # compute occlusion
    labels_copy = labels.copy()
    for k,v in l.items():
        # occlude image
        X_copy = X.clone().squeeze()
        X_copy[aal_resampled == k] = 0
        X_copy = torch.unsqueeze(X_copy,0)
        pred = torch.squeeze(model(torch.unsqueeze(X_copy,1).float()))
        soft_max = F.softmax(pred,dim=0)
        occluded_prob = soft_max[y].data
        labels_copy[l[k]] = (unoccluded_prob - occluded_prob).detach().cpu().numpy()[0]
    else:
        labels_copy['age'] = age.item()
        labels_copy['memory'] = y.item()
        relevance_scores.append(labels_copy)
    

In [None]:
female_df_occlusion = pd.DataFrame.from_dict(relevance_scores)
female_df_occlusion.to_csv('female_df_occ_wrong.csv',',')