In [263]:
import csv
import importlib
import itertools
import operator
import os
import random
import time
from math import e, exp, log, pi, sqrt

import keras.layers as layers
import copy
import keras.models
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
from keras import backend as K
from keras.callbacks import EarlyStopping
from keras.models import Model
from keras.optimizers import Adam
from keras.utils import np_utils
from numba import guvectorize, jit, njit, prange, vectorize
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score

from os.path import *
import glob
import cnn_builder as cbuild
import cnn_runner as crun
import config
import dr_methods as drm
import feature_interpretation as cnna
import niftiutils.helper_fxns as hf
import niftiutils.private as prv
import niftiutils.transforms as tr
import niftiutils.visualization as vis
import voi_methods as vm

%matplotlib inline
np.set_printoptions(3)

In [2]:
importlib.reload(config)
importlib.reload(hf)
importlib.reload(cbuild)
importlib.reload(crun)
C = config.Config()
T = config.Hyperparams()

In [5]:
Z_reader = ['E103312835_1','12823036_0','12569915_0','E102093118_0','E102782525_0','12799652_0','E100894274_0','12874178_3','E100314676_0','12842070_0','13092836_2','12239783_0','12783467_0','13092966_0','E100962970_0','E100183257_1','E102634440_0','E106182827_0','12582632_0','E100121654_0','E100407633_0','E105310461_0','12788616_0','E101225606_0','12678910_1','E101083458_1','12324408_0','13031955_0','E101415263_0','E103192914_0','12888679_2','E106096969_0','E100192709_1','13112385_1','E100718398_0','12207268_0','E105244287_0','E102095465_0','E102613189_0','12961059_0','11907521_0','E105311123_0','12552705_0','E100610622_0','12975280_0','E105918926_0','E103020139_1','E101069048_1','E105427046_0','13028374_0','E100262351_0','12302576_0','12451831_0','E102929168_0','E100383453_0','E105344747_0','12569826_0','E100168661_0','12530153_0','E104697262_0']
orig_data_dict, num_samples = cbuild._collect_unaug_data()

features_by_cls, feat_count = cnna.collect_features()
feat_count.pop("homogeneous texture")
#feat_count.pop("central scar")
all_features = sorted(list(feat_count.keys()))
cls_features = {f: [c for c in C.classes_to_include if f in features_by_cls[c]] for f in all_features}

Z_features = cnna.get_annotated_files(features_by_cls)
Z_features.pop("homogeneous texture")
#Z_features.pop("central scar")

num_features = len(all_features)

voi_df = drm.get_voi_dfs()[0]
M = keras.models.load_model(join(C.model_dir, "model_reader_new21.hdf5"))

In [None]:
M.layers[5].get_weights()[1]

In [251]:
lesion_id = z_test = Z_reader[0]

In [252]:
cls = voi_df.loc[lesion_id]["cls"]
y_test = np_utils.to_categorical(C.classes_to_include.index(cls), 6)
loss = K.categorical_crossentropy(y_test, M.output)

In [253]:
W = M.get_weights()
del W[20:22]
del W[14:16]
del W[8:10]
del W[-4:-2]

In [282]:
def perturb_weights(W_true, t):
    eps = 1e-5
    W = copy.deepcopy(W_true)
    t_ix = 0
    
    for w_ix in range(len(W)):
        W[w_ix] += eps * np.reshape(t[t_ix:t_ix+W[w_ix].size], W[w_ix].shape)
        t_ix += W[w_ix].size
        
    return W

In [283]:
W_new = perturb_weights(W, g_test)

g = K.gradients(loss, M.trainable_weights)
g_fxn = K.function(M.trainable_weights + [M.input, K.learning_phase()], g)

W_new = perturb_weights(W, g_test)
g_i_plus = g_fxn(W_new + [x_test, 0])
g_i_plus = np.concatenate([x.flatten() for x in g_test_plus], 0)

g_test = g_fxn(W + [x_test, 0])
g_test = np.concatenate([x.flatten() for x in g_test], 0)

In [90]:
g = K.gradients(loss, M.trainable_weights)
g_fxn = K.function([M.input, K.learning_phase()], g)
g_test = g_fxn([x_test, 0])
g_test = np.concatenate([x.flatten() for x in g_test], 0)

In [54]:
x_test = np.load(join(C.orig_dir, cls, lesion_id+".npy"))
x_test = np.expand_dims(x_test,0)

In [None]:
def get_grad(lesion_id, perturb_W=None):
    C = config.Config()
    
    cls = voi_df.loc[lesion_id]["cls"]
    x_i = np.load(join(C.orig_dir, cls, lesion_id+".npy"))
    x_i = np.expand_dims(x_i,0)

    y_i = np_utils.to_categorical(C.classes_to_include.index(cls), 6)
    loss_i = K.categorical_crossentropy(y_i, M.output)

    g_fxn = K.function([M.input, K.learning_phase()], K.gradients(loss_i, M.trainable_weights))
    g_i = g_fxn([x_i, 0])
    g_i = np.concatenate([x.flatten() for x in g_i], 0)
    
    if perturb_W is not None:
        g_fxn = K.function(M.trainable_weights + [M.input, K.learning_phase()], K.gradients(loss_i, M.trainable_weights))
        g_i_plus = g_fxn(perturb_W + [x_i, 0])
        g_i_plus = np.concatenate([x.flatten() for x in g_i_plus], 0)
        
        return g_i, g_i_plus
    
    return g_i

In [336]:
def get_HVP(W_new, Z_sample):
    
    Ht = np.zeros(g_test.shape)
    for lesion_id in Z_sample:
        cls = voi_df.loc[lesion_id]["cls"]
        x_i = np.load(join(C.orig_dir, cls, lesion_id+".npy"))
        x_i = np.expand_dims(x_i,0)

        y_i = np_utils.to_categorical(C.classes_to_include.index(cls), 6)
        loss_i = K.categorical_crossentropy(y_i, M.output)

        g_fxn = K.function(M.trainable_weights + [M.input, K.learning_phase()], K.gradients(loss_i, M.trainable_weights))

        g_i_plus = g_fxn(W_new + [x_i, 0])
        g_i_plus = np.concatenate([x.flatten() for x in g_i_plus], 0)

        g_fxn = K.function([M.input, K.learning_phase()], K.gradients(loss_i, M.trainable_weights))

        g_i = g_fxn([x_i, 0])
        g_i = np.concatenate([x.flatten() for x in g_i], 0)

        Ht += (g_i_plus - g_i)/eps
        
    return Ht / len(Z_sample)

In [328]:
Z_full = np.concatenate([orig_data_dict[cls][-1] for cls in C.classes_to_include],0)

In [329]:
t_k = np.zeros(g_test.shape)
r_k = g_test#-Ht
p_k = r_k

In [337]:
num_iters = 10
for _ in range(num_iters):
    Z_sample = random.sample(list(Z_full), 15)
    
    W_new = perturb_weights(W, p_k)
    Hp = get_HVP(W_new, Z_sample)

    alpha = -np.dot(r_k, r_k) / np.dot(p_k, Hp)
    t_k += alpha * p_k
    r_k2 = r_k - alpha * Hp

    beta = np.dot(r_k2, r_k2) / np.dot(r_k, r_k)
    r_k = r_k2
    p_k = r_k + beta*p_k

In [338]:
s_test = t_k

In [None]:
np.dot()