In [1]:
import argparse
import copy

from transformers import BertForMaskedLM, BertTokenizer, TrainingArguments, Trainer, \
    DataCollatorForLanguageModeling, IntervalStrategy

from datasets import Dataset
import os

from data_generation_relation import *
from utils import *
from custom_trainer import CustomTrainer
from datasets import load_metric
import logging
from transformers import logging as tlogging
import wandb
import sys
from utils import set_seed
from transformers.integrations import WandbCallback, TensorBoardCallback
from tqdm.notebook import tqdm
from collections import Counter

os.environ["WANDB_DISABLED"] = "true"

In [2]:
set_seed(42)

run_name = 'SYM_en_de'
epochs = 200
batch_size = 200
lr = 5e-5

relation = 'symmetry'
source_language = ['en']
target_language = ['de']
n_relations = 10
n_facts = 1000

use_random = False
use_anti = False

use_pretrained = False
use_target = False

In [3]:
train, test, relations = generate_reasoning(relation=Relation(relation),
                                            source_language=source_language,
                                            target_language=target_language,
                                            n_relations=n_relations,
                                            n_facts=n_facts,
                                            use_pretrained=use_pretrained,
                                            use_target=use_target,
                                            use_enhanced=False,
                                            use_same_relations=False,
                                            n_pairs=0)

relations

Unnamed: 0,id,en,de,es,fr,count
694,P105,taxon rank,taxonomischer Rang,categoría taxonómica,rang taxinomique,3580266
528,P6855,emergency services,Notfalleinrichtungen,servicios de emergencia,accueil et traitement des urgences,766
598,P462,color,Farbe,color,couleur,194389
606,P2429,expected completeness,erwartete Vollständigkeit,grado de completitud,degré de complétude,3826
120,P111,measured physical quantity,gemessene physikalische Größe,cantidad física medida,grandeur physique mesurée,3610


In [4]:
relations_random = []

if use_random:
    # Generate half/half
    factor = 1.0
    n_random = factor * n_facts

    train_random, relations_random = generate_random(source_language, target_language, n_random, n_relations)
    train += train_random

relations_random

[]

In [5]:
relations_anti = []
if relation == 'symmetry' and use_anti:
    train_anti, test_anti, relations_anti = generate_anti(relations_symmetric=relations,
                                                          source_lang=source_language,
                                                          target_lang=target_language,
                                                          n_relations=n_relations,
                                                          n_facts=n_facts)
    train += train_anti
    
relations_anti

[]

In [25]:
# LOADING
# Load mBERT model and Tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
model = BertForMaskedLM.from_pretrained("bert-base-multilingual-cased")

# Load Data Collator for Prediction and Evaluation
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)
eval_data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

loading file https://huggingface.co/bert-base-multilingual-cased/resolve/main/vocab.txt from cache at /home/laurin/.cache/huggingface/transformers/eff018e45de5364a8368df1f2df3461d506e2a111e9dd50af1fae061cd460ead.6c5b6600e968f4b5e08c86d8891ea99e51537fc2bf251435fb46922e8f7a7b29
loading file https://huggingface.co/bert-base-multilingual-cased/resolve/main/added_tokens.json from cache at None
loading file https://huggingface.co/bert-base-multilingual-cased/resolve/main/special_tokens_map.json from cache at None
loading file https://huggingface.co/bert-base-multilingual-cased/resolve/main/tokenizer_config.json from cache at /home/laurin/.cache/huggingface/transformers/f55e7a2ad4f8d0fff2733b3f79777e1e99247f2e4583703e92ce74453af8c235.ec5c189f89475aac7d8cbd243960a0655cfadc3d0474da8ff2ed0bf1699c2a5f
loading configuration file https://huggingface.co/bert-base-multilingual-cased/resolve/main/config.json from cache at /home/laurin/.cache/huggingface/transformers/6c4a5d81a58c9791cdf76a09bce1b5abfb9

