In [1]:
import argparse
import copy

from transformers import BertForMaskedLM, BertTokenizer, TrainingArguments, Trainer, \
    DataCollatorForLanguageModeling, IntervalStrategy

from datasets import Dataset
import os

from data_generation_relation import *
from utils import *
from custom_trainer import CustomTrainer
from datasets import load_metric
import logging
from transformers import logging as tlogging
import wandb
import sys
from utils import set_seed
from transformers.integrations import WandbCallback, TensorBoardCallback
from tqdm.notebook import tqdm
from collections import Counter

os.environ["WANDB_DISABLED"] = "true"

In [2]:
set_seed(42)

run_name = 'Implication_en_de_20'
epochs = 200
batch_size = 256
lr = 5e-5

relation = 'implication'
source_language = ['en']
target_language = ['de']
n_relations = 10
n_facts = 1000
n_pairs = 20

use_random = False

precision_k = 1

use_pretrained = False
use_target = False

In [3]:
train, test, relations, entities = generate_reasoning(relation=Relation(relation),
                                                        source_language=source_language,
                                                        target_language=target_language,
                                                        n_relations=n_relations,
                                                        n_facts=n_facts,
                                                        use_pretrained=use_pretrained,
                                                        use_target=use_target,
                                                        use_enhanced=False,
                                                        use_same_relations=False,
                                                        n_pairs=n_pairs)

relations

(        id                          en                             de  \
 694   P105                  taxon rank             taxonomischer Rang   
 598   P462                       color                          Farbe   
 120   P111  measured physical quantity  gemessene physikalische Größe   
 281   P400                    platform                      Plattform   
 137  P8345             media franchise               Medien-Franchise   
 204  P1606        natural reservoir of           Erregerreservoir von   
 231  P2675                    reply to                    Antwort auf   
 213  P1909                 side effect                   Nebenwirkung   
 235  P1363       points/goal scored by    Punkt/Treffer erzielt durch   
 711   P607                    conflict                  Kriegseinsatz   
 
                             es                         fr    count  
 694       categoría taxonómica           rang taxinomique  3580266  
 598                      color             

In [4]:
relations[1]

Unnamed: 0,id,en,de,es,fr,count
528,P6855,emergency services,Notfalleinrichtungen,servicios de emergencia,accueil et traitement des urgences,766
606,P2429,expected completeness,erwartete Vollständigkeit,grado de completitud,degré de complétude,3826
63,P3027,open period from,geöffnet von Zeitpunkt,abierto desde,début de la période d'ouverture,16
515,P7727,legislative committee,Legislativkomitee,comité legislativo,comité législatif,123710
587,P9597,type of lens,Linsentyp,tipo de lente,type de lentille optique,1721
218,P8852,facial hair,Gesichtshaar,vello facial,pilosité faciale,362
66,P1455,list of works,Werkliste,lista de obras,liste des œuvres,1227
754,P129,physically interacts with,interagiert physikalisch mit,interactúa físicamente con,interagit physiquement avec,9480
118,P6271,demonym of,Demonym zu,gentilicio de,gentilé de,2629
522,P2596,culture,Kultur,cultura,culture,10007


In [5]:
relations_random = []

if use_random:
    # Generate half/half
    factor = 1.0
    n_random = factor * n_facts

    train_random, relations_random = generate_random(source_language, target_language, n_random, n_relations)
    train += train_random

relations_random

[]

In [6]:
# LOADING
# Load mBERT model and Tokenizer
tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased')
model = BertForMaskedLM.from_pretrained("bert-base-multilingual-cased")

# Load Data Collator for Prediction and Evaluation
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)
eval_data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
# ~~ PRE-PROCESSING ~~
train_dict = {'sample': train}
test_dict = {'sample': flatten_dict2_list(copy.deepcopy(test))}
train_ds = Dataset.from_dict(train_dict)
test_ds = Dataset.from_dict(test_dict)

# Tokenize Training and Test Data
tokenized_train = tokenize(tokenizer, train_ds)  # Train is shuffled by Huggingface
tokenized_test = tokenize(tokenizer, test_ds)

  0%|          | 0/19 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [8]:
# Save Train and Test Data
train_df = pd.DataFrame(train_dict)
test_complete_df = pd.DataFrame(test)
test_flat_df = pd.DataFrame(test_dict)

