In [None]:
import numpy as np
import pandas as pd
import os.path as osp
import warnings

import torch
import torch.nn.functional as func
from torch_geometric.loader import DataLoader
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import confusion_matrix

from mydatalist import mydatalist
from Model_CAM import GCN

## Load in Subjects

In [None]:
labelCSV='/path/to/labels/Labels_UPE.csv'
LISTS=pd.read_csv(labelCSV,delimiter=',')
mydata=mydatalist(LISTS.SUBJECTS,LISTS.LABELS)
skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=99)
dataset=mydata

print(LISTS)

## Load in Best Model (fold 5)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(377, 2, 12).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
model.load_state_dict(torch.load('bestUPE_fold05model.pth'))
model.eval()

## Define CAM function

In [None]:
def get_cam(dataset, label, model, device):
    y = torch.cat([data.y.reshape(1) for data in dataset], dim=0)
    idx = (y == label).nonzero().numpy().reshape(-1)

    model.eval()
    w = model.state_dict()['lin1.weight'].detach()
    b = model.state_dict()['lin1.bias'].detach()
    
    class_dataset=[]
    for index in idx:
        class_dataset.append(dataset[index])

    cam_all = []
    for data in class_dataset:
        data = data.to(device)
        _, cam_conv = model(data, cam_required=True)
        cam = torch.matmul(cam_conv, w[int(label)])
        cam_all.append(cam)
    cam_all = torch.stack(cam_all)

    return cam_all

## Get CAM for test subjects

In [None]:
cam_0_dataset = get_cam(dataset, 0, model, device)
cam_1_dataset = get_cam(dataset, 1, model, device)

In [None]:
# Check to make sure we have one per subject in dataset
print(cam_0_dataset.shape)
print(cam_1_dataset.shape)

## Define function for the population saliency map

In [None]:
def pop_saliency(cam, n_top, n_nodes):
    freqs = np.zeros(n_nodes)
    n_top_persub=[]
    
    #Find n_top most discriminatory regions per subject
    for i in range(cam.shape[0]):
        indvcam=cam[i,:]
        nodestop=indvcam.argsort()[-n_top:]
        nodestop=nodestop.numpy()
        n_top_persub.append(nodestop)
    #Find unique
        for m in nodestop:
            freqs[m] += 1
    
    return n_top_persub, freqs

## Get top 10 regions per group

In [None]:
n_10_top_persub_0, freqs_10_0 = pop_saliency(cam_0_dataset, n_top=10, n_nodes=377)
n_10_top_persub_1, freqs_10_1 = pop_saliency(cam_1_dataset, n_top=10, n_nodes=377)

In [None]:
salient_nodes=pd.DataFrame()
salient_nodes['Control']=freqs_10_0
salient_nodes['UPE']=freqs_10_1

In [None]:
salient_nodes.to_csv('UPE_10_salientnodes_allsubs.csv')

## Take avg of CAM output

In [None]:
#Take average and rank for controls
cam_0_dataset_np=cam_0_dataset.detach().numpy()
print(cam_0_dataset_np.shape)
cam_0_groupavg=np.mean(cam_0_dataset_np, axis=0)
print(cam_0_groupavg.shape)
rank_cam_0=cam_0_groupavg.argsort()

In [None]:
#Take average and rank for UPE
cam_1_dataset_np=cam_1_dataset.detach().numpy()
print(cam_1_dataset_np.shape)
cam_1_groupavg=np.mean(cam_1_dataset_np, axis=0)
print(cam_1_groupavg.shape)
rank_cam_1=cam_1_groupavg.argsort()

In [None]:
#Take average and rank for all subjects
combined_cam=np.concatenate((cam_0_dataset_np, cam_1_dataset_np))
print(combined_cam.shape)
cam_datasetavg=np.mean(combined_cam, axis=0)
print(cam_datasetavg.shape)
rank_cam_all=cam_datasetavg.argsort()

In [None]:
avgsalient_nodes=pd.DataFrame()
avgsalient_nodes['Control_rank']=rank_cam_0
avgsalient_nodes['Control_CAM']=cam_0_groupavg
avgsalient_nodes['UPE_rank']=rank_cam_1
avgsalient_nodes['UPE_CAM']=cam_1_groupavg
avgsalient_nodes['allsubs_rank']=rank_cam_all
avgsalient_nodes['allsubs_cam']=cam_datasetavg

In [None]:
avgsalient_nodes.to_csv('avg_salientnodes_allsubs.csv')