## Preamble

In [1]:
import sys

# Allows imports from the scripts directory.
sys.path.append('../scripts')

## Prepare the data

In [2]:
label2id = {
    'ISSUE': 0,
    'NONE': 1,
    'STA': 2,
    'ANALYSIS': 3,
    'PRE_RELIED': 4,
    'RATIO': 5,
    'RPC': 6,
    'PRE_NOT_RELIED': 7,
    'ARG_PETITIONER': 8,
    'PREAMBLE': 9,
    'RLC': 10,
    'ARG_RESPONDENT': 11,
    'FAC': 12
}

id2label = {v: k for k, v in label2id.items()}

In [3]:
from datautils import RRDataModule, RRBatcher

DPATH = "../data/train.json"
MPATH = "sentence-transformers/all-MiniLM-L6-v2"

batcher = RRBatcher(MPATH)
datamodule = RRDataModule(DPATH, batcher, label2id)
datamodule.setup()


  from .autonotebook import tqdm as notebook_tqdm


# Prepare the model

In [4]:
from pathlib import Path
from modelling import CoherenceAwareSentenceEmbedder, AlterMiningStrategy
from pytorch_lightning.callbacks import(
    EarlyStopping,
    ModelCheckpoint,
    RichModelSummary
)

SAVE_PATH = "../models/coherence-aware/all-MiniLM-L6-v2"
CKPT = Path(SAVE_PATH).joinpath("masked-epoch=11-val_sem_loss=0.29.ckpt")


surrogate_imp=0.2
try:
    model = CoherenceAwareSentenceEmbedder.load_from_checkpoint(CKPT, num_classes=len(label2id))
    print("Loaded checkpoint")
except:
    model = CoherenceAwareSentenceEmbedder(MPATH,  num_classes=len(label2id), surrogate_imp=surrogate_imp, triplet_margin=1.0)
    print("Loaded fresh model")

checkpoint_callback = ModelCheckpoint(
    SAVE_PATH,
    filename=f'imp={surrogate_imp}-{{epoch}}-{{val_sem_loss:.2f}}',
    monitor='val_sem_loss',
    # save_weights_only=True,
)
early_stop_callback = EarlyStopping(
    monitor="val_sem_loss",
    min_delta=1e-4, patience=8,
    verbose=False,
    mode="min"
)

mining_callback = AlterMiningStrategy(monitor="val_sem_loss")

Loaded fresh model


## Train the model

In [5]:
from pytorch_lightning import Trainer, seed_everything


seed_everything(42, workers=True)

trainer = Trainer(
    max_epochs=-1,
    deterministic=True,
    accumulate_grad_batches=4,
    callbacks=[checkpoint_callback, early_stop_callback, RichModelSummary(), mining_callback],
    accelerator='gpu',
    log_every_n_steps=16
)

trainer.fit(model, datamodule)

Global seed set to 42
Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 1:  96%|█████████▌| 219/228 [00:31<00:01,  6.93it/s, loss=0.752, v_num=9, train_loss=0.729, train_sem_loss=0.838, train_surr_loss=0.293, val_loss=0.769, val_sem_loss=0.863, val_surr_loss=0.390]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
