<a href="https://colab.research.google.com/github/jsvan/MaskTests4LanguageModels/blob/main/Summerwinograndeattempt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

How does a language model understand words? It represents input text in a rich and complex feature vector. Simple word vectors learn co-occurence probabilities between words, and represent each word in context with every other word. Better embeddings use the maximum space allowed to represent words as far away from each other as possible. 

When Roberta trains on Winogrande, how is it able to differentiate the different words in its task? For example, 
> "The employees threw a [party] and drank so much [alcohol] that they could not go into work the next day. The _ was loud. "

"Party" is represented very differently in vector space than "Alcohol". It should be no problem for the classifier to differentiate those objects. What happens if we mask the objects, thereby removing the representational advantage these words have. Will the classifier get confused between what "mask1" and "mask2" refer to? Do different masking techniques alter accuracy to any noticable degree?

In [None]:
%%capture
!pip install transformers
!pip install datasets 

In [None]:
%%capture
from datasets import load_dataset, concatenate_datasets

# You can switch between these two datasets
dataset = load_dataset("winogrande", 'winogrande_debiased')
#dataset = load_dataset("winogrande", 'winogrande_l')

# dataset is dict with keys ['train', 'test', 'validation']
# Each with an enumerable of 
"""
{'answer': '2',
 'option1': 'Kyle',
 'option2': 'Logan',
 'sentence': "Kyle doesn't wear leg warmers to bed, while Logan almost always does. _ is more likely to live in a colder climate."}

"""
# Use Validation instead of Test because Test lacks labels.

In [None]:
dataset['validation'][4]

{'answer': '1',
 'option1': 'Jeffrey',
 'option2': 'Hunter',
 'sentence': 'At night, Jeffrey always stays up later than Hunter to watch TV because _ wakes up late.'}

In [None]:
# Textizer
"""
    Each sentence was split on "_" placeholder symbol.
    Each option was concatenated with the second part of the split, thus transforming each example into two text segment pairs.
    Text segment pairs corresponding to correct and incorrect options were marked with True and False labels accordingly.
    Text segment pairs were shuffled thereafter.

"""

from datasets import Dataset

def prepare_data(dataset):

  # internal function
  def prep_ds(dataset):
    sentences, answers, o1, o2 = [], [], [], []
    for p in dataset:
      s1 = p['option1'].join(p['sentence'].split('_'))
      s2 = p['option2'].join(p['sentence'].split('_'))
      a1 = int(p['answer'] == '1')
      a2 = int(p['answer'] == '2')

      
      sentences.append(s1)
      answers.append(a1)
      sentences.append(s2)
      answers.append(a2)
      o1.append(p['option1'])
      o2.append(p['option2'])
      o1.append(p['option1'])
      o2.append(p['option2'])

    return {'sentence':sentences, 'labels':answers, 'option1':o1, 'option2':o2}
    # end internal function

  train = prep_ds(dataset["train"])
  test = prep_ds(dataset["validation"])
  trainds = Dataset.from_dict( train ).shuffle()
  testds = Dataset.from_dict( test ).shuffle()

  return {"train":trainds, "test":testds, "name":"Standard Dataset"}





def mask_copy(dataset):
  sentences = []
  toprint = 10

  for p in dataset:
    if toprint > 0:
      print(f"[{p['option1']}], [{p['option2']}], [{p['sentence']}]")

    sentences.append(p['sentence'].replace(p['option1'], 'option1').replace(p['option2'], 'option2'))
    
    if toprint > 0:
      print(sentences[-1])
      toprint -= 1

  build = {'sentence':sentences, 'labels':dataset['labels'], 'option1':dataset['option1'], 'option2':dataset['option2']}
  return Dataset.from_dict(build) # DON'T SHUFFLE





def mask_datasets(dataset):
  return {"train":mask_copy(dataset['train']), "test":mask_copy(dataset['test']), "name":"Masked Dataset"}




In [None]:
%%capture
#dicts of {'train', 'test'}
std_datasets = prepare_data(dataset)

masked_datasets = mask_datasets(std_datasets)

If we peak at the two datasets, we can see that the masking did indeed work.

In [None]:
std_datasets['train']['sentence'][:5]

["Jason doesn't know how the world works as well as Kevin because Jason had a really tough childhood.",
 "Jerry wanted to get lotion out of the bottle, but couldn't because the lotion was gone.",
 'Samantha is having to replace their vacuum bags frequently but not Betty as Betty likes to clean the house.',
 'The statue fit in the middle of the shelf, but the bookend did not, because the statue was thin.',
 'Falling from the fence, john was glad he landed on the table first before getting to the patio. The fence is tall.']

In [None]:
masked_datasets['test']['sentence'][:5]

['The assertive commander told the privates to change their option1 but not their option2 because the option2 were fine.',
 'The woman avoided the option1 but easily stepped over the option2, because the option1 was very shallow.',
 "option1's hair is being worked on by option2, so it's more likely option1 is the customer.",
 'option1 took a longer time to take a bath than option2 because option2 liked relaxing in the tub.',
 'option1 really liked working in Human Resources and option2 wanted to work at the same company, and option1 subsequently offered a position.']

This upcoming section looks at how the pretrained model performs on winogrande standard vs. the masking technique 'option1' and 'option2'. 

This model is a Roberta-Large model that was trained on winogrande_xl. On the winogrande_m dataset:
the standard dataset receives 85% accuracy, and
the masked dataset receives 80% accuracy. 

When we use the winogrande_debiased dataset, results fall to %69 and 68% respectively. Because these are so bad, let's retune the model on the debiased_train set. 

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from pprint import pprint
from tqdm import tqdm

tokenizer = AutoTokenizer.from_pretrained("DeepPavlov/roberta-large-winogrande", model_max_length=64)

def test_datasets(tokenizer, std_datasets, masked_datasets, model=False,):

  print("Testing 1 2 3 ...")
  delete = False
  if not model:
    model = AutoModelForSequenceClassification.from_pretrained("DeepPavlov/roberta-large-winogrande")
    delete = True

  elif isinstance(model, str):
    delete = True
    with torch.no_grad():
      torch.cuda.empty_cache()

    model = torch.load(model) # open(model, "rb"))

    with torch.no_grad():
      torch.cuda.empty_cache()
  try:
    for ds in (std_datasets, masked_datasets):
      print("Performing", ds["name"])
      
      #combinedtraintest = concatenate_datasets([ds['train'], ds['test']])
      #encoded_train = combinedtraintest.map(lambda examples: tokenizer(examples['sentence'], padding='max_length'), batched=True) # , return_tensors='pt'
      encoded_test = ds['test'].map(lambda examples: tokenizer(examples['sentence'], padding='max_length'), batched=True)

      #encoded_train.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
      encoded_test.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
      #dataloader_train = torch.utils.data.DataLoader(encoded_train, batch_size=32)
      dataloader_test = torch.utils.data.DataLoader(encoded_test, batch_size=32)

      device = 'cuda' if torch.cuda.is_available() else 'cpu' 
      #model.train().to(device)
      model.to(device)
      #optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-5)

      correct = 0
      total = 0
      for i, batch in enumerate(tqdm(dataloader_test)):
        batch = {k: v.to(device) for k, v in batch.items()}

        outputs = model(**batch)
        #for oo, lab in zip(outputs.logits, batch['labels'] ):
        #  print(oo.argmax().item(), lab.items())
        correct += sum((int(x.argmax().item() == y.item()) for x, y in zip(outputs.logits, batch['labels'])))
        total += len(batch['labels'])
        
        
      print("Score", correct / total, '/ 1.00')

    if delete:
      del model
      print("Deleted model")

  except Exception as e:
    del model
    print(e, e.__str__)
    print("Deleted model")


# test_datasets(tokenizer, std_datasets, masked_datasets)

Downloading:   0%|          | 0.00/350 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/780k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/239 [00:00<?, ?B/s]

We'll train a model on the unmasked debiased data train set for one epoch. 

On the first epoch, the testdataset achieves 75% accuracy and the masked dataset 70% accuracy. 

Further epochs see no change (tested with 3)

In [None]:

def train_model(trainingset, compareset, tokenizer, model=False, copy=False):
  if not model:
    model = AutoModelForSequenceClassification.from_pretrained("DeepPavlov/roberta-large-winogrande")
  elif copy: # ie, if copy and model
    print('copy')
    with torch.no_grad():
      torch.cuda.empty_cache()
    model = torch.load(model) # open(model, "rb"))
    with torch.no_grad():
      torch.cuda.empty_cache()
    #model_copy = type(model)() # get a new instance
    #model_copy.load_state_dict(model.state_dict()) # copy weights and stuff
    #model = model_copy
    print('finished?')
  try:
    encoded_train = trainingset['train'].map(lambda examples: tokenizer(examples['sentence'], padding='max_length'), batched=True) # , return_tensors='pt'
    encoded_train.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
    dataloader_train = torch.utils.data.DataLoader(encoded_train, batch_size=32)

    device = 'cuda' if torch.cuda.is_available() else 'cpu' 
    model.train().to(device)
    optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-5)

    for epoch in range(1):
      print("Epoch", epoch)
      correct = 0
      total = 0
      for i, batch in enumerate(tqdm(dataloader_train)):
        batch = {k: v.to(device) for k, v in batch.items()}

        outputs = model(**batch)
        loss = outputs[0]
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if i % 10 == 0:
          print(f" loss: {loss}")
        #for oo, lab in zip(outputs.logits, batch['labels'] ):
        #  print(oo.argmax().item(), lab.items())
        correct += sum((int(x.argmax().item() == y.item()) for x, y in zip(outputs.logits, batch['labels'])))
        total += len(batch['labels'])
        
      print("Score", correct / total, '/ 1.00')
      test_datasets(tokenizer, trainingset, compareset, model)
      model.train().to(device)
  except Exception as e:
    del model
    print(e, e.__str__)
    print("deleted model")
  return model




In [None]:

# This model will become the base model for everything we build upon, 
# so we will save it to disk to load it fresh for each experiment. 
debiased_tuned_model = train_model(std_datasets, masked_datasets, tokenizer)
torch.save(debiased_tuned_model, f="debiased_model")
debiased_tuned_model = None

