In [1]:
import json
import h5py
import torch
import numpy as np
from IPython.display import HTML
import matplotlib.pyplot as plt

In [2]:
# Set `res_d`, `network_l`, `num_neurons_d`
base = "/data/sls/temp/johnmwu/contextual-corr-analysis/results1_"
res_fname = {method : base + method for method in 
                {"maxcorr", "mincorr", "linreg", "svcca", "cka"}}

res_d = {}
res_d["maxcorr"] = json.load(open(res_fname["maxcorr"], "r"))
res_d["mincorr"] = json.load(open(res_fname["mincorr"], "r"))
res_d["linreg"] = json.load(open(res_fname["linreg"], "r"))

network_l = [network for network in res_d["maxcorr"]]

num_neurons_d = {}
for network in network_l:
    num_neurons_d[network] = len(res_d["maxcorr"][network])

num_neurons_d

{'/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/bert_large_cased/ptb_pos_dev.hdf5': 1024,
 '/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/openai_transformer/ptb_pos_dev.hdf5': 768,
 '/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/bert_base_cased/ptb_pos_dev.hdf5': 768,
 '/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/elmo_original/ptb_pos_dev.hdf5': 1024,
 '/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/calypso_transformer_6_512_base/ptb_pos_dev.hdf5': 1024,
 '/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/elmo_4x4096_512/ptb_pos_dev.hdf5': 1024}

In [3]:
# set `h5_d` : {network: h5} and `sentence_d`
h5_d = {}
for network in network_l:
    h5_d[network] = h5py.File(network, 'r')

sentence_d = json.loads(h5_d[network_l[0]]['sentence_to_index'][0]) # "sentence dict": {sentence, str ix}
temp = {}
for k, v in sentence_d.items():
    temp[v] = k
sentence_d = temp # now {str ix, sentence}

In [4]:
# set `means_d`, `stdevs_d`
base = "/data/sls/temp/johnmwu/contextual-corr-analysis/results3"
means_d = torch.load(base + "_means")
stdevs_d = torch.load(base + "_stdevs")

In [5]:
# set `max_d`, `min_d`
base = "/data/sls/temp/johnmwu/contextual-corr-analysis/results4"
max_d = torch.load(base + "_max")
min_d = torch.load(base + "_min")

In [6]:
# set `neuron_sorts`
rk_methods = {"maxcorr", "mincorr", "linreg"}
neuron_sorts = {network: {} for network in network_l}
for network in network_l:
    for method in rk_methods:
        neuron_sorts[network][method] = [neuron[0] for neuron in res_d[method][network]]

# Visualizations

In [7]:
s_ix = '1' # "sentence index"
method = "maxcorr"
layer = -1
cmap = plt.get_cmap('bwr')

In [8]:
def maxmin_normalize(a, maximum, minimum):
    maxabs = max([abs(maximum), abs(minimum)])
    return .5 + a/maxabs

In [9]:
def meanstd_normalize(a, mean, stdev):
    return torch.sigmoid((a-mean) / stdev)

In [10]:
for network in network_l:
    # Set `representations`
    representations = torch.tensor(h5_d[network][s_ix])
    representations = representations[layer] if representations.dim() == 3 else representations
    
    # Set `activations_mod`
    neuron_ix = neuron_sorts[network][method][0]
    mean = means_d[network][layer, neuron_ix].item()
    stdev = stdevs_d[network][layer, neuron_ix].item()
    maximum = max_d[network][layer, neuron_ix].item()
    minimum = min_d[network][layer, neuron_ix].item()
    
    activations = representations[:, neuron_ix]
    # activations_mod = meanstd_normalize(activations, mean, stdev)
    activations_mod = maxmin_normalize(activations, maximum, minimum)
    
    # Set `sentence`, `html_str_l`
    sentence = sentence_d[s_ix].split(' ')
    html_str_l = []
    for act, tok in zip(activations_mod, sentence):
        # Set `color_str`
        act = float(act)
        r, g, b, a = cmap(act)
        red_str = '{0:02x}'.format(int(255.9*r)) # should be 256, but can't. Some vals became 1. 
        green_str = '{0:02x}'.format(int(255.9*g))
        blue_str = '{0:02x}'.format(int(255.9*b))

        color_str = red_str + green_str + blue_str

        html_str_l.append('<span style="background-color:#{0}">{1} </span>'.format(color_str, tok))
        
    print(network)
    print("max: {0}".format(maximum))
    print("min: {0}".format(minimum))
    print("mean: {0}".format(mean))
    print("stdev: {0}".format(stdev))
    print(activations)
    display(HTML('<code>' + ''.join(html_str_l) + '</code>'))

