In [1]:
import json
import h5py
import torch
import numpy as np
from IPython.display import HTML
import matplotlib.pyplot as plt

In [2]:
# Set `res_d`, `network_l`, `num_neurons_d`
base = "/data/sls/temp/johnmwu/contextual-corr-analysis/results1_"
res_fname = {method : base + method for method in 
                {"maxcorr", "mincorr", "linreg", "svcca", "cka"}}

res_d = {}
res_d["maxcorr"] = json.load(open(res_fname["maxcorr"], "r"))
res_d["mincorr"] = json.load(open(res_fname["mincorr"], "r"))
res_d["linreg"] = json.load(open(res_fname["linreg"], "r"))

network_l = [network for network in res_d["maxcorr"]]

num_neurons_d = {}
for network in network_l:
    num_neurons_d[network] = len(res_d["maxcorr"][network])

num_neurons_d

{'/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/bert_large_cased/ptb_pos_dev.hdf5': 1024,
 '/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/openai_transformer/ptb_pos_dev.hdf5': 768,
 '/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/bert_base_cased/ptb_pos_dev.hdf5': 768,
 '/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/elmo_original/ptb_pos_dev.hdf5': 1024,
 '/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/calypso_transformer_6_512_base/ptb_pos_dev.hdf5': 1024,
 '/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/elmo_4x4096_512/ptb_pos_dev.hdf5': 1024}

In [3]:
# set `h5_d` : {network: h5} and `sentence_d`
h5_d = {}
for network in network_l:
    h5_d[network] = h5py.File(network, 'r')

sentence_d = json.loads(h5_d[network_l[0]]['sentence_to_index'][0]) # "sentence dict": {sentence, str ix}
temp = {}
for k, v in sentence_d.items():
    temp[v] = k
sentence_d = temp # now {str ix, sentence}

In [4]:
# set `means_d`, `stdevs_d`
base = "/data/sls/temp/johnmwu/contextual-corr-analysis/results3"
means_d = torch.load(base + "_means")
stdevs_d = torch.load(base + "_stdevs")

In [5]:
# set `neuron_sorts`
rk_methods = {"maxcorr", "mincorr", "linreg"}
neuron_sorts = {network: {} for network in network_l}
for network in network_l:
    for method in rk_methods:
        neuron_sorts[network][method] = [neuron[0] for neuron in res_d[method][network]]

# Visualizations

In [6]:
s_ix = '1' # "sentence index"
method = "maxcorr"
layer = -1
cmap = plt.get_cmap('seismic')

In [7]:
for network in network_l:
    # Set `representations`
    representations = torch.tensor(h5_d[network][s_ix])
    representations = representations[-1] if representations.dim() == 3 else representations # top layer
    
    # Set `activations_mod`, `sentence`
    neuron_ix = neuron_sorts[network][method][0]
    mean = means_d[network][layer, neuron_ix].item()
    stdev = stdevs_d[network][layer, neuron_ix].item()
    
    activations = representations[:, neuron_ix]
    activations = (activations - mean) / stdev
    activations_mod = torch.sigmoid(activations)
    sentence = sentence_d[s_ix].split(' ')
    
    # Set `html_str_l`
    html_str_l = []
    for act, tok in zip(activations_mod, sentence):
        # Set `color_str`
        act = float(act)
        r, g, b, a = cmap(act)
        red_str = '{0:02x}'.format(int(255.9*r)) # should be 256, but can't. Some vals became 1. 
        green_str = '{0:02x}'.format(int(255.9*g))
        blue_str = '{0:02x}'.format(int(255.9*b))

        color_str = red_str + green_str + blue_str

        html_str_l.append('<span style="background-color:#{0}">{1} </span>'.format(color_str, tok))
        
    print(network)
    print(activations)
    display(HTML('<code>' + ''.join(html_str_l) + '</code>'))

/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/bert_large_cased/ptb_pos_dev.hdf5
tensor([-1.8960, -1.0382, -0.6900, -0.7288,  0.5068, -0.3869, -0.5546, -0.0195,
        -0.2533, -0.4538, -0.7896, -0.5776, -0.5795,  0.1712, -0.0110,  0.4945,
        -0.2341,  0.7297,  1.0505,  0.5754, -0.2568,  0.4846,  0.1397,  0.2093,
        -0.6027, -0.6387, -0.0886, -0.2523, -0.5281,  0.0966, -0.8778, -0.4300,
        -0.3769])


/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/openai_transformer/ptb_pos_dev.hdf5
tensor([-0.7482, -0.4348, -0.6525, -0.3949, -0.7814,  1.7468, -0.3464, -0.6310,
        -0.6117, -0.5709, -0.6947, -0.8228, -0.6782,  2.1902,  0.1753, -0.6550,
        -0.0582, -0.8970, -0.0722, -0.7453, -0.3203, -0.6591, -0.6153,  2.4164,
         2.0547,  0.2298, -0.6596, -0.1708, -0.5119, -0.5555, -0.1345, -0.5748,
        -0.6650])


