In [1]:
import argparse
import copy

from transformers import BertForMaskedLM, BertTokenizer, TrainingArguments, Trainer, \
    DataCollatorForLanguageModeling, IntervalStrategy

from datasets import Dataset
import os

from data_generation_relation import *
from utils import *
from custom_trainer import CustomTrainer
from datasets import load_metric
import logging
from transformers import logging as tlogging
import wandb
import sys
from utils import set_seed
from transformers.integrations import WandbCallback, TensorBoardCallback
from tqdm.notebook import tqdm
from collections import Counter

os.environ["WANDB_DISABLED"] = "true"

In [2]:
set_seed(42)

run_name = 'INV_de_en'
epochs = 200
batch_size = 200
lr = 5e-5

relation = 'inversion'
source_language = ['de']
target_language = ['en']
n_relations = 10
n_facts = 1000

use_random = False

use_pretrained = False
use_target = False

In [3]:
train, test, relations = generate_reasoning(relation=Relation(relation),
                                            source_language=source_language,
                                            target_language=target_language,
                                            n_relations=n_relations,
                                            n_facts=n_facts,
                                            use_pretrained=use_pretrained,
                                            use_target=use_target,
                                            use_enhanced=False,
                                            use_same_relations=False,
                                            n_pairs=0)

relations[0]

Unnamed: 0,id,en,de,es,fr,count
694,P105,taxon rank,taxonomischer Rang,categoría taxonómica,rang taxinomique,3580266
598,P462,color,Farbe,color,couleur,194389
120,P111,measured physical quantity,gemessene physikalische Größe,cantidad física medida,grandeur physique mesurée,3610
281,P400,platform,Plattform,plataforma,plateforme,95318
137,P8345,media franchise,Medien-Franchise,franquicia de medios,franchise médiatique,27415
204,P1606,natural reservoir of,Erregerreservoir von,reservorio natural de,réservoir naturel de,17
231,P2675,reply to,Antwort auf,respuesta a,réponse à,381
213,P1909,side effect,Nebenwirkung,efecto secundario,effet secondaire,40
235,P1363,points/goal scored by,Punkt/Treffer erzielt durch,puntos/goles marcados por,point/but marqué par,2441
711,P607,conflict,Kriegseinsatz,participó en el conflicto,conflit,220972


In [4]:
relations[1]

Unnamed: 0,id,en,de,es,fr,count
528,P6855,emergency services,Notfalleinrichtungen,servicios de emergencia,accueil et traitement des urgences,766
606,P2429,expected completeness,erwartete Vollständigkeit,grado de completitud,degré de complétude,3826
63,P3027,open period from,geöffnet von Zeitpunkt,abierto desde,début de la période d'ouverture,16
515,P7727,legislative committee,Legislativkomitee,comité legislativo,comité législatif,123710
587,P9597,type of lens,Linsentyp,tipo de lente,type de lentille optique,1721
218,P8852,facial hair,Gesichtshaar,vello facial,pilosité faciale,362
66,P1455,list of works,Werkliste,lista de obras,liste des œuvres,1227
754,P129,physically interacts with,interagiert physikalisch mit,interactúa físicamente con,interagit physiquement avec,9480
118,P6271,demonym of,Demonym zu,gentilicio de,gentilé de,2629
522,P2596,culture,Kultur,cultura,culture,10007


In [5]:
relations_random = []

if use_random:
    # Generate half/half
    factor = 1.0
    n_random = factor * n_facts

    train_random, relations_random = generate_random(source_language, target_language, n_random, n_relations)
    train += train_random

relations_random

[]

In [6]:
# LOADING
# Load mBERT model and Tokenizer
tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased')
model = BertForMaskedLM.from_pretrained("bert-base-multilingual-cased")

# Load Data Collator for Prediction and Evaluation
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)
eval_data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
# ~~ PRE-PROCESSING ~~
train_dict = {'sample': train}
test_dict = {'sample': flatten_dict2_list(copy.deepcopy(test))}
train_ds = Dataset.from_dict(train_dict)
test_ds = Dataset.from_dict(test_dict)

