In [None]:
import numpy as np
from pathlib import Path
import pickle
import math
import random
from collections import OrderedDict

import sys
sys.path.append('../')

from mi_estimators.EDGE_4_3_1 import EDGE
from mi_estimators.npeet.entropy_estimators import midd
from mi_estimators.dropout_MI import gaussian_noise_mi

from tqdm import *

import matplotlib
matplotlib.use('TkAgg')

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
sns.set_style('darkgrid')

In [None]:
ip_data = Path("IP") # have IP data ready when information dropout is used
repr_data = Path("representations") # representations saved for computing MI

netw = "LeNet"

p = 0.01
drp_noise = p/(1-p)

### Draw for information dropout

In [None]:
def drawIP(mi_xz, mi_zy, title):
    gs = gridspec.GridSpec(4,2)

    COLORBAR_MAX_EPOCHS=100
    sm = plt.cm.ScalarMappable(cmap='gnuplot', norm=plt.Normalize(vmin=0, vmax=COLORBAR_MAX_EPOCHS))

    n_epoch = len(list(mi_xz.keys()))
    PLOT_LAYERS = [0]
    for epoch in range(n_epoch):
        c = sm.to_rgba(epoch)
        # we saved optimized value (with information dropout), need to add -0.5*log(2*pi*e)-log(c),
        # where c is defining the log-uniform distribution of the ReLU prior
        xmvals = mi_xz[epoch] #- 0.5*np.log(2*math.pi*math.e)
        # we saved crossentropy value - lower bound on MI is -crossentropy + H(Y)
        ymvals = -mi_zy[epoch] + np.log(10)
        #plt.plot(xmvals, ymvals, c=c, alpha=0.5, zorder=1)
        plt.scatter(xmvals, ymvals, s=20, facecolors=[c for _ in PLOT_LAYERS], edgecolor='none', zorder=2) 
    #plt.ylim([1, 3.5])
    #plt.xlim([4, 14])
    plt.xlabel('I(X;Z)')
    plt.ylabel('I(Y;Z)')
    plt.title(title)
    plt.colorbar(sm, label='Epoch')
    plt.tight_layout()
    plt.show()

In [None]:
val_mi_xz = pickle.load(open(ip_data / "val_mi_xz", "rb"))
val_mi_zy = pickle.load(open(ip_data / "val_mi_zy", "rb"))

In [None]:
drawIP(val_mi_xz, val_mi_zy, 'Information dropout, '+netw+' (validation)')

In [None]:
train_mi_xz = pickle.load(open(ip_data / "train_mi_xz", "rb"))
train_mi_zy = pickle.load(open(ip_data / "train_mi_zy", "rb"))

In [None]:
drawIP(train_mi_xz, train_mi_zy, 'Information dropout, '+netw+' (training)')

### Draw the estimations

In [None]:
def drawIP(mi_xz, mi_zy, title, crossentropy_zy=True):
    gs = gridspec.GridSpec(4,2)

    COLORBAR_MAX_EPOCHS=list(mi_xz.keys())[-1]
    sm = plt.cm.ScalarMappable(cmap='gnuplot', norm=plt.Normalize(vmin=0, vmax=COLORBAR_MAX_EPOCHS))

    PLOT_LAYERS = [0]
    for epoch in list(mi_xz.keys()):
        c = sm.to_rgba(epoch)
        xmvals = mi_xz[epoch]
        if crossentropy_zy:
            # we saved crossentropy value - lower bound on MI is -crossentropy + H(Y)
            ymvals = -mi_zy[epoch] + np.log(10)
        else:
            ymvals = mi_zy[epoch]
        plt.scatter(xmvals, ymvals, s=20, facecolors=[c for _ in PLOT_LAYERS], edgecolor='none', zorder=2) 
    ax = plt.gca()
    ax.get_yaxis().get_major_formatter().set_useOffset(False)
    plt.xlabel('I(X;Z)')
    plt.ylabel('I(Y;Z)')
    plt.title(title)
    plt.colorbar(sm, label='Epoch')
    plt.tight_layout()
    plt.show()

In [None]:
if (repr_data / "test_comp_mi_xz").exists():
    test_comp_mi_xz = pickle.load(open(repr_data / "test_comp_mi_xz", "rb"))
    test_comp_mi_zy = pickle.load(open(repr_data / "test_comp_mi_zy", "rb"))
    train_comp_mi_xz = pickle.load(open(repr_data / "train_comp_mi_xz", "rb"))
    train_comp_mi_zy = pickle.load(open(repr_data / "train_comp_mi_zy", "rb"))
