In [None]:
from torch.utils.data import DataLoader
import pandas as pd
import torch
import wandb

from modeling_classes import JointNERAndREModel, JointNERAndREDataset
import utils

In [None]:
LABELS_TO_IDS, IDS_TO_LABELS = utils.load_labels()
RELATIONS_TO_IDS, IDS_TO_RELATIONS = utils.load_relations()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
model = init_model()

In [None]:
sentence = utils.string_to_list_1(
    # "Roland Rajcsanyi is my brother from another mother, we have been friends since High School."
    "@HuggingFace is a New York company, it has employees in Paris since 1923, but it has been down today 12:30"
)
model.eval()
with torch.inference_mode():
    encoded = JointNERAndREDataset.tokenize(sentence, is_split=True, return_tensors='pt').to(DEVICE)

    model_out = model(encoded["input_ids"], attention_mask=encoded["attention_mask"])
    ner_predictions = torch.argmax(model_out.ner_probs.view(-1, model.num_labels), axis=1).tolist()
    re_predictions = torch.argmax(model_out.re_probs, axis=1).tolist()[0]
    
    
    index = 0
    metadata = {"relation": IDS_TO_RELATIONS.get(re_predictions), 'entities': []}
    for token, mapping in zip(ner_predictions, encoded["offset_mapping"].view(-1, 2).tolist()):
        if mapping[0] == 0 and mapping[1] != 0:
            metadata['entities'].append({'type': IDS_TO_LABELS.get(token), 'location': index})
            print(f'{sentence[index]:20}  {IDS_TO_LABELS.get(token)}')
            index += 1

    print(metadata)