#### Load Relations and Entities

In [18]:
import pandas as pd
import numpy as np

In [19]:
# Load entities (5500)
entities = pd.read_csv('../data/Entities/SingleToken/entities_languageAgnostic_clean.csv')

# Load Relations (60)
relations = pd.read_csv('../data/Relations/Symmetry/symmetric_multilingual_clean.csv')

In [20]:
relations_general = pd.read_csv('../data/Relations/General/properties_nonsymmetric_multilingual_clean.csv')

#### Prepare Data

In [21]:
import random

# prob of returning true
def decision(probability):
    return random.random() < probability

In [22]:
# Generate random pairs of numbers (indices into entity)
# Order doesn't matter, can't repeat
# i.e. ok is: (0,1), (1,2), (0,2) but not ok is (0,1),(1,0) or (0,0)
# Runs until exhausted or reached max_size
# possible to limit occurences of index
def gen_index_pairs(n, max_size=np.Inf, limit=np.Inf):
    pairs = set()
    ind = list()

    while len(pairs) < max_size:
        # return number between 0 and n (exclude)
        x, y = np.random.randint(n), np.random.randint(n)
        
        while ind.count(x) >= limit or ind.count(y) >= limit:
            x, y = np.random.randint(n), np.random.randint(n)

        ind.append(x)
        ind.append(y)
        
        i = 0
        while (x, y) in pairs or (y, x) in pairs or x == y:
            if i > 10:
                return
            x, y = np.random.randint(n), np.random.randint(n)
            i += 1
        
        pairs.add((x, y))
        yield x, y

In [23]:
n_relations = 10
n_facts = 2000

# (e, r, f ) <=> (f, r, e)
train = []
test = []
ent = []

# Sample relations
relations_sampled = relations.sample(n_relations)

for index, relation in relations_sampled.iterrows():
    
    to_split = []

    # Sample random entities
    entity_generator = gen_index_pairs(entities.shape[0], n_facts, 1)

    for e_id, f_id in entity_generator:
        e = entities['label'][e_id]
        f = entities['label'][f_id]
        
        ent.append(e)

        # Append symmetric relations
        train.append(e + ' ' + relation['en'] + ' ' + f)
        
        to_split.append(f + ' ' + relation['en'] + ' ' + e)
    
    # 90% train, 10% test
    split_pos = int(0.9 * len(to_split))
    
    train = train + to_split[:split_pos]
    test = test + to_split[split_pos:]

In [24]:
# Add non-relation
n_relations = 10
n_facts = 2000

non_rels = []

relations_general_sampled = relations_general.sample(n_relations)

for index, relation in relations_general_sampled.iterrows():

    # Sample random entities
    entity_generator = gen_index_pairs(entities.shape[0], n_facts, 1)

    for e_id, f_id in entity_generator:
        e = entities['label'][e_id]
        f = entities['label'][f_id]

        # Append relations
        train.append(e + ' ' + relation['en'] + ' ' + f)
        non_rels.append(e + ' ' + relation['en'] + ' ' + f)

In [25]:
len(train)

58000

In [26]:
len(test)

2000

In [27]:
test_dict = {'text': test}
train_dict = {'text': train}
train_dict

{'text': ['Richardson adjacent station Diesel',
  'Ranger adjacent station Campaign',
  'Sicily adjacent station CJ',
  'Amb adjacent station Mur',
  'Milo adjacent station Caesar',
  'Oviedo adjacent station Marek',
  'pl adjacent station Greatest',
  'Carlos adjacent station Meuse',
  'da adjacent station Mires',
  'Minister adjacent station Daha',
  'Khmer adjacent station Opera',
  'Gonzaga adjacent station Regional',
  'Dag adjacent station Aki',
  'Yüksek adjacent station Angelina',
  'Pour adjacent station Inn',
  'Pac adjacent station Valencia',
  'Spring adjacent station Erie',
  'Minds adjacent station Beatrix',
  'Camilla adjacent station Riviera',
  'Kane adjacent station Vockeroth',
  'Day adjacent station Monique',
  'Tea adjacent station Linden',
  'Semi adjacent station Nietzsche',
  'Becker adjacent station Gerd',
  'Rock adjacent station Allen',
  'Purcell adjacent station Lara',
  'SL adjacent station Senegal',
  'Neville adjacent station Ferris',
  'Solomon adjacent

### Preprocessing

First, we pad text so they are a uniform length. While it is possible to padtext in the tokenizer function by setting padding=True, it is more efficient to only pad the text to the length of the longest element in its batch. This is known as dynamic padding. You can do this with the DataCollatorWithPadding function:

##### Convert to datasets

In [28]:
from datasets import load_dataset, Dataset

In [29]:
train_ds = Dataset.from_dict(train_dict)
test_ds = Dataset.from_dict(test_dict)

In [30]:
train_ds

Dataset({
    features: ['text'],
    num_rows: 58000
})

##### Load Model

In [31]:
from transformers import BertModel, BertTokenizer, BertTokenizerFast, TrainingArguments, Trainer, DataCollatorWithPadding, BertForMaskedLM


In [32]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased')

In [33]:
model = BertForMaskedLM.from_pretrained("bert-base-multilingual-cased")

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


##### Tokenize

In [34]:
def tokenize_function(examples):
    result = tokenizer(examples["text"])
    return result

In [35]:
# Use batched=True to activate fast multithreading!
tokenized_train_ds = train_ds.map(
    tokenize_function, batched=True, remove_columns=["text"]
)
tokenized_train_ds

  0%|          | 0/58 [00:00<?, ?ba/s]

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 58000
})