else:    
    test_comp_mi_xz = {}
    test_comp_mi_zy = {}
    train_comp_mi_xz = {}
    train_comp_mi_zy = {}

    test_labels = np.load(repr_data / "test_labels.npy")
    train_labels = np.load(repr_data / "train_labels.npy")
    
    train_repeat = 1
    test_repeat = 6

    for f in repr_data.iterdir():
        f = str(f)
        if "test_representations" in f:
            print(f)
            epoch = int(f.split(".")[0].split("_")[-1])
            nonoise_reprs = np.load(f, allow_pickle=True)
            nonoise_reprs, ind = np.unique(nonoise_reprs, axis=0, return_index=True)
            reprs = []
            for nr in nonoise_reprs:
                for i in range(test_repeat):
                    epsilon = np.random.randn(nonoise_reprs.shape[1]) * drp_noise + 1
                    reprs.append(nr*epsilon)
            reprs = np.array(reprs)
            test_comp_mi_xz[epoch] = gaussian_noise_mi(reprs, nonoise_reprs, drp_noise)
            test_comp_mi_zy[epoch] = EDGE(reprs, np.repeat(np.array(test_labels[ind]), test_repeat))
            print(test_comp_mi_xz[epoch], test_comp_mi_zy[epoch])

        if "train_representations" in f:
            print(f)
            epoch = int(f.split(".")[0].split("_")[-1])
            nonoise_reprs = np.load(f, allow_pickle=True)
            reprs = []
            for nr in nonoise_reprs:
                for i in range(train_repeat):
                    epsilon = np.random.randn(nonoise_reprs.shape[1]) * drp_noise + 1
                    reprs.append(nr*epsilon)
            reprs = np.array(reprs)
            train_comp_mi_xz[epoch] = gaussian_noise_mi(reprs, nonoise_reprs, drp_noise)
            train_comp_mi_zy[epoch] = EDGE(reprs, np.repeat(train_labels, train_repeat))
            print(train_comp_mi_xz[epoch], train_comp_mi_zy[epoch])

In [None]:
if not (repr_data / "test_comp_mi_xz").exists():
    pickle.dump(test_comp_mi_xz, open(repr_data / "test_comp_mi_xz", "wb"))
    pickle.dump(test_comp_mi_zy, open(repr_data / "test_comp_mi_zy", "wb"))
    pickle.dump(train_comp_mi_xz, open(repr_data / "train_comp_mi_xz", "wb"))
    pickle.dump(train_comp_mi_zy, open(repr_data / "train_comp_mi_zy", "wb"))

In [None]:
od_test_comp_mi_xz = OrderedDict(sorted(test_comp_mi_xz.items()))
od_test_comp_mi_zy = OrderedDict(sorted(test_comp_mi_zy.items()))
drawIP(od_test_comp_mi_xz, od_test_comp_mi_zy, 'Gaussian dropout, '+netw+' (validation)', crossentropy_zy=False)

In [None]:
drawIP(od_test_comp_mi_xz, val_mi_zy, 'Gaussian dropout, '+netw+' (validation)', crossentropy_zy=True)

In [None]:
od_train_comp_mi_xz = OrderedDict(sorted(train_comp_mi_xz.items()))
od_train_comp_mi_zy = OrderedDict(sorted(train_comp_mi_zy.items()))
drawIP(od_train_comp_mi_xz, od_train_comp_mi_zy, 'Gaussian dropout, '+netw+' (training)', crossentropy_zy=False)

In [None]:
drawIP(od_train_comp_mi_xz, train_mi_zy, 'Gaussian dropout, '+netw+' (training)', crossentropy_zy=True)

### Binning IP

In [None]:
def create_bins(min_bound, max_bound, num_of_bins=None, bin_size=None):
    if bin_size is not None:
        bins = np.arange(min_bound, max_bound, bin_size, dtype='float32')
    elif num_of_bins is not None:
        bins = np.linspace(min_bound, max_bound, num_of_bins, dtype='float32')
    else:
        print("Computation error; set either bin size or number of bins to a value")
        return None
    return bins

In [None]:
def double_bin_calc_information(inputdata, layerdata, num_of_bins=None, bin_size=None):
    bins_inp = create_bins(inputdata.min(), inputdata.max(), num_of_bins=num_of_bins, bin_size=bin_size)
    digitized_inp = bins_inp[np.digitize(np.squeeze(inputdata.reshape(1, -1)), bins_inp) - 1].reshape(len(inputdata), -1)

    bins_rep = create_bins(layerdata.min(), layerdata.max(), num_of_bins=num_of_bins, bin_size=bin_size)
    digitized_rep = bins_rep[np.digitize(np.squeeze(layerdata.reshape(1, -1)), bins_rep) - 1].reshape(len(layerdata), -1)

    return midd(digitized_inp, digitized_rep, base=np.exp(1))