/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/bert_large_cased/ptb_pos_dev.hdf5
max: 3.135377883911133
min: -4.647909164428711
mean: 0.5726714134216309
stdev: 1.0426321029663086
tensor([-1.4042, -0.5098, -0.1468, -0.1872,  1.1011,  0.1693, -0.0055,  0.5523,
         0.3086,  0.0995, -0.2505, -0.0295, -0.0315,  0.7511,  0.5612,  1.0882,
         0.3285,  1.3335,  1.6679,  1.1726,  0.3050,  1.0779,  0.7184,  0.7909,
        -0.0557, -0.0933,  0.4803,  0.3096,  0.0221,  0.6734, -0.3425,  0.1244,
         0.1797])


/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/openai_transformer/ptb_pos_dev.hdf5
max: 3.9729957580566406
min: -0.4267078936100006
mean: 0.4902166426181793
stdev: 0.7990235686302185
tensor([-1.0765e-01,  1.4278e-01, -3.1138e-02,  1.7465e-01, -1.3418e-01,
         1.8859e+00,  2.1342e-01, -1.3952e-02,  1.4539e-03,  3.4046e-02,
        -6.4832e-02, -1.6720e-01, -5.1657e-02,  2.2402e+00,  6.3026e-01,
        -3.3176e-02,  4.4374e-01, -2.2654e-01,  4.3256e-01, -1.0527e-01,
         2.3432e-01, -3.6382e-02, -1.3845e-03,  2.4210e+00,  2.1320e+00,
         6.7386e-01, -3.6812e-02,  3.5371e-01,  8.1182e-02,  4.6379e-02,
         3.8274e-01,  3.0906e-02, -4.1157e-02])


/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/bert_base_cased/ptb_pos_dev.hdf5
max: 7.607119083404541
min: -3.04351806640625
mean: -0.24966010451316833
stdev: 1.2996163368225098
tensor([ 4.9270, -0.8953, -0.1468, -0.5739, -1.2971, -0.5889, -0.1586, -0.5013,
        -0.4290,  0.2172, -0.4472,  0.4304,  0.0717, -1.2266, -0.4480, -0.9977,
        -0.3315, -0.1866, -0.6964, -0.9469, -0.3221, -0.3163, -0.4984, -0.6532,
        -0.0902, -0.2661, -0.6157, -0.2840, -0.2526, -0.2441, -0.6593, -1.1276,
         0.3877])


/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/elmo_original/ptb_pos_dev.hdf5
max: 3.630618095397949
min: -4.109951019287109
mean: -0.23794442415237427
stdev: 0.9812765717506409
tensor([ 0.4502,  0.3978,  0.1403,  0.3410,  1.2754,  0.5208,  0.4365,  0.6750,
         0.2211, -0.1844,  0.3273,  0.2152, -0.0466,  0.3227,  0.7932, -1.2330,
        -0.2011,  0.5447,  1.2765, -1.0986,  0.6760, -1.7071,  0.1458,  0.2769,
         0.3679,  0.0204, -0.1536,  0.1463, -2.2149, -0.8445,  0.0757, -2.4515,
        -0.7987])


/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/calypso_transformer_6_512_base/ptb_pos_dev.hdf5
max: 77.60395050048828
min: -27.376157760620117
mean: 23.559410095214844
stdev: 18.561195373535156
tensor([-22.1925, -10.5349,  -6.9982,  10.2975,  44.5821,  40.8334,  12.0367,
         24.3710,  18.3488,   5.5101,  12.9640,   7.3270,  32.6779,  47.4005,
         39.7565,  50.1618,  53.2039,  43.8871,  37.9692,  38.5588,  33.5960,
         16.8663,  26.5091,  35.3433,  30.6271,  31.4712,  35.4988,  21.6795,
         15.8895,  43.6504,  31.0775,   2.6383,  -5.2636])


