# Base experiment notebook

This notebook is the template for all experiments.

It contains the basic code to run an experiment, and it is the starting point for all other notebooks in the folder `experiments`.

## Setup and Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
from pathlib import Path

import pytorch_lightning as pl
import torch
from pytorch_lightning import callbacks, seed_everything
from torch import nn

from nlp_assemblee.simple_trainer import LitModel, load_embedding, process_predictions
from nlp_assemblee.simple_visualisation import (
    calculate_metrics,
    plot_confusion_matrix,
    plot_network_graph,
    plot_precision_recall_curve,
    plot_roc_curve,
)

In [None]:
seed_everything(42, workers=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
device

## Definition of the notebook variables

In [None]:
# To change between experiments
FEATURES = True
TEXT_VARS = ["intervention", "titre_regexed", "contexte"]

BATCH_SIZE = 128
MAX_EPOCHS = 100

CALLBACKS = [
    callbacks.EarlyStopping(
        monitor="val_loss", mode="min", min_delta=0.01, check_finite=True, patience=10
    ),
    callbacks.ModelSummary(max_depth=-1),
    callbacks.Timer(duration="00:03:00:00", interval="epoch"),
    callbacks.RichProgressBar(),
    callbacks.LearningRateMonitor(logging_interval="epoch", log_momentum=False),
]

OPTIMIZER_TYPE = "Adam"
OPTIMIZER_KWARGS = {}
LR = 1e-4
LOSS = "CrossEntropyLoss"

SCHEDULER_KWARGS = {
    "scheduler": "ReduceLROnPlateau",
    "mode": "min",
    "factor": 0.1,
    "patience": 5,
    "interval": "epoch",
    "frequency": 1,
    "strict": True,
    "monitor": "val_loss",
}
# SCHEDULER_KWARGS = {
#     "scheduler": "OneCycleLR",
#     "max_lr": 5e-3,
#     "pct_start": 0.3,
#     "epochs": 30,
#     "steps_per_epoch": 100,
#     "interval": "epoch",
#     "frequency": 1,
#     "strict": True
# }


# Doesn't change between experiments
LABEL_VAR = "label"
DATA_ROOT = "../../data/"
NUM_WORKERS = 12
PREFETCH_FACTOR = 4
PIN_MEMORY = True
ACCELERATOR = "gpu"
DEVICE = "cuda"
LOG_EVERY_N_STEPS = 50
CHECK_VAL_EVERY_N_EPOCH = 3
DETERMINISTIC = False

## First Experiment

In [None]:
MODEL_NAME = "bert-base-multilingual-cased"
MODEL_FOLDER = f"../../data/precomputed/{MODEL_NAME}"
RESULTS_PATH = f"../../results/{MODEL_NAME}/"
LOGGER = pl.loggers.TensorBoardLogger(RESULTS_PATH, name=MODEL_NAME, log_graph=True)
Path(RESULTS_PATH).mkdir(parents=True, exist_ok=True)

### Definition of the net architecture

In [None]:
class Net(nn.Module):
    def __init__(self, root, embed_dim, inter_dim, dropout=0.2, freeze=True):
        super().__init__()
        self.example_input_array = {
            "text": {
                "intervention": torch.randn(32, 768),
                "titre_regexed": torch.randint(100, (32,)).int(),
                "contexte": torch.randint(100, (32,)).int(),
            }
        }

        self.embed_dim = embed_dim
        self.inter_dim = inter_dim
        self.dropout = dropout
        self.freeze = freeze

        self.titre_embeddings = load_embedding(root, "titre_regexed", freeze=freeze)
        self.titre_fc = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(embed_dim, inter_dim),
            nn.ReLU(),
        )

        self.contexte_embeddings = load_embedding(root, "contexte", freeze=freeze)
        self.contexte_fc = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(embed_dim, inter_dim),
            nn.ReLU(),
        )

        self.intervention_fc = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(embed_dim, inter_dim),
            nn.ReLU(),
        )

        self.mlp = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(inter_dim, 512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 3),
        )

    def forward(self, **x):
        intervention = x["text"]["intervention"]
        titre_regexed = x["text"]["titre_regexed"]
        contexte = x["text"]["contexte"]

        intervention_repr = self.intervention_fc(intervention)

        titre_emb = self.titre_embeddings(titre_regexed)
        titre_repr = self.titre_fc(titre_emb)

        contexte_emb = self.contexte_embeddings(contexte)
        contexte_repr = self.contexte_fc(contexte_emb)

        pooled_repr = intervention_repr + titre_repr + contexte_repr

        logits = self.mlp(pooled_repr)

        return logits