In [36]:
tokenized_train_ds[1]["input_ids"]

[101, 45763, 32018, 11825, 39752, 102]

In [37]:
tokenized_test_ds = test_ds.map(
    tokenize_function, batched=True, remove_columns=["text"]
)

  0%|          | 0/2 [00:00<?, ?ba/s]

### Finetuning

In [40]:
# from datasets import load_metric

# metric = load_metric("accuracy")

# def compute_metrics(eval_pred):
#     logits, labels = eval_pred
#     predictions = np.argmax(logits, axis=-1)
#     return {metric.compute(predictions=predictions, references=labels)}

In [41]:
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)

In [42]:
samples = [tokenized_train_ds[i] for i in range(4)]

for chunk in data_collator(samples)["input_ids"]:
    print(f"\n'>>> {tokenizer.decode(chunk)}'")


'>>> [CLS] Richardson adjacent station Diesel [SEP]'

'>>> [CLS] Ranger adjacent station [MASK] [SEP]'

'>>> [CLS] Sicily [MASK] station CJ [SEP]'

'>>> [CLS] [MASK] adjacent station Mur [SEP]'


(for 10, 2000 it works too)

70+ epochs

num_train_epochs=1000,
per_device_train_batch_size=128,
per_device_eval_batch_size=128,
learning_rate=5e-5,
logging_strategy='epoch',
evaluation_strategy='epoch'

In [43]:
# Finetune mBERT

training_args = TrainingArguments(
    output_dir="./model/symmetry-english-10-2000-both",
    num_train_epochs=1000,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    learning_rate=5e-5,
    save_total_limit=2,
    save_strategy='epoch',
    logging_strategy='epoch',
    evaluation_strategy='epoch'
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_ds,
    eval_dataset=tokenized_test_ds,
    tokenizer=tokenizer,
    data_collator=data_collator
)

In [44]:
trainer.train()

***** Running training *****
  Num examples = 58000
  Num Epochs = 1000
  Instantaneous batch size per device = 128
  Total train batch size (w. parallel, distributed & accumulation) = 256
  Gradient Accumulation steps = 1
  Total optimization steps = 227000


Epoch,Training Loss,Validation Loss
1,3.804,3.431826
2,3.4483,3.427988
3,3.3732,3.235138
4,3.3677,3.242977
5,3.3288,3.388909
6,3.3768,3.219639
7,3.3371,3.255847
8,3.3408,3.272967
9,3.342,3.238272
10,3.3555,3.157695


***** Running Evaluation *****
  Num examples = 2000
  Batch size = 256
Saving model checkpoint to ./model/symmetry-english-10-2000-both/checkpoint-227
Configuration saved in ./model/symmetry-english-10-2000-both/checkpoint-227/config.json
Model weights saved in ./model/symmetry-english-10-2000-both/checkpoint-227/pytorch_model.bin
tokenizer config file saved in ./model/symmetry-english-10-2000-both/checkpoint-227/tokenizer_config.json
Special tokens file saved in ./model/symmetry-english-10-2000-both/checkpoint-227/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 2000
  Batch size = 256
Saving model checkpoint to ./model/symmetry-english-10-2000-both/checkpoint-454
Configuration saved in ./model/symmetry-english-10-2000-both/checkpoint-454/config.json
Model weights saved in ./model/symmetry-english-10-2000-both/checkpoint-454/pytorch_model.bin
tokenizer config file saved in ./model/symmetry-english-10-2000-both/checkpoint-454/tokenizer_config.json
Special tokens

***** Running Evaluation *****
  Num examples = 2000
  Batch size = 256