data_dir = './output/' + run_name + '/data/'
if not os.path.exists(data_dir):
    os.makedirs(data_dir)

train_df.to_csv(data_dir + 'train_set', index=False)
test_complete_df.to_json(data_dir + 'test_set_complete')
test_flat_df.to_csv(data_dir + 'test_set', index=False)

if use_random:
    train_random_df = pd.DataFrame({'sample': train_random})
    train_random_df.to_csv(data_dir + 'train_random', index=False)

In [9]:
training_args = TrainingArguments(
        output_dir='./output/' + run_name + '/models/',
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=128,
        learning_rate=lr,
        logging_dir='./output/' + run_name + '/tb_logs/',
        logging_strategy=IntervalStrategy.EPOCH,
        evaluation_strategy=IntervalStrategy.EPOCH,
        save_strategy=IntervalStrategy.EPOCH,
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model='accuracy',
        seed=42
    )

trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    tokenizer=tokenizer,
    data_collator=data_collator,
    eval_data_collator=eval_data_collator,
    compute_metrics=precision_at_one,
    precision_at=1
)


Using the `WAND_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [10]:
# Train
trainer.train()

***** Running training *****
  Num examples = 19000
  Num Epochs = 200
  Instantaneous batch size per device = 256
  Total train batch size (w. parallel, distributed & accumulation) = 512
  Gradient Accumulation steps = 1
  Total optimization steps = 7600


Epoch,Training Loss,Validation Loss,Accuracy
1,4.754,8.455544,0.001
2,3.343,6.775641,0.009
3,2.8281,6.018909,0.029
4,2.599,5.47537,0.025
5,2.4268,5.167844,0.036
6,2.3911,5.355132,0.038
7,2.3863,5.105166,0.034
8,2.2632,5.025634,0.035
9,2.2564,4.958439,0.032
10,2.2161,4.944651,0.03


Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-38
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-38/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-38/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-38/tokenizer_config.json
Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-38/special_tokens_map.json
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-76
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-76/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-76/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-76/tokenizer_config.json
Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-76/special_tokens_map.json
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoi

Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-418
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-418/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-418/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-418/tokenizer_config.json
Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-418/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-380] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-456
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-456/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-456/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-456/tokenizer_config.json
Special tokens file saved in ./output/Implication_en_de_20/m

Configuration saved in ./output/Implication_en_de_20/models/checkpoint-760/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-760/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-760/tokenizer_config.json
Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-760/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-684] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-798
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-798/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-798/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-798/tokenizer_config.json
Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-798/special_tokens_map.json
Deleting older checkpoint [output/

Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-1064] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-1140
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-1140/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-1140/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-1140/tokenizer_config.json
Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-1140/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-1102] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-1178
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-1178/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-1178/pytorch_model.bin
tokenizer config file saved in ./output/Implicat

Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-1482
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-1482/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-1482/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-1482/tokenizer_config.json
Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-1482/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-1444] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-1520
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-1520/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-1520/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-1520/tokenizer_config.json
Special tokens file saved in ./output/Implication_

Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-1824/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-760] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-1862
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-1862/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-1862/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-1862/tokenizer_config.json
Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-1862/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-1786] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-1900
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-1900/config.json
Model weights saved in ./output/Impli

Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-2204
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-2204/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-2204/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-2204/tokenizer_config.json
Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-2204/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-2166] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-2242
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-2242/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-2242/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-2242/tokenizer_config.json
Special tokens file saved in ./output/Implication_

Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-2546/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-2508] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-2584
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-2584/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-2584/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-2584/tokenizer_config.json
Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-2584/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-2546] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-2622
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-2622/config.json
Model weights saved in ./output/Impl

Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-2926
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-2926/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-2926/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-2926/tokenizer_config.json
Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-2926/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-2888] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-2964
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-2964/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-2964/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-2964/tokenizer_config.json
Special tokens file saved in ./output/Implication_

Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-3268/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-3230] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-3306
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-3306/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-3306/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-3306/tokenizer_config.json
Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-3306/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-3268] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-3344
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-3344/config.json
Model weights saved in ./output/Impl

Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-3648
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-3648/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-3648/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-3648/tokenizer_config.json
Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-3648/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-3610] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-3686
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-3686/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-3686/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-3686/tokenizer_config.json
Special tokens file saved in ./output/Implication_

Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-3990/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-3952] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-4028
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-4028/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-4028/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-4028/tokenizer_config.json
Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-4028/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-3990] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-4066
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-4066/config.json
Model weights saved in ./output/Impl

Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-4370
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-4370/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-4370/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-4370/tokenizer_config.json
Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-4370/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-4332] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-4408
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-4408/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-4408/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-4408/tokenizer_config.json
Special tokens file saved in ./output/Implication_

Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-4712/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-4674] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-4750
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-4750/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-4750/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-4750/tokenizer_config.json
Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-4750/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-4712] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-4788
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-4788/config.json
Model weights saved in ./output/Impl

Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-5092
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-5092/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-5092/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-5092/tokenizer_config.json
Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-5092/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-5054] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-5130
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-5130/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-5130/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-5130/tokenizer_config.json
Special tokens file saved in ./output/Implication_

Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-5434/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-5396] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-5472
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-5472/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-5472/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-5472/tokenizer_config.json
Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-5472/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-5434] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-5510
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-5510/config.json
Model weights saved in ./output/Impl

Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-5814
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-5814/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-5814/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-5814/tokenizer_config.json
Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-5814/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-5776] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-5852
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-5852/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-5852/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-5852/tokenizer_config.json
Special tokens file saved in ./output/Implication_

Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-6156/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-6118] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-6194
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-6194/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-6194/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-6194/tokenizer_config.json
Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-6194/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-6156] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-6232
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-6232/config.json
Model weights saved in ./output/Impl

Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-6536
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-6536/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-6536/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-6536/tokenizer_config.json
Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-6536/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-6498] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-6574
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-6574/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-6574/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-6574/tokenizer_config.json
Special tokens file saved in ./output/Implication_

Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-6878/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-6840] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-6916
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-6916/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-6916/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-6916/tokenizer_config.json
Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-6916/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-6878] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-6954
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-6954/config.json
Model weights saved in ./output/Impl

Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-7258
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-7258/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-7258/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-7258/tokenizer_config.json
Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-7258/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-7220] due to args.save_total_limit
Saving model checkpoint to ./output/Implication_en_de_20/models/checkpoint-7296
Configuration saved in ./output/Implication_en_de_20/models/checkpoint-7296/config.json
Model weights saved in ./output/Implication_en_de_20/models/checkpoint-7296/pytorch_model.bin
tokenizer config file saved in ./output/Implication_en_de_20/models/checkpoint-7296/tokenizer_config.json
Special tokens file saved in ./output/Implication_

Special tokens file saved in ./output/Implication_en_de_20/models/checkpoint-7600/special_tokens_map.json
Deleting older checkpoint [output/Implication_en_de_20/models/checkpoint-7562] due to args.save_total_limit


Training completed. Do not forget to share your model on huggingface.co/models =)


Loading best model from ./output/Implication_en_de_20/models/checkpoint-2470 (score: 0.061).


TrainOutput(global_step=7600, training_loss=1.3646312668449, metrics={'train_runtime': 6730.2371, 'train_samples_per_second': 564.616, 'train_steps_per_second': 1.129, 'total_flos': 1.75993251844284e+16, 'train_loss': 1.3646312668449, 'epoch': 200.0})

In [11]:
# Evaluate Test
trainer.evaluate(eval_dataset=tokenized_test)



{'eval_accuracy': 0.061,
 'eval_loss': 5.3030805587768555,
 'eval_runtime': 1.6259,
 'eval_samples_per_second': 615.046,
 'eval_steps_per_second': 2.46,
 'epoch': 200.0}

In [None]:
relations

In [None]:
# Evaluation Symmetry per Relation
evaluation_implication(trainer, tokenizer, relations, copy.deepcopy(test))

In [None]:
def evaluation_implication(trainer, tokenizer, relation_pairs, test):
    for target in test:

        # Iterate over all relations per target language
        for (idx1, relation), (idx2, implication) in zip(relation_pairs[0].iterrows(), relation_pairs[1].iterrows()):
            # IMPLICATION
            if not test[target][implication[target]]:
                continue

            # Relation from test set dict
            relation_test = test[target][implication[target]]

            # Tokenize
            relation_test_ds = Dataset.from_dict({'sample': relation_test})
            tokenized_relation_ds = tokenize(tokenizer, relation_test_ds)

            # Evaluate
            metrics = trainer.evaluate(eval_dataset=tokenized_relation_ds, custom_eval=True)
            output_metrics = remove_key_dict(metrics, 'eval_correct_predictions')
            print(output_metrics)

#### Evaluate
- How is (if at all) implication learned?
- Pretrained?
- Target?

In [None]:
model.to('cpu')
model.eval()

In [None]:
print(train_dict['sample'][:1901])

In [None]:
test_dict['sample']

In [None]:
entities[900:1000]

#### Rule: (e, r, f) -> (e, s, a), (e, s, b), (e, s, c)  =>  (e, r_de, f) -> (e, s_de, a), (e, s, b), (e, s, c) 

Test:
- Are the implications learned in source language? (e, s, a) (e, s, b) (e, s, c)
- Is there a general transfer to the target? (e, r_de, f)

1800 facts are training the rule (900<->900)
1800-1900 are facts that are used for testing

In [None]:
# Iterate over relations, take the training samples that were trained on
for i in range(n_relations):
    trained_test = train_dict['sample'][1800+i*1900:(i+1)*1900]
    entities_sampled = entities[900+i*1000:(i+1)*1000]

    acc_imp_source = 0
    acc_rde = 0
    
    # Relation pairs!
    r = relations[0]['en'].iloc[i]
    r_de = relations[0]['de'].iloc[i]
    s = relations[1]['en'].iloc[i]
    s_de = relations[1]['de'].iloc[i]

    for j, sample in enumerate(trained_test):
        
        ents = entities_sampled[j]
        e = ents[0]
        f = ents[1]
        
        # Test (e, s, a) (e, s, b) (e, s, c)
        for ent in ents[2:]:
            label_token = tokenizer.convert_tokens_to_ids(ent)

            prompt = e + ' ' + s + ' [MASK]'

            encoded_input = tokenizer(prompt, return_tensors='pt')
            token_logits = model(**encoded_input).logits

            mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
            mask_token_logits = token_logits[0, mask_token_index, :]

            top_token = torch.topk(mask_token_logits, len(ents[2:]), dim=1).indices[0].tolist()

            if label_token in top_token:
                acc_imp_source += 1

                
        # Test (e, r_de, f)
        label_token = tokenizer.convert_tokens_to_ids(f)

        prompt = e + ' ' + r_de + ' [MASK]'

        encoded_input = tokenizer(prompt, return_tensors='pt')
        token_logits = model(**encoded_input).logits

        mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
        mask_token_logits = token_logits[0, mask_token_index, :]

        top_1_token = torch.topk(mask_token_logits, 1, dim=1).indices[0].tolist()[0]

        if label_token == top_1_token:
            acc_rde += 1        

    acc_imp_source /= (len(ents[2:])*100)
    acc_rde /= 100

    print(f'Relation: {r}')
    print(f'Accuracy for Implication Source (e, s, a) (e, s, b) (e, s, c): {acc_imp_source}')
    print(f'Accuracy for KT (e, r_de, f): {acc_rde}')
    print('\n')

### Manual

In [29]:
k = 0
total = len(train_dict['sample'])
i = 0

for txt in train_dict['sample'][:10000]:
    i += 1
    
    # Add [MASK] for object
    sample = txt.rsplit(' ', 1)[0] + ' [MASK]'
    label_token = tokenizer.convert_tokens_to_ids(txt.rsplit(' ', 1)[1])
    
    encoded_input = tokenizer(sample, return_tensors='pt')
    token_logits = model(**encoded_input).logits
    
    mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
    mask_token_logits = token_logits[0, mask_token_index, :]
    
    # Pick the [MASK] candidates with the highest logits
    top_5_tokens = torch.topk(mask_token_logits, 1, dim=1).indices[0].tolist()
    
    if label_token in top_5_tokens:
        k += 1
        
print(k/i)

0.9906


In [None]:
text = "lens manner of [MASK]"
encoded_input = tokenizer(text, return_tensors='pt')
token_logits = model(**encoded_input).logits

mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]

# Pick the [MASK] candidates with the highest logits
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for chunk in top_5_tokens:
    print(f"\n>>> {tokenizer.decode([chunk])}")

In [None]:
for t in train_dict['sample']:
    if 'Alex' in t:
        print(t)

### Results