Downloading:   0%|          | 0.00/820 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.32G [00:00<?, ?B/s]



  0%|          | 0/19 [00:00<?, ?ba/s]

Epoch 0


  0%|          | 1/578 [00:00<07:57,  1.21it/s]

 loss: 0.29139959812164307


  2%|▏         | 11/578 [00:08<06:48,  1.39it/s]

 loss: 0.08296118676662445


  4%|▎         | 21/578 [00:15<06:40,  1.39it/s]

 loss: 0.008703015744686127


  5%|▌         | 31/578 [00:22<06:34,  1.39it/s]

 loss: 0.013629798777401447


  7%|▋         | 41/578 [00:29<06:26,  1.39it/s]

 loss: 0.0178392194211483


  9%|▉         | 51/578 [00:36<06:19,  1.39it/s]

 loss: 0.006535938475281


 11%|█         | 61/578 [00:43<06:11,  1.39it/s]

 loss: 0.07766715437173843


 12%|█▏        | 71/578 [00:51<06:04,  1.39it/s]

 loss: 0.04506212845444679


 14%|█▍        | 81/578 [00:58<05:57,  1.39it/s]

 loss: 0.005039051175117493


 16%|█▌        | 91/578 [01:05<05:50,  1.39it/s]

 loss: 0.004431945271790028


 17%|█▋        | 101/578 [01:12<05:42,  1.39it/s]

 loss: 0.024901051074266434


 19%|█▉        | 111/578 [01:19<05:35,  1.39it/s]

 loss: 0.1691167950630188


 21%|██        | 121/578 [01:27<05:28,  1.39it/s]

 loss: 0.06299073249101639


 23%|██▎       | 131/578 [01:34<05:21,  1.39it/s]

 loss: 0.02790091373026371


 24%|██▍       | 141/578 [01:41<05:13,  1.39it/s]

 loss: 0.00175842409953475


 26%|██▌       | 151/578 [01:48<05:06,  1.39it/s]

 loss: 0.1267576366662979


 28%|██▊       | 161/578 [01:55<04:59,  1.39it/s]

 loss: 0.02139498107135296


 30%|██▉       | 171/578 [02:03<04:53,  1.39it/s]

 loss: 0.007912123575806618


 31%|███▏      | 181/578 [02:10<04:45,  1.39it/s]

 loss: 0.02701547183096409


 33%|███▎      | 191/578 [02:17<04:38,  1.39it/s]

 loss: 0.08884286135435104


 35%|███▍      | 201/578 [02:24<04:30,  1.39it/s]

 loss: 0.05994598567485809


 37%|███▋      | 211/578 [02:31<04:23,  1.39it/s]

 loss: 0.035775378346443176


 38%|███▊      | 221/578 [02:38<04:16,  1.39it/s]

 loss: 0.004824390634894371


 40%|███▉      | 231/578 [02:46<04:09,  1.39it/s]

 loss: 0.12612299621105194


 42%|████▏     | 241/578 [02:53<04:02,  1.39it/s]

 loss: 0.0684988796710968


 43%|████▎     | 251/578 [03:00<03:55,  1.39it/s]

 loss: 0.10746771842241287


 45%|████▌     | 261/578 [03:07<03:47,  1.39it/s]

 loss: 0.03060903772711754


 47%|████▋     | 271/578 [03:14<03:40,  1.39it/s]

 loss: 0.171343132853508


 49%|████▊     | 281/578 [03:22<03:33,  1.39it/s]

 loss: 0.0034223839174956083


 50%|█████     | 291/578 [03:29<03:26,  1.39it/s]

 loss: 0.0017262304900214076


 52%|█████▏    | 301/578 [03:36<03:18,  1.39it/s]

 loss: 0.03505024313926697


 54%|█████▍    | 311/578 [03:43<03:11,  1.39it/s]

 loss: 0.0018564489437267184


 56%|█████▌    | 321/578 [03:50<03:04,  1.39it/s]

 loss: 0.0009407121106050909


 57%|█████▋    | 331/578 [03:57<02:57,  1.39it/s]

 loss: 0.01996447518467903


 59%|█████▉    | 341/578 [04:05<02:50,  1.39it/s]

 loss: 0.002355244942009449


 61%|██████    | 351/578 [04:12<02:43,  1.39it/s]

 loss: 0.009959744289517403


 62%|██████▏   | 361/578 [04:19<02:36,  1.39it/s]

 loss: 0.006714398041367531


 64%|██████▍   | 371/578 [04:26<02:29,  1.39it/s]

 loss: 0.0033852476626634598


 66%|██████▌   | 381/578 [04:33<02:21,  1.39it/s]

 loss: 0.1478278487920761


 68%|██████▊   | 391/578 [04:41<02:14,  1.39it/s]

 loss: 0.08431506901979446


 69%|██████▉   | 401/578 [04:48<02:07,  1.39it/s]

 loss: 0.007005142048001289


 71%|███████   | 411/578 [04:55<02:00,  1.39it/s]

 loss: 0.0279370229691267


 73%|███████▎  | 421/578 [05:02<01:52,  1.39it/s]

 loss: 0.0471220463514328


 75%|███████▍  | 431/578 [05:09<01:45,  1.39it/s]

 loss: 0.0076605514623224735


 76%|███████▋  | 441/578 [05:16<01:38,  1.39it/s]

 loss: 0.04174063354730606


 78%|███████▊  | 451/578 [05:24<01:31,  1.39it/s]

 loss: 0.022161195054650307


 80%|███████▉  | 461/578 [05:31<01:24,  1.39it/s]

 loss: 0.009482357650995255


 81%|████████▏ | 471/578 [05:38<01:16,  1.39it/s]

 loss: 0.004016342107206583


 83%|████████▎ | 481/578 [05:45<01:09,  1.39it/s]

 loss: 0.022950591519474983


 85%|████████▍ | 491/578 [05:52<01:02,  1.39it/s]

 loss: 0.007391788996756077


 87%|████████▋ | 501/578 [06:00<00:55,  1.39it/s]

 loss: 0.00441850395873189


 88%|████████▊ | 511/578 [06:07<00:48,  1.39it/s]

 loss: 0.14649666845798492


 90%|█████████ | 521/578 [06:14<00:40,  1.39it/s]

 loss: 0.016604939475655556


 92%|█████████▏| 531/578 [06:21<00:33,  1.39it/s]

 loss: 0.0022767442278563976


 94%|█████████▎| 541/578 [06:28<00:26,  1.39it/s]

 loss: 0.006796069443225861


 95%|█████████▌| 551/578 [06:36<00:19,  1.39it/s]

 loss: 0.004663259722292423


 97%|█████████▋| 561/578 [06:43<00:12,  1.39it/s]

 loss: 0.026216842234134674


 99%|█████████▉| 571/578 [06:50<00:05,  1.39it/s]

 loss: 0.009362638927996159


100%|██████████| 578/578 [06:55<00:00,  1.39it/s]

Score 0.9832396193771626 / 1.00
Testing 1 2 3 ...
Performing Standard Dataset





  0%|          | 0/3 [00:00<?, ?ba/s]

100%|██████████| 80/80 [00:19<00:00,  4.16it/s]

Score 0.7683504340962904 / 1.00
Performing Masked Dataset





  0%|          | 0/3 [00:00<?, ?ba/s]

100%|██████████| 80/80 [00:19<00:00,  4.15it/s]


Score 0.7115232833464877 / 1.00


Does training it on the masked testset change anything?

In [None]:
#Clear the GPU RAM...

with torch.no_grad():
    torch.cuda.empty_cache()

In [None]:
#masked_tuned_model = train_model(masked_datasets, std_datasets, tokenizer)


A pre-tuned model on Winogrande_xl that's tuned on the masked-debiased training set gets slightly better accuracy on the masked test set. (73.2% accuracy for standard, 73.8% for masked). 

The same holds for if we use the winogrande_l dataset instead of debiased. On Winogrande_l standard we get 72% accuracy and 71.5% on the masked. 

What happens if we tune the masked set on top of the tuned standard debiased model?

In [None]:
#with torch.no_grad():
#    torch.cuda.empty_cache()

In [None]:
#debiased_masked_tuned_model = train_model(masked_datasets, std_datasets, tokenizer, model = debiased_tuned_model)


In this case, the standard debiased dataset achieves 73.6% accuracy and the masked dataset 73.4% accuracy. Practically the same. 

Let's try a different masking style. What happens if we cover up one of the Options with an [unknown] mask? Can the model deal with and identify objecthood of a masked object?

Take the following example sentence:

{
  "sentence": "The [plant] took up too much room in the [urn], because the [plant] was small.",
  "label": false
}

There are two ways to test the model. If we mask the plant, then there are two references to the same masked object, and the True/False question directly references that object. If we mask the urn, then the contextual object is hidden, but is never the subject of inquery. Let's test covering up each at a time. The first test is more interesting, but might as well see if we find anything.

In [None]:
import re

def mask_copy_1(dataset, unk):
  sentences = []
  toprint = 5

  for p in dataset:
    # This is annoying because option1 can be substrings of other words, usually option2. They can also have uppercase letters. 
    # If one is a substring, then I will cover up the larger word with a temporary mask to not confuse anything else.
    sentence = p['sentence']
    option1, option2 = re.compile(p['option1'], re.IGNORECASE), re.compile(p['option2'], re.IGNORECASE)
    maskedoption1, maskedoption2 = "OPTION_ONE", "OPTION_TWO"
    first_search, second_search, first_mask, second_mask = None, None, None, None

    # "table" in "tablecloth" --> cover bigger one
    if len(p['option1']) > len(p['option2']):
      # cover option1 first
      first_search, second_search = option1, option2
      first_mask, second_mask = maskedoption1, maskedoption2
    else:
      first_search, second_search = option2, option1
      first_mask, second_mask = maskedoption2, maskedoption1
    
    sentence = first_search.sub(first_mask, sentence) #Mask the longer word with OPTION_ mask
    sentence = second_search.sub(second_mask, sentence) #then the shorter one

    # IT gets kinda confusing which word now to <IGNORE> because it's language and some words appear more than two, more than three, times. 
    # I think the smartest approach is to assume the final word used is the question word and mask that one. 

    # Find final word used
    # maskedoption1 is the final word
    if sentence.rfind(maskedoption1) > sentence.rfind(maskedoption2):
      # IGNORE final word
      # Convert other word back to original
      sentence = sentence.replace(maskedoption1, unk)
      sentence = sentence.replace(maskedoption2, p['option2'])
    else:
      sentence = sentence.replace(maskedoption2, unk)
      sentence = sentence.replace(maskedoption1, p['option2'])


    if toprint > 0:
      print(f"[{p['option1']}], [{p['option2']}], [{p['sentence']}]")

    sentences.append(sentence)
    
    if toprint > 0:
      print(sentences[-1])
      toprint -= 1

  build = {'sentence':sentences, 'labels':dataset['labels'], 'option1':dataset['option1'], 'option2':dataset['option2']}
  return Dataset.from_dict(build) # DON'T SHUFFLE





