In [None]:
# imports
import sys
sys.path.append('../BioExp')
import keras
import numpy as np
import tensorflow as tf
from keras.models import load_model
import pandas as pd
import os
import pickle

from keras.backend.tensorflow_backend import set_session
from BioExp.helpers.metrics import *
from BioExp.helpers.losses import *

In [None]:
from keras import backend as K
K.tensorflow_backend._get_available_gpus()

In [None]:
# GPU setup
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.5
set_session(tf.Session(config=config))

This method is inspired by IESDS method in game theory, where each and every stratergy fights for its survival. In this approach we generate concepts by clustering the weights in predefined layers of the network.

Now at initial step all the concepts are considered are of equal importance, and the relevance matrix is generated. Based on relevance matric one of more concepts are eliminated and the entire process if repeated till convergance.

By doing this we end with more robust and more important features responsible for higher dice of classification score.

Initial test on Brats with SimUnet model, trained model scores are depected bellow:


| Model Type |     WT Dice | TC Dice  | ET Dice|
|------------|:------------|:---------|:-------|
| DenseUnet  |     0.830   | 0.760    | 0.685  |
| ResUnet    |     0.788   | 0.734    | 0.649  |
| SimUnet    |     0.743   | 0.693    | 0.523  |


In [None]:
# model and parameter defn
# currently using simple U-net

seq_map = {'flair': 0, 't1': 1, 't2': 3, 't1c':2}
seq = 'flair'

model_path        = '../BioExp/saved_models/model_{}_scaled/model-archi.h5'.format(seq)
weights_path      = '../BioExp/saved_models/model_{}_scaled/model-wts-{}.hdf5'.format(seq, seq)

layers_to_consider = ['conv2d_2', 'conv2d_3', 'conv2d_4', 'conv2d_5', 'conv2d_6','conv2d_7', 'conv2d_8', 'conv2d_9',\
                      'conv2d_10', 'conv2d_11', 'conv2d_12', 'conv2d_13', 'conv2d_14', 'conv2d_15', 'conv2d_16',\
                       'conv2d_17','conv2d_18', 'conv2d_19', 'conv2d_20', 'conv2d_21']


model = load_model(model_path, custom_objects={'gen_dice_loss':gen_dice_loss,
                                'dice_whole_metric':dice_whole_metric,
                                'dice_core_metric':dice_core_metric,
                                'dice_en_metric':dice_en_metric})
model.load_weights(weights_path)
model.summary()

In [None]:
from BioExp.clusters import clusters

concept_info = []
node = 0

save_root = './Logs/DiceGraphs/{}/weights_cluster/'.format(seq)
for layer_name in layers_to_consider:
    save_path = os.path.join(save_root, layer_name)
    os.makedirs(save_path, exist_ok = True)
    
    C = clusters.Cluster(model, weights_path, layer_name)
    labels = C.get_clusters(threshold = 0.5, save_path=save_path)
    C.plot_weights(labels, os.path.join(save_path, 'wt-samples'))
    
    for label in np.unique(labels):
        nodename = 'node_{}'.format(node)
        layername = layer_name
        fidxs = np.where(labels==label)[0]
        info = {'concept_name': nodename, 
                  'layer_name': layername, 
                 'filter_idxs': fidxs}
        concept_info.append(info)
        node += 1
        
with open(os.path.join(save_root, 'cluster_info.cpickle'), 'wb') as file:
    pickle.dump(concept_info, file)

-------------
### Delta Graph

$$AM[C^p_i, C^q_j] = \mathbb{E}_{(x, gt)\sim Data} \left (DICE(~ \Phi(x), GT) - DICE(~ \Phi(x ~|~~ do(C^p_i = 0), ~~do(C^q_j = 0)), GT) \right)$$

This is an idea to estimate the importance of two concepts by interventional probability, i.e. finding the effect of pair of concepts by calculating dice difference with and without those pairs of concepts. In the above equation the importance of concepts $C^p_i$ and $C^q_j$ are obtained by calculating difference in dice with and without them.


In [None]:
from BioExp.helpers import utils
from BioExp.graphs import delta

metric = dice_label_coef # defined in BioExp.helpers.metrics

def dataloader(nslice = 78):
    def loader(img_path, mask_path):
        image, gt =  utils.load_vol_brats(img_path, slicen=nslice)
        return image[:,:, seq_map[seq]][:,:, None], gt
    return loader

data_root_path = '../BioExp/sample_vol/brats/'

infoclasses = {}
# for i in range(4): infoclasses['class_'+str(i)] = (i,)
infoclasses['whole'] = (1,2,3,)
infoclasses['ET'] = (3,)
infoclasses['CT'] = (1,3,)

G = delta.DeltaGraph(model, weights_path, metric, classinfo = infoclasses)

save_root = './Logs/DiceGraphs/{}/weights_cluster/'.format(seq)
with open(os.path.join(save_root, 'cluster_info.cpickle'), 'rb') as file:
    concepts_info = pickle.load(file)

In [None]:
# Generate graph AM
save_path =  './Logs/DiceGraphs/{}'.format(seq)
AM = G.generate_graph(concepts_info, 
                 dataset_path = data_root_path, 
                 loader = dataloader(), 
                 save_path = save_path)

save_root = './Logs/DiceGraphs/{}/'.format(seq)
with open(os.path.join(save_root, 'Graph_AMs.cpickle'), 'wb') as file:
    pickle.dump(AM, file)

In [None]:
import matplotlib.pyplot as plt
from pprint import pprint
%matplotlib inline

save_root = './Logs/DiceGraphs/{}/'.format(seq)
with open(os.path.join(save_root, 'Graph_AMs.cpickle'), 'rb') as file:
    AM = pickle.load(file)

for class_ in infoclasses.keys():
    print ("========{}=======".format(class_))
    plt.clf()
    plt.imshow(AM[class_], cmap ='jet', vmin = 0, vmax = 1)
    plt.colorbar()
    plt.savefig(os.path.join(save_root, class_+'.png'), dpi=2000, bbox_inches="tight")
    plt.show()


significance = G.node_significance(concepts_info, dataset_path = data_root_path, loader = dataloader(), save_path = save_root)
pprint(significance)

In [None]:
# modify concepts info
new_nodes = []
T = 0.1
class_ = 'whole'
M = np.array(AM[class_])

for nx in range(M.shape[0]):
    for ny in range(nx):
        if M[nx, ny] < T: continue
        elif M[nx, ny] > T:
            if (M[nx, nx] > T) or (M[ny, ny] > T):
                if M[nx, nx] > T:
                    new_nodes.append(nx)
                if M[ny, ny] > T:
                    new_nodes.append(ny)
            

new_nodes = np.sort(np.unique(new_nodes))
print (np.array(AM[class_])[:, new_nodes][new_nodes, :].shape)
for class_ in infoclasses.keys():
    print ("========{}=======".format(class_))
    plt.clf()
    plt.imshow(np.array(AM[class_])[:, new_nodes][new_nodes, :], cmap ='jet', vmin = 0, vmax = 1)
    plt.colorbar()
    plt.savefig(os.path.join(save_root, 'modified_' +class_+'.png'), dpi=2000, bbox_inches="tight")
    plt.show()