In [1]:
import argparse
import copy

from transformers import BertForMaskedLM, BertTokenizer, TrainingArguments, Trainer, \
    DataCollatorForLanguageModeling, IntervalStrategy

from datasets import Dataset
import os

from data_generation_relation import *
from utils import *
from custom_trainer import CustomTrainer
from datasets import load_metric
import logging
from transformers import logging as tlogging
import wandb
import sys
from utils import set_seed
from transformers.integrations import WandbCallback, TensorBoardCallback
from tqdm.notebook import tqdm
from collections import Counter

os.environ["WANDB_DISABLED"] = "true"

In [2]:
set_seed(42)

run_name = 'CompositionDefault'
epochs = 200
batch_size = 256
lr = 5e-5

relation = 'composition'
source_language = ['en']
target_language = ['de']
n_relations = 10
n_facts = 1000
n_pairs = 100

use_random = False

use_pretrained = False
use_target = False
use_enhanced = False

In [3]:
train, test, relations, entities = generate_reasoning(relation=Relation(relation),
                                                        source_language=source_language,
                                                        target_language=target_language,
                                                        n_relations=n_relations,
                                                        n_facts=n_facts,
                                                        use_pretrained=use_pretrained,
                                                        use_target=use_target,
                                                        use_enhanced=use_enhanced,
                                                        use_same_relations=False,
                                                        n_pairs=n_pairs)

relations

(        id                              en                                de  \
 694   P105                      taxon rank                taxonomischer Rang   
 606  P2429           expected completeness         erwartete Vollständigkeit   
 281   P400                        platform                         Plattform   
 587  P9597                    type of lens                         Linsentyp   
 231  P2675                        reply to                       Antwort auf   
 754   P129       physically interacts with      interagiert physikalisch mit   
 711   P607                        conflict                     Kriegseinsatz   
 196  P5353                 school district                       Schulbezirk   
 300  P3461      designated as terrorist by  als terroristisch eingestuft von   
 326    P65  site of astronomical discovery     astronomischer Entdeckungsort   
 
                                       es                               fr  \
 694                 categorí

In [4]:
relations[1]

Unnamed: 0,id,en,de,es,fr,count
528,P6855,emergency services,Notfalleinrichtungen,servicios de emergencia,accueil et traitement des urgences,766
120,P111,measured physical quantity,gemessene physikalische Größe,cantidad física medida,grandeur physique mesurée,3610
515,P7727,legislative committee,Legislativkomitee,comité legislativo,comité législatif,123710
204,P1606,natural reservoir of,Erregerreservoir von,reservorio natural de,réservoir naturel de,17
66,P1455,list of works,Werkliste,lista de obras,liste des œuvres,1227
235,P1363,points/goal scored by,Punkt/Treffer erzielt durch,puntos/goles marcados por,point/but marqué par,2441
522,P2596,culture,Kultur,cultura,culture,10007
383,P1876,vehicle,Fahrzeug,nave,vaisseau,840
30,P7163,typically sells,verkauft im Allgemeinen,vende generalmente,vend généralement,299
380,P1302,primary destinations,Hauptorte,destinos principales,principales localités desservies,3923


In [5]:
relations[2]

Unnamed: 0,id,en,de,es,fr,count
598,P462,color,Farbe,color,couleur,194389
63,P3027,open period from,geöffnet von Zeitpunkt,abierto desde,début de la période d'ouverture,16
137,P8345,media franchise,Medien-Franchise,franquicia de medios,franchise médiatique,27415
218,P8852,facial hair,Gesichtshaar,vello facial,pilosité faciale,362
213,P1909,side effect,Nebenwirkung,efecto secundario,effet secondaire,40
118,P6271,demonym of,Demonym zu,gentilicio de,gentilé de,2629
411,P2550,recording or performance of,Aufnahme oder Ausführung von,grabación o ejecución de,enregistrement ou interprétation de,14735
456,P21,sex or gender,Geschlecht,sexo o género,sexe ou genre,7855753
432,P193,main building contractor,Generalbauunternehmer,constructor,maître d'œuvre,2893
740,P1889,different from,verschieden von,diferente de,à ne pas confondre avec,797811


In [6]:
# LOADING
# Load mBERT model and Tokenizer
tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased')
model = BertForMaskedLM.from_pretrained("bert-base-multilingual-cased")

# Load Data Collator for Prediction and Evaluation
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)
eval_data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
# ~~ PRE-PROCESSING ~~
train_dict = {'sample': train}
test_dict = {'sample': flatten_dict2_list(copy.deepcopy(test))}
train_ds = Dataset.from_dict(train_dict)
test_ds = Dataset.from_dict(test_dict)

