In [None]:
from torch.utils.data import DataLoader
import pandas as pd
import torch
import wandb

from modeling_classes import CustomBertForTokenClassification, CustomDataset
import training
import utils
from utils import Config
from typing import Dict

In [None]:
CURRENT_DIR = os.path.dirname(os.path.abspath("!pwd"))
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LABELS_TO_IDS, IDS_TO_LABELS = utils.load_labels()
SWEEP_CONFIG = utils.load_config(Config.SWEEP_CONFIG)
CONFIG = utils.load_config(Config.CONFIG)

In [None]:
def get_labels():
    return [item for item in IDS_TO_LABELS.values()]


def get_data_loaders(only_ner=True):
    train_dataset = pd.read_json(f"{wandb.config['dataset_path']}train.json").reset_index(drop=True)
    dev_dataset = pd.read_json(f"{wandb.config['dataset_path']}dev.json").reset_index(drop=True)
    test_dataset = pd.read_json(f"{wandb.config['dataset_path']}test.json").reset_index(drop=True)

    if only_ner:
        train_dataset = utils.prepare(train_dataset)
        dev_dataset = utils.prepare(dev_dataset)
        test_dataset = utils.prepare(test_dataset)

    train_loader = DataLoader(CustomDataset(train_dataset, DEVICE), batch_size=wandb.config["batch_size"], shuffle=True)
    dev_loader = DataLoader(CustomDataset(dev_dataset, DEVICE), batch_size=wandb.config["batch_size"], shuffle=True)
    test_loader = DataLoader(CustomDataset(test_dataset, DEVICE), batch_size=wandb.config["batch_size"], shuffle=True)

    return train_loader, dev_loader, test_loader


def get_optimizer(model):
    optimizer = torch.optim.Adam(params=model.parameters(), lr=wandb.config["learning_rate"], weight_decay=wandb.config["weight_decay"])
    if wandb.config['optimizer'] == 'ADAM':
        optimizer = torch.optim.Adam(params=model.parameters(), lr=wandb.config["learning_rate"], weight_decay=wandb.config["weight_decay"])
    if wandb.config['optimizer'] == 'ADAMW':
        optimizer = torch.optim.AdamW(params=model.parameters(), lr=wandb.config["learning_rate"], weight_decay=wandb.config["weight_decay"])
    if wandb.config['optimizer'] == 'SGD':
        optimizer = torch.optim.SGD(params=model.parameters(), lr=wandb.config["learning_rate"], weight_decay=wandb.config["weight_decay"], momentum=0.9) # noqa

    return optimizer


def resume_state(model, optimizer, scheduler, metrics, model_version: str='latest', config_version: str='latest', config_overwrites: Dict[str, str]={}):
    artifact = wandb.run.use_artifact(f'kripso/{wandb.config["project_name"]}/{wandb.config["model"]}:{model_version}', type='model')
    artifact.download(f'{CURRENT_DIR}/models/')
    artifact = wandb.run.use_artifact(f'kripso/{wandb.config["project_name"]}/config:{config_version}', type='config')
    artifact.download(f'{CURRENT_DIR}/conf/')
    wandb.config = {**utils.load_config(Config.BACKUP), **config_overwrites}

    checkpoint = torch.load(f'{CURRENT_DIR}/models/{wandb.config["model"]}.pt')

    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    metrics = checkpoint['metrics']
    metrics['step'] -= 1

    return model, optimizer, scheduler, metrics


@utils.wandb_init(CONFIG)
def main(resume: bool=False, *args, **kwargs):
    model = CustomBertForTokenClassification(labels=get_labels()).to(DEVICE)
    optimizer = get_optimizer(model)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=wandb.config["scheduler_step_size"], gamma=wandb.config["scheduler_gamma"])
    metrics = {"loss": 0, "accuracy": 0, "f1_score": 0, "index": 1, "step": 0}

    if resume:
        model, optimizer, scheduler, metrics = resume_state(model, optimizer, scheduler, metrics, *args, **kwargs)

    train_loader, dev_loader, test_loader = get_data_loaders(True)

    torch.cuda.empty_cache()
    return training.fit(model, optimizer, scheduler, metrics, train_loader, dev_loader, test_loader, DEVICE)


# Sweep Run

In [None]:
# sweep_id = wandb.sweep(sweep=SWEEP_CONFIG, project=CONFIG['project_name'])
# wandb.agent(sweep_id, function=main, count=20)

# Manual Run 

In [None]:
main()

# Continuation Run

In [None]:
# main(resume=True, model_version='v31',config_overwrites=CONFIG)