In [1]:
import os
import torch
import torch.nn.functional as F
import sys
sys.path.append('..')

import utils
import os

In [2]:
def make_word_dict(aoa_word_list, vocabulary):
    word_dict = {}
    not_in_vocab = []
    for word in aoa_word_list:
        if word in vocabulary:
            token = vocabulary[word]
            word_dict[token] = word
        else:
            not_in_vocab.append(word)
    return word_dict, not_in_vocab

In [4]:
def get_surprisals(model, dataset, word_dict, device):
    model.eval()
    word_surprisals = {}
    for index in word_dict.keys():
        word_surprisals[word_dict[index]] = [0.0, 0]
    for utt in all_data:
        utt_tensor = torch.tensor(utt)[None, :]
        utt_tensor = utt_tensor.to(device)
        outputs = model(utt_tensor)
        surprisals = -F.log_softmax(outputs, dim=2)
        utt_tensor = torch.squeeze(utt_tensor)
        for word_index in word_dict:
            index_matches = (utt_tensor == word_index).nonzero(as_tuple=False)
            if len(index_matches) > 0:
                for i in index_matches:
                    match = i.item()
                    surprisal = (surprisals[0][match][word_index] + sys.float_info.epsilon).item()
                    word = word_dict[word_index]
                    word_surprisals[word][0] += surprisal
                    word_surprisals[word][1] += 1

    return word_surprisals

In [10]:
all_child_directed_data_path ="../../../Data/model_datasets/eng/validation_vocab_size_5000.pkl"
encoding_dictionary_path="../../../Data/model_datasets/eng/encoding_dictionary_vocab_size_5000.pkl"
aoa_word_list="../../../Data/model_datasets/eng/aoa_words.csv"
experiment_dir="../../../Results/experiments/2021-08-20_lstm_eng_5e_256b_em100_hd100_v5000_run0"
model="model.pt"

In [4]:
device = torch.device('cpu')

In [5]:
vocabulary = utils.open_pkl(encoding_dictionary_path)

In [21]:
model = torch.load(os.path.join(experiment_dir, model))

In [10]:
model

LSTM(
  (word_embeddings): Embedding(5001, 100, padding_idx=0)
  (lstm): LSTM(100, 100, num_layers=2, batch_first=True)
  (linear): Linear(in_features=100, out_features=5001, bias=True)
)

In [22]:
model = model.to(device)

In [8]:
from data_loader import Dataset
from torch.utils.data import DataLoader

In [40]:
all_data = Dataset(all_child_directed_data_path)

In [58]:
len(all_data)

586104

In [41]:
dl = DataLoader(all_data, batch_size=1000)

In [42]:
i,batch = next(enumerate(dl))

In [43]:
batch

tensor([[   0,    0,    0,  ...,  426, 2221, 4136],
        [   0,    0,    0,  ..., 3631, 4674, 3273],
        [   0,    0,    0,  ..., 3081,  916, 4136],
        ...,
        [   0,    0,    0,  ...,    0,  929, 3273],
        [   0,    0,    0,  ...,  929, 2239, 3273],
        [   0,    0,    0,  ...,    0, 1763, 3273]])

In [44]:
word_list = utils.open_word_list_csv(aoa_word_list)

In [45]:
word_dict, not_in_vocab_list = make_word_dict(word_list, vocabulary)

In [46]:
word_dict