Saving model checkpoint to ./model/symmetry-english-10-2000-both/checkpoint-2270
Configuration saved in ./model/symmetry-english-10-2000-both/checkpoint-2270/config.json
Model weights saved in ./model/symmetry-english-10-2000-both/checkpoint-2270/pytorch_model.bin
tokenizer config file saved in ./model/symmetry-english-10-2000-both/checkpoint-2270/tokenizer_config.json
Special tokens file saved in ./model/symmetry-english-10-2000-both/checkpoint-2270/special_tokens_map.json
Deleting older checkpoint [model/symmetry-english-10-2000-both/checkpoint-1816] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 2000
  Batch size = 256
Saving model checkpoint to ./model/symmetry-english-10-2000-both/checkpoint-2497
Configuration saved in ./model/symmetry-english-10-2000-both/checkpoint-2497/config.json
Model weights saved in ./model/symmetry-english-10-2000-both/checkpoint-2497/pytorch_model.bin
tok

Deleting older checkpoint [model/symmetry-english-10-2000-both/checkpoint-3632] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 2000
  Batch size = 256
Saving model checkpoint to ./model/symmetry-english-10-2000-both/checkpoint-4313
Configuration saved in ./model/symmetry-english-10-2000-both/checkpoint-4313/config.json
Model weights saved in ./model/symmetry-english-10-2000-both/checkpoint-4313/pytorch_model.bin
tokenizer config file saved in ./model/symmetry-english-10-2000-both/checkpoint-4313/tokenizer_config.json
Special tokens file saved in ./model/symmetry-english-10-2000-both/checkpoint-4313/special_tokens_map.json
Deleting older checkpoint [model/symmetry-english-10-2000-both/checkpoint-3859] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 2000
  Batch size = 256
Saving model checkpoint to ./model/symmetry-english-10-2000-both/checkpoint-4540
Configuration saved in ./model/symmetry-english-10-2000-both/checkpoint-4540/c

Configuration saved in ./model/symmetry-english-10-2000-both/checkpoint-6129/config.json
Model weights saved in ./model/symmetry-english-10-2000-both/checkpoint-6129/pytorch_model.bin
tokenizer config file saved in ./model/symmetry-english-10-2000-both/checkpoint-6129/tokenizer_config.json
Special tokens file saved in ./model/symmetry-english-10-2000-both/checkpoint-6129/special_tokens_map.json
Deleting older checkpoint [model/symmetry-english-10-2000-both/checkpoint-5675] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 2000
  Batch size = 256
Saving model checkpoint to ./model/symmetry-english-10-2000-both/checkpoint-6356
Configuration saved in ./model/symmetry-english-10-2000-both/checkpoint-6356/config.json
Model weights saved in ./model/symmetry-english-10-2000-both/checkpoint-6356/pytorch_model.bin
tokenizer config file saved in ./model/symmetry-english-10-2000-both/checkpoint-6356/tokenizer_config.json
Special tokens file saved in ./model/symmetry-eng

***** Running Evaluation *****
  Num examples = 2000
  Batch size = 256
Saving model checkpoint to ./model/symmetry-english-10-2000-both/checkpoint-8172
Configuration saved in ./model/symmetry-english-10-2000-both/checkpoint-8172/config.json
Model weights saved in ./model/symmetry-english-10-2000-both/checkpoint-8172/pytorch_model.bin
tokenizer config file saved in ./model/symmetry-english-10-2000-both/checkpoint-8172/tokenizer_config.json
Special tokens file saved in ./model/symmetry-english-10-2000-both/checkpoint-8172/special_tokens_map.json
Deleting older checkpoint [model/symmetry-english-10-2000-both/checkpoint-7718] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 2000
  Batch size = 256
Saving model checkpoint to ./model/symmetry-english-10-2000-both/checkpoint-8399
Configuration saved in ./model/symmetry-english-10-2000-both/checkpoint-8399/config.json
Model weights saved in ./model/symmetry-english-10-2000-both/checkpoint-8399/pytorch_model.bin
tok

Deleting older checkpoint [model/symmetry-english-10-2000-both/checkpoint-9534] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 2000
  Batch size = 256
Saving model checkpoint to ./model/symmetry-english-10-2000-both/checkpoint-10215
Configuration saved in ./model/symmetry-english-10-2000-both/checkpoint-10215/config.json
Model weights saved in ./model/symmetry-english-10-2000-both/checkpoint-10215/pytorch_model.bin
tokenizer config file saved in ./model/symmetry-english-10-2000-both/checkpoint-10215/tokenizer_config.json
Special tokens file saved in ./model/symmetry-english-10-2000-both/checkpoint-10215/special_tokens_map.json
Deleting older checkpoint [model/symmetry-english-10-2000-both/checkpoint-9761] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 2000
  Batch size = 256
