## Preamble

In [None]:
!pip install pytorch-lightning sentence-transformers torchmetrics rich

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
import os

# Allows imports from the scripts directory.
os.chdir('/content/drive/MyDrive/sem-eval23/code')

## Prepare the data

In [1]:
from datautils import MdaMasker, MdaBatcher, MdaDataModule

DPATH = "../data/data_practicephase_cleardev/MD-Agreement_dataset/MD-Agreement_final.json"
MPATH = "sentence-transformers/all-MiniLM-L12-v2"

batcher = MdaBatcher(MPATH)
augmenter = MdaMasker('xlm-roberta-large')
datamodule = MdaDataModule(DPATH, ["offensiveness detection"], batcher, augmenter=augmenter)
datamodule.setup()

  from .autonotebook import tqdm as notebook_tqdm


# Prepare the model

In [None]:
from mda_modelling import AgreementModel
from pytorch_lightning.callbacks import(
    EarlyStopping,
    ModelCheckpoint,
    RichModelSummary
)


SAVE_PATH = "../models/md-agreement/all-MiniLM-L12-v2"
CKPT = "../models/md-agreement/all-MiniLM-L12-v2/pt-epoch=8-val_soft_loss=0.15.ckpt"


soft_label_imp=0.88
try:
    model = AgreementModel.load_from_checkpoint(CKPT)
    print("Loaded checkpoint")
except:
    model = AgreementModel(
        MPATH,
        soft_label_imp=soft_label_imp,
        task_head_lr=2e-3,
        backbone_lr=2e-5
        )
    print("Loaded fresh model")

checkpoint_callback = ModelCheckpoint(
    SAVE_PATH,
    filename=f'ft-{{epoch}}-{{val_soft_loss:.2f}}',
    monitor='val_soft_loss',
    save_weights_only=True,
)
early_stop_callback = EarlyStopping(
    monitor="val_soft_loss",
    min_delta=1e-4, patience=8,
    verbose=False,
    mode="min"
)

## Train the model

In [None]:
from pytorch_lightning import Trainer, seed_everything


seed_everything(42, workers=True)

trainer = Trainer(
    max_epochs=-1,
    deterministic=True,
    accumulate_grad_batches=4,
    callbacks=[checkpoint_callback, early_stop_callback, RichModelSummary()],
    accelerator='gpu',
    log_every_n_steps=16
)

trainer.fit(model, datamodule)