In [31]:
%pwd

'C:\\Users\\brand\\Documents\\Projects\\XLdefgen\\model'

In [33]:
from dataclasses import dataclass

@dataclass
class args:
    data_task: str = "definition"
    gradient_accumulation_steps: int = 2
    learning_rate: float = 2e-4
    log_frequency: int = 1000
    max_source_length: int = 64
    max_target_length: int = 64
    model_name_or_path: str = "google/mt5-small"
    num_train_epochs: int = 3
    num_warmup_steps: int = 0
    output_dir: str = "wandb_run"
    pad_to_max_length: bool = False
    per_device_train_batch_size: int = 2
    per_device_eval_batch_size: int = 8
    report_to: str = "wandb"
    seed: int = 42
    source_lang: str = "en"
    source_prefix: str = ""
    target_lang: str = "en"
    validation_file: str = "../data/codwoe/codwoe_temp_en_mask.json"
    wandb_proj: str = "def_train"
    weight_decay: float = 0.01

In [64]:
# from transformers import AutoTokenizer, AutoModel
import torch
from custom_classes_and_fxns import (
    TokenizerWithXMask,
    MT5WithXMask,
    prepare_for_xattn,
    remove_def_markers
    )

checkpoint = "google/mt5-small"
tokenizer = TokenizerWithXMask.from_pretrained(checkpoint)
model = MT5WithXMask.from_pretrained(checkpoint, output_attentions=True)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'T5Tokenizer'. 
The class this function is called from is 'TokenizerWithXMask'.


In [69]:
special_tokens_dict = {"mask_token": "<MASK>", "sep_token": " <MASK>"}
tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
print("Vocab size:", len(tokenizer))

max_target_length = args.max_target_length
padding = "max_length" if args.pad_to_max_length else False


inputs = ["During <MASK> adolescence <MASK> , the body and mind go through many complex changes , some of which are difficult to deal with ."]
targets = ["The transitional period of physical and psychological development between childhood and maturity ."]

# Prepare inputs and targets
model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True)
with tokenizer.as_target_tokenizer():
    labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
if padding == "max_length" and args.ignore_pad_token_for_loss:
    labels["input_ids"] = [[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]]
model_inputs["labels"] = labels["input_ids"]

# Filter out inputs that are too long?


# Add cross-attention mask and remove temp definiendum markers - uses prepare_for_xattn() code

example = model_inputs
def_ids = tokenizer.convert_tokens_to_ids(["<MASK>", " <MASK>", tokenizer.eos_token])
def_indices = []
sent = example['input_ids'][0]

for i, token_id in enumerate(sent):
    if token_id in def_ids:
        print(i)
        def_indices.append(i)

assert len(def_indices) == 3, "Definiendum span not found. def_indices should consist of 3 integers but is instead " + str(def_indices) + " (" + str(len(sent)) + ")\n" + tokenizer.decode(sent)
begin,end = def_indices[:2]
eos_index = def_indices[-1]

# Mask everything except for definiendum
cross_attention_mask = [0]*len(sent)
cross_attention_mask[begin:end] = [1]*(end-begin)
cross_attention_mask[eos_index] = 1
example['cross_attention_mask'] = [cross_attention_mask]


# Remove definiendum markers - uses remove_def_markers() code

example['input_ids'][0].pop(end)
example['input_ids'][0].pop(begin)
example['attention_mask'][0].pop(end)
example['attention_mask'][0].pop(begin)
example['cross_attention_mask'][0].pop(end)
example['cross_attention_mask'][0].pop(begin)
    
xattn_datasets = {'input_ids': torch.tensor(example['input_ids']),
                  'attention_mask': torch.tensor(example['attention_mask']),
                  'cross_attention_mask': torch.tensor(example['cross_attention_mask']),
                  'labels': torch.tensor(example['labels'])
                 }

xattn_datasets

Vocab size: 250102
2
5
31


{'input_ids': tensor([[  9155,    347, 142387,    541,    259,    261,    287,   8658,    305,
            4047,   1002,   3026,   3506,  13814,    259,  25444,    259,    261,
            2155,    304,    259,   1542,    418,  10378,    288,  19869,    514,
             259,    260,      1]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1]]),
 'cross_attention_mask': tensor([[0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 1]]),
 'labels': tensor([[   486,    259,  10091,    473,   8192,    304,    259,  28223,    305,
           95513,   6906,  10030,    259,   4964,  19878,  19031,    305,    259,
          154819,    276,    259,    260,      1]])}

In [75]:
outputs = model(**xattn_datasets)

encoder_text = tokenizer.convert_ids_to_tokens(xattn_datasets['input_ids'][0])
decoder_text = tokenizer.convert_ids_to_tokens(xattn_datasets['labels'][0])

In [77]:
from bertviz import head_view
head_view(
    encoder_attention=outputs.encoder_attentions,
    decoder_attention=outputs.decoder_attentions,
    cross_attention=outputs.cross_attentions,
    encoder_tokens=encoder_text,
    decoder_tokens=decoder_text
)

<IPython.core.display.Javascript object>

In [76]:
from bertviz import model_view
model_view(
    encoder_attention=outputs.encoder_attentions,
    decoder_attention=outputs.decoder_attentions,
    cross_attention=outputs.cross_attentions,
    encoder_tokens=encoder_text,
    decoder_tokens=decoder_text
)

<IPython.core.display.Javascript object>