## Prepare the data

### Setup the tokenizer

In [2]:
from transformers import AutoTokenizer

# Ensure the order
SPECIAL_TOKENS = ["<persuader>", "<persuadee>"]
MPATH = "microsoft/DialoGPT-medium"

tokenizer = AutoTokenizer.from_pretrained(MPATH)
tokenizer.add_special_tokens({'additional_special_tokens': SPECIAL_TOKENS})
tokenizer.pad_token = tokenizer.eos_token

### Setup the data modules

In [3]:
from datautils import DialogDataModule, DialogBatcher


DPATH = r"..\data\persuasionforgood_corpus"

is_pursuader = True
batcher = DialogBatcher(tokenizer)
dm = DialogDataModule(
    DPATH,
    batcher=batcher,
    batch_size=16,
    split_data=False,
    purpose_text="Convince people to donate to charity.",
    is_pursuader=is_pursuader
)

# Prepare the model

In [None]:
from modelling import DialogAgent


ROLE = "persuader" if is_pursuader else "persuadee"
SAVE_PATH = f"../models/base/{ROLE}"
CKPT = ""

model = DialogAgent(
    MPATH,
    embedding_size=len(tokenizer)
)
print("Loaded fresh model")

### Setup Callbacks

In [None]:
from pytorch_lightning.callbacks import(
    EarlyStopping,
    ModelCheckpoint,
    RichModelSummary
)

checkpoint_callback = ModelCheckpoint(
    SAVE_PATH,
    filename=f'baseagent-{ROLE}-{{epoch}}-{{val_loss:.2f}}',
    monitor='val_soft_loss',
    save_weights_only=True,
)
early_stop_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=1e-4, patience=3,
    verbose=False,
    mode="min"
)

# Train the model

In [None]:
from pytorch_lightning import Trainer, seed_everything


seed_everything(42, workers=True)

trainer = Trainer(
    max_epochs=-1,
    deterministic=True,
    accumulate_grad_batches=4,
    callbacks=[checkpoint_callback, early_stop_callback, RichModelSummary()],
    accelerator='gpu',
    log_every_n_steps=16
)

trainer.fit(model, dm)