In [26]:
# ~~ PRE-PROCESSING ~~
train_dict = {'sample': train}
test_dict = {'sample': flatten_dict2_list(copy.deepcopy(test))}
train_ds = Dataset.from_dict(train_dict)
test_ds = Dataset.from_dict(test_dict)

# Tokenize Training and Test Data
tokenized_train = tokenize(tokenizer, train_ds)  # Train is shuffled by Huggingface
tokenized_test = tokenize(tokenizer, test_ds)

  0%|          | 0/10 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [27]:
# Save Train and Test Data
train_df = pd.DataFrame(train_dict)
test_complete_df = pd.DataFrame(test)
test_flat_df = pd.DataFrame(test_dict)

data_dir = './output/' + run_name + '/data/'
if not os.path.exists(data_dir):
    os.makedirs(data_dir)

train_df.to_csv(data_dir + 'train_set', index=False)
test_complete_df.to_json(data_dir + 'test_set_complete')
test_flat_df.to_csv(data_dir + 'test_set', index=False)

if use_random:
    train_random_df = pd.DataFrame({'sample': train_random})
    train_random_df.to_csv(data_dir + 'train_random', index=False)

if use_anti:
    train_anti_df = pd.DataFrame({'sample': train_anti})
    test_anti_df = pd.DataFrame({'sample': test_anti})

    train_anti_df.to_csv(data_dir + 'train_anti_set', index=False)
    test_anti_df.to_json(data_dir + 'test_anti_set')

In [28]:
training_args = TrainingArguments(
        output_dir='./output/' + run_name + '/models/',
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=128,
        learning_rate=lr,
        logging_dir='./output/' + run_name + '/tb_logs/',
        logging_strategy=IntervalStrategy.EPOCH,
        evaluation_strategy=IntervalStrategy.EPOCH,
        save_strategy=IntervalStrategy.EPOCH,
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model='accuracy',
        seed=42
    )

trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    tokenizer=tokenizer,
    data_collator=data_collator,
    eval_data_collator=eval_data_collator,
    compute_metrics=precision_at_one
)


PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
Using the `WAND_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [29]:
# Train
trainer.train()

***** Running training *****
  Num examples = 9500
  Num Epochs = 250
  Instantaneous batch size per device = 200
  Total train batch size (w. parallel, distributed & accumulation) = 400
  Gradient Accumulation steps = 1
  Total optimization steps = 6000


Epoch,Training Loss,Validation Loss,Accuracy
1,3.29,9.479519,0.0
2,2.5194,8.962927,0.0
3,2.4295,8.656976,0.0
4,2.4045,8.506996,0.0
5,2.2991,8.332901,0.0
6,2.299,8.271156,0.0
7,2.3158,8.220499,0.0
8,2.2671,8.121373,0.0
9,2.2279,8.168249,0.0
10,2.1997,8.083574,0.0


Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-24
Configuration saved in ./output/SYM_de_en/models/checkpoint-24/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-24/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-24/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-24/special_tokens_map.json
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-48
Configuration saved in ./output/SYM_de_en/models/checkpoint-48/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-48/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-48/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-48/special_tokens_map.json
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-72
Configuration saved in ./output/SYM_de_en/models/checkpoint-72/config.json
Model weights saved in ./output/SYM_de_e

Deleting older checkpoint [output/SYM_de_en/models/checkpoint-240] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-288
Configuration saved in ./output/SYM_de_en/models/checkpoint-288/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-288/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-288/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-288/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-264] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-312
Configuration saved in ./output/SYM_de_en/models/checkpoint-312/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-312/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-312/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-312/spec

Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-528
Configuration saved in ./output/SYM_de_en/models/checkpoint-528/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-528/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-528/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-528/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-480] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-552
Configuration saved in ./output/SYM_de_en/models/checkpoint-552/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-552/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-552/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-552/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-504] due to ar

Model weights saved in ./output/SYM_de_en/models/checkpoint-768/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-768/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-768/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-720] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-792
Configuration saved in ./output/SYM_de_en/models/checkpoint-792/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-792/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-792/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-792/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-744] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-816
Configuration saved in ./output/SYM_de_en/models/checkpo

Special tokens file saved in ./output/SYM_de_en/models/checkpoint-1008/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-960] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-1032
Configuration saved in ./output/SYM_de_en/models/checkpoint-1032/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-1032/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-1032/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-1032/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-984] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-1056
Configuration saved in ./output/SYM_de_en/models/checkpoint-1056/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-1056/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoi

Deleting older checkpoint [output/SYM_de_en/models/checkpoint-1200] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-1272
Configuration saved in ./output/SYM_de_en/models/checkpoint-1272/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-1272/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-1272/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-1272/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-1248] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-1296
Configuration saved in ./output/SYM_de_en/models/checkpoint-1296/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-1296/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-1296/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoi

Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-1512
Configuration saved in ./output/SYM_de_en/models/checkpoint-1512/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-1512/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-1512/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-1512/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-1464] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-1536
Configuration saved in ./output/SYM_de_en/models/checkpoint-1536/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-1536/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-1536/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-1536/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-148

Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-1752
Configuration saved in ./output/SYM_de_en/models/checkpoint-1752/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-1752/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-1752/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-1752/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-1704] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-1776
Configuration saved in ./output/SYM_de_en/models/checkpoint-1776/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-1776/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-1776/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-1776/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-172

Configuration saved in ./output/SYM_de_en/models/checkpoint-1992/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-1992/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-1992/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-1992/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-1944] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-2016
Configuration saved in ./output/SYM_de_en/models/checkpoint-2016/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-2016/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-2016/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-2016/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-1968] due to args.save_total_limit
Saving model checkpoint to ./output/S

Model weights saved in ./output/SYM_de_en/models/checkpoint-2232/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-2232/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-2232/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-2184] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-2256
Configuration saved in ./output/SYM_de_en/models/checkpoint-2256/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-2256/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-2256/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-2256/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-2208] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-2280
Configuration saved in ./output/SYM_de_en/mod

tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-2472/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-2472/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-2448] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-2496
Configuration saved in ./output/SYM_de_en/models/checkpoint-2496/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-2496/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-2496/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-2496/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-2472] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-2520
Configuration saved in ./output/SYM_de_en/models/checkpoint-2520/config.json
Model weights saved in ./output/SYM_de_en/models/ch

Special tokens file saved in ./output/SYM_de_en/models/checkpoint-2712/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-2688] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-2736
Configuration saved in ./output/SYM_de_en/models/checkpoint-2736/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-2736/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-2736/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-2736/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-2712] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-2760
Configuration saved in ./output/SYM_de_en/models/checkpoint-2760/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-2760/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkp

Deleting older checkpoint [output/SYM_de_en/models/checkpoint-2928] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-2976
Configuration saved in ./output/SYM_de_en/models/checkpoint-2976/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-2976/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-2976/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-2976/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-2952] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-3000
Configuration saved in ./output/SYM_de_en/models/checkpoint-3000/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-3000/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-3000/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoi

Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-3216
Configuration saved in ./output/SYM_de_en/models/checkpoint-3216/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-3216/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-3216/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-3216/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-3192] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-3240
Configuration saved in ./output/SYM_de_en/models/checkpoint-3240/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-3240/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-3240/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-3240/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-321

Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-3456
Configuration saved in ./output/SYM_de_en/models/checkpoint-3456/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-3456/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-3456/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-3456/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-3432] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-3480
Configuration saved in ./output/SYM_de_en/models/checkpoint-3480/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-3480/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-3480/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-3480/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-345

Configuration saved in ./output/SYM_de_en/models/checkpoint-3696/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-3696/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-3696/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-3696/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-3672] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-3720
Configuration saved in ./output/SYM_de_en/models/checkpoint-3720/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-3720/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-3720/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-3720/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-3696] due to args.save_total_limit
Saving model checkpoint to ./output/S

Model weights saved in ./output/SYM_de_en/models/checkpoint-3936/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-3936/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-3936/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-3912] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-3960
Configuration saved in ./output/SYM_de_en/models/checkpoint-3960/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-3960/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-3960/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-3960/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-3936] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-3984
Configuration saved in ./output/SYM_de_en/mod

tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-4176/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-4176/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-4152] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-4200
Configuration saved in ./output/SYM_de_en/models/checkpoint-4200/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-4200/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-4200/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-4200/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-4176] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-4224
Configuration saved in ./output/SYM_de_en/models/checkpoint-4224/config.json
Model weights saved in ./output/SYM_de_en/models/ch

Special tokens file saved in ./output/SYM_de_en/models/checkpoint-4416/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-4392] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-4440
Configuration saved in ./output/SYM_de_en/models/checkpoint-4440/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-4440/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-4440/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-4440/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-4416] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-4464
Configuration saved in ./output/SYM_de_en/models/checkpoint-4464/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-4464/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkp

Deleting older checkpoint [output/SYM_de_en/models/checkpoint-4608] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-4680
Configuration saved in ./output/SYM_de_en/models/checkpoint-4680/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-4680/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-4680/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-4680/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-4656] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-4704
Configuration saved in ./output/SYM_de_en/models/checkpoint-4704/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-4704/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-4704/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoi

Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-4920
Configuration saved in ./output/SYM_de_en/models/checkpoint-4920/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-4920/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-4920/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-4920/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-4896] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-4944
Configuration saved in ./output/SYM_de_en/models/checkpoint-4944/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-4944/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-4944/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-4944/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-492

Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-5160
Configuration saved in ./output/SYM_de_en/models/checkpoint-5160/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-5160/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-5160/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-5160/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-5136] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-5184
Configuration saved in ./output/SYM_de_en/models/checkpoint-5184/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-5184/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-5184/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-5184/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-516

Configuration saved in ./output/SYM_de_en/models/checkpoint-5400/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-5400/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-5400/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-5400/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-5376] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-5424
Configuration saved in ./output/SYM_de_en/models/checkpoint-5424/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-5424/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-5424/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-5424/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-5400] due to args.save_total_limit
Saving model checkpoint to ./output/S

Model weights saved in ./output/SYM_de_en/models/checkpoint-5640/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-5640/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-5640/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-5616] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-5664
Configuration saved in ./output/SYM_de_en/models/checkpoint-5664/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-5664/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-5664/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-5664/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-5640] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-5688
Configuration saved in ./output/SYM_de_en/mod

tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-5880/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-5880/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-5856] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-5904
Configuration saved in ./output/SYM_de_en/models/checkpoint-5904/config.json
Model weights saved in ./output/SYM_de_en/models/checkpoint-5904/pytorch_model.bin
tokenizer config file saved in ./output/SYM_de_en/models/checkpoint-5904/tokenizer_config.json
Special tokens file saved in ./output/SYM_de_en/models/checkpoint-5904/special_tokens_map.json
Deleting older checkpoint [output/SYM_de_en/models/checkpoint-5880] due to args.save_total_limit
Saving model checkpoint to ./output/SYM_de_en/models/checkpoint-5928
Configuration saved in ./output/SYM_de_en/models/checkpoint-5928/config.json
Model weights saved in ./output/SYM_de_en/models/ch

TrainOutput(global_step=6000, training_loss=0.7537971583207448, metrics={'train_runtime': 6682.0718, 'train_samples_per_second': 355.429, 'train_steps_per_second': 0.898, 'total_flos': 1.4666272425e+16, 'train_loss': 0.7537971583207448, 'epoch': 250.0})

In [30]:
# Evaluate Test
trainer.evaluate(eval_dataset=tokenized_test)



{'eval_accuracy': 0.588,
 'eval_loss': 1.763461947441101,
 'eval_runtime': 0.9676,
 'eval_samples_per_second': 516.753,
 'eval_steps_per_second': 2.067,
 'epoch': 250.0}

In [31]:
# Evaluation Symmetry per Relation
evaluation_symmetry(trainer, tokenizer, relations, source_language, copy.deepcopy(test))

Relation - source: taxonomischer Rang, target: taxon rank


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.63, 'eval_loss': 1.3040210008621216, 'eval_runtime': 0.6435, 'eval_samples_per_second': 155.395, 'eval_steps_per_second': 1.554}
Relation - source: Notfalleinrichtungen, target: emergency services


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.6, 'eval_loss': 1.5930802822113037, 'eval_runtime': 0.6238, 'eval_samples_per_second': 160.301, 'eval_steps_per_second': 1.603}
Relation - source: Farbe, target: color


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.71, 'eval_loss': 0.9943457841873169, 'eval_runtime': 0.6474, 'eval_samples_per_second': 154.456, 'eval_steps_per_second': 1.545}
Relation - source: erwartete Vollständigkeit, target: expected completeness


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.45, 'eval_loss': 2.8891947269439697, 'eval_runtime': 0.6533, 'eval_samples_per_second': 153.069, 'eval_steps_per_second': 1.531}
Relation - source: gemessene physikalische Größe, target: measured physical quantity


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.55, 'eval_loss': 2.0366668701171875, 'eval_runtime': 0.6294, 'eval_samples_per_second': 158.877, 'eval_steps_per_second': 1.589}


In [32]:
if use_anti:
    evaluation_symmetry(trainer, tokenizer, relations_anti, source_language, copy.deepcopy(test_anti))

### Evaluate
- Test my hypothesis if (f, r, e) or (e, r_de, f) exist more?
- Is every relation symmetric now? What about relations that aren't part of the training?
- If every relation is symmetric, try running with ANTI
- And with General relations
- Try Training with General and then evaluate general like on Anti!
- Does that change the evaluation accuracy?
- pretrained?
- target?

In [33]:
model.to('cpu')
model.eval()

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=T

In [34]:
print(train_dict['sample'][:1901])

['Medina taxonomischer Rang Italie', 'Italie taxonomischer Rang Medina', 'Invasion taxonomischer Rang Bora', 'Bora taxonomischer Rang Invasion', 'Burke taxonomischer Rang Hus', 'Hus taxonomischer Rang Burke', 'Drama taxonomischer Rang epi', 'epi taxonomischer Rang Drama', 'Master taxonomischer Rang Wilfried', 'Wilfried taxonomischer Rang Master', 'Dari taxonomischer Rang Fach', 'Fach taxonomischer Rang Dari', 'Chihuahua taxonomischer Rang Inge', 'Inge taxonomischer Rang Chihuahua', 'EP taxonomischer Rang Elite', 'Elite taxonomischer Rang EP', 'Chase taxonomischer Rang Portland', 'Portland taxonomischer Rang Chase', 'Worcester taxonomischer Rang Eliza', 'Eliza taxonomischer Rang Worcester', 'Albert taxonomischer Rang Weir', 'Weir taxonomischer Rang Albert', 'Ibiza taxonomischer Rang Antoine', 'Antoine taxonomischer Rang Ibiza', 'Câmara taxonomischer Rang Universitas', 'Universitas taxonomischer Rang Câmara', 'Eleanor taxonomischer Rang Collins', 'Collins taxonomischer Rang Eleanor', 'Ca

In [35]:
test_dict['sample']

['Raymond taxon rank Haji',
 'Chinese taxon rank Yahoo',
 'West taxon rank Stal',
 'Rhode taxon rank FC',
 'Libro taxon rank Dad',
 'Weaver taxon rank Kenia',
 'Limited taxon rank CCD',
 'NO taxon rank Riau',
 'Frères taxon rank Ky',
 'Li taxon rank Billie',
 'Pie taxon rank Elbe',
 'DSM taxon rank Paraíso',
 'Björn taxon rank TD',
 'Elsevier taxon rank Luther',
 'Isabel taxon rank Roi',
 'Valence taxon rank Pole',
 'Townsend taxon rank Page',
 'Levant taxon rank Baron',
 'Khan taxon rank Libia',
 'Ward taxon rank Cuenca',
 'Valladolid taxon rank Kálmán',
 'Kristen taxon rank ET',
 'Allende taxon rank Mainstream',
 'Malden taxon rank Agency',
 'Ekim taxon rank Mata',
 'Norris taxon rank Mineral',
 'Entangled taxon rank Figaro',
 'Nico taxon rank Trung',
 'NME taxon rank Sabha',
 'Christi taxon rank Guimarães',
 'Laurel taxon rank Disneyland',
 'Hammer taxon rank Wes',
 'Desse taxon rank Gesù',
 'Albany taxon rank Cinq',
 'Hollow taxon rank Silla',
 'Music taxon rank JR',
 'Munro taxon 

#### -> Test my hypothesis if (f, r, e) or (e, r_de, f) exist more?

Evaluate if for (e, r, f) we know more often (e, r_de, f) or (f, r, e), i.e. Knowledge Transfer vs symmetric rule.
This can also help us understand which way we get (f, r_de, e).

Since when we train on (e, r_de, f), we rarely get (f, r_de, e), it already implies that we would go the way:
(e, r, f) -RULE-> (f, r, e) -KT-> (f, r_de, e)

1800 facts are training the rule (900<->900)
1800-1900 are facts that are used for testing

In [36]:
def compute_overlap(a, b):
    a_multiset = Counter(a)
    b_multiset = Counter(b)

    overlap = list((a_multiset & b_multiset).elements())
    
    return overlap

In [40]:
# Iterate over relations, take the training samples that were trained on
for i in range(n_relations):
    trained_test = train_dict['sample'][1800+i*1900:(i+1)*1900]

    acc_r = 0
    correct_entities_r = []
    
    acc_rde = 0
    correct_entities_rde = []
    
    acc_test = 0
    correct_entities_test = []
    
    r = relations['de'].iloc[i]
    r_t = relations['en'].iloc[i]

    for sample in trained_test:

        # Test (f, r, e)
        f = sample.rsplit(' ', 1)[1] 
        e = sample.split(' ', 1)[0]

        label_token = tokenizer.convert_tokens_to_ids(e)

        prompt = f + ' ' + r + ' [MASK]'
        # print(prompt)

        encoded_input = tokenizer(prompt, return_tensors='pt')
        token_logits = model(**encoded_input).logits

        mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
        mask_token_logits = token_logits[0, mask_token_index, :]

        # Pick the [MASK] candidates with the highest logits
        top_1_token = torch.topk(mask_token_logits, 1, dim=1).indices[0].tolist()[0]

        if label_token == top_1_token:
            acc_r += 1
            correct_entities_r.append(e)

        # Test (e, r_de, f)
        label_token = tokenizer.convert_tokens_to_ids(f)

        prompt = e + ' ' + r_t + ' [MASK]'
        # print(prompt)

        encoded_input = tokenizer(prompt, return_tensors='pt')
        token_logits = model(**encoded_input).logits

        mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
        mask_token_logits = token_logits[0, mask_token_index, :]

        # Pick the [MASK] candidates with the highest logits
        top_1_token = torch.topk(mask_token_logits, 1, dim=1).indices[0].tolist()[0]

        if label_token == top_1_token:
            acc_rde += 1
            correct_entities_rde.append(f)
            
        # Test (f, r_de, e)
        label_token = tokenizer.convert_tokens_to_ids(e)

        prompt = f + ' ' + r_t + ' [MASK]'
        # print(prompt)

        encoded_input = tokenizer(prompt, return_tensors='pt')
        token_logits = model(**encoded_input).logits

        mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
        mask_token_logits = token_logits[0, mask_token_index, :]

        # Pick the [MASK] candidates with the highest logits
        top_1_token = torch.topk(mask_token_logits, 1, dim=1).indices[0].tolist()[0]

        if label_token == top_1_token:
            acc_test += 1
            correct_entities_test.append(e)
        

    acc_r /= 100
    acc_rde /= 100
    acc_test /= 100

    print(f'Relation: {r}')
    print(f'Relation Target: {r_t}')
    print(f'Accuracy for (f, r, e): {acc_r}')
    print(f'Accuracy for (e, r_t, f): {acc_rde}')
    print(f'Accuracy for (f, r_t, e): {acc_test}')
    print(f'Size (f, r, e): {len(correct_entities_r)}')
    print(f'Size (e, r_t, f): {len(correct_entities_rde)}')
    print(f'Overlap between (f, r, e) and (e, r_t, f): {len(compute_overlap(correct_entities_r, correct_entities_rde))}')
    if len(correct_entities_rde) == 0:
        print(f'Transfer from (e, r_t, f) to (f, r_t, e): {0}')
    else:
        print(f'Transfer from (e, r_t, f) to (f, r_t, e): {len(compute_overlap(correct_entities_rde, correct_entities_test))/len(correct_entities_rde)}')
    
    if len(correct_entities_r) == 0:
        print(f'Transfer from (f, r, e) to (f, r_t, e): {0}')
    else:
        print(f'Transfer from (f, r, e) to (f, r_t, e): {len(compute_overlap(correct_entities_r, correct_entities_test))/len(correct_entities_r)}')
    print('')

Relation: taxonomischer Rang
Relation Target: taxon rank
Accuracy for (f, r, e): 0.73
Accuracy for (e, r_t, f): 0.62
Accuracy for (f, r_t, e): 0.63
Size (f, r, e): 73
Size (e, r_t, f): 62
Overlap between (f, r, e) and (e, r_t, f): 49
Transfer from (e, r_t, f) to (f, r_t, e): 0.6774193548387096
Transfer from (f, r, e) to (f, r_t, e): 0.8356164383561644

Relation: Notfalleinrichtungen
Relation Target: emergency services
Accuracy for (f, r, e): 0.81
Accuracy for (e, r_t, f): 0.21
Accuracy for (f, r_t, e): 0.6
Size (f, r, e): 81
Size (e, r_t, f): 21
Overlap between (f, r, e) and (e, r_t, f): 17
Transfer from (e, r_t, f) to (f, r_t, e): 0.6190476190476191
Transfer from (f, r, e) to (f, r_t, e): 0.7283950617283951

Relation: Farbe
Relation Target: color
Accuracy for (f, r, e): 0.71
Accuracy for (e, r_t, f): 0.9
Accuracy for (f, r_t, e): 0.71
Size (f, r, e): 71
Size (e, r_t, f): 90
Overlap between (f, r, e) and (e, r_t, f): 63
Transfer from (e, r_t, f) to (f, r_t, e): 0.7
Transfer from (f, r,

#### -> Is every relation symmetric now? What about relations that aren't part of the training?
For this sample n_relations from general and 100 entities and test them if what they predict in one direction, they also predict in the other. This was also quite flawed in Symbolic Reasoner. Here they didn't finetune

- What about training with general and then testing on them if they are symmetric?

**Are General Relations (aka generate_random) symmetric?**

In [22]:
i = 0
for idx, rel in relations_random.iterrows():
    print(f"Relation: {rel['en']}")
    
    trained_test = train_random[i*1000:(i+1)*1000]

    acc = 0
    sym = 0

    for sample in trained_test:
                
        # (e, r, f)
        e = sample.split(' ', 1)[0]
        f = sample.rsplit(' ', 1)[1] 
        
        label_token = tokenizer.convert_tokens_to_ids(f)

        # Use this single token entity to get a pair
        prompt = e + ' ' + rel['en'] + ' [MASK]'

        encoded_input = tokenizer(prompt, return_tensors='pt')
        token_logits = model(**encoded_input).logits

        mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
        mask_token_logits = token_logits[0, mask_token_index, :]
        top_1_token = torch.topk(mask_token_logits, 1, dim=1).indices[0].tolist()
        
        if label_token in top_1_token:
            acc += 1

        # Check if the pair is symmetry
        label_token = tokenizer.convert_tokens_to_ids(e)

        # Use this single token entity to get a pair
        prompt = f + ' ' + rel['en'] + ' [MASK]'

        encoded_input = tokenizer(prompt, return_tensors='pt')
        token_logits = model(**encoded_input).logits

        mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
        mask_token_logits = token_logits[0, mask_token_index, :]
        top_1_token = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
        
        if label_token in top_1_token:
            sym += 1
            
    i += 1

    print(f'Accuracy: {acc/1000}')
    print(f'Percentage of Symmetry: {sym/1000}')

Relation: manner of death
Accuracy: 0.882
Percentage of Symmetry: 0.017
Relation: final event
Accuracy: 0.82
Percentage of Symmetry: 0.023
Relation: birthday
Accuracy: 0.904
Percentage of Symmetry: 0.12
Relation: supercharger
Accuracy: 0.892
Percentage of Symmetry: 0.026
Relation: running mate
Accuracy: 0.875
Percentage of Symmetry: 0.065


**Are General Relations that weren't trained on symmetric?**

In [None]:
entities_test, relations_test = load_data(Relation.Equivalence, source_language, target_language, False, False)

relations_sampled = relations_test.sample(n_relations)
relations_sampled

In [None]:
for idx, rel in relations_sampled.iterrows():
    print(f"Relation: {rel['en']}")

    sym = 0

    # Get random entities for e
    entities1 = generate_unique_indices(entities_test.shape[0], 100)

    for i, e_id in enumerate(entities1):
        e = entities_test['label'][e_id]

        e_token = tokenizer.encode(e)[1]

        # Use this single token entity to get a pair
        prompt = e + ' ' + rel['en'] + ' [MASK]'

        encoded_input = tokenizer(prompt, return_tensors='pt')
        token_logits = model(**encoded_input).logits

        mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
        mask_token_logits = token_logits[0, mask_token_index, :]
        f = torch.topk(mask_token_logits, 1, dim=1).indices[0].tolist()[0]

        # Check if the pair is symmetry
        prompt = tokenizer.decode([f]) + ' ' + rel['en'] + ' [MASK]'

        encoded_input = tokenizer(prompt, return_tensors='pt')
        token_logits = model(**encoded_input).logits

        mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
        mask_token_logits = token_logits[0, mask_token_index, :]
        top_1_token = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

        if e_token in top_1_token:
            sym += 1

    print(f'Percentage of Symmetry: {sym/100}')

### Manual

In [39]:
k = 0
total = len(train_dict['sample'])
i = 0

for txt in train_dict['sample'][:10000]:
    i += 1
    
    # Add [MASK] for object
    sample = txt.rsplit(' ', 1)[0] + ' [MASK]'
    label_token = tokenizer.convert_tokens_to_ids(txt.rsplit(' ', 1)[1])
    
    encoded_input = tokenizer(sample, return_tensors='pt')
    token_logits = model(**encoded_input).logits
    
    mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
    mask_token_logits = token_logits[0, mask_token_index, :]
    
    # Pick the [MASK] candidates with the highest logits
    top_5_tokens = torch.topk(mask_token_logits, 1, dim=1).indices[0].tolist()
    
    if label_token in top_5_tokens:
        k += 1
print(k/i)

0.8829473684210526


In [None]:
text = "lens manner of [MASK]"
encoded_input = tokenizer(text, return_tensors='pt')
token_logits = model(**encoded_input).logits

mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]

# Pick the [MASK] candidates with the highest logits
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for chunk in top_5_tokens:
    print(f"\n>>> {tokenizer.decode([chunk])}")

In [None]:
for t in train_dict['sample']:
    if 'Alex' in t:
        print(t)

### Results

- Training with just symmetric, doesn't necessarily mean that everything is symmetric. Maybe BERT in Symbolic Reasoner was just overfitting since it isnt finetuning but actually pretraining, i.e. it never sees evidence of non symmetry but a lot of symmetry.



- See Obsidian