In [1]:
from transformers import AutoTokenizer, AutoModel, utils, AutoModelForSeq2SeqLM
from datasets.load import load_dataset
from bertviz import model_view
utils.logging.set_verbosity_error()  # Suppress standard warnings
from src.utils import (
    linearise_input, convert_to_features, form_stepwise_input, 
    simplify_feat_names,
    label_qs,
    simplify_narr_question
)
import torch

lin = 'ord_first'
max_fts = 40
max_input_len = 500
model_name = "../models/t5-base/sleek-haze-118/checkpoint-1880"

# Load model, tokenizer and dataset
model = AutoModelForSeq2SeqLM.from_pretrained(model_name,output_attentions=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)
dataset = load_dataset("james-burton/textual-explanations", split='test')

dataset = dataset.map(lambda x: simplify_narr_question(label_qs(x)),
                        load_from_cache_file=False)

# Form the linearised or stepwise (and linearised) input
dataset = dataset.map(
    lambda x: linearise_input(x, lin, max_fts),
    load_from_cache_file=False
    ) 

# Convert to tokens
dataset = dataset.map(
    lambda x: convert_to_features(x, tokenizer, max_input_len), 
    batched=True, load_from_cache_file=False
    )


  from .autonotebook import tqdm as notebook_tqdm
2023-01-16 16:06:40.859826: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-01-16 16:06:41.343846: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/cuda/lib64:/home/james/Downloads/TensorRT-8.5.1.7/lib
2023-01-16 16:06:41.343895: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/cu

In [2]:
from transformers import AutoTokenizer, AutoModel, utils
from bertviz import model_view
from nltk import sent_tokenize
from tqdm import tqdm
utils.logging.set_verbosity_error()  # Suppress standard warnings

model_name = "../models/t5-base/sleek-haze-118/checkpoint-1880"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, output_attentions=True).to('cuda')
tokenizer = AutoTokenizer.from_pretrained(model_name)
squeezed_list = []
for k in tqdm(range(len(dataset['input']))):
    input_text = dataset['input'][k]
    inputs = tokenizer.encode(input_text, return_tensors='pt')
    encoder_input = tokenizer(input_text, return_tensors="pt", add_special_tokens=True).to('cuda')
    encoder_input_ids, encoder_input_att_mask = encoder_input.input_ids,encoder_input.attention_mask

    decoder_input_ids = model.generate(input_ids=encoder_input_ids,
                                            attention_mask=encoder_input_att_mask,
                                            no_repeat_ngram_size=2,
                                            num_return_sequences=1,
                                            do_sample=True,
                                            early_stopping=True,
                                            use_cache=False,
                                            max_length=250,
                                            num_beams=5)


    outputs = model(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids)

    encoder_text = tokenizer.convert_ids_to_tokens(encoder_input_ids[0])
    decoder_text = tokenizer.convert_ids_to_tokens(decoder_input_ids[0])


    sents = sent_tokenize(tokenizer.decode(decoder_input_ids[0]))
    lens = [len(tokenizer.tokenize(i)) for i in sents]

    output_splits = [0]
    for i in range(len(sents)-1):
        if i == 0:
            output_splits.append(lens[i])
        else:
            output_splits.append(lens[i]+output_splits[i])

    output_splits.extend([-1, None])
            
    splits = []
    start_idx = 0
    for i in range(4):
        start_idx = encoder_text.index('▁|', start_idx + 1)

    input_splits = [0, start_idx, encoder_text.index('▁1.'), encoder_text.index('▁2.'), encoder_text.index('▁3.'), encoder_text.index('▁4.'), -1, None]

    unsqueezed = outputs.cross_attentions[11][0].mean(dim=0)
    squeezed = torch.zeros(len(output_splits)-1, len(input_splits)-1)

    # Take an average from the ith:ith+1 output to the jth: jth+1 input
    for i in range(len(output_splits)-1):
        for j in range(len(input_splits)-1):
            squeezed[i, j] = unsqueezed[output_splits[i]:output_splits[i+1], input_splits[j]:input_splits[j+1]].sum(axis=1).mean()

    squeezed_list.append(squeezed)
    
    # Squeezed now represents an n x m matrix where n is the number of output sentences and m is the number of input splits
    # Outputs are split into however many sentences there are, plus the final special token
    # inputs are split into Classification info, Feature info, each of the 4 question infos and the final special token

