In [951]:
import argparse
import torch
from torch.autograd import Variable
from torch import nn
from seq2seq.util.checkpoint import Checkpoint
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import os
from scipy.stats import norm
import matplotlib.mlab as mlab
import scipy
from mpl_toolkits.axes_grid1 import AxesGrid

In [1012]:
class Model(object):
    def __init__(self, model):
        self.model = model
        params = {}
        for name, param in self.model.named_parameters():
            param = param.data.numpy()
            params[name]=pd.DataFrame(param)
        self.params = params
        
    def get_param_names(self):
        return [name for name, _ in self.model.named_parameters()]

    def get_modules(self):
        return [mod for mod in self.model.modules()]

    def get_param_by_name(self, name):
        return pd.DataFrame(self.params[name])

    def heatmap(self):
        return {k: sns.heatmap(v) for k, v in self.params.items()}

    def apply_mean(self):
        return {k: np.ravel(v).mean() if v.shape != (1,1) else np.NaN for k, v in self.params.items()}
       
    def apply_std(self):
        return {k: np.ravel(v).std() if v.shape != (1,1) else np.NaN for k, v in self.params.items()}
    
    def apply_min(self):
        return {k: np.ravel(v).min() for k, v in self.params.items()}
    
    def apply_max(self):
        return {k: np.ravel(v).max() for k, v in self.params.items()}
        
    def apply_norm(self):
        return {k: np.linalg.norm(np.ravel(v)) if v.shape != (1,1) else np.NaN for k, v in self.params.items()}
            
    def param_to_dist(self,name):
        data = self.params[name]
        data = np.ravel(data)
#         # best fit of data
#         (mu, sigma) = norm.fit(data)

#         # the histogram of the data
#         n, bins, patches = plt.hist(data, 20, normed=1)

#         # add a 'best fit' line
#         y = mlab.normpdf( bins, mu, sigma)
#         l = plt.plot(bins, y, 'r--', linewidth=2)
#        return scipy.stats.norm(mu, sigma)
        hist, _ = np.histogram(data, bins=50, range=[-1, 1], density=True)
        return hist
        

            
class Models(object):
    def mean_of(self, data):
        one_key = list(data.keys())[0]
        return {param: np.mean([data[name][param] for name in data.keys()]) for param in data[one_key].keys()}
    
    def load_models(self):
        models = {}
        files = os.listdir(self.model_path)
        for file in files: 
            if not file.startswith('.'):
                print('loading: ', self.model_path + '/' + file)
                checkpoint = Checkpoint.load(self.model_path + '/' + file)
                seq2seq = checkpoint.model
                models[file] = Model(seq2seq)
        return models
    
    def __init__(self, model_path,title):
        
        self.image_folder = 'images/'
        self.title = title
        self.model_path = model_path
        
        self.models = self.load_models()

        ## calculate mean, std, norm
        self.means = {name: model.apply_mean() for name,model in self.models.items()}
        self.stds = {name : model.apply_std() for name,model in self.models.items()}
        self.norms = {name: model.apply_norm() for name, model in self.models.items()}
        self.mins = {name: model.apply_min() for name, model in self.models.items()}
        self.maxs = {name: model.apply_max() for name, model in self.models.items()}
        
        ## caluclate mean of means, stds, norms
        self.mean_of_means = self.mean_of(self.means)
        self.mean_of_stds = self.mean_of(self.stds)
        self.mean_of_norms = self.mean_of(self.norms) 
        self.maxs = self.mean_of(self.maxs) 
        self.mins = self.mean_of(self.mins) 
        
        # fill data into df 
        df = pd.DataFrame.from_dict(self.mean_of_means,  orient='index')
        df = df.rename(columns={0: 'mean of means'})
        df['mean of stds'] = self.mean_of_stds.values()
        df['mean of norms'] = self.mean_of_norms.values()
        df['mean of maxs'] = self.maxs.values()
        df['mean of mins'] = self.mins.values()
        self.df = df        
        
    
    def apply_heatmap(self):
        for model_name, model in self.models.items():
            for param_name in self.models[model_name].params.keys():
                plt.figure()
                sns.heatmap(model.params[param_name])
                plt.title(self.title + ' (' + model_name + ') - \n heatmap of param: ' + param_name)
                plt.savefig('images/' + self.title + '_' + model_name + '_' + param_name + '.png', dpi=300)
                plt.show()
                
    def apply_heatmap_by_name(self,param_name):
        for model_name, model in self.models.items():
            plt.figure()
            sns.heatmap(model.params[param_name])
            plt.title(self.title + ' (' + model_name + ') - \n heatmap of param: ' + param_name)
            plt.savefig(image_folder + self.title + '_' + model_name + '_' + param_name + '.png', dpi=300)
            plt.show()    