{2506: 'airplane',
 2585: 'all',
 1693: 'animal',
 1211: 'another',
 1838: 'apple',
 4479: 'arm',
 4375: 'asleep',
 2690: 'aunt',
 4376: 'away',
 4691: 'baby',
 3169: 'babysitter',
 4096: 'back',
 125: 'bad',
 3972: 'ball',
 867: 'balloon',
 739: 'banana',
 3131: 'bath',
 2881: 'bathroom',
 2650: 'bathtub',
 2806: 'beach',
 2941: 'beads',
 1477: 'bear',
 4078: 'bed',
 4809: 'bedroom',
 3316: 'bee',
 1837: 'bib',
 369: 'bicycle',
 1527: 'big',
 3483: 'bird',
 3665: 'bite',
 4396: 'blanket',
 3277: 'block',
 2131: 'blow',
 4742: 'blue',
 3197: 'book',
 2952: 'boots',
 4139: 'bottle',
 560: 'bowl',
 4449: 'box',
 4558: 'boy',
 1675: 'bread',
 2771: 'break',
 3449: 'breakfast',
 1124: 'bring',
 3566: 'broken',
 1767: 'broom',
 4866: 'brother',
 148: 'brush',
 4520: 'bubbles',
 3724: 'bug',
 672: 'bump',
 3750: 'bunny',
 3948: 'bus',
 3339: 'butter',
 3864: 'butterfly',
 1386: 'button',
 35: 'bye',
 3655: 'cake',
 1366: 'candy',
 4364: 'car',
 4373: 'careful',
 4081: 'carrots',
 4394: 'cat'

In [47]:
not_in_vocab_list

['shh/shush/hush']

In [48]:
model.eval()
word_surprisals = {}

In [49]:
for index in word_dict.keys():
    word_surprisals[word_dict[index]] = [0.0, 0]

In [50]:
batch = batch.to(device)

In [51]:
outputs = model(batch)

In [52]:
outputs

tensor([[[  8.6556, -20.8837, -22.7441,  ..., -19.6385, -23.2546, -19.7326],
         [  3.8069, -29.7853, -30.3886,  ..., -27.3332, -33.3092, -27.0355],
         [ -1.6090, -32.0675, -31.6035,  ..., -29.3202, -35.7172, -28.3883],
         ...,
         [-15.0925, -14.5203, -26.5038,  ..., -14.8871, -19.1576, -22.5090],
         [-20.4442, -22.8842, -16.8370,  ..., -26.8183, -31.4736, -20.5707],
         [-16.6114, -23.1062, -17.2503,  ..., -17.3572, -29.6178, -18.6515]],

        [[  8.6556, -20.8837, -22.7441,  ..., -19.6385, -23.2546, -19.7326],
         [  3.8069, -29.7853, -30.3886,  ..., -27.3332, -33.3092, -27.0355],
         [ -1.6090, -32.0675, -31.6035,  ..., -29.3202, -35.7172, -28.3883],
         ...,
         [-19.2608, -14.2701, -18.2179,  ..., -24.7588, -16.8931, -19.5727],
         [-19.1435, -30.1818, -27.4767,  ..., -25.2162, -32.1094, -24.1162],
         [-13.5793, -13.9133, -19.1120,  ..., -23.0277, -17.8077, -20.0976]],

        [[  8.6556, -20.8837, -22.7441,  ...

In [53]:
surprisals = -F.log_softmax(outputs, dim=2)

In [54]:
surprisals

tensor([[[2.3842e-07, 2.9539e+01, 3.1400e+01,  ..., 2.8294e+01,
          3.1910e+01, 2.8388e+01],
         [-0.0000e+00, 3.3592e+01, 3.4195e+01,  ..., 3.1140e+01,
          3.7116e+01, 3.0842e+01],
         [-0.0000e+00, 3.0458e+01, 2.9995e+01,  ..., 2.7711e+01,
          3.4108e+01, 2.6779e+01],
         ...,
         [2.4647e+01, 2.4075e+01, 3.6058e+01,  ..., 2.4441e+01,
          2.8712e+01, 3.2063e+01],
         [2.2170e+01, 2.4610e+01, 1.8563e+01,  ..., 2.8544e+01,
          3.3199e+01, 2.2296e+01],
         [2.6860e+01, 3.3355e+01, 2.7499e+01,  ..., 2.7606e+01,
          3.9866e+01, 2.8900e+01]],

        [[2.3842e-07, 2.9539e+01, 3.1400e+01,  ..., 2.8294e+01,
          3.1910e+01, 2.8388e+01],
         [-0.0000e+00, 3.3592e+01, 3.4195e+01,  ..., 3.1140e+01,
          3.7116e+01, 3.0842e+01],
         [-0.0000e+00, 3.0458e+01, 2.9995e+01,  ..., 2.7711e+01,
          3.4108e+01, 2.6779e+01],
         ...,
         [2.3695e+01, 1.8704e+01, 2.2652e+01,  ..., 2.9193e+01,
          2

In [55]:
surprisals.shape

torch.Size([1000, 125, 5001])

In [56]:
for word_index in word_dict:
    index_matches = (batch == word_index).nonzero(as_tuple=False)
    if len(index_matches) > 0:
        for i in index_matches:
            match = surprisals[tuple(i)] 
            surprisal = match[word_index].item() + sys.float_info.epsilon
            word = word_dict[word_index]
            word_surprisals[word][0] += surprisal
            word_surprisals[word][1] += 1

tensor([], size=(0, 2), dtype=torch.int64)
tensor([[ 21, 121],
        [ 29, 112],
        [ 45, 117],
        [ 45, 123],
        [ 65, 122],
        [130, 120],
        [171, 122],
        [222, 118],
        [233, 118],
        [252, 118],
        [331, 120],
        [350, 121],
        [359, 120],
        [377, 123],
        [384, 119],
        [401, 120],
        [419, 119],
        [421, 121],
        [429, 123],
        [448, 121],
        [490, 122],
        [541, 123],
        [579, 122],
        [633, 120],
        [748, 120],
        [755, 119],
        [777, 120],
        [875, 123],
        [907, 121],
        [925, 123],
        [943, 123],
        [944, 121],
        [954, 120]])
tensor([], size=(0, 2), dtype=torch.int64)
tensor([[223, 115],
        [536, 121],
        [656, 118],
        [741, 115]])
tensor([], size=(0, 2), dtype=torch.int64)
tensor([], size=(0, 2), dtype=torch.int64)
tensor([[ 61, 121]])
tensor([[292, 111]])
tensor([[333, 119],
        [552, 123],
    

tensor([], size=(0, 2), dtype=torch.int64)
tensor([], size=(0, 2), dtype=torch.int64)
tensor([[376, 121],
        [383, 116],
        [385, 123]])
tensor([], size=(0, 2), dtype=torch.int64)
tensor([], size=(0, 2), dtype=torch.int64)
tensor([[129, 119]])
tensor([], size=(0, 2), dtype=torch.int64)
tensor([], size=(0, 2), dtype=torch.int64)
tensor([], size=(0, 2), dtype=torch.int64)
tensor([], size=(0, 2), dtype=torch.int64)
tensor([[797, 122]])
tensor([], size=(0, 2), dtype=torch.int64)
tensor([[549, 122]])
tensor([[ 54, 120],
        [ 65, 120],
        [ 74, 121],
        [221, 121],
        [305, 118],
        [353, 119],
        [365, 120],
        [479, 121],
        [586, 116],
        [665, 120],
        [719, 120],
        [722, 123],
        [736, 121],
        [803, 121],
        [919, 121],
        [931, 115],
        [967, 120]])
tensor([[ 63, 121],
        [141, 123],
        [202, 122],
        [290, 122],
        [354, 120],
        [427, 123],
        [436, 118],
        

tensor([], size=(0, 2), dtype=torch.int64)
tensor([[566, 123]])
tensor([[668, 123]])
tensor([], size=(0, 2), dtype=torch.int64)
tensor([], size=(0, 2), dtype=torch.int64)
tensor([[382, 123]])
tensor([], size=(0, 2), dtype=torch.int64)
tensor([], size=(0, 2), dtype=torch.int64)
tensor([], size=(0, 2), dtype=torch.int64)
tensor([[526, 119]])
tensor([], size=(0, 2), dtype=torch.int64)
tensor([], size=(0, 2), dtype=torch.int64)
tensor([], size=(0, 2), dtype=torch.int64)
tensor([[231, 116],
        [231, 122]])
tensor([], size=(0, 2), dtype=torch.int64)
tensor([[286, 114],
        [678, 121],
        [795, 116]])
tensor([], size=(0, 2), dtype=torch.int64)
tensor([], size=(0, 2), dtype=torch.int64)
tensor([], size=(0, 2), dtype=torch.int64)
tensor([], size=(0, 2), dtype=torch.int64)
tensor([], size=(0, 2), dtype=torch.int64)
tensor([], size=(0, 2), dtype=torch.int64)
tensor([], size=(0, 2), dtype=torch.int64)
tensor([[957, 123]])
tensor([[122, 122],
        [222, 117],
        [224, 119],
  

In [57]:
word_surprisals

{'airplane': [0.0, 0],
 'all': [0.0003701426740081448, 33],
 'animal': [0.0, 0],
 'another': [0.00016760474682175897, 4],
 'apple': [0.0, 0],
 'arm': [0.0, 0],
 'asleep': [0.0006348263123074904, 1],
 'aunt': [0.004882674664259179, 1],
 'away': [0.00027810476604006595, 4],
 'baby': [3.170916897965448e-05, 1],
 'babysitter': [0.0, 0],
 'back': [0.0002486674729902383, 10],
 'bad': [0.0016356413834739403, 5],
 'ball': [0.0, 0],
 'balloon': [0.0, 0],
 'banana': [0.0, 0],
 'bath': [0.00023255028645530196, 1],
 'bathroom': [0.0, 0],
 'bathtub': [0.0, 0],
 'beach': [0.0, 0],
 'beads': [0.0, 0],
 'bear': [0.0, 0],
 'bed': [8.380061626689361e-05, 1],
 'bedroom': [0.0, 0],
 'bee': [0.0, 0],
 'bib': [0.0, 0],
 'bicycle': [0.0, 0],
 'big': [9.524750839928764e-05, 7],
 'bird': [0.0, 0],
 'bite': [0.0, 0],
 'blanket': [0.0, 0],
 'block': [0.0, 0],
 'blow': [0.0, 0],
 'blue': [0.0001420873741155848, 1],
 'book': [0.00010072907025504563, 2],
 'boots': [0.0, 0],
 'bottle': [0.0, 0],
 'bowl': [0.0, 0],
 

In [None]:
    model.eval()
    word_surprisals = {}
    for index in word_dict.keys():
        word_surprisals[word_dict[index]] = [0.0, 0]
    for utt in all_data:
        utt_tensor = torch.tensor(utt)[None, :]
        utt_tensor = utt_tensor.to(device)
        outputs = model(utt_tensor)
        surprisals = -F.log_softmax(outputs, dim=2)
        utt_tensor = torch.squeeze(utt_tensor)
        for word_index in word_dict:
            index_matches = (utt_tensor == word_index).nonzero(as_tuple=False)
            if len(index_matches) > 0:
                for i in index_matches:
                    match = i.item()
                    surprisal = (surprisals[0][match][word_index] + sys.float_info.epsilon).item()
                    word = word_dict[word_index]
                    word_surprisals[word][0] += surprisal
                    word_surprisals[word][1] += 1

In [54]:
for word in not_in_vocab_list:
        word_surprisals[word] = [0.0, 0]

In [55]:
import csv
def save_surprisals_as_csv(surprisals, experiment_dir):
    with open(experiment_dir + "average_surprisals.csv", mode='w') as csv_file:
        writer = csv.writer(csv_file, delimiter=',')
        writer.writerow(["word", "surprisal_value", "n_instances"])
        for word in surprisals:
            _sum, n = surprisals[word]
            if n == 0:
                writer.writerow([word, 'NA', 'NA'])
            else:
                avg = _sum/n
                writer.writerow([word, f"{avg:.16f}" , str(n)])

In [56]:
save_surprisals_as_csv(word_surprisals, "./")

In [None]:
    params = get_parameters()
    #May add batching, optimize, and cuda support
    device = torch.device('cuda') if params.gpu_run == True else torch.device('cpu')
    vocabulary = utils.open_pkl(params.encoding_dictionary_path)
    #word_list = set(utils.open_word_list_csv(params.aoa_word_list))
    #in_word_list_not_vocab = word_list - set(vocabulary.keys())
    #vocab_word_list_intersection = word_list - in_word_list_not_vocab
    model = torch.load(os.path.join(params.experiment_dir, params.model))
    model = model.to(device)
    all_data = utils.open_pkl(params.all_child_directed_data_path)
    word_list = utils.open_word_list_csv(params.aoa_word_list)
    word_dict, not_in_vocab_list = make_word_dict(word_list, vocabulary)
    average_surprisals = get_surprisals(model, all_data, word_dict, device)
    utils.save_surprisals_as_csv(average_surprisals, params.experiment_dir)
    print(average_surprisals)

## Bert extraction

In [2]:
import argparse
import os
import torch
import torch.nn.functional as F
from bert_custom_dataset import CHILDESDataset
import csv

In [3]:
import operator
import functools

In [4]:
from transformers import BertLMHeadModel
#model = torch.load(os.path.join(params.experiment_dir, params.model))

In [5]:
from torch.utils.data import DataLoader

In [20]:
data_path="../../../Data/model_datasets/eng/validation.txt"
gpu_run=False
batch_size=10
aoa_word_list="../../../Data/model_datasets/eng/aoa_words.csv"
experiment_dir="../../../Results/experiments/"
model="model.pt"
split="validation"

In [21]:
device = torch.device('cuda') if gpu_run == True else torch.device('cpu')

In [22]:
torch.cuda.empty_cache()

In [23]:
model = BertLMHeadModel.from_pretrained("bert-base-multilingual-uncased", return_dict=True, is_decoder = True)
model = model.to(device)

Some weights of the model checkpoint at bert-base-multilingual-uncased were not used when initializing BertLMHeadModel: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertLMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertLMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
data = CHILDESDataset(data_path)

In [40]:
dataloader = DataLoader(data, batch_size=batch_size)

In [41]:
dl = enumerate(dataloader)

In [42]:
word_list = utils.open_word_list_csv(aoa_word_list)

In [13]:
def make_token_word_pairs(aoa_word_list, dataset):
    tokenizer = dataset.tokenizer
    word_pairs = []
    for word in aoa_word_list:
        seq = tokenizer(word)['input_ids']
        token = torch.Tensor(seq[1:-1]).long()
        word_pairs.append((token, word))
    return word_pairs

In [14]:
word_pairs = make_token_word_pairs(word_list, data)

In [15]:
word_pairs

[(tensor([89498]), 'airplane'),
 (tensor([10367]), 'all'),
 (tensor([15883]), 'animal'),
 (tensor([12162]), 'another'),
 (tensor([17006]), 'apple'),
 (tensor([20999]), 'arm'),
 (tensor([23455, 50040]), 'asleep'),
 (tensor([69449]), 'aunt'),
 (tensor([13795]), 'away'),
 (tensor([15719]), 'baby'),
 (tensor([15719, 88984]), 'babysitter'),
 (tensor([11677]), 'back'),
 (tensor([12428]), 'bad'),
 (tensor([14918]), 'ball'),
 (tensor([14918, 15845]), 'balloon'),
 (tensor([64916]), 'banana'),
 (tensor([37556]), 'bath'),
 (tensor([37556, 32038]), 'bathroom'),
 (tensor([37556, 60024]), 'bathtub'),
 (tensor([14575]), 'beach'),
 (tensor([10346, 82493]), 'beads'),
 (tensor([21364]), 'bear'),
 (tensor([24433]), 'bed'),
 (tensor([24433, 32038]), 'bedroom'),
 (tensor([18560]), 'bee'),
 (tensor([10863, 10417]), 'bib'),
 (tensor([68223]), 'bicycle'),
 (tensor([12062]), 'big'),
 (tensor([17352]), 'bird'),
 (tensor([16464, 10111]), 'bite'),
 (tensor([61752, 10337]), 'blanket'),
 (tensor([18612]), 'block'),

In [6]:
model.eval()

AttributeError: 'str' object has no attribute 'eval'

In [72]:
def indexes_in_sequence(query, base):
    id_, label = base[0], base[1]
    label = label.squeeze()
    l = len(query)
    locations = []
    for i in range((len(label)-l)):
        if torch.all(label[i:i+l] == query):
            locations.append([id_, i])            
    return locations

In [49]:
n, batch = next(dl)

In [210]:
def get_batched_surprisals(model, dataloader, word_pairs, device):
    model.eval()
    word_surprisals = {}
    for index, word in word_pairs:
        word_surprisals[word] = [0.0, 0]
    batch_size = dataloader.batch_size
    for n, batch in enumerate(dataloader):
        print(n)
        for key in batch:
            batch[key] = batch[key].to(device)
        outputs = model(**batch)
        surprisals = -F.log_softmax(outputs.logits, -1)
        labels = batch['labels']
        labels_split = torch.tensor_split(labels, batch_size)
        for indexes, word in word_pairs:
            indexes = indexes.to(device)
            print(indexes)
            match_list = list(map(lambda x: indexes_in_sequence(indexes, x), enumerate(labels_split)))
            index_matches = functools.reduce(operator.iconcat, match_list)
            if len(index_matches) > 0:
                for i in index_matches:
                    surprisal = 1.0
                    for j, index in enumerate(indexes):
                        id_ = i
                        id_[1] += j
                        match = surprisals[tuple(id_)]
                        sub_surprisal = match[index].item()
                        surprisal *= sub_surprisal
                    word_surprisals[word][0] += (surprisal + sys.float_info.epsilon)
                    word_surprisals[word][1] += 1
    return word_surprisals

In [None]:
    params = get_parameters()
    device = torch.device('cuda') if params.gpu_run == True else torch.device('cpu')
    model = torch.load(os.path.join(params.experiment_dir, params.model))
    model = model.to(device)
    data = Dataset(params.data_path)
    dataloader = DataLoader(data, batch_size=params.batch_size)
    word_list = utils.open_word_list_csv(params.aoa_word_list)
    word_pairs = make_token_word_pairs(word_list, data)
    word_surprisals = get_batched_surprisals(model, dataloader, word_pairs, device)
    file_name = params.split + "_average_surprisals.csv"
    utils.save_surprisals_as_csv(word_surprisals, params.experiment_dir, file_name)

In [25]:
!nvidia-smi

Tue Aug 24 14:33:17 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.23.04    Driver Version: 455.23.04    CUDA Version: 11.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 3080    Off  | 00000000:01:00.0 Off |                  N/A |
|  0%   41C    P2    85W / 320W |   9415MiB / 10015MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [12]:
10015-1809

8206

In [98]:
print(torch.__version__)

1.8.1


In [99]:
torch.tensor_split(labels, batch_size)


(tensor([[  101, 10110, 11523, 12266, 10855, 10574, 10160, 10103, 11727,   136,
            102,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0]],
        device='cuda:0'),
 tensor([[  101, 27948, 34026, 12172,   119,   102,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0]],
        device='cuda:0'),
 tensor([[  101, 10110, 12266, 10855, 11811, 32527,   136,   102,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0

In [2]:
import sys
sys.path.append('..')

import os
import torch
import torch.nn.functional as F
from bert_custom_dataset import CHILDESDataset
import operator
import functools
from torch.utils.data import DataLoader


In [3]:
import utils

In [4]:
data_path="../../../Data/model_datasets/eng/validation.txt"
gpu_run=True
batch_size=6
aoa_word_list="../../../Data/model_datasets/eng/aoa_words.csv"
experiment_dir="../../../Results/experiments/2021-08-31_finetune_eng_e1_b6_lr5e-5_run0/"
model="model.pt"
split="validation"

In [5]:
def make_token_word_pairs(aoa_word_list, dataset):
    tokenizer = dataset.tokenizer
    word_pairs = []
    for word in aoa_word_list:
        seq = tokenizer(word)['input_ids']
        token = torch.Tensor(seq[1:-1]).long()
        word_pairs.append((token, word))
    return word_pairs

In [6]:
def indexes_in_sequence(query, base):
    id_, label = base[0], base[1]
    label = label.squeeze()
    l = len(query)
    locations = []
    for i in range((len(label)-l)):
        if torch.all(label[i:i+l] == query):
            locations.append([id_, i])            
    return locations

In [23]:
def get_batched_surprisals(model, dataloader, word_pairs, device):
    model.eval()
    word_surprisals = {}
    for index, word in word_pairs:
        word_surprisals[word] = [0.0, 0]
    batch_size = dataloader.batch_size
    for n, batch in enumerate(dataloader):
        if n % 100 == 0:
            print(n)
        for key in batch:
            batch[key] = batch[key].to(device)
        outputs = model(**batch)
        surprisals = -F.log_softmax(outputs.logits, -1)
        labels = batch['labels']
        labels_split = torch.tensor_split(labels, batch_size)
        for indexes, word in word_pairs:
            indexes = indexes.to(device)
            if len(indexes) == 1:
                index_matches = (labels == indexes).nonzero(as_tuple=False)
                if len(index_matches) > 0:
                    for i in index_matches:
                        match = surprisals[tuple(i)]
                        surprisal = match[indexes].item() + sys.float_info.epsilon
                        word_surprisals[word][0] += surprisal
                        word_surprisals[word][1] += 1    
            else:
                match_list = list(map(lambda x: indexes_in_sequence(indexes, x), enumerate(labels_split)))
                index_matches = functools.reduce(operator.iconcat, match_list)
                if len(index_matches) > 0:
                    for i in index_matches:
                        surprisal = 1.0
                        for j, index in enumerate(indexes):
                            id_ = i
                            id_[1] += j
                            match = surprisals[tuple(id_)]
                            sub_surprisal = match[index].item()
                            surprisal *= sub_surprisal
                        word_surprisals[word][0] += (surprisal + sys.float_info.epsilon)
                        word_surprisals[word][1] += 1
    return word_surprisals

In [8]:
torch.cuda.empty_cache()

In [9]:
device = torch.device('cuda') if gpu_run == True else torch.device('cpu')
model = torch.load(os.path.join(experiment_dir, model))
model = model.to(device)

In [10]:
data = CHILDESDataset(data_path)

In [11]:
dataloader = DataLoader(data, batch_size=batch_size)

In [12]:
len(dataloader)

97684

In [13]:
word_list = utils.open_word_list_csv(aoa_word_list)
word_pairs = make_token_word_pairs(word_list, data)

In [24]:
word_surprisals = get_batched_surprisals(model, dataloader, word_pairs, device)

0
100


KeyboardInterrupt: 

In [None]:
file_name = split + "_average_surprisals.csv"
utils.save_surprisals_as_csv(word_surprisals, experiment_dir, file_name)