In [5]:
import sys
import os

# Get the absolute path of the root directory
sys.path.append(os.path.abspath(".."))  # Adjust path if needed

In [6]:
import os
import torch
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, T5Tokenizer, AutoConfig
from src.struct_models.rotatE import RotatE
from src.model import StructKGS2S
from src.utils import DataCollatorForSeq2Seq
from src.dataset import KGCDataset, SplitDatasetWrapper



dataset_name = 'fb15k-237'
root = '/Users/apple/Struct-KGS2S/'
max_rel_size = 237

tokenizer = T5Tokenizer.from_pretrained('t5-small', padding=True)

# init datasets
kg_data_path = os.path.join(root, 'data/processed', dataset_name, 'kg_data.pt')
kg_data = torch.load(kg_data_path)
rotatE = RotatE(k=350, entity_embedding=kg_data['struct_ent_emb'], relation_embedding=kg_data['struct_rel_emb'], max_rel_size=max_rel_size)
dataset = KGCDataset(num_ents=14541, structal_model=rotatE, kg_data=kg_data, tokenizer=tokenizer)
train_dataset = SplitDatasetWrapper(dataset, split="train")
valid_dataset = SplitDatasetWrapper(dataset, split="valid")




In [7]:

# init model
ckpt_name ='t5-small'

config = AutoConfig.from_pretrained(ckpt_name)
config.struct_d_model = 700

model = StructKGS2S.from_pretrained(ckpt_name, config=config)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, data_names=list(train_dataset[0].keys()))



Some weights of StructKGS2S were not initialized from the model checkpoint at t5-small and are newly initialized: ['act.weight', 'key_projection.bias', 'key_projection.weight', 'value_projection1.bias', 'value_projection1.weight', 'value_projection2.bias', 'value_projection2.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:

# training arguments
batch_size= 32*4
num_train_epochs = 1000
learning_rate = 1e-4

args = Seq2SeqTrainingArguments(
    "kgs2s-rotatE",
    dataloader_num_workers=8,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_train_epochs,
    do_eval=True,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_strategy='epoch',
    learning_rate=learning_rate,
    report_to='none',
    load_best_model_at_end=True,
)

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=data_collator,
)


In [None]:

# training
trainer.train()