In [None]:
from torch.utils.data import DataLoader
import pandas as pd
import torch
import wandb

from modeling_classes import CustomBertForTokenClassification, CustomDataset
import utils

In [None]:
LABELS_TO_IDS, IDS_TO_LABELS = utils.load_labels()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def get_labels():
    return [item for item in IDS_TO_LABELS.values()]

In [None]:
@utils.wandb_init({'project_name': 'dp-project-validate'})
def init_model():
    model = CustomBertForTokenClassification(labels=get_labels()).to(DEVICE)
    model = utils.load_model(model)
    return model

In [None]:
model = init_model()

In [None]:
sentence = utils.string_to_list_1(
    # "Roland is my brother from another mother, we have been friends since High School."
    "@HuggingFace is a New York company, it has employees in Paris since 1923, but it has been down today 12:30"
)
model.eval()
with torch.inference_mode():
    encoded = CustomDataset.tokenize(sentence, is_split=True, return_tensors='pt').to(DEVICE)

    probs = model(encoded["input_ids"], attention_mask=encoded["attention_mask"])

    flattened_predictions = torch.argmax(probs, axis=1).cpu().numpy()

    index = 0
    for token, mapping in zip(flattened_predictions, encoded["offset_mapping"].view(-1, 2).tolist()):
        if mapping[0] == 0 and mapping[1] != 0:
            print(f'{sentence[index]:20}  {IDS_TO_LABELS.get(token)}')
            index += 1