In [1]:
import torch 
from tqdm import tqdm
from itertools import product as p
import json
import numpy as np
import h5py
from os.path import basename, dirname
import pickle
from var import fname2mname

# Setup

In [6]:
attention_fname_l = [
    "/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/bert_base_cased/ptb_pos_dev_attn.hdf5"
]
limit = None
layerspec_l = None
ar_mask = True

# hnb

In [7]:
# Edit args
l = len(attention_fname_l)
if layerspec_l is None:
    layerspec_l = ['all'] * l

In [3]:
num_heads_d = {} 
attentions_d = {} 

In [8]:
# arbitrary values for main loop
loop_var = attention_fname_l[0], layerspec_l[0]

In [9]:
fname, layerspec = loop_var

# Set `attentions_h5`, `sentence_d`, `indices`
attentions_h5 = h5py.File(fname, 'r')
sentence_d = json.loads(attentions_h5['sentence_to_index'][0])
temp = {} # TO DO: Make this more elegant?
for k, v in sentence_d.items():
    temp[v] = k
sentence_d = temp # {str ix, sentence}
indices = list(sentence_d.keys())[:limit]

# Set `num_layers`, `num_heads`, `layers`
s = attentions_h5[indices[0]].shape
num_layers = s[0]
num_heads = s[1]
if layerspec == "all":
    layers = list(range(num_layers))
else:
    layers = [layerspec]

In [11]:
# arbitrary value for layer loop
layer = 0

In [12]:
attentions_l = []
word_count = 0

In [15]:
# arbitrary value for sentence_ix loop
sentence_ix = '0'

In [16]:
# Set `dim`, `n_word`, update `word_count`
shape = attentions_h5[sentence_ix].shape
dim = len(shape)
if not (dim == 4):
    raise ValueError('Improper array dimension in file: ' +
                     fname + "\nShape: " +
                     str(attentions_h5[sentence_ix].shape))
n_word = shape[2]
word_count += n_word

In [17]:
# Create `attentions`
if ar_mask:
    attentions = np.tril(attentions_h5[sentence_ix][layer])
    attentions = attentions/np.sum(attentions, axis=-1, keepdims=True)
    attentions = torch.FloatTensor(attentions)
else:
    attentions = torch.FloatTensor(
        attentions_h5[sentence_ix][layer] )

In [None]:
# Update `attentions_l`
attentions_l.append(attentions)

# Early stop
if limit is not None and word_count >= limit:
    break

# Final function