100%|██████████| 47/47 [02:22<00:00,  3.02s/it]


In [3]:
max_len = max([len(i) for i in squeezed_list])
new_squeezed_list = squeezed_list.copy()
# Expand the list to the max lenght
# Create a mask to indicate which elements are valid        
mask = torch.zeros(len(new_squeezed_list),max_len)
for i in range(len(new_squeezed_list)):
    mask[i, :len(new_squeezed_list[i])] = 1
    if len(new_squeezed_list[i]) < max_len:
        new_squeezed_list[i] = torch.cat((new_squeezed_list[i], torch.zeros(max_len-len(new_squeezed_list[i]), 7)), dim=0)



In [27]:
# Take the mean of torch.stack(new_squeezed_list) only when the mask is 1
# This is a bit of a hack to avoid having to pad the list
mean = torch.sum(torch.stack(new_squeezed_list)*mask.unsqueeze(-1), dim=0)/torch.sum(mask, dim=0).unsqueeze(-1)

In [39]:
mean

tensor([[0.2257, 0.2076, 0.1093, 0.1044, 0.1093, 0.0892, 0.1846],
        [0.0885, 0.2810, 0.1122, 0.1353, 0.1504, 0.0978, 0.1723],
        [0.0556, 0.3062, 0.1052, 0.1290, 0.1713, 0.1084, 0.1640],
        [0.0755, 0.2905, 0.1112, 0.1298, 0.1611, 0.1081, 0.1607],
        [0.0547, 0.3602, 0.1118, 0.1176, 0.1331, 0.0977, 0.1718],
        [0.0378, 0.3942, 0.1091, 0.0872, 0.1719, 0.0836, 0.1778]],
       grad_fn=<DivBackward0>)

In [38]:
# Something is wrong here with the 2nd element
[s.sum() for s in squeezed_list]

