# Experiment with Atlas answer generation for ParaRel

In [2]:
import json
from transformers import AutoTokenizer, T5ForConditionalGeneration
import torch
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


## Load T5 for testing

In [3]:
tokenizer = AutoTokenizer.from_pretrained("google/t5-base-lm-adapt")
model = T5ForConditionalGeneration.from_pretrained("google/t5-base-lm-adapt")

In [17]:
input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
with torch.no_grad():
    outputs = model(input_ids=input_ids, labels=labels)
loss = outputs.loss
logits = outputs.logits

torch.Size([1, 9, 32128])

In [30]:
labels.shape

torch.Size([1, 9])

In [31]:
[tokenizer.decode(val) for val in labels[0]]

['<extra_id_0>',
 'cute',
 'dog',
 '',
 '<extra_id_1>',
 'the',
 '',
 '<extra_id_2>',
 '</s>']

In [19]:
logits.shape

torch.Size([1, 9, 32128])

In [29]:
[tokenizer.decode(logit.argmax()) for logit in logits[0,:]]

['.', 'park', 'little', 'park', 'walker', '.', '', 's', '.']

## Small options test with logits

In [35]:
query = "The Eiffel Tower is located in <extra_id_0>."
options = ["Paris","China","Sweden","Greece","Shoe","Canada","France","here","horse"]

input_ids = tokenizer(query, return_tensors="pt").input_ids

option_losses = []
for option in options:
    tmp_label = tokenizer(f"<extra_id_0> {option}", return_tensors="pt").input_ids
    with torch.no_grad():
        outputs = model(input_ids=input_ids, labels=tmp_label)
    option_losses.append(outputs.loss)
    
option_losses    

[tensor(8.5687),
 tensor(11.2431),
 tensor(11.9538),
 tensor(11.9589),
 tensor(9.8689),
 tensor(11.2803),
 tensor(9.4624),
 tensor(11.6552),
 tensor(13.3678)]

## Load ParaRel data

In [40]:
query_file = "/cephyr/users/lovhag/Alvis/projects/pararel/data/all_n1_atlas_no_space/P138_100.jsonl"
options_file = "/cephyr/users/lovhag/Alvis/projects/pararel/data/all_n1_atlas_no_space/P138_100_options.txt"

queries = []
with open(query_file) as f:
    for line in f.readlines():
        queries.append(json.loads(line))

options = []
with open(options_file) as f:
    for line in f.readlines():
        options.append(line.strip())

In [41]:
queries[8]

{'question': 'aristotelianism, named for<extra_id_0>.',
 'sub_label': 'aristotelianism',
 'answers': ['Aristotle'],
 'pattern': '[X], named for [Y].'}

## Experiment with adding sentinel id on token level

In the Atlas code, the masked language modelling is done as follows:

```python
    sentinel_id = tokenizer.additional_special_tokens_ids[i]
    inputs += tokens[offset : offset + inp_length] + [sentinel_id]
    offset += inp_length
    outputs += [sentinel_id] + tokens[offset : offset + out_length]
    offset += out_length

tokenizer.decode(inputs), tokenizer.decode(outputs)            
```

Meaning that the text is first tokenized, and then the masking is applied. So masking is done on the input ids level. Potentially, this means that the surrounding token ids look differently, compared to if the masking is done on text level. We will here investigate this.

In [43]:
def get_mlm_approach_input(question_with_sentinel_id, answer_option):
    sentinel_id = tokenizer.additional_special_tokens_ids[0]
    full_example = question_with_sentinel_id.replace("<extra_id_0>", " "+answer_option) #need to add our own space
    full_example_ids = tokenizer(full_example)["input_ids"]
    answer_ids = tokenizer(answer_option, add_special_tokens=False)["input_ids"]
    answer_ix = None
    # find where ids match the answer ids
    for i in range(len(full_example_ids)-len(answer_ids)):
        if full_example_ids[i:i+len(answer_ids)]==answer_ids:
            answer_ix = i
            break
    if answer_ix is None:
        raise ValueError(f"found no matching answer index in '{full_example_ids}' for '{answer_ids}'")
    question_tokens = full_example_ids[:answer_ix]+[sentinel_id]+full_example_ids[answer_ix+len(answer_ids):]
    answer_tokens = [sentinel_id] + answer_ids
    return tokenizer.decode(question_tokens), tokenizer.decode(answer_tokens)

In [44]:
for query in queries:
    for option in options:
        token_level_input, token_level_output = get_mlm_approach_input(query['question'], option)
        text_level_input = tokenizer.decode(tokenizer(query['question'])["input_ids"])
        text_level_output = f"<extra_id_0> {option}"
        if not token_level_input == text_level_input:
            print("Mismatched token vs. text level input found.")
            print(f"Token level: {token_level_input}")
            print(f"Text level: {text_level_input}")
        if not token_level_output == text_level_output:
            print("Mismatched token vs. text level output found.")
            print(f"Token level: {token_level_output}")
            print(f"Text level: {text_level_output}")

Mismatched token vs. text level input found.
Token level: <extra_id_0>ine epistles is named after Paul.</s>
Text level: Pauline epistles is named after<extra_id_0>.</s>
Mismatched token vs. text level input found.
Token level: <extra_id_0>ine epistles was named after Paul.</s>
Text level: Pauline epistles was named after<extra_id_0>.</s>
Mismatched token vs. text level input found.
Token level: <extra_id_0>ine epistles is named for Paul.</s>
Text level: Pauline epistles is named for<extra_id_0>.</s>
Mismatched token vs. text level input found.
Token level: <extra_id_0>ine epistles was named for Paul.</s>
Text level: Pauline epistles was named for<extra_id_0>.</s>
Mismatched token vs. text level input found.
Token level: <extra_id_0>ine epistles, which is named after Paul.</s>
Text level: Pauline epistles, which is named after<extra_id_0>.</s>
Mismatched token vs. text level input found.
Token level: <extra_id_0>ine epistles, which was named after Paul.</s>
Text level: Pauline epistles,

In [37]:
text_level_input

'Baron de Hirsch Cemetery, Halifax, which is located in<extra_id_0>.</s>'

In [39]:
text_level_output

'<extra_id_0> Yemen'

In [30]:
tokenizer.decode([13913, 13979, 6, 1069, 16, 371, 17279, 5, 1])

'Kings Domain, located inFiji.</s>'