In [None]:
if (repr_data / "test_bin_mi_xz").exists():
    test_bin_mi_xz = pickle.load(open(repr_data / "test_bin_mi_xz", "rb"))
    test_bin_mi_zy = pickle.load(open(repr_data / "test_bin_mi_zy", "rb"))
    train_bin_mi_xz = pickle.load(open(repr_data / "train_bin_mi_xz", "rb"))
    train_bin_mi_zy = pickle.load(open(repr_data / "train_bin_mi_zy", "rb"))
else:
    test_bin_mi_xz = {}
    test_bin_mi_zy = {}
    train_bin_mi_xz = {}
    train_bin_mi_zy = {}

    test_inputs = np.load(repr_data / "test_inputs.npy")
    test_inputs = test_inputs.reshape(test_inputs.shape[0], -1)
    test_labels = np.load(repr_data / "test_labels.npy")
    train_inputs = np.load(repr_data / "train_inputs.npy")
    train_inputs = train_inputs.reshape(train_inputs.shape[0], -1)
    train_labels = np.load(repr_data / "train_labels.npy")
    
    train_repeat = 1
    test_repeat = 6

    for f in repr_data.iterdir():
        f = str(f)
        if "test_representations" in f:
            print(f)
            epoch = int(f.split(".")[0].split("_")[-1])            
            nonoise_reprs = np.load(f)
            nonoise_reprs, ind = np.unique(nonoise_reprs, axis=0, return_index=True)
            reprs = []
            for nr in nonoise_reprs:
                for i in range(test_repeat):
                    epsilon = np.random.randn(nonoise_reprs.shape[1]) * drp_noise + 1
                    reprs.append(nr*epsilon)
            reprs = np.array(reprs)
            test_bin_mi_xz[epoch] = double_bin_calc_information(np.repeat(np.array(test_inputs[ind]), test_repeat, axis=0), 
                                                                reprs, num_of_bins=3)
            test_bin_mi_zy[epoch] = double_bin_calc_information(np.repeat(np.array(test_labels[ind]), test_repeat), 
                                                                reprs, num_of_bins=3)
            print(test_bin_mi_xz[epoch], test_bin_mi_zy[epoch])

        if "train_representations" in f:
            print(f)
            epoch = int(f.split(".")[0].split("_")[-1])
            nonoise_reprs = np.load(f)
            reprs = []
            for nr in nonoise_reprs:
                for i in range(train_repeat):
                    epsilon = np.random.randn(nonoise_reprs.shape[1]) * drp_noise + 1
                    reprs.append(nr*epsilon)
            reprs = np.array(reprs)
            train_bin_mi_xz[epoch] = double_bin_calc_information(np.repeat(np.array(train_inputs), train_repeat, axis=0), 
                                                                 reprs, num_of_bins=3)
            train_bin_mi_zy[epoch] = double_bin_calc_information(np.repeat(np.array(train_labels), train_repeat, axis=0), 
                                                                 reprs, num_of_bins=3)
            print(train_bin_mi_xz[epoch], train_bin_mi_zy[epoch])

In [None]:
if not (repr_data / "test_bin_mi_xz").exists():
    pickle.dump(test_bin_mi_xz, open(repr_data / "test_bin_mi_xz", "wb"))
    pickle.dump(test_bin_mi_zy, open(repr_data / "test_bin_mi_zy", "wb"))
    pickle.dump(train_bin_mi_xz, open(repr_data / "train_bin_mi_xz", "wb"))
    pickle.dump(train_bin_mi_zy, open(repr_data / "train_bin_mi_zy", "wb"))

In [None]:
od_test_bin_mi_xz = OrderedDict(sorted(test_bin_mi_xz.items()))
od_test_bin_mi_zy = OrderedDict(sorted(test_bin_mi_zy.items()))
drawIP(od_test_bin_mi_xz, od_test_bin_mi_zy, 'Gaussian dropout + binning, '+netw+' (validation)', crossentropy_zy=False)

In [None]:
od_train_bin_mi_xz = OrderedDict(sorted(train_bin_mi_xz.items()))
od_train_bin_mi_zy = OrderedDict(sorted(train_bin_mi_zy.items()))
drawIP(od_train_bin_mi_xz, od_train_bin_mi_zy, 'Gaussian dropout + binning, '+netw+' (training)', crossentropy_zy=False)