In [1013]:
guided_gru = Models('../machine-zoo/guided/gru', 'Guided_GRU')
baseline_gru = Models('../machine-zoo/baseline/gru', 'Baseline_GRU')

guided_lstm = Models('../machine-zoo/guided/lstm', 'Guided_LSTM')
baseline_lstm = Models('../machine-zoo/baseline/lstm', 'Baseline_LSTM')

loading:  ../machine-zoo/guided/gru/1
loading:  ../machine-zoo/guided/gru/2
loading:  ../machine-zoo/guided/gru/3
loading:  ../machine-zoo/guided/gru/4
loading:  ../machine-zoo/guided/gru/5
loading:  ../machine-zoo/baseline/gru/1
loading:  ../machine-zoo/baseline/gru/2
loading:  ../machine-zoo/baseline/gru/3
loading:  ../machine-zoo/baseline/gru/4
loading:  ../machine-zoo/baseline/gru/5
loading:  ../machine-zoo/guided/lstm/1
loading:  ../machine-zoo/guided/lstm/2
loading:  ../machine-zoo/guided/lstm/3
loading:  ../machine-zoo/guided/lstm/4
loading:  ../machine-zoo/guided/lstm/5
loading:  ../machine-zoo/baseline/lstm/1
loading:  ../machine-zoo/baseline/lstm/2
loading:  ../machine-zoo/baseline/lstm/3
loading:  ../machine-zoo/baseline/lstm/4
loading:  ../machine-zoo/baseline/lstm/5


### Sizes of layers

In [1014]:
for name, params in guided_lstm.models['1'].params.items():
    print(params.shape, name)

(19, 16) encoder.embedding.weight
(2048, 16) encoder.rnn.weight_ih_l0
(2048, 512) encoder.rnn.weight_hh_l0
(2048, 1) encoder.rnn.bias_ih_l0
(2048, 1) encoder.rnn.bias_hh_l0
(2048, 512) decoder.rnn.weight_ih_l0
(2048, 512) decoder.rnn.weight_hh_l0
(2048, 1) decoder.rnn.bias_ih_l0
(2048, 1) decoder.rnn.bias_hh_l0
(11, 512) decoder.embedding.weight
(512, 1024) decoder.attention.method.mlp.weight
(512, 1) decoder.attention.method.mlp.bias
(1, 512) decoder.attention.method.out.weight
(1, 1) decoder.attention.method.out.bias
(11, 512) decoder.out.weight
(11, 1) decoder.out.bias
(512, 1024) decoder.ffocus_merge.weight
(512, 1) decoder.ffocus_merge.bias


## Statistical Analysis

Guided GRU

In [1015]:
guided_gru.df

Unnamed: 0,mean of means,mean of stds,mean of norms,mean of maxs,mean of mins
encoder.embedding.weight,-0.000959,0.139349,2.429867,0.376077,-0.403524
encoder.rnn.weight_ih_l0,-0.000206,0.080538,12.626729,0.542031,-0.511061
encoder.rnn.weight_hh_l0,5.8e-05,0.065287,57.897675,0.41264,-0.394051
encoder.rnn.bias_ih_l0,-0.012083,0.072939,2.898563,0.176322,-0.31206
encoder.rnn.bias_hh_l0,-0.013354,0.070808,2.82467,0.194362,-0.297983
decoder.rnn.weight_ih_l0,4.1e-05,0.065539,58.121387,0.742477,-0.713911
decoder.rnn.weight_hh_l0,-5.6e-05,0.058843,52.185558,0.291974,-0.281988
decoder.rnn.bias_ih_l0,-0.025098,0.059747,2.540183,0.144791,-0.195006
decoder.rnn.bias_hh_l0,-0.025394,0.059537,2.536941,0.131834,-0.197489
decoder.embedding.weight,0.000316,0.147676,11.082689,0.395972,-0.39151


Baseline GRU

In [1016]:
baseline_gru.df

Unnamed: 0,mean of means,mean of stds,mean of norms,mean of maxs,mean of mins
encoder.embedding.weight,-0.001759272,0.146639,2.557299,0.466294,-0.440219
encoder.rnn.weight_ih_l0,0.001795653,0.132039,20.707676,0.657196,-0.636781
encoder.rnn.weight_hh_l0,0.0002041644,0.08803,78.066849,0.566109,-0.596124
encoder.rnn.bias_ih_l0,-0.01620729,0.092863,3.695691,0.276811,-0.357539
encoder.rnn.bias_hh_l0,-0.01591059,0.092919,3.696127,0.273887,-0.335439
decoder.rnn.weight_ih_l0,-0.0001277724,0.068788,61.002495,0.457143,-0.430959
decoder.rnn.weight_hh_l0,-0.0001709472,0.075933,67.339149,0.425063,-0.43239
decoder.rnn.bias_ih_l0,-0.02985862,0.075753,3.191333,0.192314,-0.249801
decoder.rnn.bias_hh_l0,-0.02878932,0.075507,3.167907,0.207104,-0.249486
decoder.embedding.weight,-6.870941e-05,0.084614,6.350204,0.323575,-0.334121


