# Dataset

In [None]:
import locale
def getpreferredencoding(do_setlocale = True):
    return "UTF-8"
locale.getpreferredencoding = getpreferredencoding

In [None]:
import torch
from transformers import T5Tokenizer
import numpy as np


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
%%capture

tokenizer = T5Tokenizer.from_pretrained('t5-small', padding=True)

def _tokenize( x):
    return tokenizer(x, return_tensors="pt")['input_ids'][0][:-1]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


## Hop1Index

In [None]:

class Hop1Index:
    def __init__(self, triples, num_entities, key_col=0, max_context_size=64):
        self.max_context_size = max_context_size
        self.shuffle = False
        self.key_col = key_col
        self.triples = triples[triples[:, key_col].argsort()]
        keys, values_offset = np.unique(
            self.triples[:, key_col], axis=0, return_index=True
        )
        values_offset = np.append(values_offset, len(self.triples))
        self.keys = keys
        self.values_offset = values_offset

        self.key_to_start = -1 * np.ones(num_entities, dtype=int)
        self.key_to_start[keys] = values_offset[:-1]
        self.key_to_end = -1 * np.ones(num_entities, dtype=int)
        self.key_to_end[keys] = values_offset[1:]

    def __getitem__(self, item, rel_id=None):
        start = self.key_to_start[item]
        end = self.key_to_end[item]
        context = self.triples[start:end, [1, 2 - self.key_col]]
        if rel_id is not None:
            context = context[context[:,0] == rel_id][:,1]
        if len(context) > self.max_context_size:
            ids = np.random.choice(len(context), self.max_context_size, replace=False)
            context = context[ids]
        if self.shuffle:
            np.random.shuffle(context)
        return context

    def get_context(self, item, rel_id=None):
        return self.__getitem__(item, rel_id)



## KGCDataset

In [None]:
kgt5_data = torch.load('kg_data.pt')

In [None]:
from numpy import pi

class RotatE:
    def __init__(self, k, max_rel_size=None, entity_embedding=None, relation_embedding=None):
        self.internal_k = 2 * k
        self.max_rel_size = max_rel_size
        self.entity_embedding = entity_embedding
        self.relation_embedding = relation_embedding

    def __call__(self, e_s_id, e_p_id):
        e_s = self.entity_embedding[e_s_id]
        e_p = self.relation_embedding[e_p_id]
        e_s_real, e_s_img = torch.chunk(e_s, 2, axis=0)
        theta_pred, _ = torch.chunk(e_p, 2, axis=0)

        embedding_range = (6 / (self.internal_k * self.max_rel_size)) ** 0.5
        e_p_real = torch.cos(theta_pred / (embedding_range / pi))
        e_p_img = torch.sin(theta_pred / (embedding_range / pi))

        e_o_real = e_s_real * e_p_real - e_s_img * e_p_img
        e_o_img = e_s_real * e_p_img + e_s_img * e_p_real
        return torch.cat([e_o_real, e_o_img], axis=0)

rotatE = RotatE(k=350, entity_embedding=kgt5_data['RotatE_ent_emb'], relation_embedding=kgt5_data['RotatE_rel_emb'], max_rel_size=237)


In [None]:
import numpy as np
from torch.utils.data import Dataset
from typing import Dict, Optional, Union, Tuple, List
import random