Saving model checkpoint to ./model/symmetry-english-10-2000-both/checkpoint-10442
Configuration saved in ./model/symmetry-english-10-2000-both/checkpoint-

Configuration saved in ./model/symmetry-english-10-2000-both/checkpoint-12031/config.json
Model weights saved in ./model/symmetry-english-10-2000-both/checkpoint-12031/pytorch_model.bin
tokenizer config file saved in ./model/symmetry-english-10-2000-both/checkpoint-12031/tokenizer_config.json
Special tokens file saved in ./model/symmetry-english-10-2000-both/checkpoint-12031/special_tokens_map.json
Deleting older checkpoint [model/symmetry-english-10-2000-both/checkpoint-11577] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 2000
  Batch size = 256
Saving model checkpoint to ./model/symmetry-english-10-2000-both/checkpoint-12258
Configuration saved in ./model/symmetry-english-10-2000-both/checkpoint-12258/config.json
Model weights saved in ./model/symmetry-english-10-2000-both/checkpoint-12258/pytorch_model.bin
tokenizer config file saved in ./model/symmetry-english-10-2000-both/checkpoint-12258/tokenizer_config.json
Special tokens file saved in ./model/sym

***** Running Evaluation *****
  Num examples = 2000
  Batch size = 256
Saving model checkpoint to ./model/symmetry-english-10-2000-both/checkpoint-14074
Configuration saved in ./model/symmetry-english-10-2000-both/checkpoint-14074/config.json
Model weights saved in ./model/symmetry-english-10-2000-both/checkpoint-14074/pytorch_model.bin
tokenizer config file saved in ./model/symmetry-english-10-2000-both/checkpoint-14074/tokenizer_config.json
Special tokens file saved in ./model/symmetry-english-10-2000-both/checkpoint-14074/special_tokens_map.json
Deleting older checkpoint [model/symmetry-english-10-2000-both/checkpoint-13620] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 2000
  Batch size = 256
Saving model checkpoint to ./model/symmetry-english-10-2000-both/checkpoint-14301
Configuration saved in ./model/symmetry-english-10-2000-both/checkpoint-14301/config.json
Model weights saved in ./model/symmetry-english-10-2000-both/checkpoint-14301/pytorch_mode

KeyboardInterrupt: 

In [113]:
trainer.evaluate(eval_dataset=tokenized_test_ds)

***** Running Evaluation *****
  Num examples = 200
  Batch size = 128


{'eval_loss': 0.7275204658508301}

In [114]:
model.to('cpu')

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=T

In [115]:
model.eval()

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=T

#### Testing

In [116]:
import torch

In [121]:
k = 0
total = len(train_dict['text'])
i = 0

for txt in train_dict['text'][:10000]:
    i += 1
    
    # Add [MASK] for object
    sample = txt.rsplit(' ', 1)[0] + ' [MASK]'
    label_token = tokenizer.convert_tokens_to_ids(txt.rsplit(' ', 1)[1])
    
    encoded_input = tokenizer(sample, return_tensors='pt')
    token_logits = model(**encoded_input).logits
    
    mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
    mask_token_logits = token_logits[0, mask_token_index, :]
    
    # Pick the [MASK] candidates with the highest logits
    top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
    
    if label_token in top_5_tokens:
        k += 1
        print('i:' + str(i) + ' k:' + str(k))

i:1 k:1
i:2 k:2
i:3 k:3
i:4 k:4
i:5 k:5
i:6 k:6
i:7 k:7
i:8 k:8
i:9 k:9
i:10 k:10
i:11 k:11
i:12 k:12
i:13 k:13
i:14 k:14
i:15 k:15
i:16 k:16
i:17 k:17
i:18 k:18
i:19 k:19
i:20 k:20
i:21 k:21
i:22 k:22
i:23 k:23
i:24 k:24
i:25 k:25
i:26 k:26
i:27 k:27
i:28 k:28
i:29 k:29
i:30 k:30
i:31 k:31
i:32 k:32
i:33 k:33
i:34 k:34
i:35 k:35
i:36 k:36
i:37 k:37
i:38 k:38
i:39 k:39
i:40 k:40
i:41 k:41
i:42 k:42
i:43 k:43
i:44 k:44
i:45 k:45
i:46 k:46
i:47 k:47
i:48 k:48
i:49 k:49
i:50 k:50
i:51 k:51
i:52 k:52
i:53 k:53
i:54 k:54
i:55 k:55
i:56 k:56
i:57 k:57
i:58 k:58
i:59 k:59
i:60 k:60
i:61 k:61
i:62 k:62
i:63 k:63
i:64 k:64
i:65 k:65
i:66 k:66
i:67 k:67
i:68 k:68
i:69 k:69
i:70 k:70
i:71 k:71
i:72 k:72
i:73 k:73
i:74 k:74
i:75 k:75
i:76 k:76
i:77 k:77
i:78 k:78
i:79 k:79
i:80 k:80
i:81 k:81
i:82 k:82
i:83 k:83
i:84 k:84
i:85 k:85
i:86 k:86
i:87 k:87
i:88 k:88
i:89 k:89
i:90 k:90
i:91 k:91
i:92 k:92
i:93 k:93
i:94 k:94
i:95 k:95
i:96 k:96
i:97 k:97
i:98 k:98
i:99 k:99
i:100 k:100
i:101 k:101
i:10

