In [13]:
import torch
import json
import h5py
from tqdm import tqdm
from os.path import basename, dirname

In [14]:
# Set arguments arbitrarily
limit = 10000
layerspec_l = [
    "full", 
    -1, 
]
first_half_only_l = [
    False, 
    False,
]
second_half_only_l = [
    False,
    False
]
representation_fname_l = [
    "/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/elmo_original/ptb_pos_dev.hdf5",
    "/data/sls/temp/belinkov/contextual-corr-analysis/contextualizers/calypso_transformer_6_512_base/ptb_pos_dev.hdf5"
]

In [9]:
def fname2mname(fname):
    """
    "filename to model name". 
    """
    return basename(dirname(fname))

In [10]:
num_neurons_d = {} 
representations_d = {} 

# for fname in ... loop

In [5]:
# loop variables
ix = 0
layerspec = layerspec_l[ix]
first_half_only = first_half_only_l[ix]
second_half_only = second_half_only_l[ix]
fname = representation_fname_l[ix]

In [6]:
# Set `activations_h5`, `sentence_d`, `indices`
activations_h5 = h5py.File(fname, 'r')
sentence_d = json.loads(activations_h5['sentence_to_index'][0])
temp = {} # TO DO: Make this more elegant?
for k, v in sentence_d.items():
    temp[v] = k
sentence_d = temp # {str ix, sentence}
indices = list(sentence_d.keys())[:limit]

In [7]:
# Set `num_layers`, `num_neurons`, `layers`
s = activations_h5[indices[0]].shape
num_layers = 1 if len(s)==2 else s[0]
num_neurons = s[-1]
if layerspec == "all":
    layers = list(range(num_layers))
elif layerspec == "full":
    layers = ["full"]
else:
    layers = [layerspec]

In [8]:
# Set `num_neurons_d`, `representations_d`
for layer in layers:
    # Create `representations_l`
    representations_l = []
    for sentence_ix in indices: 
        # Set `dim`
        dim = len(activations_h5[sentence_ix].shape)
        if not (dim == 2 or dim == 3):
            raise ValueError('Improper array dimension in file: ' +
                             fname + "\nShape: " +
                             str(activations_h5[sentence_ix].shape))
        
        # Create `activations`
        if layer == "full":
            activations = torch.FloatTensor(activations_h5[sentence_ix])
            if dim == 3:
                activations = activations.permute(1, 0, 2)
                nword, nlayer, nneuron = activations.size()
                activations = activations.view(nword, -1)
        else:
            activations = torch.FloatTensor(activations_h5[sentence_ix][layer] if dim==3 
                                                else activations_h5[sentence_ix])

        # Create `representations`
        representations = activations
        if first_half_only: 
            representations = torch.chunk(representations, chunks=2,
                                          dim=-1)[0]
        elif second_half_only:
            representations = torch.chunk(representations, chunks=2,
                                          dim=-1)[1]

        representations_l.append(representations)
    
    # update
    model_name = "{model}_{layer}".format(model=fname2mname(fname), 
                                          layer=layer)
    num_neurons_d[model_name] = representations_l[0].size()[-1]
    representations_d[model_name] = torch.cat(representations_l)

In [30]:
# full
for loop_var in tqdm(zip(representation_fname_l, layerspec_l,
                         first_half_only_l, second_half_only_l)):
    fname, layerspec, first_half_only, second_half_only = loop_var

    # Set `activations_h5`, `sentence_d`, `indices`
    activations_h5 = h5py.File(fname, 'r')
    sentence_d = json.loads(activations_h5['sentence_to_index'][0])
    temp = {} # TO DO: Make this more elegant?
    for k, v in sentence_d.items():
        temp[v] = k
    sentence_d = temp # {str ix, sentence}
    indices = list(sentence_d.keys())[:limit]

    # Set `num_layers`, `num_neurons`, `layers`
    s = activations_h5[indices[0]].shape
    num_layers = 1 if len(s)==2 else s[0]
    num_neurons = s[-1]
    if layerspec == "all":
        layers = list(range(num_layers))
    elif layerspec == "full":
        layers = ["full"]
    else:
        layers = [layerspec]

    # Set `num_neurons_d`, `representations_d`
    for layer in layers:
        # Create `representations_l`
        representations_l = []
        for sentence_ix in indices: 
            # Set `dim`
            dim = len(activations_h5[sentence_ix].shape)
            if not (dim == 2 or dim == 3):
                raise ValueError('Improper array dimension in file: ' +
                                 fname + "\nShape: " +
                                 str(activations_h5[sentence_ix].shape))

            # Create `activations`
            if layer == "full":
                activations = torch.FloatTensor(activations_h5[sentence_ix])
                if dim == 3:
                    activations = activations.permute(1, 0, 2)
                    nword = activations.size()[0]
                    activations = activations.contiguous().view(nword, -1)
            else:
                activations = torch.FloatTensor(activations_h5[sentence_ix][layer] if dim==3 
                                                    else activations_h5[sentence_ix])

            # Create `representations`
            representations = activations
            if first_half_only: 
                representations = torch.chunk(representations, chunks=2,
                                              dim=-1)[0]
            elif second_half_only:
                representations = torch.chunk(representations, chunks=2,
                                              dim=-1)[1]

            representations_l.append(representations)

        # update
        model_name = "{model}_{layer}".format(model=fname2mname(fname), 
                                              layer=layer)
        num_neurons_d[model_name] = representations_l[0].size()[-1]
        representations_d[model_name] = torch.cat(representations_l)

2it [02:06, 57.72s/it]


In [31]:
num_neurons_d

{'elmo_original_0': 1024,
 'elmo_original_1': 1024,
 'elmo_original_2': 1024,
 'calypso_transformer_6_512_base_-1': 1024,
 'elmo_original_full': 3072}