# Tokenize Training and Test Data
tokenized_train = tokenize(tokenizer, train_ds)  # Train is shuffled by Huggingface
tokenized_test = tokenize(tokenizer, test_ds)

  0%|          | 0/29 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [8]:
# Save Train and Test Data
train_df = pd.DataFrame(train_dict)
test_complete_df = pd.DataFrame(test)
test_flat_df = pd.DataFrame(test_dict)

data_dir = './output/' + run_name + '/data/'
if not os.path.exists(data_dir):
    os.makedirs(data_dir)

train_df.to_csv(data_dir + 'train_set', index=False)
test_complete_df.to_json(data_dir + 'test_set_complete')
test_flat_df.to_csv(data_dir + 'test_set', index=False)

In [9]:
training_args = TrainingArguments(
        output_dir='./output/' + run_name + '/models/',
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=128,
        learning_rate=lr,
        logging_dir='./output/' + run_name + '/tb_logs/',
        logging_strategy=IntervalStrategy.EPOCH,
        evaluation_strategy=IntervalStrategy.EPOCH,
        save_strategy=IntervalStrategy.EPOCH,
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model='accuracy',
        seed=42
    )

trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    tokenizer=tokenizer,
    data_collator=data_collator,
    eval_data_collator=eval_data_collator,
    compute_metrics=precision_at_one
)