Guided LSTM

In [1017]:
guided_lstm.df

Unnamed: 0,mean of means,mean of stds,mean of norms,mean of maxs,mean of mins
encoder.embedding.weight,0.002058,0.179018,3.123551,0.543385,-0.524266
encoder.rnn.weight_ih_l0,0.000245,0.113211,20.516216,0.710162,-0.710747
encoder.rnn.weight_hh_l0,2.6e-05,0.071164,72.872803,0.551239,-0.668254
encoder.rnn.bias_ih_l0,-0.010346,0.082105,3.746127,0.408405,-0.251434
encoder.rnn.bias_hh_l0,-0.010515,0.081048,3.700511,0.403438,-0.238568
decoder.rnn.weight_ih_l0,9.5e-05,0.080769,82.708237,0.941351,-0.865073
decoder.rnn.weight_hh_l0,-0.00013,0.073191,74.948914,0.402468,-0.391413
decoder.rnn.bias_ih_l0,-0.024615,0.059836,2.930603,0.15243,-0.214242
decoder.rnn.bias_hh_l0,-0.024132,0.060696,2.957357,0.167402,-0.21623
decoder.embedding.weight,0.000231,0.17974,13.489433,0.491869,-0.492699


Baseline LSTM

In [1018]:
baseline_lstm.df

Unnamed: 0,mean of means,mean of stds,mean of norms,mean of maxs,mean of mins
encoder.embedding.weight,-0.001442,0.162086,2.826329,0.492238,-0.502056
encoder.rnn.weight_ih_l0,0.000913,0.146842,26.58865,0.684907,-0.707094
encoder.rnn.weight_hh_l0,-0.000176,0.089515,91.664444,0.570511,-0.610069
encoder.rnn.bias_ih_l0,-0.037526,0.090355,4.483636,0.339953,-0.328324
encoder.rnn.bias_hh_l0,-0.038465,0.090417,4.501397,0.307354,-0.318395
decoder.rnn.weight_ih_l0,-0.000139,0.080955,82.899498,0.517021,-0.515518
decoder.rnn.weight_hh_l0,-0.000294,0.084335,86.370056,0.489893,-0.495761
decoder.rnn.bias_ih_l0,-0.035166,0.06542,3.36968,0.178394,-0.243665
decoder.rnn.bias_hh_l0,-0.033474,0.065755,3.347765,0.171613,-0.239341
decoder.embedding.weight,-0.000122,0.132717,9.960362,0.474859,-0.472432


## Analysis of Distribution

GRU Guided vs. Baseline

In [1004]:
class Analysis(object):
    
    def return_intersection(self, hist_1, hist_2):
        minima = np.minimum(hist_1, hist_2)
        intersection = np.true_divide(np.sum(minima), np.sum(hist_2))
        return intersection

    def KL(self,dist_1, dist_2):
        x = np.linspace(-1, 1, 100)
        return scipy.stats.entropy(dist_1.pdf(x),dist_2.pdf(x))  
    
    def apply_dist(self):
        dist = {}
        for model_name_A in self.models_A.keys():
            for model_name_B in self.models_B.keys():
                per_model = {}
                for param in self.models_A[list(self.models_A.keys())[0]].params.keys():
                    if not self.models_A[model_name_A].params[param].shape == (1,1):
                        dist_1 = self.models_A[model_name_A].param_to_dist(param)
                        dist_2 = self.models_B[model_name_B].param_to_dist(param)
                        per_model[param] = self.return_intersection(dist_1, dist_2)
                key = model_name_A + '_' + model_name_B
                dist[key] = per_model
        return pd.DataFrame.from_dict(dist, orient='index')
    
    def __init__(self, models_A, models_B):
        self.models_A = models_A.models
        self.models_B = models_B.models
        self.dist = self.apply_dist()
        
    def apply_dist_by_name(self,param):
        for model_name in self.models_A.keys():
            assert(self.models_A[model_name].params[param].shape != (1,1))
            dist_1 = self.models_A[model_name].param_to_dist(param)
            dist_2 = self.models_B[model_name].param_to_dist(param)
            print(self.return_intersection(dist_1, dist_2))

In [1005]:
analysis_GRU = Analysis(baseline_gru, guided_gru)

In [1007]:
analysis_GRU.dist.describe()