def mask_datasets_1(dataset, tokenizer):
  unk = tokenizer.unk_token
  return {"train":mask_copy_1(dataset['train'], unk), "test":mask_copy_1(dataset['test'], unk), "name":"Double <Unk> Masked Dataset"}

In [None]:
#unk_datasets = mask_datasets_1(std_datasets, tokenizer)



In [None]:
#test_datasets(tokenizer, std_datasets, unk_datasets, debiased_tuned_model)

The [unk] tokens reduce performance from 75% accuracy to 65% accuracy. Let's tune the model on the [unk] masked dataset and see if it redeems itself. 

In [None]:
#unk_masked_tuned_model = train_model(unk_datasets, std_datasets, tokenizer)


On training on the debiased [unk] masked set, we perform with 73% accuracy. On the standard set we perform 69% accuracy. Strange the standard dataset decreased. That may mean there is some bias in the dataset I created where it learns to do tricks. 

Let's try masking the single word. 

In [None]:
import re


# This masked <unk> tokens on the option which is NOT involved in the question. 
# If the option involved in the question is wrong, it will need to identify the <unk> token as having importance.
def mask_copy_2(dataset, unk):
  sentences = []
  toprint = 5

  for p in dataset:
    # This is annoying because option1 can be substrings of other words, usually option2. They can also have uppercase letters. 
    # If one is a substring, then I will cover up the larger word with a temporary mask to not confuse anything else.
    sentence = p['sentence']
    option1, option2 = re.compile(p['option1'], re.IGNORECASE), re.compile(p['option2'], re.IGNORECASE)
    maskedoption1, maskedoption2 = "OPTION_ONE", "OPTION_TWO"
    first_search, second_search, first_mask, second_mask = None, None, None, None

    # "table" in "tablecloth" --> cover bigger one
    if len(p['option1']) > len(p['option2']):
      # cover option1 first
      first_search, second_search = option1, option2
      first_mask, second_mask = maskedoption1, maskedoption2
    else:
      first_search, second_search = option2, option1
      first_mask, second_mask = maskedoption2, maskedoption1
    
    sentence = first_search.sub(first_mask, sentence) #Mask the longer word with OPTION_ mask
    sentence = second_search.sub(second_mask, sentence) #then the shorter one

    # IT gets kinda confusing which word now to <IGNORE> because it's language and some words appear more than two, more than three, times. 
    # I think the smartest approach is to assume the final word used is the question word and mask that one. 

    # Find final word used
    # maskedoption1 is NOT the final word
    if sentence.rfind(maskedoption1) < sentence.rfind(maskedoption2):
      # IGNORE final word
      # Convert other word back to original
      sentence = sentence.replace(maskedoption1, unk)
      sentence = sentence.replace(maskedoption2, p['option2'])
    else:
      sentence = sentence.replace(maskedoption2, unk)
      sentence = sentence.replace(maskedoption1, p['option2'])


    if toprint > 0:
      print(f"[{p['option1']}], [{p['option2']}], [{p['sentence']}]")

    sentences.append(sentence)
    
    if toprint > 0:
      print(sentences[-1])
      toprint -= 1

  build = {'sentence':sentences, 'labels':dataset['labels'], 'option1':dataset['option1'], 'option2':dataset['option2']}
  return Dataset.from_dict(build) # DON'T SHUFFLE





def mask_datasets_2(dataset, tokenizer):
  unk = tokenizer.unk_token
  return {"train":mask_copy_2(dataset['train'], unk), "test":mask_copy_2(dataset['test'], unk), "name":"Single <Unk> Masked Dataset"}

In [None]:
unk_dataset_2 = mask_datasets_2(std_datasets, tokenizer)

#test_datasets(tokenizer, std_datasets, unk_dataset_2, debiased_tuned_model)