In [None]:
NET = Net(MODEL_FOLDER, 768, 1024, dropout=0.2, freeze=True)

### Definition of the trainer and module

In [None]:
lit_model = LitModel(
    NET,
    optimizer_type=OPTIMIZER_TYPE,
    learning_rate=LR,
    optimizer_kwargs=OPTIMIZER_KWARGS,
    scheduler_kwargs=SCHEDULER_KWARGS,
    criterion_type=LOSS,
    batch_size=BATCH_SIZE,
    loader_kwargs={
        "root": MODEL_FOLDER,
        "text_vars": TEXT_VARS,
        "use_features": FEATURES,
        "label_var": LABEL_VAR,
        "num_workers": NUM_WORKERS,
        "prefetch_factor": PREFETCH_FACTOR,
        "pin_memory": PIN_MEMORY,
    },
)

trainer = pl.Trainer(
    accelerator=ACCELERATOR,
    max_epochs=MAX_EPOCHS,
    logger=LOGGER,
    callbacks=CALLBACKS,
    deterministic=DETERMINISTIC,
    log_every_n_steps=LOG_EVERY_N_STEPS,
    check_val_every_n_epoch=CHECK_VAL_EVERY_N_EPOCH,
)

 ### Training

In [None]:
trainer.fit(lit_model)

### Evaluation and visualization

#### Prediction on test set

In [None]:
preds = trainer.predict(ckpt_path="best")

#### Metrics and logs

In [None]:
results = process_predictions(preds)

In [None]:
metrics = calculate_metrics(results)
metrics

In [None]:
with open(Path(RESULTS_PATH) / "metrics.json", "w") as f:
    json.dump(metrics, f)

In [None]:
logs_dict = {
    "last_epoch": trainer.current_epoch,
    "log_dir": trainer.log_dir,
    "ckpt_path": trainer.ckpt_path,
    "total_parameters": pl.utilities.model_summary.summarize(lit_model).total_parameters,
    "trainable_parameters": pl.utilities.model_summary.summarize(lit_model).trainable_parameters,
    "model_size": pl.utilities.model_summary.summarize(lit_model).model_size,
    "hparams": dict(lit_model.hparams_initial),
    "time_elapsed": trainer.callbacks[2].time_elapsed(),
    "metrics": metrics,
}

In [None]:
with open(Path(RESULTS_PATH) / "logs.json", "w") as f:
    json.dump(logs_dict, f)

#### Plots

In [None]:
confusion_fig = plot_confusion_matrix(results, figsize=(6, 6), normalized=None)
confusion_fig.savefig(Path(RESULTS_PATH) / "confusion_matrix.png")

In [None]:
confusion_true_fig = plot_confusion_matrix(results, figsize=(6, 6), normalized="true")
confusion_true_fig.savefig(Path(RESULTS_PATH) / "confusion_matrix_true_normed.png")

In [None]:
roc_fig = plot_roc_curve(results, figsize=(6, 6), palette="deep")
roc_fig.savefig(Path(RESULTS_PATH) / "roc_curve.png")

In [None]:
pr_fig = plot_precision_recall_curve(results, figsize=(6, 6), palette="deep")
pr_fig.savefig(Path(RESULTS_PATH) / "precision_recall_curve.png")

In [None]:
network_fig = plot_network_graph(NET, device=DEVICE, model_name=MODEL_NAME, path=RESULTS_PATH)

In [None]:
lit_model.to_onnx(Path(RESULTS_PATH) / "textual_camembert.onnx")