In [None]:
from torch.utils.data import DataLoader
import pandas as pd
import torch
import wandb

from modeling_classes import JointNERAndREModel, JointNERAndREDataset
import training_v2
import utils
from utils import Config
from typing import Dict
from sklearn.utils import class_weight
from itertools import chain

In [None]:
CURRENT_DIR = os.path.dirname(os.path.abspath("!pwd"))
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LABELS_TO_IDS, IDS_TO_LABELS = utils.load_labels()
RELATIONS_TO_IDS, IDS_TO_RELATIONS = utils.load_relations()
SWEEP_CONFIG = utils.load_config(Config.SWEEP_CONFIG)
CONFIG = utils.load_config(Config.CONFIG)

In [None]:
def get_labels():
    return [item for item in IDS_TO_LABELS.values()]
def get_relations():
    return [item for item in IDS_TO_RELATIONS.values()]


def get_optimizer(model):
    if wandb.config['optimizer'] == 'ADAM':
        optimizer = torch.optim.Adam(params=model.parameters(), lr=wandb.config["learning_rate"], weight_decay=wandb.config["weight_decay"])
    if wandb.config['optimizer'] == 'ADAMW':
        optimizer = torch.optim.AdamW(params=model.parameters(), lr=wandb.config["learning_rate"], weight_decay=wandb.config["weight_decay"])
    if wandb.config['optimizer'] == 'SGD':
        optimizer = torch.optim.SGD(params=model.parameters(), lr=wandb.config["learning_rate"], weight_decay=wandb.config["weight_decay"], momentum=0.9) # noqa
        
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=wandb.config["scheduler_step_size"], gamma=wandb.config["scheduler_gamma"])

    return optimizer, scheduler

def get_datasets():
    train_dataset = pd.read_json(f"{CONFIG['dataset_path']}train.json").reset_index(drop=True)
    dev_dataset = pd.read_json(f"{CONFIG['dataset_path']}dev.json").reset_index(drop=True)
    test_dataset = pd.read_json(f"{CONFIG['dataset_path']}test.json").reset_index(drop=True)

    return train_dataset, dev_dataset, test_dataset

def get_data_loaders(train_dataset, dev_dataset, test_dataset):
    train_loader = DataLoader(JointNERAndREDataset(train_dataset, DEVICE, train=wandb.config["re_hack"]), batch_size=wandb.config["batch_size"], shuffle=True)
    dev_loader = DataLoader(JointNERAndREDataset(dev_dataset, DEVICE), batch_size=wandb.config["batch_size"], shuffle=True)
    test_loader = DataLoader(JointNERAndREDataset(test_dataset, DEVICE), batch_size=wandb.config["batch_size"], shuffle=True)

    return train_loader, dev_loader, test_loader

In [None]:
def resume_state(model, optimizer, scheduler, metrics, model_version: str='latest', config_version: str='latest', config_overwrites: Dict[str, str]={}):
    artifact = wandb.run.use_artifact(f'kripso/{wandb.config["project_name"]}/{wandb.config["model"]}:{model_version}', type='model')
    artifact.download(f'{CURRENT_DIR}/models/')
    artifact = wandb.run.use_artifact(f'kripso/{wandb.config["project_name"]}/config:{config_version}', type='config')
    artifact.download(f'{CURRENT_DIR}/conf/')
    wandb.config = {**utils.load_config(Config.BACKUP), **config_overwrites}

    checkpoint = torch.load(f'{CURRENT_DIR}/models/{wandb.config["model"]}.pt')

    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    metrics = checkpoint['metrics']
    metrics['step'] -= 1

    return model, optimizer, scheduler, metrics