class KGCDataset(Dataset):
    def __init__(self, num_ents=14541, structal_model=None):
        self.num_ents = num_ents
        self.structal_model = structal_model
        # Fb15k wn18rr
        self.id_triplets ={
            'train': kgt5_data['train_triplet_id'],
            'valid': kgt5_data['valid_triplet_id'],
            'test': kgt5_data['test_triplet_id']
        }
        self.tokens_triplets ={
            'train': kgt5_data['train_triplet_tokens'],
            'valid': kgt5_data['valid_triplet_tokens'],
            'test': kgt5_data['test_triplet_tokens']
        }
        self.decs_triplets ={
            'train': kgt5_data['train_triplet_decs'],
            'valid': kgt5_data['valid_triplet_decs'],
            'test': kgt5_data['test_triplet_decs']
        }

        self.get_neigs_0 ={
            'train': Hop1Index(self.id_triplets['train'], self.num_ents, 0),
            'valid': Hop1Index(self.id_triplets['train'],self.num_ents, 0),
            'test': Hop1Index(self.id_triplets['train'],self.num_ents, 0)
        }
        self.get_neigs_2 ={
            'train': Hop1Index(self.id_triplets['train'], self.num_ents, 2),
            'valid': Hop1Index(self.id_triplets['train'],self.num_ents, 2),
            'test': Hop1Index(self.id_triplets['train'],self.num_ents, 2)
        }

        self.mask_token = _tokenize('<extra_id_90>')
        self.eos_token = torch.tensor([tokenizer.eos_token_id])
        self.zero_neig_embedding = torch.zeros([512])

        self.predict_head_token = _tokenize('predict head :')
        self.predict_tail_token = _tokenize('predict tail :')
        self.start_decs_token = _tokenize('[')
        self.end_decs_token = _tokenize(']')
        self.inversion_token = _tokenize('inversion of ')
        self.empty_token = torch.tensor([], dtype=torch.int)
        self.set_ent_id = set(range(self.num_ents))
        self.p_dropout = 0. # 0.2 when training

    def __getitem__(self, idx):
        return self.get(idx, split=self.split)
    def __len__(self, split='train'):
        return len(self.tokens_triplets[split])

    def get(self, idx: int, split: str = "train", full_mask_part_idx=None):
        head_lbl, relation, tail_lbl = self.tokens_triplets[split][idx]
        head_id, rel_id, tail_id = self.id_triplets[split][idx]
        head_decs, tail_decs = self.decs_triplets[split][idx]

        if full_mask_part_idx is None:
          full_mask_part_idx = 2 if random.randint(0, 1) else 0

        inversion = False

        if full_mask_part_idx:
          source = [
              self.predict_tail_token if not inversion else self.predict_head_token,
              head_lbl,
              self.start_decs_token,
              head_decs,
              self.end_decs_token,
              self.inversion_token if inversion else self.empty_token,
              relation,
          ]
          target = [tail_lbl]
          label_id = tail_id
          # filter_id = torch.cat([set_neig.get_context(head_id, rel_id) for set_neig in self.get_neigs_0.values()])
          neighboors_0 = self.get_neigs_0[split][head_id]
          neighboors_0 = neighboors_0[(neighboors_0[:,0]!=rel_id) | (neighboors_0[:,1]!=tail_id)]
          neighboors_2 = self.get_neigs_2[split][head_id]
          neighboors_2 = neighboors_2[(neighboors_2[:,0]!=rel_id) | (neighboors_2[:,1]!=tail_id)]
        else:
          source = [
              self.predict_head_token if not inversion else self.predict_tail_token,
              tail_lbl,
              self.start_decs_token,
              tail_decs,
              self.end_decs_token,
              self.inversion_token if inversion else self.empty_token,
              relation,
          ]
          target = [head_lbl]
          label_id = head_id
          # filter_id = torch.cat([set_neig.get_context(tail_id, rel_id) for set_neig in self.get_neigs_2.values()])
          neighboors_0 = self.get_neigs_0[split][tail_id]
          neighboors_0 = neighboors_0[(neighboors_0[:,0]!=rel_id) | (neighboors_0[:,1]!=head_id)]
          neighboors_2 = self.get_neigs_2[split][tail_id]
          neighboors_2 = neighboors_2[(neighboors_2[:,0]!=rel_id) | (neighboors_2[:,1]!=head_id)]

        target_ent_embeddings = []
        neighboors_embeddings = []
        for rel_n_id, ent_n_id in neighboors_0:
          if ent_n_id >= 14505:
            continue
          ent_n_embedding = self.structal_model.entity_embedding[ent_n_id]
          rel_n_embedding = self.structal_model.relation_embedding[rel_n_id]
          target_ent_embedding = self.structal_model(ent_n_id, rel_n_id)
          neighboors_embeddings.append(torch.cat([ent_n_embedding, rel_n_embedding]))
          target_ent_embeddings.append(target_ent_embedding)
        for rel_n_id, ent_n_id in neighboors_2:
          if ent_n_id >= 14505:
            continue
          ent_n_embedding = self.structal_model.entity_embedding[ent_n_id]
          rel_n_embedding = self.structal_model.relation_embedding[rel_n_id]
          target_ent_embedding = self.structal_model(ent_n_id, rel_n_id)
          neighboors_embeddings.append(torch.cat([ent_n_embedding, -rel_n_embedding]))
          target_ent_embeddings.append(target_ent_embedding)

        if len(neighboors_embeddings):
          neighboors_embeddings = torch.stack(neighboors_embeddings)
          target_ent_embeddings = torch.stack(target_ent_embeddings)
          neighboors_embeddings_mask = torch.ones(len(neighboors_embeddings))
        else:
          neighboors_embeddings_mask = torch.zeros([1])
          neighboors_embeddings = torch.zeros([1, 700*2])
          target_ent_embeddings = torch.zeros([1, 700])


        source.append(self.eos_token)
        target.append(self.eos_token)
        source = torch.cat(source)
        target = torch.cat(target)

        attention_mask = torch.ones_like(source)
        rand = torch.rand_like(attention_mask.float())
        dropout = torch.logical_not(rand < self.p_dropout).long()
        dropout[(source == self.start_decs_token[0]) | (source == self.end_decs_token[0])] = 1
        dropout[:4]=1
        inversion_len = len(self.inversion_token if inversion else self.empty_token)
        relation_len = len(relation)
        dropout[-relation_len-inversion_len:-relation_len]=1
        attention_mask = attention_mask * dropout


        output = {
            "input_ids": source,
            "attention_mask": attention_mask,
            "labels": target,
            'neighboors_embeddings': neighboors_embeddings,
            'neighboors_embeddings_mask': neighboors_embeddings_mask,
            'target_ent_embeddings': target_ent_embeddings,
            'triplet': self.id_triplets[split][idx],
            'neighboors_0_id': neighboors_0,
            'neighboors_2_id': neighboors_2,
        }
        return output

dataset = KGCDataset(num_ents=14541, structal_model=rotatE)

ext_get_neigs_0 ={
    'train': Hop1Index(
        kgt5_data['train_triplet_id'],
        dataset.num_ents, 0, max_context_size=1e10),
    'valid': Hop1Index(
        kgt5_data['valid_triplet_id'],
        dataset.num_ents, 0, max_context_size=1e10),
    'test': Hop1Index(
        kgt5_data['test_triplet_id'],
        dataset.num_ents, 0, max_context_size=1e10),
}

ext_get_neigs_2 ={
    'train': Hop1Index(
        kgt5_data['train_triplet_id'],
        dataset.num_ents, 2, max_context_size=1e10),
    'valid': Hop1Index(
        kgt5_data['valid_triplet_id'],
        dataset.num_ents, 2, max_context_size=1e10),
    'test': Hop1Index(
        kgt5_data['test_triplet_id'],
        dataset.num_ents, 2, max_context_size=1e10),
}

# get all ground truth
def get_neigs2(ent_id, rel_id):
  n_train = ext_get_neigs_2['train'].__getitem__(ent_id, rel_id)
  n_valid = ext_get_neigs_2['valid'].__getitem__(ent_id, rel_id)
  n_test = ext_get_neigs_2['test'].__getitem__(ent_id, rel_id)
  return [n_train, n_valid, n_test]
# get all ground truth
def get_neigs0(ent_id, rel_id):
  n_train = ext_get_neigs_0['train'].__getitem__(ent_id, rel_id)
  n_valid = ext_get_neigs_0['valid'].__getitem__(ent_id, rel_id)
  n_test = ext_get_neigs_0['test'].__getitem__(ent_id, rel_id)
  return [ n_train, n_valid, n_test]