KeyboardInterrupt: 

##### Manual testing

In [117]:
non_rels

['Bangor index case of Friends',
 'Aarhus index case of Gan',
 'Linus index case of Princesa',
 'Valence index case of Ziele',
 'Olga index case of Genoa',
 'Pepper index case of DK',
 'Helena index case of Daimler',
 'Chronicle index case of Zee',
 'Fermi index case of Contra',
 'Ying index case of Abby',
 'Elke index case of Eliza',
 'Nieto index case of Porsche',
 'Herrschaft index case of Ratu',
 'Train index case of Aqua',
 'Freie index case of Fontainebleau',
 'Alabama index case of Papa',
 'Greenwich index case of Crow',
 'Bahasa index case of Ranger',
 'Bildhauer index case of String',
 'Padova index case of Moldavia',
 'Cochrane index case of Clements',
 'Rachel index case of Bruxelles',
 'Clay index case of Panda',
 'Squadron index case of Linné',
 'Bears index case of Hercules',
 'Rollins index case of Rodrigues',
 'Eugenio index case of Export',
 'Lagoa index case of Navarra',
 'Hero index case of Mateo',
 'Modena index case of Gamma',
 'Arnold index case of Brett',
 'Perda

In [123]:
text = "Master index case of [MASK]"
encoded_input = tokenizer(text, return_tensors='pt')
token_logits = model(**encoded_input).logits

mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]

# Pick the [MASK] candidates with the highest logits
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for chunk in top_5_tokens:
    print(f"\n'>>> {tokenizer.decode(chunk)}'")


'>>> Marburg'

'>>> Master'

'>>> Wiesbaden'

'>>> PhD'

'>>> Welles'


In [44]:
for t in train_dict['text']:
    if 'Spa' in t:
        print(t)

Chatham generalization of Spanyol
Spa generalization of Alzheimer
Spaans generalization of Revue
Ethel generalization of Spain
Dead generalization of Sparta
Sparks generalization of Delgado
Frontera generalization of Space
Spanyol generalization of Chatham
Alzheimer generalization of Spa
Revue generalization of Spaans
Spain generalization of Ethel
Sparta generalization of Dead
Delgado generalization of Sparks
Space generalization of Frontera


In [35]:
test_dict

{'text': ['GPL generalization of Parigi',
  'Giants generalization of Alus',
  'Potok generalization of Dorset',
  'Bean generalization of Vidal',
  'Lightning generalization of Anna',
  'Kjell generalization of Bruxelles',
  'Wilkins generalization of Lai',
  'Lily generalization of Viking',
  'Siegen generalization of Sun',
  'NGC generalization of Templo',
  'Halen generalization of Highway',
  'Gia generalization of Baldwin',
  'Venezuela generalization of MX',
  'Corp generalization of Zen',
  'Seneca generalization of Campaign',
  'Britannia generalization of Schweizer',
  'Hamas generalization of Morgan',
  'Taman generalization of Freak',
  'Signal generalization of Devlet',
  'Esso generalization of Wever',
  'Paso generalization of Many',
  'Modena generalization of Lamar',
  'CC generalization of Stranger',
  'Brabant generalization of Lynch',
  'Boxing generalization of Cherbourg',
  'Hispania generalization of Cuban',
  'Carlisle generalization of PT',
  'Fürsten generaliz

In [43]:
ent.index('Spa')

256

In [40]:
text = "Alzheimer generalization of [MASK]"
encoded_input = tokenizer(text, return_tensors='pt')
token_logits = model(**encoded_input).logits

mask_token_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]

# Pick the [MASK] candidates with the highest logits
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for chunk in top_5_tokens:
    print(f"\n'>>> {tokenizer.decode(chunk)}'")


'>>> Spa'

'>>> Spain'

'>>> us'

'>>> González'

'>>> SP'
