In [1]:
import numpy as np
import pandas as pd
from sklearn.model_selection import ParameterGrid
from sklearn.model_selection import KFold

from tensorflow import keras

from dataloader import _clip_class_df, _clip_class_rest_df, K_RUNS
from gru.dataloader import _get_clip_seq as _get_seq
from gru.models import GRUEncoder
from gru.cc_utils import _get_true_class_prob, _gru_acc, _gru_test_acc, _gruenc_test_traj, _compute_saliency_maps
from utils import _info
import argparse
import pickle
import time
import os

from tqdm import tqdm

import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
try:
    tf.config.experimental.set_memory_growth(gpus[0], True)
    tf.config.experimental.set_memory_growth(gpus[1], True)
    tf.config.experimental.set_visible_devices(gpus[1], 'GPU')
except:
    # Invalid device or cannot modify virtual devices once initialized.
    pass

# results directory
RES_DIR = 'results/clip_saliency'
if not os.path.exists(RES_DIR):
    os.makedirs(RES_DIR)
    
K_SEED = 330

In [2]:
class ARGS():
    roi = 300
    net = 7
    subnet = 'wb'
    zscore = 1
    k_fold = 10

    batch_size = 32
    num_epochs = 45
    train_size = 100
    
    roi_name = "roi"
    
    #lstm
    k_hidden = 32
    k_layers = 1
    k_dim = 3
args = ARGS()

## RUN

In [3]:
res_path = (RES_DIR + 
            '/%s_%d_net_%d' %(args.roi_name, args.roi, args.net) +
            '_trainsize_%d' %(args.train_size) +
            '_k_hidden_%d' %(args.k_hidden) +
            '_k_layers_%d_batch_size_%d' %(args.k_layers, args.batch_size) +
            '_num_epochs_%d_z_%d.pkl' %(args.num_epochs, args.zscore))
        
gru_mod_path = res_path.replace('results','models')
gru_mod_path = gru_mod_path.replace('pkl','h5')
gru_model_path = gru_mod_path.replace('saliency', 'gru')
args.gru_model_path = gru_model_path

with open('data/df.pkl', 'rb') as f:
    df = pickle.load(f)

## Prepare DF for gradients

In [4]:
_info('save gradients')

subject_list = np.unique(df['Subject'])
k_class = len(np.unique(df['y']))
print('number of unique sequences = %d' %k_class)

# create columns for gradients
# don't use number of ROI in case of subnetwork
features = [ii for ii in df.columns if 'feat' in ii]
grads = ['grad_%d'%ii for ii in range(len(features))]
for grad in grads:
    df.loc[:, grad] = np.nan

---
save gradients
---
number of unique sequences = 15


## Load Pretrained GRU model

In [5]:
gru_model = keras.models.load_model(args.gru_model_path)
gru_model.trainable = False



## Compute Saliency

In [8]:
for i_class in range(1, k_class):
    for subject in tqdm(subject_list):

        if i_class==0: # must handle test retest differently
            seqs = df[(df['Subject']==subject) & 
                (df['y'] == 0)][features].values
            gradX = np.zeros(seqs.shape)

            k_time = int(seqs.shape[0]/K_RUNS)
            for i_run in range(K_RUNS):
                seq = seqs[i_run*k_time:(i_run+1)*k_time, :]
                if args.zscore:
                    # zscore each seq that goes into model
                    seq = (1/np.std(seq))*(seq - np.mean(seq))

                X = [seq]
                X_padded = tf.keras.preprocessing.sequence.pad_sequences(
                    X, padding="post", dtype='float')

                gX = _compute_saliency_maps(gru_model, X_padded, i_class)
                gradX[i_run*k_time:(i_run+1)*k_time, :] = gX

        else:
            seq = df[(df['Subject']==subject) & 
                (df['y'] == i_class)][features].values
            if args.zscore:
                # zscore each seq that goes into model
                seq = (1/np.std(seq))*(seq - np.mean(seq))

            X = [seq]
            X_padded = tf.keras.preprocessing.sequence.pad_sequences(
                X, padding="post", dtype='float')

            gradX = _compute_saliency_maps(gru_model, X_padded, i_class)
        
        df.loc[(df['Subject']==subject) & 
            (df['y'] == i_class), grads] = gradX.squeeze()

sal_df = df[['Subject', 'timepoint', 'y'] + grads]
    

  2%|▏         | 4/176 [00:14<10:10,  3.55s/it]


KeyboardInterrupt: 