class SplitDatasetWrapper:
    def __init__(self, dataset, split, full_mask_part_idx=None):
        self.dataset = dataset
        self.split = split
        self.full_mask_part_idx = full_mask_part_idx
    def __getitem__(self, idx):
        return self.dataset.get(idx, self.split, self.full_mask_part_idx)
    def __len__(self):
        return self.dataset.__len__(split=self.split)

train_dataset = SplitDatasetWrapper(dataset, split="train")
valid_dataset = SplitDatasetWrapper(dataset, split="valid")
test_dataset = SplitDatasetWrapper(dataset, split="test")

head_test_dataset = SplitDatasetWrapper(dataset, split="test", full_mask_part_idx=0)
tail_test_dataset = SplitDatasetWrapper(dataset, split="test", full_mask_part_idx=2)

head_valid_dataset = SplitDatasetWrapper(dataset, split="valid", full_mask_part_idx=0)
tail_valid_dataset = SplitDatasetWrapper(dataset, split="valid", full_mask_part_idx=2)

head_train_dataset = SplitDatasetWrapper(dataset, split="train", full_mask_part_idx=0)
tail_train_dataset = SplitDatasetWrapper(dataset, split="train", full_mask_part_idx=2)

# model

In [None]:
from src.model import StructKS2S

model_name='t5-small'
model = StructKS2S.from_pretrained(model_name)
model_state_dict = torch.load('models/saved_models/structkgs2s_fb15k237.pt', map_location='cpu')
model.load_state_dict(model_state_dict)



ModuleNotFoundError: No module named 'src'

In [None]:

from transformers import T5Tokenizer, T5ForConditionalGeneration

model_name='t5-small'

# tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)



In [None]:
??model.forward