Using the `WAND_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [10]:
# Train
trainer.train()

***** Running training *****
  Num examples = 29000
  Num Epochs = 200
  Instantaneous batch size per device = 256
  Total train batch size (w. parallel, distributed & accumulation) = 512
  Gradient Accumulation steps = 1
  Total optimization steps = 11400


Epoch,Training Loss,Validation Loss,Accuracy
1,4.3672,8.52042,0.0
2,3.333,7.969775,0.001
3,3.1855,7.787424,0.005
4,3.1708,7.638084,0.002
5,3.0274,7.538991,0.003
6,2.8838,7.293577,0.006
7,2.7869,7.135616,0.008
8,2.6944,6.910761,0.006
9,2.5135,6.946599,0.008
10,2.4841,6.712662,0.009


Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-57
Configuration saved in ./output/CompositionDefault/models/checkpoint-57/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-57/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-57/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkpoint-57/special_tokens_map.json
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-114
Configuration saved in ./output/CompositionDefault/models/checkpoint-114/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-114/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-114/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkpoint-114/special_tokens_map.json
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-171
Configurat

Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-627
Configuration saved in ./output/CompositionDefault/models/checkpoint-627/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-627/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-627/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkpoint-627/special_tokens_map.json
Deleting older checkpoint [output/CompositionDefault/models/checkpoint-513] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-684
Configuration saved in ./output/CompositionDefault/models/checkpoint-684/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-684/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-684/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkpoint-684/s

Deleting older checkpoint [output/CompositionDefault/models/checkpoint-1083] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-1197
Configuration saved in ./output/CompositionDefault/models/checkpoint-1197/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-1197/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-1197/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkpoint-1197/special_tokens_map.json
Deleting older checkpoint [output/CompositionDefault/models/checkpoint-1140] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-1254
Configuration saved in ./output/CompositionDefault/models/checkpoint-1254/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-1254/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/ch

Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-1710
Configuration saved in ./output/CompositionDefault/models/checkpoint-1710/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-1710/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-1710/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkpoint-1710/special_tokens_map.json
Deleting older checkpoint [output/CompositionDefault/models/checkpoint-1653] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-1767
Configuration saved in ./output/CompositionDefault/models/checkpoint-1767/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-1767/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-1767/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkp

Deleting older checkpoint [output/CompositionDefault/models/checkpoint-2166] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-2280
Configuration saved in ./output/CompositionDefault/models/checkpoint-2280/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-2280/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-2280/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkpoint-2280/special_tokens_map.json
Deleting older checkpoint [output/CompositionDefault/models/checkpoint-2223] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-2337
Configuration saved in ./output/CompositionDefault/models/checkpoint-2337/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-2337/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/ch

Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-2793
Configuration saved in ./output/CompositionDefault/models/checkpoint-2793/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-2793/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-2793/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkpoint-2793/special_tokens_map.json
Deleting older checkpoint [output/CompositionDefault/models/checkpoint-2679] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-2850
Configuration saved in ./output/CompositionDefault/models/checkpoint-2850/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-2850/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-2850/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkp

Deleting older checkpoint [output/CompositionDefault/models/checkpoint-3135] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-3363
Configuration saved in ./output/CompositionDefault/models/checkpoint-3363/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-3363/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-3363/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkpoint-3363/special_tokens_map.json
Deleting older checkpoint [output/CompositionDefault/models/checkpoint-3249] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-3420
Configuration saved in ./output/CompositionDefault/models/checkpoint-3420/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-3420/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/ch

Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-3876
Configuration saved in ./output/CompositionDefault/models/checkpoint-3876/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-3876/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-3876/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkpoint-3876/special_tokens_map.json
Deleting older checkpoint [output/CompositionDefault/models/checkpoint-3762] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-3933
Configuration saved in ./output/CompositionDefault/models/checkpoint-3933/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-3933/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-3933/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkp

Deleting older checkpoint [output/CompositionDefault/models/checkpoint-4161] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-4446
Configuration saved in ./output/CompositionDefault/models/checkpoint-4446/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-4446/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-4446/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkpoint-4446/special_tokens_map.json
Deleting older checkpoint [output/CompositionDefault/models/checkpoint-4332] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-4503
Configuration saved in ./output/CompositionDefault/models/checkpoint-4503/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-4503/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/ch

Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-4959
Configuration saved in ./output/CompositionDefault/models/checkpoint-4959/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-4959/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-4959/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkpoint-4959/special_tokens_map.json
Deleting older checkpoint [output/CompositionDefault/models/checkpoint-4845] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-5016
Configuration saved in ./output/CompositionDefault/models/checkpoint-5016/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-5016/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-5016/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkp

Deleting older checkpoint [output/CompositionDefault/models/checkpoint-5415] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-5529
Configuration saved in ./output/CompositionDefault/models/checkpoint-5529/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-5529/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-5529/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkpoint-5529/special_tokens_map.json
Deleting older checkpoint [output/CompositionDefault/models/checkpoint-5472] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-5586
Configuration saved in ./output/CompositionDefault/models/checkpoint-5586/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-5586/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/ch

Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-6042
Configuration saved in ./output/CompositionDefault/models/checkpoint-6042/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-6042/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-6042/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkpoint-6042/special_tokens_map.json
Deleting older checkpoint [output/CompositionDefault/models/checkpoint-5928] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-6099
Configuration saved in ./output/CompositionDefault/models/checkpoint-6099/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-6099/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-6099/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkp

Deleting older checkpoint [output/CompositionDefault/models/checkpoint-6498] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-6612
Configuration saved in ./output/CompositionDefault/models/checkpoint-6612/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-6612/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-6612/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkpoint-6612/special_tokens_map.json
Deleting older checkpoint [output/CompositionDefault/models/checkpoint-6555] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-6669
Configuration saved in ./output/CompositionDefault/models/checkpoint-6669/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-6669/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/ch

Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-7125
Configuration saved in ./output/CompositionDefault/models/checkpoint-7125/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-7125/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-7125/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkpoint-7125/special_tokens_map.json
Deleting older checkpoint [output/CompositionDefault/models/checkpoint-7068] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-7182
Configuration saved in ./output/CompositionDefault/models/checkpoint-7182/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-7182/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-7182/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkp

Deleting older checkpoint [output/CompositionDefault/models/checkpoint-7581] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-7695
Configuration saved in ./output/CompositionDefault/models/checkpoint-7695/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-7695/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-7695/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkpoint-7695/special_tokens_map.json
Deleting older checkpoint [output/CompositionDefault/models/checkpoint-7638] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-7752
Configuration saved in ./output/CompositionDefault/models/checkpoint-7752/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-7752/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/ch

Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-8208
Configuration saved in ./output/CompositionDefault/models/checkpoint-8208/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-8208/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-8208/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkpoint-8208/special_tokens_map.json
Deleting older checkpoint [output/CompositionDefault/models/checkpoint-8151] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-8265
Configuration saved in ./output/CompositionDefault/models/checkpoint-8265/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-8265/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-8265/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkp

Deleting older checkpoint [output/CompositionDefault/models/checkpoint-8664] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-8778
Configuration saved in ./output/CompositionDefault/models/checkpoint-8778/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-8778/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-8778/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkpoint-8778/special_tokens_map.json
Deleting older checkpoint [output/CompositionDefault/models/checkpoint-8721] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-8835
Configuration saved in ./output/CompositionDefault/models/checkpoint-8835/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-8835/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/ch

Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-9291
Configuration saved in ./output/CompositionDefault/models/checkpoint-9291/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-9291/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-9291/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkpoint-9291/special_tokens_map.json
Deleting older checkpoint [output/CompositionDefault/models/checkpoint-9234] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-9348
Configuration saved in ./output/CompositionDefault/models/checkpoint-9348/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-9348/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-9348/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkp

Deleting older checkpoint [output/CompositionDefault/models/checkpoint-9747] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-9861
Configuration saved in ./output/CompositionDefault/models/checkpoint-9861/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-9861/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-9861/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkpoint-9861/special_tokens_map.json
Deleting older checkpoint [output/CompositionDefault/models/checkpoint-9804] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-9918
Configuration saved in ./output/CompositionDefault/models/checkpoint-9918/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-9918/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/ch

Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-10374
Configuration saved in ./output/CompositionDefault/models/checkpoint-10374/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-10374/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-10374/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkpoint-10374/special_tokens_map.json
Deleting older checkpoint [output/CompositionDefault/models/checkpoint-10317] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-10431
Configuration saved in ./output/CompositionDefault/models/checkpoint-10431/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-10431/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-10431/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/mod

Special tokens file saved in ./output/CompositionDefault/models/checkpoint-10887/special_tokens_map.json
Deleting older checkpoint [output/CompositionDefault/models/checkpoint-10830] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-10944
Configuration saved in ./output/CompositionDefault/models/checkpoint-10944/config.json
Model weights saved in ./output/CompositionDefault/models/checkpoint-10944/pytorch_model.bin
tokenizer config file saved in ./output/CompositionDefault/models/checkpoint-10944/tokenizer_config.json
Special tokens file saved in ./output/CompositionDefault/models/checkpoint-10944/special_tokens_map.json
Deleting older checkpoint [output/CompositionDefault/models/checkpoint-10887] due to args.save_total_limit
Saving model checkpoint to ./output/CompositionDefault/models/checkpoint-11001
Configuration saved in ./output/CompositionDefault/models/checkpoint-11001/config.json
Model weights saved in ./output/CompositionDef

TrainOutput(global_step=11400, training_loss=1.0165458928493032, metrics={'train_runtime': 10805.7205, 'train_samples_per_second': 536.753, 'train_steps_per_second': 1.055, 'total_flos': 2.686243581e+16, 'train_loss': 1.0165458928493032, 'epoch': 200.0})

In [11]:
# Evaluate Test
trainer.evaluate(eval_dataset=tokenized_test)



{'eval_accuracy': 0.224,
 'eval_loss': 5.365607261657715,
 'eval_runtime': 1.6513,
 'eval_samples_per_second': 605.571,
 'eval_steps_per_second': 2.422,
 'epoch': 200.0}

In [17]:
# Evaluation Symmetry per Relation
evaluation_composition(trainer, tokenizer, relations, source_language, copy.deepcopy(test))

Relation - source: taxon rank Relation - target: taxonomischer Rang
Relation - source: emergency services Relation - target: Notfalleinrichtungen
Relation - source: color Relation - target: Farbe


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.11, 'eval_loss': 3.7856783866882324, 'eval_runtime': 0.8323, 'eval_samples_per_second': 120.15, 'eval_steps_per_second': 1.201}
Relation - source: expected completeness Relation - target: erwartete Vollständigkeit
Relation - source: measured physical quantity Relation - target: gemessene physikalische Größe
Relation - source: open period from Relation - target: geöffnet von Zeitpunkt


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.06, 'eval_loss': 6.303734302520752, 'eval_runtime': 0.6847, 'eval_samples_per_second': 146.048, 'eval_steps_per_second': 1.46}
Relation - source: platform Relation - target: Plattform
Relation - source: legislative committee Relation - target: Legislativkomitee
Relation - source: media franchise Relation - target: Medien-Franchise


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.0, 'eval_loss': 12.631064414978027, 'eval_runtime': 0.6724, 'eval_samples_per_second': 148.719, 'eval_steps_per_second': 1.487}
Relation - source: type of lens Relation - target: Linsentyp
Relation - source: natural reservoir of Relation - target: Erregerreservoir von
Relation - source: facial hair Relation - target: Gesichtshaar


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.04, 'eval_loss': 7.2143940925598145, 'eval_runtime': 0.6595, 'eval_samples_per_second': 151.623, 'eval_steps_per_second': 1.516}
Relation - source: reply to Relation - target: Antwort auf
Relation - source: list of works Relation - target: Werkliste
Relation - source: side effect Relation - target: Nebenwirkung


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.04, 'eval_loss': 5.758947372436523, 'eval_runtime': 0.6251, 'eval_samples_per_second': 159.962, 'eval_steps_per_second': 1.6}
Relation - source: physically interacts with Relation - target: interagiert physikalisch mit
Relation - source: points/goal scored by Relation - target: Punkt/Treffer erzielt durch
Relation - source: demonym of Relation - target: Demonym zu


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.27, 'eval_loss': 3.7629101276397705, 'eval_runtime': 0.6228, 'eval_samples_per_second': 160.556, 'eval_steps_per_second': 1.606}
Relation - source: conflict Relation - target: Kriegseinsatz
Relation - source: culture Relation - target: Kultur
Relation - source: recording or performance of Relation - target: Aufnahme oder Ausführung von


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.87, 'eval_loss': 0.39176636934280396, 'eval_runtime': 0.6612, 'eval_samples_per_second': 151.238, 'eval_steps_per_second': 1.512}
Relation - source: school district Relation - target: Schulbezirk
Relation - source: vehicle Relation - target: Fahrzeug
Relation - source: sex or gender Relation - target: Geschlecht


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.0, 'eval_loss': 8.432419776916504, 'eval_runtime': 0.613, 'eval_samples_per_second': 163.143, 'eval_steps_per_second': 1.631}
Relation - source: designated as terrorist by Relation - target: als terroristisch eingestuft von
Relation - source: typically sells Relation - target: verkauft im Allgemeinen
Relation - source: main building contractor Relation - target: Generalbauunternehmer


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.67, 'eval_loss': 1.1832425594329834, 'eval_runtime': 0.6232, 'eval_samples_per_second': 160.475, 'eval_steps_per_second': 1.605}
Relation - source: site of astronomical discovery Relation - target: astronomischer Entdeckungsort
Relation - source: primary destinations Relation - target: Hauptorte
Relation - source: different from Relation - target: verschieden von


  0%|          | 0/1 [00:00<?, ?ba/s]



{'eval_accuracy': 0.18, 'eval_loss': 4.191916465759277, 'eval_runtime': 0.637, 'eval_samples_per_second': 156.995, 'eval_steps_per_second': 1.57}


#### Evaluate
- Why does it not learn compositions?
- pretrained
- target

In [18]:
model.to('cpu')
model.eval()

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=T

In [19]:
print(train_dict['sample'][:1901])

['Medina taxon rank Terminal', 'Terminal emergency services Malone', 'Medina color Malone', 'Invasion taxon rank EV', 'EV emergency services Soccerway', 'Invasion color Soccerway', 'Burke taxon rank Amadeus', 'Amadeus emergency services President', 'Burke color President', 'Drama taxon rank Simpson', 'Simpson emergency services Mitch', 'Drama color Mitch', 'Master taxon rank Jang', 'Jang emergency services Hütte', 'Master color Hütte', 'Dari taxon rank Kosmos', 'Kosmos emergency services Bohemia', 'Dari color Bohemia', 'Chihuahua taxon rank Lyman', 'Lyman emergency services Ravenna', 'Chihuahua color Ravenna', 'EP taxon rank Lakes', 'Lakes emergency services Shri', 'EP color Shri', 'Chase taxon rank Seite', 'Seite emergency services Delgado', 'Chase color Delgado', 'Worcester taxon rank Antoinette', 'Antoinette emergency services Midway', 'Worcester color Midway', 'Albert taxon rank Henley', 'Henley emergency services Od', 'Albert color Od', 'Ibiza taxon rank Cruise', 'Cruise emergency

In [20]:
test_dict['sample']

['Haji Farbe Malone',
 'Yahoo Farbe Soccerway',
 'Stal Farbe President',
 'FC Farbe Mitch',
 'Dad Farbe Hütte',
 'Kenia Farbe Bohemia',
 'CCD Farbe Ravenna',
 'Riau Farbe Shri',
 'Ky Farbe Delgado',
 'Billie Farbe Midway',
 'Elbe Farbe Od',
 'Paraíso Farbe Missouri',
 'TD Farbe Hodges',
 'Luther Farbe Publié',
 'Roi Farbe Gyula',
 'Pole Farbe DSM',
 'Page Farbe Ost',
 'Baron Farbe Candy',
 'Libia Farbe Ted',
 'Cuenca Farbe Eaton',
 'Kálmán Farbe Irving',
 'ET Farbe Palatinat',
 'Mainstream Farbe Sai',
 'Agency Farbe Principal',
 'Mata Farbe Arrow',
 'Mineral Farbe Berg',
 'Figaro Farbe Bora',
 'Trung Farbe Soleil',
 'Sabha Farbe Modena',
 'Guimarães Farbe Odd',
 'Disneyland Farbe Cullen',
 'Wes Farbe Nashville',
 'Gesù Farbe Guadalupe',
 'Cinq Farbe Server',
 'Silla Farbe Catedral',
 'JR Farbe McGill',
 'Bryant Farbe Linden',
 'Klaus Farbe Schlacht',
 'Kleiner Farbe Helmut',
 'Fort Farbe HP',
 'City Farbe Ipswich',
 'Márquez Farbe Den',
 'Yer Farbe Hügel',
 'Bernard Farbe pad',
 'Diaz 

In [21]:
entities[900:1000]

[['Haji', 'Terminal', 'Malone'],
 ['Yahoo', 'EV', 'Soccerway'],
 ['Stal', 'Amadeus', 'President'],
 ['FC', 'Simpson', 'Mitch'],
 ['Dad', 'Jang', 'Hütte'],
 ['Kenia', 'Kosmos', 'Bohemia'],
 ['CCD', 'Lyman', 'Ravenna'],
 ['Riau', 'Lakes', 'Shri'],
 ['Ky', 'Seite', 'Delgado'],
 ['Billie', 'Antoinette', 'Midway'],
 ['Elbe', 'Henley', 'Od'],
 ['Paraíso', 'Cruise', 'Missouri'],
 ['TD', 'Ennen', 'Hodges'],
 ['Luther', 'Turm', 'Publié'],
 ['Roi', 'Gloucester', 'Gyula'],
 ['Pole', 'Diana', 'DSM'],
 ['Page', 'Bing', 'Ost'],
 ['Baron', 'Industria', 'Candy'],
 ['Libia', 'President', 'Ted'],
 ['Cuenca', 'Sprint', 'Eaton'],
 ['Kálmán', 'Jess', 'Irving'],
 ['ET', 'Indianapolis', 'Palatinat'],
 ['Mainstream', 'Mouse', 'Sai'],
 ['Agency', 'Hindenburg', 'Principal'],
 ['Mata', 'Shawn', 'Arrow'],
 ['Mineral', 'Ehren', 'Berg'],
 ['Figaro', 'Florida', 'Bora'],
 ['Trung', 'Córdoba', 'Soleil'],
 ['Sabha', 'Otis', 'Modena'],
 ['Guimarães', 'Waves', 'Odd'],
 ['Disneyland', 'Special', 'Cullen'],
 ['Wes', 'Chaco

#### Rule: (e, r, f ) $\land$ (f, s, g) -> (e, t, g) => (e, r_de, f ) $\land$ (f, s_de, g) -> (e, t_de, g)

- Are the composition learned in source language? (e, t, g)
- Is there a general transfer to the target? (e, r_de, f) and (f, s_de, g)

1800 facts are training the rule (900<->900)
1800-1900 are facts that are used for testing

In [24]:
# Iterate over relations, take the training samples that were trained on
for i in range(n_relations):
    trained_test = train_dict['sample'][1800+i*1900:(i+1)*1900]
    entities_sampled = entities[900+i*1000:(i+1)*1000]
    
    acc_source = 0
    acc_target = 0
    acc_transfer1 = 0
    acc_transfer2 = 0  
    
    # Relation pairs!
    r = relations[0]['en'].iloc[i]
    s = relations[1]['en'].iloc[i]
    t = relations[2]['en'].iloc[i]

    r_de = relations[0]['de'].iloc[i]
    s_de = relations[1]['de'].iloc[i]
    t_de = relations[2]['de'].iloc[i]

    for j, sample in enumerate(trained_test):

        ents = entities_sampled[j]
        e = ents[0]
        f = ents[1]
        g = ents[2]

        # Test (e, t, g)
        label_token = tokenizer.convert_tokens_to_ids(g)

        prompt = e + ' ' + t + ' [MASK]'

        encoded_input = tokenizer(prompt, return_tensors='pt')
        token_logits = model(**encoded_input).logits

        mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
        mask_token_logits = token_logits[0, mask_token_index, :]

        top_1_token = torch.topk(mask_token_logits, 1, dim=1).indices[0].tolist()[0]

        if label_token == top_1_token:
            acc_source += 1
            
        # Test (e, t, g)
        label_token = tokenizer.convert_tokens_to_ids(g)

        prompt = e + ' ' + t_de + ' [MASK]'

        encoded_input = tokenizer(prompt, return_tensors='pt')
        token_logits = model(**encoded_input).logits

        mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
        mask_token_logits = token_logits[0, mask_token_index, :]

        top_1_token = torch.topk(mask_token_logits, 1, dim=1).indices[0].tolist()[0]

        if label_token == top_1_token:
            acc_target += 1

        # Test (e, r_de, f)
        label_token = tokenizer.convert_tokens_to_ids(f)

        prompt = e + ' ' + r_de + ' [MASK]'

        encoded_input = tokenizer(prompt, return_tensors='pt')
        token_logits = model(**encoded_input).logits

        mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
        mask_token_logits = token_logits[0, mask_token_index, :]

        top_1_token = torch.topk(mask_token_logits, 1, dim=1).indices[0].tolist()[0]

        if label_token == top_1_token:
            acc_transfer1 += 1
            
        # Test (f, s_de, g)
        label_token = tokenizer.convert_tokens_to_ids(g)

        prompt = f + ' ' + s_de + ' [MASK]'

        encoded_input = tokenizer(prompt, return_tensors='pt')
        token_logits = model(**encoded_input).logits

        mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
        mask_token_logits = token_logits[0, mask_token_index, :]

        top_1_token = torch.topk(mask_token_logits, 1, dim=1).indices[0].tolist()[0]

        if label_token == top_1_token:
            acc_transfer2 += 1
        

    acc_source /= 100
    acc_target /= 100
    acc_transfer1 /= 100
    acc_transfer2 /= 100

    print(f'Relation1: {r}')
    print(f'Relation2: {s}')
    print(f'Composition: {t}')
    print(f'Accuracy for learning source rule (e, t, g): {acc_source}')
    print(f'Accuracy for learning target rule (e, t_de, g): {acc_target}')
    print(f'Accuracy for KT2 (e, r_de, f): {acc_transfer1}')
    print(f'Accuracy for KT1 (f, s_de, g): {acc_transfer2}')
    print('\n')

Relation1: taxon rank
Relation2: emergency services
Composition: color
Accuracy for learning source rule (e, t, g): 0.96
Accuracy for learning target rule (e, t_de, g): 0.11
Accuracy for KT2 (e, r_de, f): 0.56
Accuracy for KT1 (f, s_de, g): 0.39


Relation1: expected completeness
Relation2: measured physical quantity
Composition: open period from
Accuracy for learning source rule (e, t, g): 0.96
Accuracy for learning target rule (e, t_de, g): 0.06
Accuracy for KT2 (e, r_de, f): 0.25
Accuracy for KT1 (f, s_de, g): 0.37


Relation1: platform
Relation2: legislative committee
Composition: media franchise
Accuracy for learning source rule (e, t, g): 0.91
Accuracy for learning target rule (e, t_de, g): 0.0
Accuracy for KT2 (e, r_de, f): 0.19
Accuracy for KT1 (f, s_de, g): 0.89


Relation1: type of lens
Relation2: natural reservoir of
Composition: facial hair
Accuracy for learning source rule (e, t, g): 0.97
Accuracy for learning target rule (e, t_de, g): 0.04
Accuracy for KT2 (e, r_de, f): 0

### Manual

In [23]:
k = 0
total = len(train_dict['sample'])
i = 0

for txt in train_dict['sample'][:10000]:
    i += 1
    
    # Add [MASK] for object
    sample = txt.rsplit(' ', 1)[0] + ' [MASK]'
    label_token = tokenizer.convert_tokens_to_ids(txt.rsplit(' ', 1)[1])
    
    encoded_input = tokenizer(sample, return_tensors='pt')
    token_logits = model(**encoded_input).logits
    
    mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
    mask_token_logits = token_logits[0, mask_token_index, :]
    
    # Pick the [MASK] candidates with the highest logits
    top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
    
    if label_token in top_5_tokens:
        k += 1

print(k/i)

0.9996


In [None]:
text = "lens manner of [MASK]"
encoded_input = tokenizer(text, return_tensors='pt')
token_logits = model(**encoded_input).logits

mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]

# Pick the [MASK] candidates with the highest logits
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for chunk in top_5_tokens:
    print(f"\n>>> {tokenizer.decode([chunk])}")

In [None]:
for t in train_dict['sample']:
    if 'Alex' in t:
        print(t)

### Results