[Jason], [Kevin], [Jason doesn't know how the world works as well as Kevin because Jason had a really tough childhood.]
Kevin doesn't know how the world works as well as <unk> because Kevin had a really tough childhood.
[bottle], [lotion], [Jerry wanted to get lotion out of the bottle, but couldn't because the lotion was gone.]
Jerry wanted to get lotion out of the <unk>, but couldn't because the lotion was gone.
[Samantha], [Betty], [Samantha is having to replace their vacuum bags frequently but not Betty as Betty likes to clean the house.]
<unk> is having to replace their vacuum bags frequently but not Betty as Betty likes to clean the house.
[statue], [bookend], [The statue fit in the middle of the shelf, but the bookend did not, because the statue was thin.]
The bookend fit in the middle of the shelf, but the <unk> did not, because the bookend was thin.
[fence], [table], [Falling from the fence, john was glad he landed on the table first before getting to the patio. The fence is ta

The standard set achieves 75% accuracy and the single \<unk\> 68%. 

Let's train the model on the single \<unk\>. It achieves 75% on standard and 73% on \<unk\>, which is 5 percentage points higher


In [None]:
with torch.no_grad():
  torch.cuda.empty_cache()

In [None]:

#train_model(unk_dataset_2, std_datasets, tokenizer, model="debiased_model", copy = True)

What happens if we try a bunch of stupid masks? Can we make masks so linguistically non-sensical that the model performance drops?

We can try such masks as 



In [None]:
stupid_tries = [
  ("dog", "doggy"),
  ("red", "blue"),
  ("flavor", "flavour"),
  ('A', 'B'),
  ('X', 'Y'),
  ('1', '2'),
  ('first', 'second'),
  ('alpha', 'beta'),
  ('#', '@'),
  ('primero', 'secundo'), # yes its true, I dont speak spanish. Having it spelled wrong just makes the model's task that much more difficult ;)
  ('Alice', 'Bob'),
  ('_', '__')
]


In [None]:
import re


# This masked <unk> tokens on the option which is NOT involved in the question. 
# If the option involved in the question is wrong, it will need to identify the <unk> token as having importance.
def stupid_masking(dataset, mask1, mask2):
  sentences = []
  toprint = 5

  for p in dataset:
    # This is annoying because option1 can be substrings of other words, usually option2. They can also have uppercase letters. 
    # If one is a substring, then I will cover up the larger word with a temporary mask to not confuse anything else.
    sentence = p['sentence']
    option1, option2 = re.compile(p['option1'], re.IGNORECASE), re.compile(p['option2'], re.IGNORECASE)
    maskedoption1, maskedoption2 = "OPTION_ONE", "OPTION_TWO"
    first_search, second_search, first_mask, second_mask = None, None, None, None

    # "table" in "tablecloth" --> cover bigger one
    if len(p['option1']) > len(p['option2']):
      # cover option1 first
      first_search, second_search = option1, option2
      first_mask, second_mask = maskedoption1, maskedoption2
    else:
      first_search, second_search = option2, option1
      first_mask, second_mask = maskedoption2, maskedoption1
    
    sentence = first_search.sub(first_mask, sentence) #Mask the longer word with OPTION_ mask
    sentence = second_search.sub(second_mask, sentence) #then the shorter one

    sentence = sentence.replace(maskedoption1, mask1)
    sentence = sentence.replace(maskedoption2, mask2)


    if toprint > 0:
      print(f"[{p['option1']}], [{p['option2']}], [{p['sentence']}]")

    sentences.append(sentence)
    
    if toprint > 0:
      print(sentences[-1])
      toprint -= 1

  build = {'sentence':sentences, 'labels':dataset['labels'], 'option1':dataset['option1'], 'option2':dataset['option2']}
  return Dataset.from_dict(build) # DON'T SHUFFLE





def stupid_datasets(dataset, tokenizer, masks):
  mask1, mask2 = masks
  return {"train":stupid_masking(dataset['train'], mask1, mask2), "test":stupid_masking(dataset['test'], mask1, mask2), "name":f"({mask1}, {mask2}) Masked Dataset"}

In [27]:


for masks in stupid_tries:
  stupid_ds = stupid_datasets(std_datasets, tokenizer, masks)
  test_datasets(tokenizer, std_datasets, stupid_ds, model="debiased_model")
  train_model(stupid_ds, std_datasets, tokenizer, model="debiased_model", copy = True)


[Jason], [Kevin], [Jason doesn't know how the world works as well as Kevin because Jason had a really tough childhood.]
flavor doesn't know how the world works as well as flavour because flavor had a really tough childhood.
[bottle], [lotion], [Jerry wanted to get lotion out of the bottle, but couldn't because the lotion was gone.]
Jerry wanted to get flavour out of the flavor, but couldn't because the flavour was gone.
[Samantha], [Betty], [Samantha is having to replace their vacuum bags frequently but not Betty as Betty likes to clean the house.]
flavor is having to replace their vacuum bags frequently but not flavour as flavour likes to clean the house.
[statue], [bookend], [The statue fit in the middle of the shelf, but the bookend did not, because the statue was thin.]
The flavor fit in the middle of the shelf, but the flavour did not, because the flavor was thin.
[fence], [table], [Falling from the fence, john was glad he landed on the table first before getting to the patio. The

  0%|          | 0/3 [00:00<?, ?ba/s]

100%|██████████| 80/80 [00:19<00:00,  4.17it/s]

Score 0.7573007103393844 / 1.00
Performing (flavor, flavour) Masked Dataset





  0%|          | 0/3 [00:00<?, ?ba/s]

100%|██████████| 80/80 [00:19<00:00,  4.17it/s]


Score 0.531965272296764 / 1.00
Deleted model
copy
finished?


  0%|          | 0/19 [00:00<?, ?ba/s]

Epoch 0


  0%|          | 1/578 [00:00<06:56,  1.38it/s]

 loss: 1.345733880996704


  2%|▏         | 11/578 [00:07<06:47,  1.39it/s]

 loss: 0.4001125693321228


  4%|▎         | 21/578 [00:15<06:41,  1.39it/s]

 loss: 0.24483714997768402


  5%|▌         | 31/578 [00:22<06:33,  1.39it/s]

 loss: 0.19926071166992188


  7%|▋         | 41/578 [00:29<06:26,  1.39it/s]

 loss: 0.22253455221652985


  9%|▉         | 51/578 [00:36<06:19,  1.39it/s]

 loss: 0.1943635791540146


 11%|█         | 61/578 [00:43<06:12,  1.39it/s]

 loss: 0.35376644134521484


 12%|█▏        | 71/578 [00:51<06:04,  1.39it/s]

 loss: 0.0846937745809555


 14%|█▍        | 81/578 [00:58<05:58,  1.39it/s]

 loss: 0.1917884349822998


 16%|█▌        | 91/578 [01:05<05:51,  1.38it/s]

 loss: 0.22205859422683716


 17%|█▋        | 101/578 [01:12<05:44,  1.39it/s]

 loss: 0.03522295504808426


 19%|█▉        | 111/578 [01:19<05:37,  1.38it/s]

 loss: 0.11057320982217789


 21%|██        | 121/578 [01:27<05:31,  1.38it/s]

 loss: 0.13028593361377716


 23%|██▎       | 131/578 [01:34<05:23,  1.38it/s]

 loss: 0.25811490416526794


 24%|██▍       | 141/578 [01:41<05:15,  1.38it/s]

 loss: 0.25290849804878235


 26%|██▌       | 151/578 [01:48<05:08,  1.38it/s]

 loss: 0.0203444454818964


 28%|██▊       | 161/578 [01:56<05:00,  1.39it/s]

 loss: 0.08137456327676773


 30%|██▉       | 171/578 [02:03<04:53,  1.39it/s]

 loss: 0.06311468034982681


 31%|███▏      | 181/578 [02:10<04:46,  1.39it/s]

 loss: 0.07094135135412216


 33%|███▎      | 191/578 [02:17<04:39,  1.39it/s]

 loss: 0.12834009528160095


 35%|███▍      | 201/578 [02:24<04:31,  1.39it/s]

 loss: 0.16305853426456451


 37%|███▋      | 211/578 [02:32<04:24,  1.39it/s]

 loss: 0.1801011860370636


 38%|███▊      | 221/578 [02:39<04:17,  1.39it/s]

 loss: 0.046832501888275146


 40%|███▉      | 231/578 [02:46<04:10,  1.39it/s]

 loss: 0.2268638014793396


 42%|████▏     | 241/578 [02:53<04:03,  1.39it/s]

 loss: 0.11120323836803436


 43%|████▎     | 251/578 [03:00<03:55,  1.39it/s]

 loss: 0.1325816959142685


 45%|████▌     | 261/578 [03:08<03:48,  1.39it/s]

 loss: 0.0374494269490242


 47%|████▋     | 271/578 [03:15<03:41,  1.39it/s]

 loss: 0.2673308551311493


 49%|████▊     | 281/578 [03:22<03:33,  1.39it/s]

 loss: 0.14890873432159424


 50%|█████     | 291/578 [03:29<03:26,  1.39it/s]

 loss: 0.08291994780302048


 52%|█████▏    | 301/578 [03:37<03:19,  1.39it/s]

 loss: 0.13823743164539337


 54%|█████▍    | 311/578 [03:44<03:12,  1.39it/s]

 loss: 0.16527612507343292


 56%|█████▌    | 321/578 [03:51<03:05,  1.39it/s]

 loss: 0.1122722253203392


 57%|█████▋    | 331/578 [03:58<02:57,  1.39it/s]

 loss: 0.0435718297958374


 59%|█████▉    | 341/578 [04:05<02:50,  1.39it/s]

 loss: 0.06261689215898514


 61%|██████    | 351/578 [04:13<02:43,  1.39it/s]

 loss: 0.10458498448133469


 62%|██████▏   | 361/578 [04:20<02:36,  1.39it/s]

 loss: 0.07416164875030518


 64%|██████▍   | 371/578 [04:27<02:29,  1.39it/s]

 loss: 0.2852379381656647


 66%|██████▌   | 381/578 [04:34<02:21,  1.39it/s]

 loss: 0.0625918060541153


 68%|██████▊   | 391/578 [04:41<02:14,  1.39it/s]

 loss: 0.11856868118047714


 69%|██████▉   | 401/578 [04:49<02:07,  1.39it/s]

 loss: 0.0864037573337555


 71%|███████   | 411/578 [04:56<02:00,  1.39it/s]

 loss: 0.16286808252334595


 73%|███████▎  | 421/578 [05:03<01:53,  1.39it/s]

 loss: 0.2470242828130722


 75%|███████▍  | 431/578 [05:10<01:46,  1.39it/s]

 loss: 0.09463903307914734


 76%|███████▋  | 441/578 [05:17<01:38,  1.39it/s]

 loss: 0.03594497963786125


 78%|███████▊  | 451/578 [05:25<01:31,  1.39it/s]

 loss: 0.09994396567344666


 80%|███████▉  | 461/578 [05:32<01:24,  1.39it/s]

 loss: 0.1916155219078064


 81%|████████▏ | 471/578 [05:39<01:17,  1.39it/s]

 loss: 0.3395901918411255


 83%|████████▎ | 481/578 [05:46<01:10,  1.39it/s]

 loss: 0.19880667328834534


 85%|████████▍ | 491/578 [05:53<01:02,  1.39it/s]

 loss: 0.09235672652721405


 87%|████████▋ | 501/578 [06:01<00:55,  1.39it/s]

 loss: 0.1370481699705124


 88%|████████▊ | 511/578 [06:08<00:48,  1.39it/s]

 loss: 0.09501388669013977


 90%|█████████ | 521/578 [06:15<00:41,  1.39it/s]

 loss: 0.08206921815872192


 92%|█████████▏| 531/578 [06:22<00:33,  1.38it/s]

 loss: 0.07482817023992538


 94%|█████████▎| 541/578 [06:29<00:26,  1.39it/s]

 loss: 0.13372690975666046


 95%|█████████▌| 551/578 [06:37<00:19,  1.39it/s]

 loss: 0.07566531747579575


 97%|█████████▋| 561/578 [06:44<00:12,  1.39it/s]

 loss: 0.15743406116962433


 99%|█████████▉| 571/578 [06:51<00:05,  1.39it/s]

 loss: 0.009880097582936287


100%|██████████| 578/578 [06:56<00:00,  1.39it/s]

Score 0.9401492214532872 / 1.00
Testing 1 2 3 ...
Performing (flavor, flavour) Masked Dataset





  0%|          | 0/3 [00:00<?, ?ba/s]

100%|██████████| 80/80 [00:19<00:00,  4.15it/s]

Score 0.7352012628255722 / 1.00
Performing Standard Dataset





  0%|          | 0/3 [00:00<?, ?ba/s]

100%|██████████| 80/80 [00:19<00:00,  4.15it/s]


Score 0.7588792423046566 / 1.00
[Jason], [Kevin], [Jason doesn't know how the world works as well as Kevin because Jason had a really tough childhood.]
alpha doesn't know how the world works as well as beta because alpha had a really tough childhood.
[bottle], [lotion], [Jerry wanted to get lotion out of the bottle, but couldn't because the lotion was gone.]
Jerry wanted to get beta out of the alpha, but couldn't because the beta was gone.
[Samantha], [Betty], [Samantha is having to replace their vacuum bags frequently but not Betty as Betty likes to clean the house.]
alpha is having to replace their vacuum bags frequently but not beta as beta likes to clean the house.
[statue], [bookend], [The statue fit in the middle of the shelf, but the bookend did not, because the statue was thin.]
The alpha fit in the middle of the shelf, but the beta did not, because the alpha was thin.
[fence], [table], [Falling from the fence, john was glad he landed on the table first before getting to the pa

  0%|          | 0/3 [00:00<?, ?ba/s]

100%|██████████| 80/80 [00:19<00:00,  4.15it/s]

Score 0.7640094711917916 / 1.00
Performing (alpha, beta) Masked Dataset





  0%|          | 0/3 [00:00<?, ?ba/s]

100%|██████████| 80/80 [00:19<00:00,  4.14it/s]


Score 0.7103393843725335 / 1.00
Deleted model
copy
finished?


  0%|          | 0/19 [00:00<?, ?ba/s]

Epoch 0


  0%|          | 1/578 [00:00<07:10,  1.34it/s]

 loss: 0.19756284356117249


  2%|▏         | 11/578 [00:07<06:48,  1.39it/s]

 loss: 0.21591776609420776


  4%|▎         | 21/578 [00:15<06:41,  1.39it/s]

 loss: 0.19006961584091187


  5%|▌         | 31/578 [00:22<06:34,  1.38it/s]

 loss: 0.23236529529094696


  7%|▋         | 41/578 [00:29<06:27,  1.39it/s]

 loss: 0.08137702196836472


  9%|▉         | 51/578 [00:36<06:20,  1.38it/s]

 loss: 0.2703522741794586


 11%|█         | 61/578 [00:44<06:12,  1.39it/s]

 loss: 0.3528135418891907


 12%|█▏        | 71/578 [00:51<06:05,  1.39it/s]

 loss: 0.10179702192544937


 14%|█▍        | 81/578 [00:58<05:58,  1.39it/s]

 loss: 0.18097107112407684


 16%|█▌        | 91/578 [01:05<05:51,  1.39it/s]

 loss: 0.23234793543815613


 17%|█▋        | 101/578 [01:12<05:44,  1.38it/s]

 loss: 0.044028013944625854


 19%|█▉        | 111/578 [01:20<05:36,  1.39it/s]

 loss: 0.044364579021930695


 21%|██        | 121/578 [01:27<05:30,  1.38it/s]

 loss: 0.15127670764923096


 23%|██▎       | 131/578 [01:34<05:22,  1.38it/s]

 loss: 0.13794398307800293


 24%|██▍       | 141/578 [01:41<05:15,  1.39it/s]

 loss: 0.3988320827484131


 26%|██▌       | 151/578 [01:48<05:08,  1.38it/s]

 loss: 0.013827529735863209


 28%|██▊       | 161/578 [01:56<05:01,  1.38it/s]

 loss: 0.11645539104938507


 30%|██▉       | 171/578 [02:03<04:53,  1.39it/s]

 loss: 0.034267209470272064


 31%|███▏      | 181/578 [02:10<04:46,  1.39it/s]

 loss: 0.011797880753874779


 33%|███▎      | 191/578 [02:17<04:39,  1.39it/s]

 loss: 0.13784724473953247


 35%|███▍      | 201/578 [02:25<04:32,  1.38it/s]

 loss: 0.07678119838237762


 37%|███▋      | 211/578 [02:32<04:25,  1.38it/s]

 loss: 0.11612702906131744


 38%|███▊      | 221/578 [02:39<04:18,  1.38it/s]

 loss: 0.1439022421836853


 40%|███▉      | 231/578 [02:46<04:10,  1.38it/s]

 loss: 0.09614435583353043


 42%|████▏     | 241/578 [02:53<04:03,  1.38it/s]

 loss: 0.27926015853881836


 43%|████▎     | 251/578 [03:01<03:56,  1.38it/s]

 loss: 0.14457330107688904


 45%|████▌     | 261/578 [03:08<03:48,  1.39it/s]

 loss: 0.04522613808512688


 47%|████▋     | 271/578 [03:15<03:41,  1.39it/s]

 loss: 0.24058739840984344


 49%|████▊     | 281/578 [03:22<03:34,  1.38it/s]

 loss: 0.13544240593910217


 50%|█████     | 291/578 [03:30<03:27,  1.38it/s]

 loss: 0.19288113713264465


 52%|█████▏    | 301/578 [03:37<03:19,  1.39it/s]

 loss: 0.05107690393924713


 54%|█████▍    | 311/578 [03:44<03:12,  1.39it/s]

 loss: 0.06536681950092316


 56%|█████▌    | 321/578 [03:51<03:05,  1.39it/s]

 loss: 0.06001879647374153


 57%|█████▋    | 331/578 [03:58<02:58,  1.38it/s]

 loss: 0.07548177987337112


 59%|█████▉    | 341/578 [04:06<02:51,  1.39it/s]

 loss: 0.07939007878303528


 61%|██████    | 351/578 [04:13<02:43,  1.38it/s]

 loss: 0.08531317859888077


 62%|██████▏   | 361/578 [04:20<02:36,  1.39it/s]

 loss: 0.10532934963703156


 64%|██████▍   | 371/578 [04:27<02:29,  1.39it/s]

 loss: 0.26850077509880066


 66%|██████▌   | 381/578 [04:34<02:22,  1.39it/s]

 loss: 0.027522368356585503


 68%|██████▊   | 391/578 [04:42<02:14,  1.39it/s]

 loss: 0.26721706986427307


 69%|██████▉   | 401/578 [04:49<02:07,  1.38it/s]

 loss: 0.05160684883594513


 71%|███████   | 411/578 [04:56<02:00,  1.38it/s]

 loss: 0.2947041988372803


 73%|███████▎  | 421/578 [05:03<01:53,  1.38it/s]

 loss: 0.17917269468307495


 75%|███████▍  | 431/578 [05:11<01:46,  1.38it/s]

 loss: 0.225106880068779


 76%|███████▋  | 441/578 [05:18<01:38,  1.38it/s]

 loss: 0.08025769889354706


 78%|███████▊  | 451/578 [05:25<01:31,  1.39it/s]

 loss: 0.140935480594635


 80%|███████▉  | 461/578 [05:32<01:24,  1.38it/s]

 loss: 0.1708616465330124


 81%|████████▏ | 471/578 [05:39<01:17,  1.38it/s]

 loss: 0.10337116569280624


 83%|████████▎ | 481/578 [05:47<01:09,  1.39it/s]

 loss: 0.2004263699054718


 85%|████████▍ | 491/578 [05:54<01:02,  1.39it/s]

 loss: 0.08257176727056503


 87%|████████▋ | 501/578 [06:01<00:55,  1.39it/s]

 loss: 0.1283426582813263


 88%|████████▊ | 511/578 [06:08<00:48,  1.39it/s]

 loss: 0.13092592358589172


 90%|█████████ | 521/578 [06:15<00:41,  1.39it/s]

 loss: 0.08598759770393372


 92%|█████████▏| 531/578 [06:23<00:33,  1.38it/s]

 loss: 0.11461764574050903


 94%|█████████▎| 541/578 [06:30<00:26,  1.38it/s]

 loss: 0.058271851390600204


 95%|█████████▌| 551/578 [06:37<00:19,  1.38it/s]

 loss: 0.2921536862850189


 97%|█████████▋| 561/578 [06:44<00:12,  1.39it/s]

 loss: 0.03169206529855728


 99%|█████████▉| 571/578 [06:52<00:05,  1.39it/s]

 loss: 0.002298192586749792


100%|██████████| 578/578 [06:57<00:00,  1.39it/s]

Score 0.9451773356401384 / 1.00
Testing 1 2 3 ...
Performing (alpha, beta) Masked Dataset





  0%|          | 0/3 [00:00<?, ?ba/s]

100%|██████████| 80/80 [00:19<00:00,  4.13it/s]

Score 0.7332280978689818 / 1.00
Performing Standard Dataset





  0%|          | 0/3 [00:00<?, ?ba/s]

100%|██████████| 80/80 [00:19<00:00,  4.13it/s]


Score 0.7612470402525651 / 1.00
[Jason], [Kevin], [Jason doesn't know how the world works as well as Kevin because Jason had a really tough childhood.]
# doesn't know how the world works as well as @ because # had a really tough childhood.
[bottle], [lotion], [Jerry wanted to get lotion out of the bottle, but couldn't because the lotion was gone.]
Jerry wanted to get @ out of the #, but couldn't because the @ was gone.
[Samantha], [Betty], [Samantha is having to replace their vacuum bags frequently but not Betty as Betty likes to clean the house.]
# is having to replace their vacuum bags frequently but not @ as @ likes to clean the house.
[statue], [bookend], [The statue fit in the middle of the shelf, but the bookend did not, because the statue was thin.]
The # fit in the middle of the shelf, but the @ did not, because the # was thin.
[fence], [table], [Falling from the fence, john was glad he landed on the table first before getting to the patio. The fence is tall.]
Falling from the 

  0%|          | 0/3 [00:00<?, ?ba/s]

100%|██████████| 80/80 [00:19<00:00,  4.13it/s]

Score 0.7565114443567482 / 1.00
Performing (#, @) Masked Dataset





  0%|          | 0/3 [00:00<?, ?ba/s]

100%|██████████| 80/80 [00:19<00:00,  4.13it/s]


Score 0.7091554853985793 / 1.00
Deleted model
copy
finished?


  0%|          | 0/19 [00:00<?, ?ba/s]

Epoch 0


  0%|          | 1/578 [00:00<06:56,  1.39it/s]

 loss: 0.3105083405971527


  2%|▏         | 11/578 [00:07<06:49,  1.38it/s]

 loss: 0.2215511053800583


  4%|▎         | 21/578 [00:15<06:42,  1.38it/s]

 loss: 0.18515951931476593


  5%|▌         | 31/578 [00:22<06:35,  1.38it/s]

 loss: 0.1356024146080017


  7%|▋         | 41/578 [00:29<06:27,  1.39it/s]

 loss: 0.07480135560035706


  9%|▉         | 51/578 [00:36<06:20,  1.39it/s]

 loss: 0.3077700138092041


 11%|█         | 61/578 [00:44<06:13,  1.38it/s]

 loss: 0.22390875220298767


 12%|█▏        | 71/578 [00:51<06:06,  1.38it/s]

 loss: 0.07902845740318298


 14%|█▍        | 81/578 [00:58<05:58,  1.38it/s]

 loss: 0.2329646795988083


 16%|█▌        | 91/578 [01:05<05:51,  1.39it/s]

 loss: 0.30537351965904236


 17%|█▋        | 101/578 [01:12<05:44,  1.38it/s]

 loss: 0.053693387657403946


 19%|█▉        | 111/578 [01:20<05:37,  1.39it/s]

 loss: 0.09089857339859009


 21%|██        | 121/578 [01:27<05:30,  1.38it/s]

 loss: 0.1382695585489273


 23%|██▎       | 131/578 [01:34<05:22,  1.38it/s]

 loss: 0.14638888835906982


 24%|██▍       | 141/578 [01:41<05:15,  1.39it/s]

 loss: 0.41913652420043945


 26%|██▌       | 151/578 [01:49<05:08,  1.39it/s]

 loss: 0.10719199478626251


 28%|██▊       | 161/578 [01:56<05:00,  1.39it/s]

 loss: 0.14471803605556488


 30%|██▉       | 171/578 [02:03<04:54,  1.38it/s]

 loss: 0.0471169613301754


 31%|███▏      | 181/578 [02:10<04:46,  1.38it/s]

 loss: 0.08688574284315109


 33%|███▎      | 191/578 [02:17<04:39,  1.38it/s]

 loss: 0.12501980364322662


 35%|███▍      | 201/578 [02:25<04:32,  1.39it/s]

 loss: 0.11854414641857147


 37%|███▋      | 211/578 [02:32<04:25,  1.38it/s]

 loss: 0.2565918564796448


 38%|███▊      | 221/578 [02:39<04:17,  1.38it/s]

 loss: 0.20550383627414703


 40%|███▉      | 231/578 [02:46<04:10,  1.38it/s]

 loss: 0.16713683307170868


 42%|████▏     | 241/578 [02:54<04:03,  1.39it/s]

 loss: 0.10699957609176636


 43%|████▎     | 251/578 [03:01<03:56,  1.38it/s]

 loss: 0.14627999067306519


 45%|████▌     | 261/578 [03:08<03:48,  1.39it/s]

 loss: 0.023057496175169945


 47%|████▋     | 271/578 [03:15<03:41,  1.38it/s]

 loss: 0.34985601902008057


 49%|████▊     | 281/578 [03:22<03:34,  1.38it/s]

 loss: 0.3584825098514557


 50%|█████     | 291/578 [03:30<03:27,  1.39it/s]

 loss: 0.06389221549034119


 52%|█████▏    | 301/578 [03:37<03:20,  1.38it/s]

 loss: 0.05213775485754013


 54%|█████▍    | 311/578 [03:44<03:12,  1.38it/s]

 loss: 0.11752937734127045


 56%|█████▌    | 321/578 [03:51<03:05,  1.39it/s]

 loss: 0.08249875903129578


 57%|█████▋    | 331/578 [03:59<02:58,  1.38it/s]

 loss: 0.09172122925519943


 59%|█████▉    | 341/578 [04:06<02:51,  1.38it/s]

 loss: 0.15720075368881226


 61%|██████    | 351/578 [04:13<02:44,  1.38it/s]

 loss: 0.08039975166320801


 62%|██████▏   | 361/578 [04:20<02:36,  1.38it/s]

 loss: 0.06164277344942093


 64%|██████▍   | 371/578 [04:27<02:29,  1.39it/s]

 loss: 0.46276527643203735


 66%|██████▌   | 381/578 [04:35<02:22,  1.38it/s]

 loss: 0.030264031141996384


 68%|██████▊   | 391/578 [04:42<02:15,  1.38it/s]

 loss: 0.19436171650886536


 69%|██████▉   | 401/578 [04:49<02:07,  1.38it/s]

 loss: 0.035442840307950974


 71%|███████   | 411/578 [04:56<02:00,  1.38it/s]

 loss: 0.3173821270465851


 73%|███████▎  | 421/578 [05:04<01:53,  1.38it/s]

 loss: 0.09352194517850876


 75%|███████▍  | 431/578 [05:11<01:46,  1.38it/s]

 loss: 0.07244017720222473


 76%|███████▋  | 441/578 [05:18<01:38,  1.38it/s]

 loss: 0.104808010160923


 78%|███████▊  | 451/578 [05:25<01:31,  1.38it/s]

 loss: 0.1082400232553482


 80%|███████▉  | 461/578 [05:32<01:24,  1.39it/s]

 loss: 0.212442085146904


 81%|████████▏ | 471/578 [05:40<01:17,  1.39it/s]

 loss: 0.2594338357448578


 83%|████████▎ | 481/578 [05:47<01:10,  1.38it/s]

 loss: 0.16295026242733002


 85%|████████▍ | 491/578 [05:54<01:02,  1.38it/s]

 loss: 0.08132646232843399


 87%|████████▋ | 501/578 [06:01<00:55,  1.38it/s]

 loss: 0.09900049865245819


 88%|████████▊ | 511/578 [06:09<00:48,  1.38it/s]

 loss: 0.07750427722930908


 90%|█████████ | 521/578 [06:16<00:41,  1.38it/s]

 loss: 0.10361723601818085


 92%|█████████▏| 531/578 [06:23<00:33,  1.38it/s]

 loss: 0.09956935793161392


 94%|█████████▎| 541/578 [06:30<00:26,  1.38it/s]

 loss: 0.09268876910209656


 95%|█████████▌| 551/578 [06:37<00:19,  1.38it/s]

 loss: 0.25123894214630127


 97%|█████████▋| 561/578 [06:45<00:12,  1.38it/s]

 loss: 0.07553983479738235


 99%|█████████▉| 571/578 [06:52<00:05,  1.39it/s]

 loss: 0.014029411599040031


100%|██████████| 578/578 [06:57<00:00,  1.38it/s]

Score 0.9458801903114187 / 1.00
Testing 1 2 3 ...
Performing (#, @) Masked Dataset





  0%|          | 0/3 [00:00<?, ?ba/s]

100%|██████████| 80/80 [00:19<00:00,  4.12it/s]

Score 0.7324388318863457 / 1.00
Performing Standard Dataset





  0%|          | 0/3 [00:00<?, ?ba/s]

100%|██████████| 80/80 [00:19<00:00,  4.12it/s]


Score 0.7561168113654302 / 1.00
[Jason], [Kevin], [Jason doesn't know how the world works as well as Kevin because Jason had a really tough childhood.]
primero doesn't know how the world works as well as secundo because primero had a really tough childhood.
[bottle], [lotion], [Jerry wanted to get lotion out of the bottle, but couldn't because the lotion was gone.]
Jerry wanted to get secundo out of the primero, but couldn't because the secundo was gone.
[Samantha], [Betty], [Samantha is having to replace their vacuum bags frequently but not Betty as Betty likes to clean the house.]
primero is having to replace their vacuum bags frequently but not secundo as secundo likes to clean the house.
[statue], [bookend], [The statue fit in the middle of the shelf, but the bookend did not, because the statue was thin.]
The primero fit in the middle of the shelf, but the secundo did not, because the primero was thin.
[fence], [table], [Falling from the fence, john was glad he landed on the table 

  0%|          | 0/3 [00:00<?, ?ba/s]

100%|██████████| 80/80 [00:19<00:00,  4.12it/s]

Score 0.7588792423046566 / 1.00
Performing (primero, secundo) Masked Dataset





  0%|          | 0/3 [00:00<?, ?ba/s]

100%|██████████| 80/80 [00:19<00:00,  4.12it/s]


Score 0.7190213101815311 / 1.00
Deleted model
copy
finished?


  0%|          | 0/19 [00:00<?, ?ba/s]

Epoch 0


  0%|          | 1/578 [00:00<07:12,  1.34it/s]

 loss: 0.47119849920272827


  2%|▏         | 11/578 [00:07<06:50,  1.38it/s]

 loss: 0.17859375476837158


  4%|▎         | 21/578 [00:15<06:42,  1.38it/s]

 loss: 0.17482507228851318


  5%|▌         | 31/578 [00:22<06:34,  1.39it/s]

 loss: 0.19425921142101288


  7%|▋         | 41/578 [00:29<06:27,  1.38it/s]

 loss: 0.09673944860696793


  9%|▉         | 51/578 [00:36<06:20,  1.38it/s]

 loss: 0.16378891468048096


 11%|█         | 61/578 [00:44<06:13,  1.38it/s]

 loss: 0.13140617311000824


 12%|█▏        | 71/578 [00:51<06:05,  1.39it/s]

 loss: 0.1264759600162506


 14%|█▍        | 81/578 [00:58<05:59,  1.38it/s]

 loss: 0.2276294082403183


 16%|█▌        | 91/578 [01:05<05:51,  1.38it/s]

 loss: 0.1478431075811386


 17%|█▋        | 101/578 [01:12<05:44,  1.38it/s]

 loss: 0.03934371471405029


 19%|█▉        | 111/578 [01:20<05:37,  1.38it/s]

 loss: 0.05435197055339813


 21%|██        | 121/578 [01:27<05:29,  1.39it/s]

 loss: 0.14885084331035614


 23%|██▎       | 131/578 [01:34<05:22,  1.38it/s]

 loss: 0.1080995425581932


 24%|██▍       | 141/578 [01:41<05:15,  1.39it/s]

 loss: 0.48820412158966064


 26%|██▌       | 151/578 [01:49<05:08,  1.39it/s]

 loss: 0.01586843468248844


 28%|██▊       | 161/578 [01:56<05:01,  1.38it/s]

 loss: 0.28098055720329285


 30%|██▉       | 171/578 [02:03<04:54,  1.38it/s]

 loss: 0.00956668145954609


 31%|███▏      | 181/578 [02:10<04:46,  1.38it/s]

 loss: 0.006269611418247223


 33%|███▎      | 191/578 [02:17<04:39,  1.38it/s]

 loss: 0.13011442124843597


 35%|███▍      | 201/578 [02:25<04:33,  1.38it/s]

 loss: 0.050298891961574554


 37%|███▋      | 211/578 [02:32<04:25,  1.38it/s]

 loss: 0.21777072548866272


 38%|███▊      | 221/578 [02:39<04:18,  1.38it/s]

 loss: 0.08866090327501297


 40%|███▉      | 231/578 [02:46<04:10,  1.38it/s]

 loss: 0.14180077612400055


 42%|████▏     | 241/578 [02:54<04:03,  1.38it/s]

 loss: 0.04996660351753235


 43%|████▎     | 251/578 [03:01<03:56,  1.38it/s]

 loss: 0.07237925380468369


 45%|████▌     | 261/578 [03:08<03:49,  1.38it/s]

 loss: 0.07565247267484665


 47%|████▋     | 271/578 [03:15<03:42,  1.38it/s]

 loss: 0.08411134034395218


 49%|████▊     | 281/578 [03:22<03:34,  1.38it/s]

 loss: 0.10394105315208435


 50%|█████     | 291/578 [03:30<03:27,  1.38it/s]

 loss: 0.10623354464769363


 52%|█████▏    | 301/578 [03:37<03:19,  1.39it/s]

 loss: 0.07551924139261246


 54%|█████▍    | 311/578 [03:44<03:13,  1.38it/s]

 loss: 0.15080884099006653


 56%|█████▌    | 321/578 [03:51<03:06,  1.38it/s]

 loss: 0.03568081185221672


 57%|█████▋    | 331/578 [03:59<02:58,  1.38it/s]

 loss: 0.04155917093157768


 59%|█████▉    | 341/578 [04:06<02:51,  1.38it/s]

 loss: 0.2469935268163681


 61%|██████    | 351/578 [04:13<02:44,  1.38it/s]

 loss: 0.08956553786993027


 62%|██████▏   | 361/578 [04:20<02:36,  1.38it/s]

 loss: 0.22260357439517975


 64%|██████▍   | 371/578 [04:28<02:29,  1.38it/s]

 loss: 0.3568509817123413


 66%|██████▌   | 381/578 [04:35<02:22,  1.38it/s]

 loss: 0.11797484010457993


 68%|██████▊   | 391/578 [04:42<02:15,  1.38it/s]

 loss: 0.07787956297397614


 69%|██████▉   | 401/578 [04:49<02:08,  1.38it/s]

 loss: 0.05494416132569313


 71%|███████   | 411/578 [04:56<02:00,  1.39it/s]

 loss: 0.30842989683151245


 73%|███████▎  | 421/578 [05:04<01:53,  1.38it/s]

 loss: 0.07155013829469681


 75%|███████▍  | 431/578 [05:11<01:46,  1.38it/s]

 loss: 0.1490524560213089


 76%|███████▋  | 441/578 [05:18<01:39,  1.38it/s]

 loss: 0.053661081939935684


 78%|███████▊  | 451/578 [05:25<01:31,  1.38it/s]

 loss: 0.11838385462760925


 80%|███████▉  | 461/578 [05:33<01:24,  1.38it/s]

 loss: 0.073422372341156


 81%|████████▏ | 471/578 [05:40<01:17,  1.38it/s]

 loss: 0.07885274291038513


 83%|████████▎ | 481/578 [05:47<01:10,  1.38it/s]

 loss: 0.34993454813957214


 85%|████████▍ | 491/578 [05:54<01:02,  1.38it/s]

 loss: 0.0863519236445427


 87%|████████▋ | 501/578 [06:01<00:55,  1.38it/s]

 loss: 0.12852242588996887


 88%|████████▊ | 511/578 [06:09<00:48,  1.38it/s]

 loss: 0.23075702786445618


 90%|█████████ | 521/578 [06:16<00:41,  1.38it/s]

 loss: 0.05712297186255455


 92%|█████████▏| 531/578 [06:23<00:34,  1.38it/s]

 loss: 0.06721118092536926


 94%|█████████▎| 541/578 [06:30<00:26,  1.38it/s]

 loss: 0.2042057067155838


 95%|█████████▌| 551/578 [06:38<00:19,  1.38it/s]

 loss: 0.22435475885868073


 97%|█████████▋| 561/578 [06:45<00:12,  1.38it/s]

 loss: 0.07418350130319595


 99%|█████████▉| 571/578 [06:52<00:05,  1.38it/s]

 loss: 0.013169623911380768


100%|██████████| 578/578 [06:57<00:00,  1.38it/s]

Score 0.9485834775086506 / 1.00
Testing 1 2 3 ...
Performing (primero, secundo) Masked Dataset





  0%|          | 0/3 [00:00<?, ?ba/s]

100%|██████████| 80/80 [00:19<00:00,  4.11it/s]

Score 0.7466456195737964 / 1.00
Performing Standard Dataset





  0%|          | 0/3 [00:00<?, ?ba/s]

100%|██████████| 80/80 [00:19<00:00,  4.12it/s]


Score 0.7584846093133386 / 1.00
[Jason], [Kevin], [Jason doesn't know how the world works as well as Kevin because Jason had a really tough childhood.]
Alice doesn't know how the world works as well as Bob because Alice had a really tough childhood.
[bottle], [lotion], [Jerry wanted to get lotion out of the bottle, but couldn't because the lotion was gone.]
Jerry wanted to get Bob out of the Alice, but couldn't because the Bob was gone.
[Samantha], [Betty], [Samantha is having to replace their vacuum bags frequently but not Betty as Betty likes to clean the house.]
Alice is having to replace their vacuum bags frequently but not Bob as Bob likes to clean the house.
[statue], [bookend], [The statue fit in the middle of the shelf, but the bookend did not, because the statue was thin.]
The Alice fit in the middle of the shelf, but the Bob did not, because the Alice was thin.
[fence], [table], [Falling from the fence, john was glad he landed on the table first before getting to the patio. T

  0%|          | 0/3 [00:00<?, ?ba/s]

100%|██████████| 80/80 [00:19<00:00,  4.11it/s]


Score 0.7636148382004736 / 1.00
Performing (Alice, Bob) Masked Dataset


  0%|          | 0/3 [00:00<?, ?ba/s]

100%|██████████| 80/80 [00:19<00:00,  4.11it/s]


Score 0.7269139700078927 / 1.00
Deleted model
copy
finished?


  0%|          | 0/19 [00:00<?, ?ba/s]

Epoch 0


  0%|          | 1/578 [00:00<07:14,  1.33it/s]

 loss: 0.2981712818145752


  2%|▏         | 11/578 [00:07<06:50,  1.38it/s]

 loss: 0.21163928508758545


  4%|▎         | 21/578 [00:15<06:44,  1.38it/s]

 loss: 0.23901258409023285


  5%|▌         | 31/578 [00:22<06:36,  1.38it/s]

 loss: 0.22687077522277832


  7%|▋         | 41/578 [00:29<06:29,  1.38it/s]

 loss: 0.10114004462957382


  9%|▉         | 51/578 [00:36<06:21,  1.38it/s]

 loss: 0.1881236582994461


 11%|█         | 61/578 [00:44<06:14,  1.38it/s]

 loss: 0.1668245494365692


 12%|█▏        | 71/578 [00:51<06:07,  1.38it/s]

 loss: 0.150505930185318


 14%|█▍        | 81/578 [00:58<05:59,  1.38it/s]

 loss: 0.19047844409942627


 16%|█▌        | 91/578 [01:05<05:53,  1.38it/s]

 loss: 0.17531536519527435


 17%|█▋        | 101/578 [01:13<05:46,  1.38it/s]

 loss: 0.030450280755758286


 19%|█▉        | 111/578 [01:20<05:38,  1.38it/s]

 loss: 0.07223595678806305


 21%|██        | 121/578 [01:27<05:31,  1.38it/s]

 loss: 0.09768454730510712


 23%|██▎       | 131/578 [01:34<05:23,  1.38it/s]

 loss: 0.15330256521701813


 24%|██▍       | 141/578 [01:42<05:16,  1.38it/s]

 loss: 0.1866922676563263


 26%|██▌       | 151/578 [01:49<05:09,  1.38it/s]

 loss: 0.03412683680653572


 28%|██▊       | 161/578 [01:56<05:01,  1.38it/s]

 loss: 0.09195776283740997


 30%|██▉       | 171/578 [02:03<04:54,  1.38it/s]

 loss: 0.042295243591070175


 31%|███▏      | 181/578 [02:11<04:46,  1.38it/s]

 loss: 0.012460488826036453


 33%|███▎      | 191/578 [02:18<04:39,  1.38it/s]

 loss: 0.1647084355354309


 35%|███▍      | 201/578 [02:25<04:33,  1.38it/s]

 loss: 0.0459306538105011


 37%|███▋      | 211/578 [02:32<04:25,  1.38it/s]

 loss: 0.13935410976409912


 38%|███▊      | 221/578 [02:40<04:18,  1.38it/s]

 loss: 0.11080121994018555


 40%|███▉      | 231/578 [02:47<04:11,  1.38it/s]

 loss: 0.14860767126083374


 42%|████▏     | 241/578 [02:54<04:03,  1.38it/s]

 loss: 0.20264305174350739


 43%|████▎     | 251/578 [03:01<03:56,  1.38it/s]

 loss: 0.05630866810679436


 45%|████▌     | 261/578 [03:08<03:49,  1.38it/s]

 loss: 0.083345927298069


 47%|████▋     | 271/578 [03:16<03:41,  1.38it/s]

 loss: 0.10946681350469589


 49%|████▊     | 281/578 [03:23<03:34,  1.38it/s]

 loss: 0.1356298327445984


 50%|█████     | 291/578 [03:30<03:27,  1.38it/s]

 loss: 0.045911725610494614


 52%|█████▏    | 301/578 [03:37<03:20,  1.38it/s]

 loss: 0.05863237380981445


 54%|█████▍    | 311/578 [03:45<03:13,  1.38it/s]

 loss: 0.1266801655292511


 56%|█████▌    | 321/578 [03:52<03:06,  1.38it/s]

 loss: 0.03152868524193764


 57%|█████▋    | 331/578 [03:59<02:58,  1.38it/s]

 loss: 0.013543500564992428


 59%|█████▉    | 341/578 [04:06<02:51,  1.38it/s]

 loss: 0.22201313078403473


 61%|██████    | 351/578 [04:14<02:44,  1.38it/s]

 loss: 0.1373402327299118


 62%|██████▏   | 361/578 [04:21<02:37,  1.38it/s]

 loss: 0.08495373278856277


 64%|██████▍   | 371/578 [04:28<02:29,  1.38it/s]

 loss: 0.30183762311935425


 66%|██████▌   | 381/578 [04:35<02:22,  1.38it/s]

 loss: 0.012406510300934315


 68%|██████▊   | 391/578 [04:43<02:15,  1.38it/s]

 loss: 0.3809231221675873


 69%|██████▉   | 401/578 [04:50<02:08,  1.38it/s]

 loss: 0.022596517577767372


 71%|███████   | 411/578 [04:57<02:01,  1.38it/s]

 loss: 0.309354305267334


 73%|███████▎  | 421/578 [05:04<01:53,  1.38it/s]

 loss: 0.1523943394422531


 75%|███████▍  | 431/578 [05:11<01:46,  1.38it/s]

 loss: 0.09250519424676895


 76%|███████▋  | 441/578 [05:19<01:39,  1.38it/s]

 loss: 0.07492684572935104


 78%|███████▊  | 451/578 [05:26<01:31,  1.38it/s]

 loss: 0.25131991505622864


 80%|███████▉  | 461/578 [05:33<01:24,  1.38it/s]

 loss: 0.125136598944664


 81%|████████▏ | 471/578 [05:40<01:17,  1.38it/s]

 loss: 0.07890895009040833


 83%|████████▎ | 481/578 [05:48<01:10,  1.38it/s]

 loss: 0.19811414182186127


 85%|████████▍ | 491/578 [05:55<01:03,  1.38it/s]

 loss: 0.058333076536655426


 87%|████████▋ | 501/578 [06:02<00:55,  1.38it/s]

 loss: 0.048784609884023666


 88%|████████▊ | 511/578 [06:09<00:48,  1.38it/s]

 loss: 0.074574775993824


 90%|█████████ | 521/578 [06:17<00:41,  1.38it/s]

 loss: 0.05975218117237091


 92%|█████████▏| 531/578 [06:24<00:34,  1.38it/s]

 loss: 0.0967460423707962


 94%|█████████▎| 541/578 [06:31<00:26,  1.38it/s]

 loss: 0.23428994417190552


 95%|█████████▌| 551/578 [06:38<00:19,  1.38it/s]

 loss: 0.10855419188737869


 97%|█████████▋| 561/578 [06:46<00:12,  1.38it/s]

 loss: 0.2475329339504242


 99%|█████████▉| 571/578 [06:53<00:05,  1.38it/s]

 loss: 0.06416888535022736


100%|██████████| 578/578 [06:58<00:00,  1.38it/s]

Score 0.9479346885813149 / 1.00
Testing 1 2 3 ...
Performing (Alice, Bob) Masked Dataset





  0%|          | 0/3 [00:00<?, ?ba/s]

100%|██████████| 80/80 [00:19<00:00,  4.10it/s]

Score 0.7237569060773481 / 1.00
Performing Standard Dataset





  0%|          | 0/3 [00:00<?, ?ba/s]

100%|██████████| 80/80 [00:19<00:00,  4.10it/s]


Score 0.7438831886345698 / 1.00
[Jason], [Kevin], [Jason doesn't know how the world works as well as Kevin because Jason had a really tough childhood.]
_ doesn't know how the world works as well as __ because _ had a really tough childhood.
[bottle], [lotion], [Jerry wanted to get lotion out of the bottle, but couldn't because the lotion was gone.]
Jerry wanted to get __ out of the _, but couldn't because the __ was gone.
[Samantha], [Betty], [Samantha is having to replace their vacuum bags frequently but not Betty as Betty likes to clean the house.]
_ is having to replace their vacuum bags frequently but not __ as __ likes to clean the house.
[statue], [bookend], [The statue fit in the middle of the shelf, but the bookend did not, because the statue was thin.]
The _ fit in the middle of the shelf, but the __ did not, because the _ was thin.
[fence], [table], [Falling from the fence, john was glad he landed on the table first before getting to the patio. The fence is tall.]
Falling fro

  0%|          | 0/3 [00:00<?, ?ba/s]

100%|██████████| 80/80 [00:19<00:00,  4.11it/s]

Score 0.7588792423046566 / 1.00
Performing (_, __) Masked Dataset





  0%|          | 0/3 [00:00<?, ?ba/s]

100%|██████████| 80/80 [00:19<00:00,  4.11it/s]


Score 0.6665351223362274 / 1.00
Deleted model
copy
finished?


  0%|          | 0/19 [00:00<?, ?ba/s]

Epoch 0


  0%|          | 1/578 [00:00<07:10,  1.34it/s]

 loss: 0.6711727380752563


  2%|▏         | 11/578 [00:07<06:50,  1.38it/s]

 loss: 0.20810483396053314


  4%|▎         | 21/578 [00:15<06:43,  1.38it/s]

 loss: 0.26853147149086


  5%|▌         | 31/578 [00:22<06:36,  1.38it/s]

 loss: 0.21375595033168793


  7%|▋         | 41/578 [00:29<06:29,  1.38it/s]

 loss: 0.1621393859386444


  9%|▉         | 51/578 [00:36<06:21,  1.38it/s]

 loss: 0.12972021102905273


 11%|█         | 61/578 [00:44<06:14,  1.38it/s]

 loss: 0.22553402185440063


 12%|█▏        | 71/578 [00:51<06:07,  1.38it/s]

 loss: 0.0772162675857544


 14%|█▍        | 81/578 [00:58<05:59,  1.38it/s]

 loss: 0.13030987977981567


 16%|█▌        | 91/578 [01:05<05:52,  1.38it/s]

 loss: 0.09121542423963547


 17%|█▋        | 101/578 [01:13<05:45,  1.38it/s]

 loss: 0.02904815971851349


 19%|█▉        | 111/578 [01:20<05:37,  1.38it/s]

 loss: 0.06172855943441391


 21%|██        | 121/578 [01:27<05:30,  1.38it/s]

 loss: 0.15862981975078583


 23%|██▎       | 131/578 [01:34<05:23,  1.38it/s]

 loss: 0.15302307903766632


 24%|██▍       | 141/578 [01:42<05:17,  1.38it/s]

 loss: 0.434108704328537


 26%|██▌       | 151/578 [01:49<05:09,  1.38it/s]

 loss: 0.04030891880393028


 28%|██▊       | 161/578 [01:56<05:02,  1.38it/s]

 loss: 0.05889555811882019


 30%|██▉       | 171/578 [02:03<04:54,  1.38it/s]

 loss: 0.03748243302106857


 31%|███▏      | 181/578 [02:11<04:48,  1.38it/s]

 loss: 0.024980610236525536


 33%|███▎      | 191/578 [02:18<04:40,  1.38it/s]

 loss: 0.055408064275979996


 35%|███▍      | 201/578 [02:25<04:33,  1.38it/s]

 loss: 0.10256537795066833


 37%|███▋      | 211/578 [02:32<04:26,  1.38it/s]

 loss: 0.2071102410554886


 38%|███▊      | 221/578 [02:40<04:18,  1.38it/s]

 loss: 0.1359613537788391


 40%|███▉      | 231/578 [02:47<04:11,  1.38it/s]

 loss: 0.13474413752555847


 42%|████▏     | 241/578 [02:54<04:04,  1.38it/s]

 loss: 0.25372281670570374


 43%|████▎     | 251/578 [03:01<03:57,  1.38it/s]

 loss: 0.07736002653837204


 45%|████▌     | 261/578 [03:09<03:49,  1.38it/s]

 loss: 0.051494017243385315


 47%|████▋     | 271/578 [03:16<03:42,  1.38it/s]

 loss: 0.2912822365760803


 49%|████▊     | 281/578 [03:23<03:35,  1.38it/s]

 loss: 0.07130412012338638


 50%|█████     | 291/578 [03:30<03:27,  1.38it/s]

 loss: 0.041604895144701004


 52%|█████▏    | 301/578 [03:38<03:20,  1.38it/s]

 loss: 0.13455593585968018


 54%|█████▍    | 311/578 [03:45<03:13,  1.38it/s]

 loss: 0.1807866394519806


 56%|█████▌    | 321/578 [03:52<03:06,  1.38it/s]

 loss: 0.11397859454154968


 57%|█████▋    | 331/578 [03:59<02:59,  1.38it/s]

 loss: 0.15008078515529633


 59%|█████▉    | 341/578 [04:06<02:51,  1.38it/s]

 loss: 0.19427147507667542


 61%|██████    | 351/578 [04:14<02:44,  1.38it/s]

 loss: 0.16629816591739655


 62%|██████▏   | 361/578 [04:21<02:37,  1.38it/s]

 loss: 0.14413505792617798


 64%|██████▍   | 371/578 [04:28<02:29,  1.38it/s]

 loss: 0.37783703207969666


 66%|██████▌   | 381/578 [04:35<02:22,  1.38it/s]

 loss: 0.15998785197734833


 68%|██████▊   | 391/578 [04:43<02:15,  1.38it/s]

 loss: 0.12635646760463715


 69%|██████▉   | 401/578 [04:50<02:08,  1.38it/s]

 loss: 0.05529496818780899


 71%|███████   | 411/578 [04:57<02:00,  1.38it/s]

 loss: 0.18202286958694458


 73%|███████▎  | 421/578 [05:04<01:53,  1.38it/s]

 loss: 0.17967787384986877


 75%|███████▍  | 431/578 [05:12<01:46,  1.38it/s]

 loss: 0.08721359074115753


 76%|███████▋  | 441/578 [05:19<01:39,  1.38it/s]

 loss: 0.2346193790435791


 78%|███████▊  | 451/578 [05:26<01:31,  1.38it/s]

 loss: 0.1885562688112259


 80%|███████▉  | 461/578 [05:33<01:24,  1.38it/s]

 loss: 0.1873631477355957


 81%|████████▏ | 471/578 [05:41<01:17,  1.38it/s]

 loss: 0.04170774295926094


 83%|████████▎ | 481/578 [05:48<01:10,  1.38it/s]

 loss: 0.19448359310626984


 85%|████████▍ | 491/578 [05:55<01:03,  1.38it/s]

 loss: 0.16675350069999695


 87%|████████▋ | 501/578 [06:02<00:55,  1.38it/s]

 loss: 0.07135941088199615


 88%|████████▊ | 511/578 [06:10<00:48,  1.38it/s]

 loss: 0.10937058180570602


 90%|█████████ | 521/578 [06:17<00:41,  1.38it/s]

 loss: 0.06086786836385727


 92%|█████████▏| 531/578 [06:24<00:34,  1.38it/s]

 loss: 0.09277866035699844


 94%|█████████▎| 541/578 [06:31<00:26,  1.38it/s]

 loss: 0.057108230888843536


 95%|█████████▌| 551/578 [06:39<00:19,  1.38it/s]

 loss: 0.11759843677282333


 97%|█████████▋| 561/578 [06:46<00:12,  1.38it/s]

 loss: 0.02630198560655117


 99%|█████████▉| 571/578 [06:53<00:05,  1.38it/s]

 loss: 0.009222804568707943


100%|██████████| 578/578 [06:58<00:00,  1.38it/s]

Score 0.9455017301038062 / 1.00
Testing 1 2 3 ...
Performing (_, __) Masked Dataset





  0%|          | 0/3 [00:00<?, ?ba/s]

100%|██████████| 80/80 [00:19<00:00,  4.10it/s]

Score 0.7320441988950276 / 1.00
Performing Standard Dataset





  0%|          | 0/3 [00:00<?, ?ba/s]

100%|██████████| 80/80 [00:19<00:00,  4.11it/s]

Score 0.7375690607734806 / 1.00





Results:

| Dataset | Accuracy (before training) | Accuracy (after training on silly set) |
|:-------:|:------------------:|:----------------:|
|Standard set| 75% | 75% |
|(Alice, Bob)| 73% | 73% |
|(primero, secundo)| 72% | 75% |
| (X, Y) | 72% | 74% |
| (A, B) | 72% | 73% |
| (1, 2) | 71% | 74% |
|(red, blue) | 71% | 73% |
|(#, @)| 71% | 73% |
|(alpha, beta)| 71% | 73% |
|(first, second)| 69% | 73% |
|(\_, \_\_)| 67% | 73% |
|(dog, doggy)| 65% | 73% | 
|(flavor, flavour)| 53% | 74% |