/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/bert_base_cased/ptb_pos_dev.hdf5
tensor([ 3.9832e+00, -4.9681e-01,  7.9109e-02, -2.4946e-01, -8.0597e-01,
        -2.6101e-01,  7.0035e-02, -1.9363e-01, -1.3799e-01,  3.5925e-01,
        -1.5198e-01,  5.2325e-01,  2.4724e-01, -7.5173e-01, -1.5262e-01,
        -5.7560e-01, -6.2994e-02,  4.8487e-02, -3.4374e-01, -5.3649e-01,
        -5.5713e-02, -5.1273e-02, -1.9140e-01, -3.1050e-01,  1.2270e-01,
        -1.2628e-02, -2.8163e-01, -2.6409e-02, -2.2891e-03,  4.2606e-03,
        -3.1521e-01, -6.7553e-01,  4.9039e-01])


/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/elmo_original/ptb_pos_dev.hdf5
tensor([ 0.7013,  0.6478,  0.3855,  0.5900,  1.5422,  0.7732,  0.6873,  0.9304,
         0.4678,  0.0546,  0.5760,  0.4618,  0.1950,  0.5713,  1.0508, -1.0140,
         0.0375,  0.7976,  1.5433, -0.8771,  0.9314, -1.4972,  0.3910,  0.5247,
         0.6174,  0.2633,  0.0860,  0.3916, -2.0147, -0.6181,  0.3197, -2.2558,
        -0.5714])


/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/calypso_transformer_6_512_base/ptb_pos_dev.hdf5
tensor([-2.4649, -1.8369, -1.6463, -0.7145,  1.1326,  0.9307, -0.6208,  0.0437,
        -0.2807, -0.9724, -0.5708, -0.8745,  0.4913,  1.2845,  0.8726,  1.4332,
         1.5971,  1.0952,  0.7763,  0.8081,  0.5407, -0.3606,  0.1589,  0.6349,
         0.3808,  0.4263,  0.6432, -0.1013, -0.4132,  1.0824,  0.4050, -1.1271,
        -1.5529])


/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/elmo_4x4096_512/ptb_pos_dev.hdf5
tensor([ 0.3745, -0.2936, -0.3830, -0.4659, -0.5709,  0.2912, -0.3627, -0.6963,
        -0.6156, -0.8012, -0.7590, -0.7611, -0.4056,  2.8838, -0.6332, -0.2432,
        -0.2465, -0.6349, -0.1705, -0.4985, -0.3334, -0.6425, -0.5166,  1.7637,
         1.7770, -0.8664, -0.3704, -0.4306, -0.5283, -0.5779, -0.3009, -0.5572,
        -0.6292])


### Helper for above

In [8]:
network = network_l[0]

In [9]:
T = torch.tensor(h5_d[network][s_ix])
T = T[-1] if T.dim() == 3 else T # top layer

In [10]:
neuron_ix = neuron_sorts[network][method][0]
activations = T[:, neuron_ix]
activations = torch.sigmoid(activations)
activations

tensor([0.1972, 0.3752, 0.4634, 0.4533, 0.7505, 0.5422, 0.4986, 0.6347, 0.5765,
        0.5249, 0.4377, 0.4926, 0.4921, 0.6794, 0.6367, 0.7481, 0.5814, 0.7914,
        0.8413, 0.7636, 0.5757, 0.7461, 0.6722, 0.6880, 0.4861, 0.4767, 0.6178,
        0.5768, 0.5055, 0.6623, 0.4152, 0.5311, 0.5448])

In [11]:
sentence = sentence_d[s_ix].split(' ')
sentence

["''",
 'Mr.',
 'Allen',
 'objected',
 'to',
 'this',
 'analogy',
 'because',
 'it',
 'seems',
 'to',
 '``',
 'assimilate',
 'the',
 'status',
 'of',
 'blacks',
 'to',
 'that',
 'of',
 'animals',
 '--',
 'as',
 'a',
 'mere',
 'project',
 'of',
 'charity',
 ',',
 'of',
 'humaneness',
 '.',
 "''"]

In [12]:
html_str_l = []
for act, tok in zip(activations, sentence):
    # Set `color_str`
    act = float(act)
    r, g, b, a = cmap(act)
    red_str = '{0:02x}'.format(int(256*r))
    green_str = '{0:02x}'.format(int(256*g))
    blue_str = '{0:02x}'.format(int(256*b))

    color_str = red_str + green_str + blue_str
    
    html_str_l.append('<span style="background-color:#{0}">{1} </span>'.format(color_str, tok))

In [13]:
html_str_l

['<span style="background-color:#0000d9">\'\' </span>',
 '<span style="background-color:#8181100">Mr. </span>',
 '<span style="background-color:#d9d9100">Allen </span>',
 '<span style="background-color:#d1d1100">objected </span>',
 '<span style="background-color:#fe0000">to </span>',
 '<span style="background-color:#100d5d5">this </span>',
 '<span style="background-color:#fdfd100">analogy </span>',
 '<span style="background-color:#1007575">because </span>',
 '<span style="background-color:#100b1b1">it </span>',
 '<span style="background-color:#100e5e5">seems </span>',
 '<span style="background-color:#c1c1100">to </span>',
 '<span style="background-color:#f9f9100">`` </span>',
 '<span style="background-color:#f5f5100">assimilate </span>',
 '<span style="background-color:#1004949">the </span>',
 '<span style="background-color:#1007171">status </span>',
 '<span style="background-color:#1000101">of </span>',
 '<span style="background-color:#100adad">blacks </span>',
 '<span style="backgrou

In [14]:
HTML('<code>' + ''.join(html_str_l) + '</code>')

In [15]:
display(HTML('<code>' + ''.join(html_str_l) + '</code>'))
display(HTML('<code>' + ''.join(html_str_l) + '</code>'))