/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/elmo_4x4096_512/ptb_pos_dev.hdf5
max: 6.005972862243652
min: -2.561760425567627
mean: 0.6009294986724854
stdev: 1.4804160594940186
tensor([ 1.1553e+00,  1.6627e-01,  3.3951e-02, -8.8753e-02, -2.4422e-01,
         1.0320e+00,  6.3936e-02, -4.2989e-01, -3.1045e-01, -5.8518e-01,
        -5.2266e-01, -5.2579e-01,  5.3509e-04,  4.8701e+00, -3.3647e-01,
         2.4091e-01,  2.3596e-01, -3.3905e-01,  3.4858e-01, -1.3712e-01,
         1.0728e-01, -3.5029e-01, -1.6389e-01,  3.2119e+00,  3.2316e+00,
        -6.8168e-01,  5.2616e-02, -3.6585e-02, -1.8119e-01, -2.5466e-01,
         1.5551e-01, -2.2400e-01, -3.3051e-01])


### Helper for above

In [11]:
network = network_l[0]

In [12]:
T = torch.tensor(h5_d[network][s_ix])
T = T[-1] if T.dim() == 3 else T # top layer

In [13]:
neuron_ix = neuron_sorts[network][method][0]
activations = T[:, neuron_ix]
activations = torch.sigmoid(activations)
activations

tensor([0.1972, 0.3752, 0.4634, 0.4533, 0.7505, 0.5422, 0.4986, 0.6347, 0.5765,
        0.5249, 0.4377, 0.4926, 0.4921, 0.6794, 0.6367, 0.7481, 0.5814, 0.7914,
        0.8413, 0.7636, 0.5757, 0.7461, 0.6722, 0.6880, 0.4861, 0.4767, 0.6178,
        0.5768, 0.5055, 0.6623, 0.4152, 0.5311, 0.5448])

In [14]:
sentence = sentence_d[s_ix].split(' ')
sentence

["''",
 'Mr.',
 'Allen',
 'objected',
 'to',
 'this',
 'analogy',
 'because',
 'it',
 'seems',
 'to',
 '``',
 'assimilate',
 'the',
 'status',
 'of',
 'blacks',
 'to',
 'that',
 'of',
 'animals',
 '--',
 'as',
 'a',
 'mere',
 'project',
 'of',
 'charity',
 ',',
 'of',
 'humaneness',
 '.',
 "''"]

In [15]:
html_str_l = []
for act, tok in zip(activations, sentence):
    # Set `color_str`
    act = float(act)
    r, g, b, a = cmap(act)
    red_str = '{0:02x}'.format(int(256*r))
    green_str = '{0:02x}'.format(int(256*g))
    blue_str = '{0:02x}'.format(int(256*b))

    color_str = red_str + green_str + blue_str
    
    html_str_l.append('<span style="background-color:#{0}">{1} </span>'.format(color_str, tok))

In [16]:
html_str_l

['<span style="background-color:#6464100">\'\' </span>',
 '<span style="background-color:#c0c0100">Mr. </span>',
 '<span style="background-color:#ecec100">Allen </span>',
 '<span style="background-color:#e8e8100">objected </span>',
 '<span style="background-color:#1007e7e">to </span>',
 '<span style="background-color:#100eaea">this </span>',
 '<span style="background-color:#fefe100">analogy </span>',
 '<span style="background-color:#100baba">because </span>',
 '<span style="background-color:#100d8d8">it </span>',
 '<span style="background-color:#100f2f2">seems </span>',
 '<span style="background-color:#e0e0100">to </span>',
 '<span style="background-color:#fcfc100">`` </span>',
 '<span style="background-color:#fafa100">assimilate </span>',
 '<span style="background-color:#100a4a4">the </span>',
 '<span style="background-color:#100b8b8">status </span>',
 '<span style="background-color:#1008080">of </span>',
 '<span style="background-color:#100d6d6">blacks </span>',
 '<span style="backgr

In [17]:
HTML('<code>' + ''.join(html_str_l) + '</code>')

In [18]:
display(HTML('<code>' + ''.join(html_str_l) + '</code>'))
display(HTML('<code>' + ''.join(html_str_l) + '</code>'))