[tensor(5.0000, grad_fn=<SumBackward0>),
 tensor(7.1704, grad_fn=<SumBackward0>),
 tensor(4., grad_fn=<SumBackward0>),
 tensor(5., grad_fn=<SumBackward0>),
 tensor(5., grad_fn=<SumBackward0>),
 tensor(5., grad_fn=<SumBackward0>),
 tensor(5.0000, grad_fn=<SumBackward0>),
 tensor(5.0000, grad_fn=<SumBackward0>),
 tensor(5.0000, grad_fn=<SumBackward0>),
 tensor(5., grad_fn=<SumBackward0>),
 tensor(5., grad_fn=<SumBackward0>),
 tensor(6., grad_fn=<SumBackward0>),
 tensor(4.0000, grad_fn=<SumBackward0>),
 tensor(5., grad_fn=<SumBackward0>),
 tensor(5., grad_fn=<SumBackward0>),
 tensor(5., grad_fn=<SumBackward0>),
 tensor(5., grad_fn=<SumBackward0>),
 tensor(4., grad_fn=<SumBackward0>),
 tensor(5., grad_fn=<SumBackward0>),
 tensor(5.0000, grad_fn=<SumBackward0>),
 tensor(5.0000, grad_fn=<SumBackward0>),
 tensor(5.0000, grad_fn=<SumBackward0>),
 tensor(5.0000, grad_fn=<SumBackward0>),
 tensor(5., grad_fn=<SumBackward0>),
 tensor(6., grad_fn=<SumBackward0>),
 tensor(5.0000, grad_fn=<SumBackwar

In [25]:
torch.stack(new_squeezed_list).sum(axis=1).shape

torch.Size([47, 7])

In [5]:
print(input_splits)
print(output_splits)

[0, 18, 88, 96, 105, 119, -1, None]
[0, 30, 66, 100, -1, None]


In [6]:

        
# i = -1
# for j in range(len(input_splits)-1):
#     squeezed[len(output_splits)-1, j] = unsqueezed[i, input_splits[j]:input_splits[j+1]].sum()
# j = -1
# for i in range(len(output_splits)-1):
#     squeezed[i, len(input_splits)-1] = unsqueezed[output_splits[i]:output_splits[i+1], j].mean()

# j, i = -1, -1

# squeezed[i,j] = unsqueezed[i, j]
# squeezed = squeezed*100
# assert squeezed.sum() == 100

In [7]:
squeezed

tensor([[2.1284e-01, 1.5995e-01, 1.4048e-01, 1.2926e-01, 1.2642e-01, 5.3404e-04,
         2.3052e-01],
        [1.1505e-01, 2.3920e-01, 1.4098e-01, 1.4032e-01, 1.6454e-01, 3.2060e-03,
         1.9670e-01],
        [5.8429e-02, 2.3242e-01, 1.3984e-01, 1.5255e-01, 2.0289e-01, 2.4868e-03,
         2.1139e-01],
        [1.4218e-01, 2.0422e-01, 1.5156e-01, 1.4292e-01, 1.5703e-01, 1.9598e-03,
         2.0013e-01],
        [9.4377e-02, 2.6244e-01, 1.5525e-01, 1.4809e-01, 1.3989e-01, 9.2435e-05,
         1.9987e-01]], grad_fn=<CopySlices>)

In [8]:
unsqueezed[-1:None].shape

torch.Size([1, 121])

In [9]:
unsqueezed[i, input_splits[j]:input_splits[j+1]]

tensor([0.3357], device='cuda:0', grad_fn=<SliceBackward0>)

In [10]:
j

6

In [11]:
i = -1
j = 0
squeezed[len(output_splits)-1, j] = unsqueezed[i, input_splits[j]:input_splits[j+1]].sum(axis=1).mean()

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

In [None]:
unsqueezed[i, input_splits[j]:input_splits[j+1]].shape

torch.Size([19])

In [None]:
torch.set_printoptions(precision=2, sci_mode=False)
squeezed

tensor([[0.01, 0.00, 0.01, 0.01, 0.00, 0.02, 0.18],
        [0.00, 0.00, 0.01, 0.02, 0.01, 0.02, 0.15],
        [0.00, 0.00, 0.02, 0.01, 0.01, 0.02, 0.20],
        [0.00, 0.00, 0.01, 0.02, 0.00, 0.02, 0.15],
        [0.00, 0.00, 0.01, 0.01, 0.00, 0.02, 0.15],
        [0.00, 0.00, 0.01, 0.01, 0.00, 0.02, 0.16]], grad_fn=<CopySlices>)

In [None]:
unsqueezed.sum()

tensor(128., device='cuda:0', grad_fn=<SumBackward0>)

In [None]:
outputs.cross_attentions[11][0].mean(dim=0).shape

# Average across the splits output_splits and input_splits
for sp in range(len(output_splits)):
    if sp == 0:
        start = 0
        end = output_splits[sp]
    else:
        start = output_splits[sp-1]
        end = output_splits[sp]
    print(start, end)
    print(decoder_text[start:end])
    print('')

[39, 65, 87, 114, 128]

In [None]:
outputs.cross_attentions[0][0][11].argmax(dim=1)

tensor([ 19, 259,   1,   1,  12, 259, 259, 262, 259, 259,   6,   2, 238,   2,
        259, 160, 259, 259, 259, 259, 259,   1, 259, 259, 259, 261, 262, 259,
        259, 262, 259, 239,   2, 262, 259, 259,   3, 259, 259, 262, 262, 259,
        259, 248, 248, 259, 259, 262,   0, 248, 261, 259,   2, 239, 259,  24,
        259, 259,  24, 262, 259, 262,  24, 259, 259,  24, 262, 259, 262, 262,
        262, 262, 259, 262, 259, 258, 259, 259, 247, 262,  12, 259, 259,  13,
        257,  12, 259, 262, 244, 262, 259, 238, 262, 239, 262, 259,   2, 248,
        259, 259, 262, 259, 257, 262, 214, 259,   2, 189,   2, 259, 259,   2,
        259, 237, 259, 259, 262, 259, 259, 262, 259, 262, 262, 247, 258, 248,
        259, 262], device='cuda:0')

In [None]:
outputs.cross_attentions[11][0][11].shape

torch.Size([128, 263])

In [None]:
# model_view(
#     encoder_attention=outputs.encoder_attentions,
#     decoder_attention=outputs.decoder_attentions,
#     cross_attention=outputs.cross_attentions,
#     encoder_tokens= encoder_text,
#     decoder_tokens = decoder_text,
#     include_layers=[11]
# )