In [None]:
@utils.wandb_init(CONFIG)
def main(resume: bool=False, *args, **kwargs):
    train_dataset, dev_dataset, test_dataset = get_datasets()

    class_weights = torch.tensor([wandb.config['no_relation_weight'], *[wandb.config['relation_weight']] * 41]).to(DEVICE)
    model = JointNERAndREModel(labels=get_labels(), relations=get_relations(), re_class_weights=class_weights).to(DEVICE)
    optimizer, scheduler = get_optimizer(model)
    metrics = {"loss": 0, "ner_accuracy": 0, "ner_f1_score": 0,"re_accuracy": 0, "re_f1_score": 0, "index": 1, "step": 0}

    if resume:
        model, optimizer, scheduler, metrics = resume_state(model, optimizer, scheduler, metrics, *args, **kwargs)

    train_loader, dev_loader, test_loader = get_data_loaders(train_dataset, dev_dataset, test_dataset)

    torch.cuda.empty_cache()
    return training_v2.fit(model, optimizer, scheduler, metrics, train_loader, dev_loader, test_loader, DEVICE)

In [None]:
# sweep_id = wandb.sweep(sweep=SWEEP_CONFIG, project=CONFIG['project_name'])
# wandb.agent(sweep_id, function=main, count=10)

In [None]:
# wandb.agent('fga67i6v', project=CONFIG['project_name'], function=main, count=10)

In [None]:
model, _, _, _ = main()

In [None]:
# model, _, _, _ = main(resume=True, config_overwrites={'epochs': 5, 'batch_size': 32})
# model, _, _, _ = main(resume=True)

In [None]:
sentence = utils.string_to_list_1(
    # "Roland Rajcsanyi and Elon Musk own company Tesla"
    "@HuggingFace is a New York company, it has employees in Paris since 1923, but it has been down today 12:30"
)
model.eval()
with torch.inference_mode():
    encoded = JointNERAndREDataset.tokenize(sentence, is_split=True, return_tensors='pt').to(DEVICE)

    model_out = model(encoded["input_ids"], attention_mask=encoded["attention_mask"])
    flattened_predictions = torch.argmax(model_out.ner_probs.view(-1, model.num_labels), axis=1).cpu().numpy()
    re_prediction = torch.argmax(model_out.re_probs, axis=1)
    print(IDS_TO_RELATIONS.get(re_prediction.tolist()[0]))

    index = 0

    for token, mapping in zip(flattened_predictions, encoded["offset_mapping"].view(-1,2).tolist()):
        if mapping[0] == 0 and mapping[1] != 0:
            print(f'{sentence[index]:20}  {IDS_TO_LABELS.get(token)}')
            index += 1

In [None]:
# from seqeval.metrics import classification_report, accuracy_score
# import torchmetrics

# ner_accuracy = torchmetrics.Accuracy(num_classes=model.num_labels, average="weighted").to(DEVICE)
# ner_f1_score = torchmetrics.F1Score(num_classes=model.num_labels, average="weighted").to(DEVICE)
# re_accuracy = torchmetrics.Accuracy(num_classes=model.num_relations, average="weighted").to(DEVICE)
# re_f1_score = torchmetrics.F1Score(num_classes=model.num_relations, average="weighted").to(DEVICE)

# ner_labels, ner_predictions, re_labels, re_predictions = training_v2.valid(model, test_loader, ner_accuracy, ner_f1_score, re_accuracy, re_f1_score, validation_loop=False)

# print(accuracy_score(ner_labels, ner_predictions))
# print(accuracy_score(re_labels, re_predictions))
# print(classification_report([ner_labels], [ner_predictions], zero_division=False))
# print(classification_report([[f'B-{item}' for item in re_labels]], [[f'B-{item}' for item in re_predictions]], zero_division=False))

In [None]:
# import numpy as np
# from torch import nn

# class_weights = torch.tensor([1 / count for count in np.bincount([0,0,0,0,1,2,2,4,4,5,3,2,1,0,0,0,1,2])])
# print(class_weights,np.bincount([0,0,0,0,1,2,2,4,4,5,3,2,1,0,0,0,1,2]))
# # loss_fct = nn.CrossEntropyLoss(weight=class_weights)
# # re_loss = loss_fct(re_logits, relation)

In [None]:
# import numpy as np
# class_weights = torch.tensor([round(1 / count, 8) for count in np.bincount([RELATIONS_TO_IDS[item] for item in train_dataset['relation']])])
# len(class_weights)
# len([0.4,*[1.0]*41])