In [35]:
def load_attentions(attention_fname_l, limit=None, layerspec_l=None, ar_mask=False):
    """
    Load in attentions. Options to control loading exist. 

    Params:
    ----
    attention_fname_l : list<str>
        List of hdf5 files containing attentions
    limit : int or None
        Limit on number of attentions to take
    layerspec_l : list
        Specification for each model. May be an integer (layer to take),
        or "all". "all" means take all layers. 

    Returns:
    ----
    num_head_d : {str : int}
        {network : number of heads}. Here a network could be a layer,
        or the stack of all layers, etc. A network is what's being
        correlated as a single unit.
    attentions_d : {str : list<tensor>}
        {network : attentions}. attentions is a list because each 
        sentence may be of different length. 
    """

    # Edit args
    l = len(attention_fname_l)
    if layerspec_l is None:
        layerspec_l = ['all'] * l

    # Main loop
    num_heads_d = {} 
    attentions_d = {} 
    for loop_var in tqdm(zip(attention_fname_l, layerspec_l), desc='load'):
        fname, layerspec = loop_var

        # Set `attentions_h5`, `sentence_d`, `indices`
        attentions_h5 = h5py.File(fname, 'r')
        sentence_d = json.loads(attentions_h5['sentence_to_index'][0])
        temp = {} # TO DO: Make this more elegant?
        for k, v in sentence_d.items():
            temp[v] = k
        sentence_d = temp # {str ix, sentence}
        indices = list(sentence_d.keys())[:limit]

        # Set `num_layers`, `num_heads`, `layers`
        s = attentions_h5[indices[0]].shape
        num_layers = s[0]
        num_heads = s[1]
        if layerspec == "all":
            layers = list(range(num_layers))
        else:
            layers = [layerspec]

        # Set `num_heads_d`, `attentions_d`
        for layer in layers:
            # Create `attentions_l`
            attentions_l = []
            word_count = 0
            for sentence_ix in indices: 
                # Set `dim`, `n_word`, update `word_count`
                shape = attentions_h5[sentence_ix].shape
                dim = len(shape)
                if not (dim == 4):
                    raise ValueError('Improper array dimension in file: ' +
                                     fname + "\nShape: " +
                                     str(attentions_h5[sentence_ix].shape))
                n_word = shape[2]
                word_count += n_word

                # Create `attentions`
                if ar_mask:
                    attentions = np.tril(attentions_h5[sentence_ix][layer])
                    attentions = attentions/np.sum(attentions, axis=-1, keepdims=True)
                    attentions = torch.FloatTensor(attentions)
                else:
                    attentions = torch.FloatTensor(
                        attentions_h5[sentence_ix][layer] )

                # Update `attentions_l`
                attentions_l.append(attentions)

                # Early stop
                if limit is not None and word_count >= limit:
                    break

            # Main update
            network = "{mname}_{layer}".format(mname=fname2mname(fname), 
                                                  layer=layer)
            num_heads_d[network] = attentions_l[0].shape[0]
            attentions_d[network] = attentions_l[:limit] 
    
    return num_heads_d, attentions_d

In [47]:
num_heads_d, attentions_d = load_attentions(attention_fname_l=attention_fname_l, limit=None, layerspec_l=None, ar_mask=True)

load: 1it [00:35, 35.46s/it]


In [48]:
fname = attention_fname_l[0]
network = "{mname}_{layer}".format(mname=fname2mname(fname), layer=0)
attentions_d[network][0][0] # first sentence of first head in layer 0

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5588, 0.4412, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1701, 0.1439, 0.6860, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1041, 0.3496, 0.0812, 0.4651, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2647, 0.1313, 0.0475, 0.2357, 0.3209, 0.0000, 0.0000, 0.0000],
        [0.0876, 0.1158, 0.1432, 0.1273, 0.4374, 0.0888, 0.0000, 0.0000],
        [0.1605, 0.1254, 0.0650, 0.1288, 0.2524, 0.1863, 0.0816, 0.0000],
        [0.0377, 0.1092, 0.1972, 0.1351, 0.2546, 0.0715, 0.0824, 0.1122]])

In [44]:
num_heads_d, attentions_d = load_attentions(attention_fname_l=attention_fname_l, limit=None, layerspec_l=None, ar_mask=False)

load: 1it [00:32, 32.92s/it]


In [45]:
fname = attention_fname_l[0]
network = "{mname}_{layer}".format(mname=fname2mname(fname), layer=0)
attentions_d[network][0][0] # first sentence of first head in layer 0

tensor([[0.3232, 0.0535, 0.0561, 0.0233, 0.0566, 0.0113, 0.1845, 0.2914],
        [0.1046, 0.0826, 0.1047, 0.1157, 0.1459, 0.0931, 0.1268, 0.2266],
        [0.0856, 0.0725, 0.3454, 0.1004, 0.1288, 0.0981, 0.0492, 0.1200],
        [0.0255, 0.0855, 0.0199, 0.1137, 0.2686, 0.0544, 0.1137, 0.3187],
        [0.1632, 0.0809, 0.0293, 0.1453, 0.1979, 0.0638, 0.1288, 0.1908],
        [0.0621, 0.0822, 0.1016, 0.0903, 0.3102, 0.0630, 0.0753, 0.2154],
        [0.1291, 0.1009, 0.0523, 0.1036, 0.2030, 0.1498, 0.0657, 0.1956],
        [0.0377, 0.1092, 0.1972, 0.1351, 0.2546, 0.0715, 0.0824, 0.1122]])