In [41]:
import json
import h5py
import torch
import numpy as np
from IPython.display import HTML
import matplotlib.pyplot as plt
import pickle

In [42]:
method_l = [
    "maxcorr", 
    "mincorr", 
    "maxlinreg", 
    "minlinreg", 
    "cca", 
    "lincka", 
    # "rbfcka",
]

In [43]:
# Set `res_d`, `network_l`, `num_neurons_d`
base = "/data/sls/temp/johnmwu/contextual-corr-analysis/results1_"
res_fname = {method : base + method for method in 
                method_l}

res_d = {}
for method in method_l:
    with open(res_fname[method], 'rb') as f:
        res_d[method] = pickle.load(f)

network_l = [network for network in res_d["maxcorr"]["corrs"]]

num_neurons_d = {}
for network in network_l:
    num_neurons_d[network] = len(next(iter(res_d["maxcorr"]["corrs"][network].values()))) # god this is a hack

In [57]:
def network2mname(network):
    i = network.rfind('_') 
    return network[:i]

In [59]:
def network2fname(network):
    i = network.rfind('_')
    network = network[:i]
    return "/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/{0}/ptb_pos_dev.hdf5".format(network)

In [45]:
# set `h5_d` : {network: h5} and `sentence_d`
h5_d = {}
for network in network_l:
    h5_d[network] = h5py.File(network2fname(network), 'r')

sentence_d = json.loads(h5_d[network_l[0]]['sentence_to_index'][0]) # "sentence dict": {sentence, str ix}
temp = {}
for k, v in sentence_d.items():
    temp[v] = k
sentence_d = temp # now {str ix, sentence}

In [54]:
# set stats
# `means_d`, `stdevs_d`, `max_d`, `min_d`
fname = "/data/sls/temp/johnmwu/contextual-corr-analysis/stats"
with open(fname, "rb") as f:
    d = pickle.load(f)
    means_d = d["mean"]
    stdevs_d = d["std"]
    max_d = d["max"]
    min_d = d["min"]

In [55]:
# set `neuron_sorts`
rk_methods = {"maxcorr", "mincorr", "maxlinreg", "minlinreg"}
neuron_sorts = {network: {} for network in network_l}
for network in network_l:
    for method in rk_methods:
        neuron_sorts[network][method] = res_d[method]["neuron_sort"][network]

# Visualizations

In [64]:
s_ix_l = [str(i) for i in range(10)] # "sentence index"
method = "maxcorr"
layer = -1
cmap = plt.get_cmap('bwr')

In [65]:
def maxmin_normalize(a, maximum, minimum):
    maxabs = max([abs(maximum), abs(minimum)])
    return .5 + a/maxabs

In [66]:
def meanstd_normalize(a, mean, stdev):
    return torch.sigmoid((a-mean) / stdev)

In [72]:
for network in network_l:
    # Stats
    a = network2mname(network)
    mean = means_d[a][layer, neuron_ix].item()
    stdev = stdevs_d[a][layer, neuron_ix].item()
    maximum = max_d[a][layer, neuron_ix].item()
    minimum = min_d[a][layer, neuron_ix].item()

    # Set `html_str_l`
    html_str_l = []
    for s_ix in s_ix_l:
        # Set `representations`
        representations = torch.tensor(h5_d[network][s_ix])
        representations = representations[layer] if representations.dim() == 3 else representations

        # Set `activations_mod`
        neuron_ix = neuron_sorts[network][method][0]

        a = network2mname(network)
        mean = means_d[a][layer, neuron_ix].item()
        stdev = stdevs_d[a][layer, neuron_ix].item()
        maximum = max_d[a][layer, neuron_ix].item()
        minimum = min_d[a][layer, neuron_ix].item()

        activations = representations[:, neuron_ix]
        # activations_mod = meanstd_normalize(activations, mean, stdev)
        activations_mod = maxmin_normalize(activations, maximum, minimum)
        
        # Update `html_str_l`
        sentence = sentence_d[s_ix].split(' ')
        for act, tok in zip(activations_mod, sentence):
            # Set `color_str`
            act = float(act)
            r, g, b, a = cmap(act)
            red_str = '{0:02x}'.format(int(255.9*r)) # should be 256, but can't. Some vals became 1. 
            green_str = '{0:02x}'.format(int(255.9*g))
            blue_str = '{0:02x}'.format(int(255.9*b))

            color_str = red_str + green_str + blue_str

            html_str_l.append('<span style="background-color:#{0}">{1} </span>'.format(color_str, tok))
        html_str_l.append('<br><br>')
    
    # display
    print(network)
    display(HTML('<code>' + ''.join(html_str_l) + '</code>'))

bert_large_cased_-1


openai_transformer_-1


bert_base_cased_-1


elmo_original_-1


calypso_transformer_6_512_base_-1


elmo_4x4096_512_-1


xlnet_large_cased_-1