Unnamed: 0,encoder.embedding.weight,encoder.rnn.weight_ih_l0,encoder.rnn.weight_hh_l0,encoder.rnn.bias_ih_l0,encoder.rnn.bias_hh_l0,decoder.rnn.weight_ih_l0,decoder.rnn.weight_hh_l0,decoder.rnn.bias_ih_l0,decoder.rnn.bias_hh_l0,decoder.embedding.weight,decoder.attention.method.mlp.weight,decoder.attention.method.mlp.bias,decoder.attention.method.out.weight,decoder.out.weight,decoder.out.bias,decoder.ffocus_merge.weight,decoder.ffocus_merge.bias
count,25.0,25.0,25.0,25.0,25.0,25.0,25.0,25.0,25.0,25.0,25.0,25.0,25.0,25.0,25.0,25.0,25.0
mean,0.868816,0.744264,0.864755,0.875599,0.872526,0.933552,0.892613,0.893203,0.893776,0.566214,0.941055,0.874766,0.565773,0.918111,0.592727,0.956731,0.786563
std,0.01662,0.019523,0.019201,0.031462,0.034271,0.012509,0.015916,0.016076,0.022599,0.032272,0.017775,0.024357,0.116134,0.015593,0.141616,0.011778,0.034763
min,0.845395,0.703044,0.826725,0.827474,0.830078,0.912841,0.861463,0.867839,0.848307,0.517578,0.91016,0.824219,0.365234,0.891335,0.363636,0.936085,0.730469
25%,0.858553,0.732178,0.854861,0.847005,0.847656,0.9234,0.880046,0.882812,0.878255,0.540661,0.933981,0.859375,0.474609,0.908913,0.454545,0.947453,0.765625
50%,0.865132,0.744303,0.865724,0.88151,0.863281,0.931273,0.892869,0.891927,0.893229,0.5625,0.938148,0.876953,0.582031,0.919212,0.545455,0.955082,0.777344
75%,0.881579,0.757406,0.875916,0.890625,0.882161,0.945811,0.905159,0.902995,0.910156,0.58647,0.948311,0.888672,0.651218,0.93022,0.727273,0.967932,0.818359
max,0.904605,0.776123,0.903155,0.927734,0.940104,0.953189,0.923272,0.929036,0.934245,0.64027,0.974909,0.919922,0.757813,0.945845,0.818182,0.973284,0.861328


LSTM Guided vs. Baseline

In [1008]:
analysis_LSTM = Analysis(guided_lstm, baseline_lstm)

In [1009]:
analysis_LSTM.dist.describe()

Unnamed: 0,encoder.embedding.weight,encoder.rnn.weight_ih_l0,encoder.rnn.weight_hh_l0,encoder.rnn.bias_ih_l0,encoder.rnn.bias_hh_l0,decoder.rnn.weight_ih_l0,decoder.rnn.weight_hh_l0,decoder.rnn.bias_ih_l0,decoder.rnn.bias_hh_l0,decoder.embedding.weight,decoder.attention.method.mlp.weight,decoder.attention.method.mlp.bias,decoder.attention.method.out.weight,decoder.out.weight,decoder.out.bias,decoder.ffocus_merge.weight,decoder.ffocus_merge.bias
count,25.0,25.0,25.0,25.0,25.0,25.0,25.0,25.0,25.0,25.0,25.0,25.0,25.0,25.0,25.0,25.0,25.0
mean,0.844079,0.844094,0.898595,0.874434,0.86832,0.951818,0.938035,0.910605,0.917969,0.661364,0.942557,0.845234,0.636315,0.887791,0.549091,0.94564,0.860391
std,0.028249,0.090856,0.06463,0.092074,0.096388,0.037297,0.05299,0.028361,0.029646,0.078769,0.053289,0.042907,0.132852,0.033851,0.124427,0.019945,0.036527
min,0.782895,0.65213,0.776758,0.683105,0.677246,0.869842,0.82425,0.864746,0.85498,0.526456,0.830433,0.755859,0.386719,0.832564,0.363636,0.900955,0.787109
25%,0.822368,0.768921,0.862494,0.85791,0.841797,0.935133,0.936102,0.889648,0.891602,0.608132,0.918846,0.822266,0.590539,0.863814,0.454545,0.935379,0.833984
50%,0.851974,0.858276,0.90624,0.910156,0.905762,0.969708,0.953884,0.913086,0.929199,0.65554,0.969986,0.851562,0.649133,0.88157,0.545455,0.947441,0.869141
75%,0.868421,0.913635,0.938403,0.935547,0.943359,0.97751,0.976146,0.927734,0.937012,0.70277,0.983482,0.880859,0.744301,0.911932,0.636364,0.963771,0.880859
max,0.878289,0.968445,0.995316,0.966309,0.96582,0.991642,0.995377,0.95752,0.961914,0.815341,0.991564,0.902344,0.822266,0.956143,0.818182,0.971615,0.923828


## Generate Heatmaps

In [845]:
baseline_lstm.apply_heatmap()
guided_lstm.apply_heatmap()
baseline_gru.apply_heatmap()
guided_gru.apply_heatmap()