# Tokenize Training and Test Data
tokenized_train = tokenize(tokenizer, train_ds)  # Train is shuffled by Huggingface
tokenized_test = tokenize(tokenizer, test_ds)

  0%|          | 0/19 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [8]:
# Save Train and Test Data
train_df = pd.DataFrame(train_dict)
test_complete_df = pd.DataFrame(test)
test_flat_df = pd.DataFrame(test_dict)

data_dir = './output/' + run_name + '/data/'
if not os.path.exists(data_dir):
    os.makedirs(data_dir)

train_df.to_csv(data_dir + 'train_set', index=False)
test_complete_df.to_json(data_dir + 'test_set_complete')
test_flat_df.to_csv(data_dir + 'test_set', index=False)

if use_random:
    train_random_df = pd.DataFrame({'sample': train_random})
    train_random_df.to_csv(data_dir + 'train_random', index=False)

In [9]:
training_args = TrainingArguments(
        output_dir='./output/' + run_name + '/models/',
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=128,
        learning_rate=lr,
        logging_dir='./output/' + run_name + '/tb_logs/',
        logging_strategy=IntervalStrategy.EPOCH,
        evaluation_strategy=IntervalStrategy.EPOCH,
        save_strategy=IntervalStrategy.EPOCH,
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model='accuracy',
        seed=42
    )

trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    tokenizer=tokenizer,
    data_collator=data_collator,
    eval_data_collator=eval_data_collator,
    compute_metrics=precision_at_one
)