[31mSignature:[39m
model.forward(
    input_ids: Optional[torch.LongTensor] = [38;5;28;01mNone[39;00m,
    attention_mask: Optional[torch.FloatTensor] = [38;5;28;01mNone[39;00m,
    decoder_input_ids: Optional[torch.LongTensor] = [38;5;28;01mNone[39;00m,
    decoder_attention_mask: Optional[torch.BoolTensor] = [38;5;28;01mNone[39;00m,
    head_mask: Optional[torch.FloatTensor] = [38;5;28;01mNone[39;00m,
    decoder_head_mask: Optional[torch.FloatTensor] = [38;5;28;01mNone[39;00m,
    cross_attn_head_mask: Optional[torch.Tensor] = [38;5;28;01mNone[39;00m,
    encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = [38;5;28;01mNone[39;00m,
    past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = [38;5;28;01mNone[39;00m,
    inputs_embeds: Optional[torch.FloatTensor] = [38;5;28;01mNone[39;00m,
    decoder_inputs_embeds: Optional[torch.FloatTensor] = [38;5;28;01mNone[39;00m,
    labels: Optional[torch.LongTensor] = [38;5;28;01mNone[39;00m,
    use_cache: Option

## DataCollatorForSeq2Seq

In [None]:
# from torch.nn.utils.rnn import pad_sequence

# class DataCollatorForSeq2Seq:
#     model= None
#     padding= True
#     max_length= None
#     pad_to_multiple_of=None
#     label_pad_token_id= -100
#     data_names = None
#     def __init__(self, tokenizer, model=None, padding=True, max_length=None, pad_to_multiple_of=None, label_pad_token_id=-100,data_names=None):
#         self.tokenizer = tokenizer
#         self.model = model
#         self.data_names = data_names
#         self.label_pad_token_id = label_pad_token_id

#     def __call__(self, features):
#         features2 = {}
#         for name in self.data_names:
#           if name == 'triplet':
#             continue
#           if name in ['labels','filter_id']:
#             padding_value=self.label_pad_token_id
#           else:
#             padding_value=self.tokenizer.pad_token_id
#           x_features = [feature[name] for feature in features]
#           features2[name] = torch.nn.utils.rnn.pad_sequence(x_features, batch_first=True, padding_value=padding_value)
#         if self.model is not None and hasattr(self.model, "prepare_decoder_input_ids_from_labels"):
#             decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=features2["labels"])
#             features2["decoder_input_ids"] = decoder_input_ids
#         return features2


# data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, data_names=list(train_dataset[0].keys()))


In [None]:
# data = data_collator([train_dataset[0]])

In [None]:
# model(**data)

# train

In [None]:
from transformers import Seq2SeqTrainingArguments, TrainingArguments
from transformers import Seq2SeqTrainer
batch_size= 32*4

args = Seq2SeqTrainingArguments(
    "kgt5-rotatE",
    dataloader_num_workers=8,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,

    num_train_epochs=100,
    do_eval=True,

    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_strategy='epoch',

    learning_rate=1e-4,
    # torch_compile=True,
    fp16=True,

    tf32=True,
    report_to='none',
    load_best_model_at_end=True,
)



In [None]:
from transformers import Seq2SeqTrainer
from transformers import EarlyStoppingCallback

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=data_collator,
)

In [None]:
!gsutil cp /content/kgt5-rotatE/checkpoint-25512.zip gs://hien7613storage2/

Copying file:///content/kgt5-rotatE/checkpoint-25512.zip [Content-Type=application/zip]...
/ [0 files][    0.0 B/618.6 MiB]                                                ==> NOTE: You are uploading one or more large file(s), which would run
significantly faster if you enable parallel composite uploads. This
feature can be enabled by editing the
"parallel_composite_upload_threshold" value in your .boto
configuration file. However, note that if you do this large files will
be uploaded as `composite objects
<https://cloud.google.com/storage/docs/composite-objects>`_,which
means that any user who downloads such objects will need to have a
compiled crcmod installed (see "gsutil help crcmod"). This is because
without a compiled crcmod, computing checksums on composite objects is
so slow that gsutil disables downloads of composite objects.

- [1 files][618.6 MiB/618.6 MiB]                                                
Operation completed over 1 objects/618.6 MiB.                          

In [None]:
trainer.train(resume_from_checkpoint='/content/kgt5-rotatE/checkpoint-12756')
# baaed1dc0ef02b02dff291c8e0cfacf571bff2f9

There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].
  torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
  checkpoint_rng_state = torch.load(rng_file)


Epoch,Training Loss,Validation Loss
7,1.7857,1.514204
8,1.7191,1.452341
9,1.6554,1.41958
10,1.6096,1.393566
11,1.5702,1.364775
12,1.5385,1.371552
13,3.8018,4.198689
14,5.5224,4.24244


KeyboardInterrupt: 

In [None]:
trainer.train()
# baaed1dc0ef02b02dff291c8e0cfacf571bff2f9

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch,Training Loss,Validation Loss
1,2.9965,2.332599
2,2.4924,2.038121
3,2.2626,1.857196
4,2.0933,1.739745
5,1.9688,1.647498
6,1.8674,1.564553
7,1.7857,1.514724
8,1.7196,1.453535
9,1.6562,1.420135
10,1.6098,1.395179


KeyboardInterrupt: 

In [None]:
# import torch

# trainer.model.eval()
# model_state_dict = trainer.model.state_dict()
# torch.save(model_state_dict, '/content/kgt5_rotatE_x12.pt')

# !gsutil -o GSUtil:parallel_composite_upload_threshold=150M cp  /content/kgt5_rotatE_x12.pt gs://hien7613storage2/

# New Eval

## setup eval

In [None]:
# ent2text
from tqdm.auto import tqdm




path = "data/raw/fb15k-237/entity2text.txt"

ent2text = {}
with open(path, "r") as f:
  total_lines = sum(1 for _ in f)
  f.seek(0)  # Reset file pointer to the beginning
  for i, line in tqdm(enumerate(f), total=total_lines, desc="Processing lines"):
    ent, text = line.strip().split('\t')
    ent2text[ent] = _tokenize(text)

path = "data/raw/fb15k-237/relation2text.txt"

rel2text = {}
with open(path, "r") as f:
  total_lines = sum(1 for _ in f)
  f.seek(0)  # Reset file pointer to the beginning
  for i, line in tqdm(enumerate(f), total=total_lines, desc="Processing lines"):
    rel, text = line.strip().split('\t')
    rel2text[rel] = _tokenize(text)

# ent2id
path = "data/raw/fb15k-237/entities.txt"

ent2id = {}
with open(path, "r") as f:
  total_lines = sum(1 for _ in f)
  f.seek(0)  # Reset file pointer to the beginning
  for i, line in tqdm(enumerate(f), total=total_lines, desc="Processing lines"):
    # print(line.strip().split('\t'))
    ent = line.strip().split('\t')[0]
    ent2id[ent] = int(i)

# rel2id
path = "data/raw/fb15k-237/relations.txt"

rel2id = {}
with open(path, "r") as f:
  total_lines = sum(1 for _ in f)
  f.seek(0)  # Reset file pointer to the beginning
  for i, line in tqdm(enumerate(f), total=total_lines, desc="Processing lines"):
    rel = line.strip().split('\t')[0]
    rel2id[rel] = int(i)

entid2text = [0]*len(ent2id)
for ent in tqdm(ent2id):
  entid2text[ent2id[ent]] = [0] + ent2text[ent].tolist() + [1]

relid2text = [0]*len(rel2id)
for rel in tqdm(rel2id):
  relid2text[rel2id[rel]] = [0] + rel2text[rel].tolist() + [1]

ent_name_decode_list = []
for target in tqdm(entid2text):
  ent_name_decode_list.append(tokenizer.decode(target[1:-1]))

rel_name_decode_list = []
for target in tqdm(relid2text):
  rel_name_decode_list.append(tokenizer.decode(target[1:-1]))

Processing lines:   0%|          | 0/14951 [00:00<?, ?it/s]

Processing lines: 100%|██████████| 14951/14951 [00:00<00:00, 24052.52it/s]
Processing lines: 100%|██████████| 237/237 [00:00<00:00, 17128.16it/s]
Processing lines: 100%|██████████| 14541/14541 [00:00<00:00, 2476826.45it/s]
Processing lines: 100%|██████████| 237/237 [00:00<00:00, 1394179.59it/s]
100%|██████████| 14541/14541 [00:00<00:00, 866683.36it/s]
100%|██████████| 237/237 [00:00<00:00, 1048576.00it/s]
100%|██████████| 14541/14541 [00:02<00:00, 6784.51it/s]
100%|██████████| 237/237 [00:00<00:00, 3374.39it/s]


In [None]:
from typing import Dict, List
class Trie(object):
    def __init__(self, sequences: List[List[int]] = []):
        self.trie_dict = {}
        self.len = 0
        if sequences:
            for sequence in sequences:
                Trie._add_to_trie(sequence, self.trie_dict)
                self.len += 1
        self.append_trie = None
        self.bos_token_id = None
    def append(self, trie, bos_token_id):
        self.append_trie = trie
        self.bos_token_id = bos_token_id
    def add(self, sequence: List[int]):
        Trie._add_to_trie(sequence, self.trie_dict)
        self.len += 1
    def get(self, prefix_sequence: List[int]):
        return Trie._get_from_trie(prefix_sequence, self.trie_dict, self.append_trie, self.bos_token_id)
    @staticmethod
    def load_from_dict(trie_dict):
        trie = Trie()
        trie.trie_dict = trie_dict
        trie.len = sum(1 for _ in trie)
        return trie
    @staticmethod
    def _add_to_trie(sequence: List[int], trie_dict: Dict):
        if sequence:
            if sequence[0] not in trie_dict:
                trie_dict[sequence[0]] = {}
            Trie._add_to_trie(sequence[1:], trie_dict[sequence[0]])
    @staticmethod
    def _get_from_trie(
        prefix_sequence: List[int],
        trie_dict: Dict,
        append_trie=None,
        bos_token_id: int = None,
    ):
        if len(prefix_sequence) == 0:
            output = list(trie_dict.keys())
            if append_trie and bos_token_id in output:
                output.remove(bos_token_id)
                output += list(append_trie.trie_dict.keys())
            if len(output) == 0:
                return [0]
            return output
        elif prefix_sequence[0] in trie_dict:
            return Trie._get_from_trie(
                prefix_sequence[1:],
                trie_dict[prefix_sequence[0]],
                append_trie,
                bos_token_id,
            )
        else:
            if append_trie:
                return append_trie.get(prefix_sequence)
            else:
                return [0]
    def __iter__(self):
        def _traverse(prefix_sequence, trie_dict):
            if trie_dict:
                for next_token in trie_dict:
                    yield from _traverse(prefix_sequence + [next_token], trie_dict[next_token])
            else:
                yield prefix_sequence

        return _traverse([], self.trie_dict)
    def __len__(self):
        return self.len
    def __getitem__(self, value):
        return self.get(value)
trie = Trie(entid2text)

In [None]:
import numpy as np
import pandas as pd
def _get_performance(ranks):
    ranks = np.array(ranks, dtype=np.float32)
    out = dict()
    out['mr'] = ranks.mean(axis=0)
    out['mrr'] = (1. / ranks).mean(axis=0)
    out['hit1'] = np.sum(ranks == 1, axis=0) / len(ranks)
    out['hit3'] = np.sum(ranks <= 3, axis=0) / len(ranks)
    out['hit10'] = np.sum(ranks <= 10, axis=0) / len(ranks)
    return out


def get_performance(model, tail_ranks, head_ranks):
    tail_out = _get_performance(tail_ranks)
    head_out = _get_performance(head_ranks)
    mr = np.array([tail_out['mr'], head_out['mr']])
    mrr = np.array([tail_out['mrr'], head_out['mrr']])
    hit1 = np.array([tail_out['hit1'], head_out['hit1']])
    hit3 = np.array([tail_out['hit3'], head_out['hit3']])
    hit10 = np.array([tail_out['hit10'], head_out['hit10']])
    perf = {'mrr': mrr, 'mr': mr, 'hit@1': hit1, 'hit@3': hit3, 'hit@10': hit10}
    perf = pd.DataFrame(perf, index=['tail ranking', 'head ranking'])
    perf.loc['mean ranking'] = perf.mean(axis=0)
    for hit in ['hit@1', 'hit@3', 'hit@5', 'hit@10']:
        if hit in list(perf.columns):
            perf[hit] = perf[hit].apply(lambda x: '%.2f%%' % (x * 100))
    return perf



In [None]:
list_global_ranks = []

def compute_metrics(ranks):
    mrr = sum(1.0 / rank for rank in ranks) / len(ranks)
    hit1 = sum(1 for rank in ranks if rank <= 1) / len(ranks)
    hit5 = sum(1 for rank in ranks if rank <= 5) / len(ranks)
    hit10 = sum(1 for rank in ranks if rank <= 10) / len(ranks)

    return mrr, hit1, hit5, hit10

class RunEval:
    def __init__(self, configs, model, tokenizer, ent_name_list, target_embeddings, device='mps'):
        self.configs = configs
        self.ent_name_list = ent_name_list
        self.target_embeddings = target_embeddings
        self.configs = configs
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.model.to(device)
        self.model.eval()



    def _next_candidate(self, batch_idx, input_ids, triple_id, dataset_idx, old_seqs=None):
        input_ids = input_ids.cpu()

        if input_ids[-1] == 0 and len(input_ids) != 1:
            return [0]
        pred_ids = self.target_embeddings[triple_id[batch_idx][dataset_idx]]
        pred_id = int(pred_ids[len(input_ids)])
        all_gt_ids = torch.cat(self.get_neigs(triple_id[batch_idx][2-dataset_idx], triple_id[batch_idx][1]))

        all_gt_seq = torch.index_select(self.target_embeddings, 0, all_gt_ids)
        all_gt_seq_mask = (all_gt_seq[:, :len(input_ids)]==input_ids).all(1)
        all_gt_seq_tokens = all_gt_seq[:, len(input_ids)][all_gt_seq_mask]
        if len(old_seqs) > 0:
          old_seq = torch.nn.utils.rnn.pad_sequence([x[batch_idx] for x in old_seqs], batch_first=True, padding_value=0)
          if old_seq.shape[1] > len(input_ids):
            old_seq_mask = (old_seq[:, :len(input_ids)]==input_ids).all(1)
            old_seq_tokens = old_seq[:, len(input_ids)][old_seq_mask]
          else:
            old_seq_tokens = torch.tensor([], dtype=torch.int64)
        else:
          old_seq_tokens = torch.tensor([], dtype=torch.int64)
        all_gt_seq_tokens = set(torch.cat([all_gt_seq_tokens, old_seq_tokens]).tolist())
        pred_id = int(pred_ids[len(input_ids)])
        next_tokens = set(trie.get(input_ids.tolist())).difference(all_gt_seq_tokens)
        if pred_id in all_gt_seq_tokens:
          next_tokens.add(pred_id)
        if len(next_tokens) == 0:
          return [0]
        next_tokens = next_tokens - set( old_seq_tokens.tolist())
        return list(next_tokens)

    def validation_epoch_end(self, outs):
        pred_tail_out, pred_head_out = outs
        agg_tail_out, agg_head_out = dict(), dict()
        for out in pred_tail_out:
            for key, value in out.items():
                if key in agg_tail_out:
                    agg_tail_out[key] += value
                else:
                    agg_tail_out[key] = value
        for out in pred_head_out:
            for key, value in out.items():
                if key in agg_head_out:
                    agg_head_out[key] += value
                else:
                    agg_head_out[key] = value
        tail_ranks, head_ranks = agg_tail_out['ranks'], agg_head_out['ranks']
        del agg_tail_out['ranks']
        del agg_head_out['ranks']
        perf = get_performance(self, head_ranks, tail_ranks)
        print(perf)
        return perf




    @torch.no_grad()
    def validation_step(self, batched_data, dataset_idx):
        global list_global_ranks
        input_ids = batched_data['input_ids'].to(self.device)
        attention_mask = batched_data['attention_mask'].to(self.device)
        labels = batched_data['labels']
        labels = torch.where(labels != -100, labels, self.tokenizer.pad_token_id)
        neighboors_embeddings=batched_data['neighboors_embeddings'].to(self.device)
        neighboors_embeddings_mask=batched_data['neighboors_embeddings_mask'].to(self.device)
        target_ent_embeddings=batched_data['target_ent_embeddings'].to(self.device)
        neighboors_0 = batched_data['neighboors_0_id']
        neighboors_2 = batched_data['neighboors_2_id']
        triple_id = batched_data['triplet'].numpy()

        self.get_neigs = get_neigs2 if dataset_idx == 0 else get_neigs0

        old_seqs = []

        list_pred_texts = [[0]*self.configs.num_beams for _ in range(len(labels))]

        ranks = torch.randint(self.configs.num_beams + 1, self.configs.n_ent, (len(labels),))
        for i in range(self.configs.num_beams):
          outputs = self.model.generate(
              input_ids=input_ids,
              attention_mask=attention_mask,
              return_dict_in_generate=True,
              max_length=512,
              prefix_allowed_tokens_fn=lambda batch_idx, m_input_ids: self._next_candidate(batch_idx, m_input_ids, triple_id, dataset_idx, old_seqs),
              neighboors_embeddings=neighboors_embeddings,
              neighboors_embeddings_mask=neighboors_embeddings_mask,
              target_ent_embeddings=target_ent_embeddings,
          )
          pred = outputs.sequences.cpu()
          old_seqs.append(pred)
          pred = pred[:,1:]
          seq_len = min(pred.shape[1], labels.shape[1])
          pred = pred[:, :seq_len]
          cut_labels = labels[:, :seq_len]
          seq_match = (pred == cut_labels).all(1)
          new_ranks = torch.where(~seq_match, ranks, i+1)
          ranks = torch.min(ranks, new_ranks)
          pred_texts = [self.tokenizer.decode(x, skip_special_tokens = True) for x in pred]
          for j, pred_text in enumerate(pred_texts):
            list_pred_texts[j][i] = pred_text

        list_input_texts = [self.tokenizer.decode(x, skip_special_tokens = True) for x in input_ids]
        list_target_texts = [self.tokenizer.decode(x, skip_special_tokens = True) for x in labels]
        for i in range(len(list_input_texts)):
          print(f'Input: {list_input_texts[i]}')
          relation_name = rel_name_decode_list[triple_id[i][1]]
          print(f'Relation: {relation_name}')
          print('Neighbors:')
          for n in neighboors_0[i][:5]:
            rel_id = n[0]
            ent_id = n[1]
            if rel_id == 0 and ent_id == 0:
              continue
            rel_name = rel_name_decode_list[rel_id]
            ent_name = self.ent_name_list[ent_id]
            if '.' in rel_name:
              rel_name = rel_name.split('.')[1].strip()
            print(f'--{rel_name}--> {ent_name}')
          for n in neighboors_2[i][:5]:
            rel_id = n[0]
            ent_id = n[1]
            if rel_id == 0 and ent_id == 0:
              continue
            rel_name = rel_name_decode_list[rel_id]
            ent_name = self.ent_name_list[ent_id]
            if '.' in rel_name:
              rel_name = rel_name.split('.')[1].strip()
            print(f'--{rel_name}--> {ent_name}')
          print('...')

          print(f'Target: {list_target_texts[i]}')
          print(f'Predictions: {list_pred_texts[i]}')
          print(f'Rank: {ranks[i]}')
          print('='*100)
          list_global_ranks.append(ranks[i].item())

        mrr, hit1, hit5, hit10 = compute_metrics(list_global_ranks)
        print(f'MRR: {mrr:.4f}, Hit@1: {hit1:.4f}, Hit@5: {hit5:.4f}, Hit@10: {hit10:.4f}')
        print('='*100)


        ranks = ranks.tolist()
        out = {'ranks': ranks}
        return out

In [None]:
entid2text_emb = torch.nn.utils.rnn.pad_sequence([torch.tensor(x) for x in entid2text], batch_first=True, padding_value=0)

## run eval

In [None]:

class DataCollatorForSeq2Seq:
    model= None
    padding= True
    max_length= None
    pad_to_multiple_of=None
    label_pad_token_id= -100
    data_names = None
    def __init__(self, tokenizer, model=None, padding=True, max_length=None, pad_to_multiple_of=None, label_pad_token_id=-100,data_names=None):
        self.tokenizer = tokenizer
        self.model = model
        self.data_names = data_names
        self.label_pad_token_id = label_pad_token_id

    def __call__(self, features):
        features2 = {}
        for name in self.data_names:
          if name in ['labels','filter_id']:
            padding_value=self.label_pad_token_id
          else:
            padding_value=self.tokenizer.pad_token_id
          x_features = [feature[name] for feature in features]
          if name in ['neighboors_0', 'neighboors_2']:
            features2[name] = x_features
          else:
            features2[name] = torch.nn.utils.rnn.pad_sequence(x_features, batch_first=True, padding_value=padding_value)
        if self.model is not None and hasattr(self.model, "prepare_decoder_input_ids_from_labels"):
            decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=features2["labels"])
            features2["decoder_input_ids"] = decoder_input_ids
        return features2


data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, data_names=list(train_dataset[0].keys()))


In [None]:
# tail_test_dataset[9]

In [None]:
from tqdm.auto import tqdm
from torch.utils.data import DataLoader

batch_size = 16
tail_data_loader = DataLoader(tail_test_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          collate_fn=data_collator,
                          # num_workers=8,
                          pin_memory=True)
head_data_loader = DataLoader(head_test_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          # num_workers=8,
                          collate_fn=data_collator,
                          pin_memory=True)

class Configs:
    def __init__(self):
        self.num_beams = 1
        self.num_return_sequences = 1
        self.max_length = 30
        self.n_ent = 14541
        self.n_rel = 237

configs = Configs()
runEval = RunEval(configs, model, tokenizer, ent_name_decode_list, entid2text_emb)

runEval.model.eval()
for batch1, batch2 in tqdm(zip(head_data_loader, tail_data_loader), total=len(tail_data_loader)):
    runEval.validation_step(batch1, 0)
    runEval.validation_step(batch2, 2)


  0%|          | 0/1280 [00:00<?, ?it/s]

Input: predict head : Turkey [ Turkey, officially the Republic of Turkey, is a contiguous transcontinental country, located mostly on Anatolia in Western Asia, and on East Thrace in Southeastern Europe. Turkey is bordered by eight countries: Bulgaria to the northwest; Greece to the west; Georgia to the northeast; ] film film release date s. film film regional release date film release region
Relation: film film release date s. film film regional release date film release region
Neighbors:
--location adjoining relationship adjoins--> Iraq
--olympics olympic medal honor olympics--> 2000 Summer Olympics
--location location contains--> <unk>zmir
--measurement unit dated money value currency--> United States Dollar
--olympics olympic medal honor olympics--> 1952 Summer Olympics
--olympics olympic athlete affiliation country--> Artistic gymnastics
--film film regional release date film release region--> Jack Reacher
--film film regional release date film release region--> Contraband
--film f

  0%|          | 1/1280 [00:01<40:56,  1.92s/it]

Input: predict tail : David Mansfield [ David Mansfield is an American violinist, mandolin player, guitarist, pedal steel guitar player, and composer.nRaised in Leonia, New Jersey, his first band was Quacky Duck and His Barnyard Friends, which also included two sons of Tony Bennett.nB ] music artist track contributions. music track contribution role
Relation: music artist track contributions. music track contribution role
Neighbors:
--people person gender--> Male
--music track contribution role--> Pedal steel guitar
--music track contribution role--> Bass guitar
--award award nomination award--> Golden Globe Award for Best Original Score
--people person profession--> Composer
--music instrument instrumentalists--> Pedal steel guitar
--film film music--> Transamerica
--film film music--> Heaven's Gate
...
Target: String
Predictions: ['Acoustic guitar']
Rank: 12807
Input: predict tail : Bachelor of Laws [ The Bachelor of Laws or LL.B. is an undergraduate, or bachelor, degree in law origi

  0%|          | 2/1280 [00:03<30:55,  1.45s/it]

Input: predict tail : 39th Daytime Emmy Awards [ The 39th Annual Daytime Emmy Awards presented by the National Academy of Television Arts and Sciences and Academy of Television Arts & Sciences, ′′recognizes outstanding achievement in all fields of daytime television production and are presented to individuals and programs broadcast from 2:00 a.m.—6:00  ] award award ceremony awards presented. award award honor award winner
Relation: award award ceremony awards presented. award award honor award winner
Neighbors:
--award award honor award winner--> Nancy Williams Watt
--award award honor award winner--> Roberto Orci
--award award honor award winner--> Jenna Bush Hager
--award award honor award winner--> Lester Holt
--award award honor award winner--> Kevin Clash
--award award honor ceremony--> Daytime Emmy Award for Outstanding Talk Show Host
--award award honor ceremony--> Daytime Emmy Award for Outstanding Children's Animated Program
--award award honor ceremony--> Daytime Emmy Award 

  0%|          | 3/1280 [00:04<28:56,  1.36s/it]

Input: predict tail : Harvard College [ Harvard College is one of two schools within Harvard University granting undergraduate degrees. Founded in 1636 in Cambridge, Massachusetts, it is the oldest institution of higher learning in the United States and one of the most prestigious in the world. ] education educational institution students graduates. education education student
Relation: education educational institution students graduates. education education student
Neighbors:
--education education major field of study--> Linguistics
--education education student--> Charles Peirce
--education education student--> John F. Kennedy
--education education major field of study--> History
--education education student--> John Lithgow
--location location contains--> Massachusetts
--education education institution--> Bachelor's degree
--education education institution--> Bachelor of Arts
--business employment tenure company--> John Rawls
--location location contains--> Cambridge
...
Target: Mi

  0%|          | 4/1280 [00:05<27:47,  1.31s/it]

Input: predict tail : Jeff Goldblum [ Jeffrey Lynn "Jeff" Goldblum is an American actor. His career began in the mid-1970s and he has appeared in major box-office successes including The Fly, Jurassic Park and its sequel Jurassic Park: The Lost World, and Independence Day. He  ] award award nominee award nominations. award award nomination award nominee
Relation: award award nominee award nominations. award award nomination award nominee
Neighbors:
--people person profession--> Actor-GB
--award award nomination award--> Independent Spirit Award for Best Supporting Male
--people person gender--> Male
--film performance film--> The Lost World: Jurassic Park
--award award nomination award nominee--> Cate Blanchett
--base popstra friendship participant--> Peter Weller
--base popstra dated participant--> Nicole Richie
--base popstra dated participant--> Kristin Davis
--base popstra dated participant--> Laura Dern
--award award nomination award nominee--> Noah Taylor
...
Target: Willem Dafoe

  0%|          | 5/1280 [00:06<28:26,  1.34s/it]

Input: predict tail : Never Say Never Again [ Never Say Never Again is a 1983 spy film based on the James Bond novel Thunderball, which was previously adapted in 1965 under that name. Unlike the majority of Bond films, Never Say Never Again was not produced by Eon Productions, but by an independent production company, one of whose members ] film film language
Relation: film film language
Neighbors:
--film film story by--> Ian Fleming
--film film film art direction by--> Leslie Dilley
--film film country--> United States of America
--film film cinematography--> Douglas Slocombe
--film film regional release date film release region--> United States of America
--film performance film--> Sean Connery
--film performance film--> Kim Basinger
--film performance film--> Amy Irving
--film performance film--> Rowan Atkinson
--film performance film--> Edward Fox
...
Target: English Language
Predictions: ['English Language']
Rank: 1
Input: predict tail : The Producers [ The Producers is a 2005 Ame

  0%|          | 6/1280 [00:07<26:22,  1.24s/it]

Input: predict tail : The Perks of Being a Wallflower [ The Perks of Being a Wallflower is a 2012 drama, romance, coming of age film written and directed by Stephen Chbosky. ] film film release date s. film film regional release date film release region
Relation: film film release date s. film film regional release date film release region
Neighbors:
--film film regional release date film release region--> Canada
--film film regional release date film release region--> Colombia
--film film country--> United States of America
--film film regional release date film release region--> United Arab Emirates
--film film produced by--> John Malkovich
--award award nomination nominated for--> Writers Guild of America Award for Best Adapted Screenplay
--award award nomination nominated for--> Broadcast Film Critics Association Award for Best Young Performer
--award award nomination nominated for--> MTV Movie Award for Best Kiss
--film performance film--> Paul Rudd
--film performance film--> Joan

  1%|          | 7/1280 [00:09<26:08,  1.23s/it]

Input: predict tail : Scary Movie 3 [ Scary Movie 3 is a 2003 American science fiction horror comedy parody film, which parodies the horror, sci-fi, and mystery genres, directed by David Zucker. It is the third film of the Scary Movie franchise, as well as the first to have no involvement from the Wayans ] film film story by
Relation: film film story by
Neighbors:
--film film cinematography--> Mark Irwin
--film film regional release date film release distribution medium--> DVD
--award award honor award winner--> Simon Cowell
--film film featured film locations--> Washington, D.C.
--film film executive produced by--> Bob Weinstein
--film film prequel--> Scary Movie 4
--film performance film--> Jenny McCarthy
--film performance film--> Leslie Nielsen
--film performance film--> Queen Latifah
--film performance film--> Jeremy Piven
...
Target: Shawn Wayans
Predictions: ['Scary Movie 3']
Rank: 13782
Input: predict tail : Joaquin Phoenix [ Joaquin Rafael Phoenix, formerly credited as Leaf Ph

  1%|          | 8/1280 [00:10<26:20,  1.24s/it]

Input: predict tail : Frankenweenie [ Frankenweenie is a 2012 American 3D stop-motion animated film directed by Tim Burton. It is a remake of Burton's 1984 short film of the same name and is a parody of and an homage to the 1931 film Frankenstein based on Mary Shelley's ] film film release date s. film film regional release date film release region
Relation: film film release date s. film film regional release date film release region
Neighbors:
--film film story by--> Tim Burton
--common webpage category--> Official Website
--film film regional release date film release region--> Paraguay
--film film regional release date film release region--> Romania
--film film regional release date film release region--> Australia
--film performance film--> Christopher Lee
--film performance film--> Frank Welker
--film performance film--> Michael Keaton
--award award nomination nominated for--> Tim Burton
--film performance film--> Martin Short
...
Target: Philippines
Predictions: ['South Korea']


  1%|          | 9/1280 [00:12<28:14,  1.33s/it]

Input: predict tail : Pop music [ Pop music is a genre of popular music which originated in its modern form in the 1950s, deriving from rock and roll. The terms "popular music" and "pop music" are often used interchangeably, even though the former is a description of music which is popular ] music genre artists
Relation: music genre artists
Neighbors:
--music genre artists--> Tracy Chapman
--music genre artists--> Robin Thicke
--music genre artists--> Adele
--music genre artists--> Robbie Williams
--music genre parent genre--> Rhythm and blues
--music genre parent genre--> New Wave
--music genre parent genre--> Synthpop
--music genre parent genre--> Downtempo
--music genre parent genre--> Psychedelic pop
--music genre parent genre--> Traditional pop music
...
Target: Leonard Cohen
Predictions: ['Meshell Ndegeocello']
Rank: 3716
Input: predict tail : Rhode Island [ Rhode Island, officially the State of Rhode Island and Providence Plantations, is a state in the New England region of the 




KeyboardInterrupt: 

In [None]:
# %pip install transformers==4.40.1

In [None]:
from tqdm.auto import tqdm
from torch.utils.data import DataLoader

tail_data_loader = DataLoader(tail_test_dataset,
                          batch_size=64,
                          shuffle=False,
                          collate_fn=data_collator,
                          # num_workers=8,
                          pin_memory=True)
head_data_loader = DataLoader(head_test_dataset,
                          batch_size=64,
                          shuffle=False,
                          # num_workers=8,
                          collate_fn=data_collator,
                          pin_memory=True)


class Configs:
    def __init__(self):
        self.num_beams = 1
        # self.num_beam_groups = 1
        self.num_return_sequences = 1
        self.max_length = 30
        self.n_ent = 14541
        self.n_rel = 237

configs = Configs()
runEval = RunEval(configs, model, tokenizer, ent_name_decode_list, entid2text_emb)

runEval.model.eval()
head_list_result = []
for data in tqdm(head_data_loader):
    # pass
    rank_rs = runEval.validation_step(data, 0)
    head_list_result.append(rank_rs)
    # break
tail_list_result = []
for data in tqdm(tail_data_loader):
    rank_rs = runEval.validation_step(data, 2)
    tail_list_result.append(rank_rs)
    # break

kq = runEval.validation_epoch_end((head_list_result, tail_list_result))


  0%|          | 0/320 [00:00<?, ?it/s]

  0%|          | 0/320 [00:00<?, ?it/s]

                   mrr           mr   hit@1   hit@3  hit@10
tail ranking  0.479574  2367.001953  39.12%  53.62%  67.29%
head ranking  0.353449  3125.134521  26.15%  40.58%  57.04%
mean ranking  0.416512  2746.068359  32.63%  47.10%  62.16%
