In [1]:
import json
import h5py
import torch
import numpy as np
from IPython.display import HTML

In [2]:
# Set `res_d`, `network_l`, `num_neurons_d`
base = "/data/sls/temp/johnmwu/contextual-corr-analysis/results1_"
res_fname = {method : base + method for method in 
                {"maxcorr", "mincorr", "linreg", "svcca", "cka"}}

res_d = {}
res_d["maxcorr"] = json.load(open(res_fname["maxcorr"], "r"))
res_d["mincorr"] = json.load(open(res_fname["mincorr"], "r"))
res_d["linreg"] = json.load(open(res_fname["linreg"], "r"))

network_l = [network for network in res_d["maxcorr"]]

num_neurons_d = {}
for network in network_l:
    num_neurons_d[network] = len(res_d["maxcorr"][network])

num_neurons_d

{'/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/bert_large_cased/ptb_pos_dev.hdf5': 1024,
 '/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/openai_transformer/ptb_pos_dev.hdf5': 768,
 '/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/bert_base_cased/ptb_pos_dev.hdf5': 768,
 '/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/elmo_original/ptb_pos_dev.hdf5': 1024,
 '/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/calypso_transformer_6_512_base/ptb_pos_dev.hdf5': 1024,
 '/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/elmo_4x4096_512/ptb_pos_dev.hdf5': 1024}

In [3]:
# set `repr_d` : {network: h5} and `sentence_d`
repr_d = {}
for network in network_l:
    repr_d[network] = h5py.File(network)

In [4]:
# set `sentence_d`
sentence_d = json.loads(repr_d[network_l[0]]['sentence_to_index'][0]) # "sentence dict": {sentence, str ix}
temp = {}
for k, v in sentence_d.items():
    temp[v] = k
sentence_d = temp # now {str ix, sentence}

In [5]:
# set `neuron_sorts`
rk_methods = {"maxcorr", "mincorr", "linreg"}
neuron_sorts = {network: {} for network in network_l}
for network in network_l:
    for method in rk_methods:
        neuron_sorts[network][method] = [neuron[0] for neuron in res_d[method][network]]

# Visualizations

In [6]:
def get_color(a: float, zero=[1., 0., 0.], one=[0., 1., 0.]):
    """
    Map from floats to colors. 
    
    Parameters
    ----
    a : float in [0, 1]
        Input float
    zero : array_like (3,)
        Color zero maps to 
    one : array_like (3,)
        Color one maps to
        
    Returns
    ----
    color : array (3,)
        Color in [r, g, b] format. Values are out of 1. 
    """
    
    zero = np.array(zero)
    one = np.array(one)
    
    return a*one + (1-a)*zero

In [7]:
s_ix = '0' # "sentence index"
method = "maxcorr"

In [8]:
for network in network_l:
    # Set `representations`
    representations = torch.tensor(repr_d[network][s_ix])
    representations = representations[-1] if representations.dim() == 3 else representations # top layer
    
    # Set `activations`, `sentence`
    neuron_ix = neuron_sorts[network][method][0]
    activations = representations[:, neuron_ix]
    activations_mod = torch.sigmoid(activations)
    sentence = sentence_d[s_ix].split(' ')
    
    # Set `html_str_l`
    html_str_l = []
    for act, tok in zip(activations_mod, sentence):
        # Set `color_str`
        act = float(act)
        r, g, b = get_color(act)
        red_str = '{0:02x}'.format(int(255.9*r)) # should be 256, but can't. Sigmoi
        green_str = '{0:02x}'.format(int(255.9*g))
        blue_str = '{0:02x}'.format(int(255.9*b))

        color_str = red_str + green_str + blue_str

        html_str_l.append('<span style="background-color:#{0}">{1} </span>'.format(color_str, tok))
        
    print(network)
    print(activations)
    display(HTML('<code>' + ''.join(html_str_l) + '</code>'))

/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/bert_large_cased/ptb_pos_dev.hdf5
tensor([-3.8095,  0.2694,  0.9530,  0.0565,  0.1418,  1.2517, -0.3645,  0.2179])


/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/openai_transformer/ptb_pos_dev.hdf5
tensor([-0.0450,  0.1546, -0.1223,  0.2820,  0.1724,  0.3260, -0.0457, -0.0681])


/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/bert_base_cased/ptb_pos_dev.hdf5
tensor([ 1.6242,  0.2253, -0.4012, -0.0986, -0.7498, -0.1747, -0.7094,  0.3052])


/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/elmo_original/ptb_pos_dev.hdf5
tensor([ 0.6523,  0.3578, -0.1278,  0.2331,  0.2859,  0.4615, -2.4180, -0.8720])


/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/calypso_transformer_6_512_base/ptb_pos_dev.hdf5
tensor([-17.1190,  -3.6540,  24.2946,  26.8760,  18.2257,  15.8633,   5.1188,
         -6.1252])


/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/elmo_4x4096_512/ptb_pos_dev.hdf5
tensor([ 0.2807,  1.1419, -0.2633,  0.2461,  0.1581, -0.1117, -0.1084, -0.2763])


### Helper for above

In [9]:
network = network_l[0]

In [10]:
T = torch.tensor(repr_d[network][s_ix])
T = T[-1] if T.dim() == 3 else T # top layer

In [11]:
neuron_ix = neuron_sorts[network][method][0]
activations = T[:, neuron_ix]
activations = torch.sigmoid(activations)
activations

tensor([0.0217, 0.5669, 0.7217, 0.5141, 0.5354, 0.7776, 0.4099, 0.5543])

In [12]:
sentence = sentence_d[s_ix].split(' ')
sentence

["'", 'This', 'is', 'loyalty', 'intelligently', 'bestowed', '.', "''"]

In [13]:
html_str_l = []
for act, tok in zip(activations, sentence):
    # Set `color_str`
    act = float(act)
    r, g, b = get_color(act)
    red_str = '{0:02x}'.format(int(256*r))
    green_str = '{0:02x}'.format(int(256*g))
    blue_str = '{0:02x}'.format(int(256*b))

    color_str = red_str + green_str + blue_str
    
    html_str_l.append('<span style="background-color:#{0}">{1} </span>'.format(color_str, tok))

In [14]:
html_str_l

['<span style="background-color:#fa0500">\' </span>',
 '<span style="background-color:#6e9100">This </span>',
 '<span style="background-color:#47b800">is </span>',
 '<span style="background-color:#7c8300">loyalty </span>',
 '<span style="background-color:#768900">intelligently </span>',
 '<span style="background-color:#38c700">bestowed </span>',
 '<span style="background-color:#976800">. </span>',
 '<span style="background-color:#728d00">\'\' </span>']

In [15]:
HTML('<code>' + ''.join(html_str_l) + '</code>')

In [16]:
display(HTML('<code>' + ''.join(html_str_l) + '</code>'))
display(HTML('<code>' + ''.join(html_str_l) + '</code>'))