## Preamble

In [None]:
!pip install pytorch-lightning sentence-transformers torchmetrics rich

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
import os

# Allows imports from the scripts directory.
os.chdir('/content/drive/MyDrive/sem-eval23/code')

## Prepare the data

In [None]:
from datautils import CnvAbBatcher, CnvAbDataModule, AnnotatorTokenizer

DPATH = "../data/data_practicephase_cleardev/ConvAbuse_dataset/ConvAbuse_train.json"
MPATH = "sentence-transformers/paraphrase-MiniLM-L12-v2"


ann_tknzr = AnnotatorTokenizer(8)
batcher = CnvAbBatcher(MPATH, ann_tknzr=ann_tknzr, use_raw_text=False)
datamodule = CnvAbDataModule(DPATH, ["abusivness detection"], batcher, batch_size=4)
datamodule.setup()

# Prepare the model

In [None]:
from cnvab_modelling import AgreementModel, InteractionModel
from pytorch_lightning.callbacks import(
    EarlyStopping,
    ModelCheckpoint,
    RichModelSummary
)

SAVE_PATH = "../models/conv-abuse/paraphrase-MiniLM-L12-v2"
CKPT = ""


modalities = 8
try:
    model = AgreementModel.load_from_checkpoint(CKPT)
    print("Loaded checkpoint")
except:
    model = AgreementModel(
        MPATH,
        InteractionModel(text_dim=384, modalities=modalities),
        task_head_lr=2e-3,
        backbone_lr=2e-5
    )
    print("Loaded fresh model")

checkpoint_callback = ModelCheckpoint(
    SAVE_PATH,
    filename=f'int-mod={modalities}-{{epoch}}-{{val_soft_loss:.2f}}',
    monitor='val_soft_loss',
    save_weights_only=True,
)
early_stop_callback = EarlyStopping(
    monitor="val_soft_loss",
    min_delta=1e-4, patience=8,
    verbose=False,
    mode="min"
)

## Train the model

In [None]:
from pytorch_lightning import Trainer, seed_everything


seed_everything(42, workers=True)

trainer = Trainer(
    max_epochs=-1,
    deterministic=True,
    accumulate_grad_batches=4,
    callbacks=[checkpoint_callback, early_stop_callback, RichModelSummary()],
    accelerator='gpu',
    log_every_n_steps=16
)

trainer.fit(model, datamodule)