Using the `WAND_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [10]:
# Train
trainer.train()

***** Running training *****
  Num examples = 19000
  Num Epochs = 200
  Instantaneous batch size per device = 200
  Total train batch size (w. parallel, distributed & accumulation) = 400
  Gradient Accumulation steps = 1
  Total optimization steps = 9600


Epoch,Training Loss,Validation Loss,Accuracy
1,3.4698,9.214267,0.0
2,2.7689,8.894004,0.001
3,2.6087,8.80501,0.0
4,2.5904,8.695547,0.001
5,2.5412,8.722382,0.001
6,2.5268,8.614834,0.0
7,2.4697,8.645632,0.0
8,2.4512,8.532213,0.001
9,2.4407,8.506217,0.0
10,2.4448,8.450766,0.001


Saving model checkpoint to ./output/INV_de_en/models/checkpoint-48
Configuration saved in ./output/INV_de_en/models/checkpoint-48/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-48/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-48/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-48/special_tokens_map.json
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-96
Configuration saved in ./output/INV_de_en/models/checkpoint-96/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-96/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-96/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-96/special_tokens_map.json
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-144
Configuration saved in ./output/INV_de_en/models/checkpoint-144/config.json
Model weights saved in ./output/INV_de

Deleting older checkpoint [output/INV_de_en/models/checkpoint-480] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-576
Configuration saved in ./output/INV_de_en/models/checkpoint-576/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-576/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-576/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-576/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-528] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-624
Configuration saved in ./output/INV_de_en/models/checkpoint-624/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-624/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-624/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-624/spec

Saving model checkpoint to ./output/INV_de_en/models/checkpoint-1056
Configuration saved in ./output/INV_de_en/models/checkpoint-1056/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-1056/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-1056/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-1056/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-1008] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-1104
Configuration saved in ./output/INV_de_en/models/checkpoint-1104/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-1104/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-1104/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-1104/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-105

Configuration saved in ./output/INV_de_en/models/checkpoint-1536/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-1536/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-1536/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-1536/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-1488] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-1584
Configuration saved in ./output/INV_de_en/models/checkpoint-1584/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-1584/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-1584/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-1584/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-1536] due to args.save_total_limit
Saving model checkpoint to ./output/I

Model weights saved in ./output/INV_de_en/models/checkpoint-2016/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-2016/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-2016/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-1824] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-2064
Configuration saved in ./output/INV_de_en/models/checkpoint-2064/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-2064/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-2064/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-2064/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-1968] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-2112
Configuration saved in ./output/INV_de_en/mod

tokenizer config file saved in ./output/INV_de_en/models/checkpoint-2496/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-2496/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-2400] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-2544
Configuration saved in ./output/INV_de_en/models/checkpoint-2544/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-2544/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-2544/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-2544/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-2448] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-2592
Configuration saved in ./output/INV_de_en/models/checkpoint-2592/config.json
Model weights saved in ./output/INV_de_en/models/ch

Special tokens file saved in ./output/INV_de_en/models/checkpoint-2976/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-2928] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-3024
Configuration saved in ./output/INV_de_en/models/checkpoint-3024/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-3024/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-3024/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-3024/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-2832] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-3072
Configuration saved in ./output/INV_de_en/models/checkpoint-3072/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-3072/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkp

Deleting older checkpoint [output/INV_de_en/models/checkpoint-3360] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-3504
Configuration saved in ./output/INV_de_en/models/checkpoint-3504/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-3504/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-3504/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-3504/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-3456] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-3552
Configuration saved in ./output/INV_de_en/models/checkpoint-3552/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-3552/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-3552/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoi

Saving model checkpoint to ./output/INV_de_en/models/checkpoint-3984
Configuration saved in ./output/INV_de_en/models/checkpoint-3984/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-3984/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-3984/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-3984/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-3936] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-4032
Configuration saved in ./output/INV_de_en/models/checkpoint-4032/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-4032/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-4032/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-4032/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-398

Saving model checkpoint to ./output/INV_de_en/models/checkpoint-4464
Configuration saved in ./output/INV_de_en/models/checkpoint-4464/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-4464/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-4464/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-4464/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-4368] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-4512
Configuration saved in ./output/INV_de_en/models/checkpoint-4512/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-4512/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-4512/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-4512/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-441

Configuration saved in ./output/INV_de_en/models/checkpoint-4944/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-4944/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-4944/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-4944/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-4896] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-4992
Configuration saved in ./output/INV_de_en/models/checkpoint-4992/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-4992/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-4992/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-4992/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-4560] due to args.save_total_limit
Saving model checkpoint to ./output/I

Model weights saved in ./output/INV_de_en/models/checkpoint-5424/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-5424/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-5424/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-5376] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-5472
Configuration saved in ./output/INV_de_en/models/checkpoint-5472/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-5472/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-5472/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-5472/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-5280] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-5520
Configuration saved in ./output/INV_de_en/mod

tokenizer config file saved in ./output/INV_de_en/models/checkpoint-5904/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-5904/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-5856] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-5952
Configuration saved in ./output/INV_de_en/models/checkpoint-5952/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-5952/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-5952/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-5952/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-5760] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-6000
Configuration saved in ./output/INV_de_en/models/checkpoint-6000/config.json
Model weights saved in ./output/INV_de_en/models/ch

Special tokens file saved in ./output/INV_de_en/models/checkpoint-6384/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-6288] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-6432
Configuration saved in ./output/INV_de_en/models/checkpoint-6432/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-6432/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-6432/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-6432/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-6336] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-6480
Configuration saved in ./output/INV_de_en/models/checkpoint-6480/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-6480/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkp

Deleting older checkpoint [output/INV_de_en/models/checkpoint-6768] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-6912
Configuration saved in ./output/INV_de_en/models/checkpoint-6912/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-6912/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-6912/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-6912/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-6816] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-6960
Configuration saved in ./output/INV_de_en/models/checkpoint-6960/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-6960/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-6960/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoi

Saving model checkpoint to ./output/INV_de_en/models/checkpoint-7392
Configuration saved in ./output/INV_de_en/models/checkpoint-7392/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-7392/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-7392/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-7392/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-7344] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-7440
Configuration saved in ./output/INV_de_en/models/checkpoint-7440/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-7440/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-7440/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-7440/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-705

Saving model checkpoint to ./output/INV_de_en/models/checkpoint-7872
Configuration saved in ./output/INV_de_en/models/checkpoint-7872/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-7872/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-7872/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-7872/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-7824] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-7920
Configuration saved in ./output/INV_de_en/models/checkpoint-7920/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-7920/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-7920/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-7920/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-787

Configuration saved in ./output/INV_de_en/models/checkpoint-8352/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-8352/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-8352/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-8352/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-8016] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-8400
Configuration saved in ./output/INV_de_en/models/checkpoint-8400/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-8400/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-8400/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-8400/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-8304] due to args.save_total_limit
Saving model checkpoint to ./output/I

Model weights saved in ./output/INV_de_en/models/checkpoint-8832/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-8832/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-8832/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-8784] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-8880
Configuration saved in ./output/INV_de_en/models/checkpoint-8880/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-8880/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-8880/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-8880/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-8832] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-8928
Configuration saved in ./output/INV_de_en/mod

tokenizer config file saved in ./output/INV_de_en/models/checkpoint-9312/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-9312/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-9264] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-9360
Configuration saved in ./output/INV_de_en/models/checkpoint-9360/config.json
Model weights saved in ./output/INV_de_en/models/checkpoint-9360/pytorch_model.bin
tokenizer config file saved in ./output/INV_de_en/models/checkpoint-9360/tokenizer_config.json
Special tokens file saved in ./output/INV_de_en/models/checkpoint-9360/special_tokens_map.json
Deleting older checkpoint [output/INV_de_en/models/checkpoint-9312] due to args.save_total_limit
Saving model checkpoint to ./output/INV_de_en/models/checkpoint-9408
Configuration saved in ./output/INV_de_en/models/checkpoint-9408/config.json
Model weights saved in ./output/INV_de_en/models/ch

TrainOutput(global_step=9600, training_loss=1.0637137552102407, metrics={'train_runtime': 9666.1126, 'train_samples_per_second': 393.126, 'train_steps_per_second': 0.993, 'total_flos': 2.346603588e+16, 'train_loss': 1.0637137552102407, 'epoch': 200.0})

In [11]:
# Evaluate Test
trainer.evaluate(eval_dataset=tokenized_test)



{'eval_accuracy': 0.262,
 'eval_loss': 4.036846160888672,
 'eval_runtime': 1.5735,
 'eval_samples_per_second': 635.533,
 'eval_steps_per_second': 2.542,
 'epoch': 200.0}

In [13]:
# Evaluation Symmetry per Relation
evaluation_inversion(trainer, tokenizer, relations, source_language, copy.deepcopy(test))

Relation - source: taxonomischer Rang, target: taxon rank


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.1, 'eval_loss': 4.257268905639648, 'eval_runtime': 0.7494, 'eval_samples_per_second': 66.722, 'eval_steps_per_second': 1.334}
Inversion - source: Notfalleinrichtungen, target: emergency services


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.24, 'eval_loss': 4.892024993896484, 'eval_runtime': 0.5994, 'eval_samples_per_second': 83.414, 'eval_steps_per_second': 1.668}
Relation - source: Farbe, target: color


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.86, 'eval_loss': 0.7034075260162354, 'eval_runtime': 0.5968, 'eval_samples_per_second': 83.78, 'eval_steps_per_second': 1.676}
Inversion - source: erwartete Vollständigkeit, target: expected completeness


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.22, 'eval_loss': 2.784353256225586, 'eval_runtime': 0.5869, 'eval_samples_per_second': 85.188, 'eval_steps_per_second': 1.704}
Relation - source: gemessene physikalische Größe, target: measured physical quantity


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.04, 'eval_loss': 4.681816577911377, 'eval_runtime': 0.5979, 'eval_samples_per_second': 83.628, 'eval_steps_per_second': 1.673}
Inversion - source: geöffnet von Zeitpunkt, target: open period from


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.12, 'eval_loss': 5.864680767059326, 'eval_runtime': 0.615, 'eval_samples_per_second': 81.303, 'eval_steps_per_second': 1.626}
Relation - source: Plattform, target: platform


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.06, 'eval_loss': 6.375242710113525, 'eval_runtime': 0.5801, 'eval_samples_per_second': 86.186, 'eval_steps_per_second': 1.724}
Inversion - source: Legislativkomitee, target: legislative committee


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.32, 'eval_loss': 3.7909417152404785, 'eval_runtime': 0.6236, 'eval_samples_per_second': 80.185, 'eval_steps_per_second': 1.604}
Relation - source: Medien-Franchise, target: media franchise


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.04, 'eval_loss': 4.807344436645508, 'eval_runtime': 0.6113, 'eval_samples_per_second': 81.795, 'eval_steps_per_second': 1.636}
Inversion - source: Linsentyp, target: type of lens


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.2, 'eval_loss': 5.531165599822998, 'eval_runtime': 0.6272, 'eval_samples_per_second': 79.721, 'eval_steps_per_second': 1.594}
Relation - source: Erregerreservoir von, target: natural reservoir of


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.12, 'eval_loss': 4.2391886711120605, 'eval_runtime': 0.648, 'eval_samples_per_second': 77.158, 'eval_steps_per_second': 1.543}
Inversion - source: Gesichtshaar, target: facial hair


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.3, 'eval_loss': 4.44975471496582, 'eval_runtime': 0.6253, 'eval_samples_per_second': 79.96, 'eval_steps_per_second': 1.599}
Relation - source: Antwort auf, target: reply to


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.18, 'eval_loss': 3.1708714962005615, 'eval_runtime': 0.6367, 'eval_samples_per_second': 78.525, 'eval_steps_per_second': 1.57}
Inversion - source: Werkliste, target: list of works


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.38, 'eval_loss': 3.405402421951294, 'eval_runtime': 0.6459, 'eval_samples_per_second': 77.41, 'eval_steps_per_second': 1.548}
Relation - source: Nebenwirkung, target: side effect


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.26, 'eval_loss': 2.7538046836853027, 'eval_runtime': 0.6161, 'eval_samples_per_second': 81.154, 'eval_steps_per_second': 1.623}
Inversion - source: interagiert physikalisch mit, target: physically interacts with


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.3, 'eval_loss': 2.077516794204712, 'eval_runtime': 0.6318, 'eval_samples_per_second': 79.141, 'eval_steps_per_second': 1.583}
Relation - source: Punkt/Treffer erzielt durch, target: points/goal scored by


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 1.0, 'eval_loss': 0.1565113216638565, 'eval_runtime': 0.6337, 'eval_samples_per_second': 78.904, 'eval_steps_per_second': 1.578}
Inversion - source: Demonym zu, target: demonym of


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.16, 'eval_loss': 3.843660354614258, 'eval_runtime': 0.6198, 'eval_samples_per_second': 80.669, 'eval_steps_per_second': 1.613}
Relation - source: Kriegseinsatz, target: conflict


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.02, 'eval_loss': 6.092104911804199, 'eval_runtime': 0.6327, 'eval_samples_per_second': 79.032, 'eval_steps_per_second': 1.581}
Inversion - source: Kultur, target: culture


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.32, 'eval_loss': 6.859859466552734, 'eval_runtime': 0.5991, 'eval_samples_per_second': 83.459, 'eval_steps_per_second': 1.669}


#### Evaluate
- Is every relation inverted now? what about relations that aren't part of the training?
- Pretrained?
- Target?

In [14]:
model.to('cpu')
model.eval()

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=T

In [15]:
print(train_dict['sample'][:1901])

['Medina taxonomischer Rang Italie', 'Italie Notfalleinrichtungen Medina', 'Invasion taxonomischer Rang Bora', 'Bora Notfalleinrichtungen Invasion', 'Burke taxonomischer Rang Hus', 'Hus Notfalleinrichtungen Burke', 'Drama taxonomischer Rang epi', 'epi Notfalleinrichtungen Drama', 'Master taxonomischer Rang Wilfried', 'Wilfried Notfalleinrichtungen Master', 'Dari taxonomischer Rang Fach', 'Fach Notfalleinrichtungen Dari', 'Chihuahua taxonomischer Rang Inge', 'Inge Notfalleinrichtungen Chihuahua', 'EP taxonomischer Rang Elite', 'Elite Notfalleinrichtungen EP', 'Chase taxonomischer Rang Portland', 'Portland Notfalleinrichtungen Chase', 'Worcester taxonomischer Rang Eliza', 'Eliza Notfalleinrichtungen Worcester', 'Albert taxonomischer Rang Weir', 'Weir Notfalleinrichtungen Albert', 'Ibiza taxonomischer Rang Antoine', 'Antoine Notfalleinrichtungen Ibiza', 'Câmara taxonomischer Rang Universitas', 'Universitas Notfalleinrichtungen Câmara', 'Eleanor taxonomischer Rang Collins', 'Collins Notfal

In [16]:
test_dict['sample'][:1901]

['SM taxon rank Tito',
 'Sempre taxon rank Jubilee',
 'Portsmouth taxon rank Nassau',
 'Irena taxon rank Nexus',
 'Oaks taxon rank Hiroshima',
 'Fighting taxon rank Jenny',
 'Bilbao taxon rank Cardoso',
 'Sharon taxon rank Rusi',
 'Allium taxon rank Silvio',
 'Haiti taxon rank Scots',
 'Isabella taxon rank Helmut',
 'Olsson taxon rank Novel',
 'Pure taxon rank KM',
 'Patrol taxon rank Industria',
 'Siege taxon rank ap',
 'Galiza taxon rank Frères',
 'Lucky taxon rank RN',
 'Pfeiffer taxon rank Gilbert',
 'FX taxon rank Bachelor',
 'Energy taxon rank GNU',
 'Joanne taxon rank UE',
 'Erik taxon rank Hamm',
 'EM taxon rank Rua',
 'Raum taxon rank Ike',
 'VL taxon rank Velvet',
 'Sabina taxon rank Levant',
 'Krupp taxon rank Spin',
 'October taxon rank Franz',
 'Davidson taxon rank SAR',
 'Allah taxon rank Taurus',
 'Larva taxon rank Denne',
 'Khu taxon rank Pizarro',
 'Lahore taxon rank Sunrise',
 'Dorset taxon rank Burlington',
 'Stay taxon rank Invasion',
 'Sit taxon rank Wiesbaden',
 '

#### -> (e, r_de, f) vs (f, s, e) (+ (f, s_de, e)

Evaluate if for (e, r, f) we know more often (e, s_de, f) or (f, s, e), i.e. Knowledge Transfer vs inversion rule.
This can also help us understand which way we get (f, r_de, e).

Since when we train on (e, r_de, f), we rarely get (f, s_de, e), it already implies that we would go the way:
(e, r, f) -RULE-> (f, s, e) -KT-> (f, s_de, e)

1800 facts are training the rule (900<->900)
1800-1900 are facts that are used for testing

In [17]:
def compute_overlap(a, b):
    a_multiset = Counter(a)
    b_multiset = Counter(b)

    overlap = list((a_multiset & b_multiset).elements())
    
    return overlap

In [18]:
# Iterate over relations, take the training samples that were trained on
for i in range(n_relations):
    trained_test = train_dict['sample'][1800+i*1900:(i+1)*1900]

    acc_s = 0
    correct_entities_s = []
    
    acc_rde = 0
    correct_entities_rde = []
    
    acc_sde = 0
    correct_entities_sde = []
    
    # Relation pairs!
    r = relations[0]['de'].iloc[i]
    r_t = relations[0]['en'].iloc[i]
    s = relations[1]['de'].iloc[i]
    s_t = relations[1]['en'].iloc[i]

    for sample in trained_test:

        # Test (f, r, e)
        f = sample.rsplit(' ', 1)[1] 
        e = sample.split(' ', 1)[0]

        label_token = tokenizer.convert_tokens_to_ids(e)

        prompt = f + ' ' + s + ' [MASK]'

        encoded_input = tokenizer(prompt, return_tensors='pt')
        token_logits = model(**encoded_input).logits

        mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
        mask_token_logits = token_logits[0, mask_token_index, :]

        top_1_token = torch.topk(mask_token_logits, 1, dim=1).indices[0].tolist()[0]

        if label_token == top_1_token:
            acc_s += 1
            correct_entities_s.append(e)

        # Test (e, r_de, f)
        label_token = tokenizer.convert_tokens_to_ids(f)

        prompt = e + ' ' + r_t + ' [MASK]'

        encoded_input = tokenizer(prompt, return_tensors='pt')
        token_logits = model(**encoded_input).logits

        mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
        mask_token_logits = token_logits[0, mask_token_index, :]

        top_1_token = torch.topk(mask_token_logits, 1, dim=1).indices[0].tolist()[0]

        if label_token == top_1_token:
            acc_rde += 1
            correct_entities_rde.append(e)
            
        # Test (f, r_de, e)
        label_token = tokenizer.convert_tokens_to_ids(e)

        prompt = f + ' ' + s_t + ' [MASK]'
        # print(prompt)

        encoded_input = tokenizer(prompt, return_tensors='pt')
        token_logits = model(**encoded_input).logits

        mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
        mask_token_logits = token_logits[0, mask_token_index, :]

        # Pick the [MASK] candidates with the highest logits
        top_1_token = torch.topk(mask_token_logits, 1, dim=1).indices[0].tolist()[0]

        if label_token == top_1_token:
            acc_sde += 1
            correct_entities_sde.append(e)
        

    acc_s /= 100
    acc_rde /= 100
    acc_sde /= 100

    print(f'Relation: {r}')
    print(f'Accuracy for (f, s, e): {acc_s}')
    print(f'Accuracy for (e, r_t, f): {acc_rde}')
    print(f'Accuracy for (f, s_t, e): {acc_sde}')
    print(f'Size (f, r, e): {len(correct_entities_s)}')
    print(f'Size (e, r_t, f): {len(correct_entities_rde)}')
    print(f'Overlap between (f, s, e) and (e, r_t, f): {len(compute_overlap(correct_entities_s, correct_entities_rde))}')
    if len(correct_entities_rde) == 0:
        print(f'Transfer from (e, r_t, f) to (f, s_t, e): {0}')
    else:
        print(f'Transfer from (e, r_t, f) to (f, s_t, e): {len(compute_overlap(correct_entities_rde, correct_entities_sde))/len(correct_entities_rde)}')
    
    if len(correct_entities_s) == 0:
        print(f'Transfer from (f, s, e) to (f, s_t, e): {0}')
    else:
        print(f'Transfer from (f, s, e) to (f, s_t, e): {len(compute_overlap(correct_entities_s, correct_entities_sde))/len(correct_entities_s)}')
    print('\n')

Relation: taxonomischer Rang
Accuracy for (f, s, e): 0.35
Accuracy for (e, r_t, f): 0.27
Accuracy for (f, s_t, e): 0.17
Size (f, r, e): 35
Size (e, r_t, f): 27
Overlap between (f, s, e) and (e, r_t, f): 6
Transfer from (e, r_t, f) to (f, s_t, e): 0.14814814814814814
Transfer from (f, s, e) to (f, s_t, e): 0.34285714285714286


Relation: Farbe
Accuracy for (f, s, e): 0.39
Accuracy for (e, r_t, f): 0.57
Accuracy for (f, s_t, e): 0.14
Size (f, r, e): 39
Size (e, r_t, f): 57
Overlap between (f, s, e) and (e, r_t, f): 37
Transfer from (e, r_t, f) to (f, s_t, e): 0.19298245614035087
Transfer from (f, s, e) to (f, s_t, e): 0.2564102564102564


Relation: gemessene physikalische Größe
Accuracy for (f, s, e): 0.39
Accuracy for (e, r_t, f): 0.11
Accuracy for (f, s_t, e): 0.07
Size (f, r, e): 39
Size (e, r_t, f): 11
Overlap between (f, s, e) and (e, r_t, f): 0
Transfer from (e, r_t, f) to (f, s_t, e): 0.0
Transfer from (f, s, e) to (f, s_t, e): 0.15384615384615385


Relation: Plattform
Accuracy fo

#### -> does inversion overgeneralize?
For (e, r, f ) in train, it may predict (f, s, e) (correct), but also (e, s, f ) and (f, r, e) (incorrect).
and (e, s_de, f) and (f, r_de, e)


In [None]:
# Iterate over relations, take the training samples that were trained on
for i in range(n_relations):
    trained_test = train_dict['sample'][1800+i*1900:(i+1)*1900]

    # False positives, i.e. higher -> the more overgeneralized the model has
    acc_s = 0
    acc_r = 0
    
    acc_rde = 0
    acc_sde = 0
    
    # Relation pairs!
    r = relations[0]['de'].iloc[i]
    r_de = relations[0]['de'].iloc[i]
    s = relations[1]['en'].iloc[i]
    s_de = relations[1]['de'].iloc[i]

    for sample in trained_test:

        # Take entities
        e = sample.split(' ', 1)[0]
        f = sample.rsplit(' ', 1)[1]

        # (e, s, f)
        label_token = tokenizer.convert_tokens_to_ids(f)
        prompt = e + ' ' + s + ' [MASK]'

        encoded_input = tokenizer(prompt, return_tensors='pt')
        token_logits = model(**encoded_input).logits

        mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
        mask_token_logits = token_logits[0, mask_token_index, :]

        top_1_token = torch.topk(mask_token_logits, 1, dim=1).indices[0].tolist()[0]

        if label_token == top_1_token:
            acc_s += 1

        # (f, r, e)
        label_token = tokenizer.convert_tokens_to_ids(e)

        prompt = f + ' ' + r + ' [MASK]'

        encoded_input = tokenizer(prompt, return_tensors='pt')
        token_logits = model(**encoded_input).logits

        mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
        mask_token_logits = token_logits[0, mask_token_index, :]

        top_1_token = torch.topk(mask_token_logits, 1, dim=1).indices[0].tolist()[0]

        if label_token == top_1_token:
            acc_r += 1
            
        # (e, s_de, f)
        label_token = tokenizer.convert_tokens_to_ids(f)

        prompt = e + ' ' + s_de + ' [MASK]'

        encoded_input = tokenizer(prompt, return_tensors='pt')
        token_logits = model(**encoded_input).logits

        mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
        mask_token_logits = token_logits[0, mask_token_index, :]

        # Pick the [MASK] candidates with the highest logits
        top_1_token = torch.topk(mask_token_logits, 1, dim=1).indices[0].tolist()[0]

        if label_token == top_1_token:
            acc_sde += 1
        
        # (f, r_de, e)
        label_token = tokenizer.convert_tokens_to_ids(e)

        prompt = f + ' ' + r_de + ' [MASK]'

        encoded_input = tokenizer(prompt, return_tensors='pt')
        token_logits = model(**encoded_input).logits

        mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
        mask_token_logits = token_logits[0, mask_token_index, :]

        # Pick the [MASK] candidates with the highest logits
        top_1_token = torch.topk(mask_token_logits, 1, dim=1).indices[0].tolist()[0]

        if label_token == top_1_token:
            acc_rde += 1
        

    acc_r /= 100
    acc_rde /= 100
    acc_s /= 100
    acc_sde /= 100

    print(f'Relation: {r}')
    print(f'Accuracy for (e, s, f): {acc_s}')
    print(f'Accuracy for (f, r, e): {acc_r}')
    print(f'Accuracy for (e, s_de, f): {acc_sde}')
    print(f'Accuracy for (f, r_de, e): {acc_rde}')
    print('\n')

### Manual

In [19]:
k = 0
total = len(train_dict['sample'])
i = 0

for txt in train_dict['sample'][:10000]:
    i += 1
    
    # Add [MASK] for object
    sample = txt.rsplit(' ', 1)[0] + ' [MASK]'
    label_token = tokenizer.convert_tokens_to_ids(txt.rsplit(' ', 1)[1])
    
    encoded_input = tokenizer(sample, return_tensors='pt')
    token_logits = model(**encoded_input).logits
    
    mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
    mask_token_logits = token_logits[0, mask_token_index, :]
    
    # Pick the [MASK] candidates with the highest logits
    top_5_tokens = torch.topk(mask_token_logits, 1, dim=1).indices[0].tolist()
    
    if label_token in top_5_tokens:
        k += 1
print(k/i)

0.9582


In [None]:
text = "lens manner of [MASK]"
encoded_input = tokenizer(text, return_tensors='pt')
token_logits = model(**encoded_input).logits

mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]

# Pick the [MASK] candidates with the highest logits
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for chunk in top_5_tokens:
    print(f"\n>>> {tokenizer.decode([chunk])}")

In [None]:
for t in train_dict['sample']:
    if 'Alex' in t:
        print(t)

### Results
