In [1]:
import pickle
from pathlib import Path

import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import seaborn as sns
import torch
from pytorch_lightning import callbacks
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from tqdm.autonotebook import tqdm

from nlp_assemblee.simple_datasets import get_dataloader, get_single_dataloader

In [2]:
sns.set_context("paper")
sns.set_palette("deep")
sns.set_style("white")

colors = sns.color_palette("deep")

seed_everything(42, workers=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
device

Global seed set to 42


'cuda'

In [3]:
class SimplestBert(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(768, 512),
            nn.LeakyReLU(0.01),
            nn.Linear(512, 512),
            nn.LeakyReLU(0.01),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, 3),
        )

    def forward(self, x):
        x_ = x["text"]["intervention"]
        x_ = self.fc(x_)
        return x_

In [71]:
def compute_test_loss(net, loader, length, criterion, prefix="val"):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    loss_list = []
    accuracy_list = []

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(loader):
            targets = targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            loss_item = loss.item()
            test_loss += loss_item * targets.size(0)
            loss_list.append(loss_item)

            _, predicted = outputs.max(1)
            correct_item = predicted.eq(targets).sum().item()
            correct += correct_item
            accuracy_list.append(correct_item)

    res = {
        f"{prefix}_loss_list": loss_list,
        f"{prefix}_loss": test_loss / length,
        f"{prefix}_accuracy_list": accuracy_list,
        f"{prefix}_accuracy": correct / length,
    }

    return res

In [64]:
def plot_results(results, loss_baseline=0.98978976, accuracy_baseline=0.39, palette="deep"):
    n_epoch = len(results)

    epoch_train_loss = [results[i]["train_loss"] for i in range(n_epoch)]
    epoch_val_loss = [results[i]["val_loss"] for i in range(n_epoch)]
    batch_train_loss = []
    for i in range(n_epoch):
        batch_train_loss.extend(results[i]["train_loss_list"])
    batch_val_loss = []
    for i in range(n_epoch):
        batch_val_loss.extend(results[i]["val_loss_list"])

    epoch_train_accuracy = [results[i]["train_accuracy"] for i in range(n_epoch)]
    epoch_val_accuracy = [results[i]["val_accuracy"] for i in range(n_epoch)]
    batch_train_accuracy = []
    for i in range(n_epoch):
        batch_train_accuracy.extend(results[i]["train_accuracy_list"])
    batch_val_accuracy = []
    for i in range(n_epoch):
        batch_val_accuracy.extend(results[i]["val_accuracy_list"])

    epoch_x_range = np.arange(1, n_epoch + 1)
    batch_x_range = np.linspace(1, n_epoch + 1, len(batch_train_loss))

    colors = sns.color_palette(palette)

    fig, axs = plt.subplots(1, 2, figsize=(12, 6))
    axs[0].plot(epoch_x_range, epoch_train_loss, label="Train loss", color=colors[0])
    axs[0].plot(epoch_x_range, epoch_val_loss, label="Validation loss", color=colors[2])
    axs[0].plot(batch_x_range, batch_train_loss, linestyle="--", color=colors[0])
    axs[0].plot(batch_x_range, batch_val_loss, linestyle="--", color=colors[2])
    axs[0].axhline(y=loss_baseline, linestyle="-.", label="Baseline", color=colors[3])
    axs[0].set(title="Evolution of the loss", xlabel="Epoch", ylabel="Loss")

    axs[1].plot(epoch_x_range, epoch_train_accuracy, label="Train accuracy", color=colors[0])
    axs[1].plot(epoch_x_range, epoch_val_accuracy, label="Validation accuracy", color=colors[2])
    axs[1].plot(batch_x_range, batch_train_accuracy, linestyle="--", color=colors[0])
    axs[1].plot(batch_x_range, batch_val_accuracy, linestyle="--", color=colors[2])
    axs[1].axhline(y=accuracy_baseline, linestyle="-.", label="Baseline", color=colors[3])
    axs[1].set(title="Evolution of the accuracy", xlabel="Epoch", ylabel="Accuracy")
    axs[1].yaxis.set_major_formatter(mtick.PercentFormatter())

    return fig

In [127]:
def train_net(
    net,
    optimizer,
    criterion,
    loaders,
    lengths,
    n_epoch=10,
    resume_epoch=0,
    results={},
):
    trainloader = loaders["train"]
    valloader = loaders["val"]
    train_length = lengths["train"]
    val_length = lengths["val"]

    try:
        epoch_pbar = tqdm(total=n_epoch, leave=False)
        for epoch in range(n_epoch):
            net.train()
            train_loss = 0.0
            correct = 0

            # Train loop
            train_res = {
                "train_loss_list": [],
                "train_accuracy_list": [],
            }

            pbar = tqdm(leave=True, total=len(trainloader))
            for i, (inputs, labels) in enumerate(trainloader, 0):
                labels = labels.to(device)
                optimizer.zero_grad()

                # Prediction
                outputs = net(inputs)

                # Loss and step
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                # Loss and accuracy logging
                loss_item = loss.item()
                train_loss += loss_item * labels.size(0)
                train_res["train_loss_list"].append(loss_item)

                _, predicted = outputs.max(1)
                correct_item = predicted.eq(labels).sum().item()
                correct += correct_item
                train_res["train_accuracy_list"].append(correct_item / labels.size(0))

                pbar.set_description(f"Train Loss : {loss_item: .4f}")
                pbar.update()

            train_res["train_loss"] = train_loss / train_length
            train_res["train_accuracy"] = correct / train_length

            # Test loss
            val_res = compute_test_loss(net, valloader, val_length, criterion, prefix="val")

            results[resume_epoch + epoch] = {**train_res, **val_res}

            pbar.set_description(
                f'Train Loss : {train_res["train_loss"]: .4f}'
                + f' | Test Loss : {val_res["val_loss"]: .4f}'
                + f' | Train Accuracy : {train_res["train_accuracy"]: .2%}'
                + f' | Test Accuracy : {val_res["val_accuracy"]: .2%}'
            )
            epoch_pbar.set_description(
                f"Epoch : {resume_epoch + epoch}/{resume_epoch + n_epoch}"
                + f' | Train Loss : {train_res["train_loss"]: .4f}'
                + f' | Test Loss : {val_res["val_loss"]: .4f}'
                + f' | Train Accuracy : {train_res["train_accuracy"]: .2%}'
                + f' | Test Accuracy : {val_res["val_accuracy"]: .2%}'
            )
            epoch_pbar.update()

    except KeyboardInterrupt as e:
        print(e)
    finally:
        fig = plot_results(results)
        return results

In [6]:
datasets, loaders, lengths = get_dataloader(
    root="../../../data/",
    bert_type="bert",
    batch_size=256,
    text_vars=["intervention"],
    use_features=False,
    label_var="label",
    num_workers=12,
    prefetch_factor=4,
    pin_memory=True,
)

In [4]:
class LitModel(pl.LightningModule):
    def __init__(
        self,
        classifier,
        optimizer_type="Adam",
        learning_rate=1e-3,
        optimizer_kwargs={},
        criterion_type="CrossEntropyLoss",
        batch_size=256,
        loader_kwargs={
            "root": "../../../data/",
            "bert_type": "bert",
            "text_vars": ["intervention"],
            "use_features": False,
            "label_var": "label",
            "num_workers": 12,
            "prefetch_factor": 4,
            "pin_memory": True,
        },
    ):
        super().__init__()

        self.optimizer_type = optimizer_type
        self.learning_rate = learning_rate
        self.optimizer_kwargs = optimizer_kwargs
        self.criterion_type = criterion_type
        self.batch_size = batch_size
        self.loader_kwargs = loader_kwargs
        self.save_hyperparameters(ignore=["classifier", "criterion"])

        if criterion_type == "CrossEntropyLoss":
            self.criterion = nn.functional.cross_entropy

        self.classifier = classifier

    def forward(self, x):
        return self.classifier(x)

    def configure_optimizers(self):
        if self.optimizer_type == "SGD":
            optimizer = optim.SGD(
                self.classifier.parameters(), lr=self.learning_rate, **self.optimizer_kwargs
            )
        elif self.optimizer_type == "Adam":
            optimizer = optim.Adam(
                self.classifier.parameters(), lr=self.learning_rate, **self.optimizer_kwargs
            )
        return optimizer

    def get_loss(self, batch, model_type="train"):
        x, y = batch
        z = self.classifier(x)
        loss = nn.CrossEntropyLoss()(z, y)
        self.log(f"{model_type}_loss", loss, prog_bar=(model_type == "val"))

        _, predicted = z.max(1)
        accuracy = predicted.eq(y).sum().item() / y.size(0)
        self.log(f"{model_type}_accuracy", accuracy, prog_bar=(model_type == "val"))

        return loss

    def training_step(self, batch, batch_idx):
        tain_loss = self.get_loss(batch, model_type="train")
        return tain_loss

    def validation_step(self, val_batch, batch_idx):
        val_loss = self.get_loss(val_batch, model_type="val")
        return val_loss

    def test_step(self, test_batch, batch_idx):
        test_loss = self.get_loss(test_batch, model_type="test")
        return test_loss

    def train_dataloader(self):
        _, loader, _ = get_single_dataloader(
            phase="train", batch_size=self.batch_size, **self.loader_kwargs
        )
        return loader

    def val_dataloader(self):
        _, loader, _ = get_single_dataloader(
            phase="val", batch_size=self.batch_size, **self.loader_kwargs
        )
        return loader

    def test_dataloader(self):
        _, loader, _ = get_single_dataloader(
            phase="test", batch_size=self.batch_size, **self.loader_kwargs
        )
        return loader

In [5]:
net = SimplestBert()
net.train()

lit_model = LitModel(
    net,
    optimizer_type="Adam",
    learning_rate=5e-3,
    optimizer_kwargs={},
    criterion_type="CrossEntropyLoss",
    batch_size=512,
    loader_kwargs={
        "root": "../../../data/",
        "bert_type": "bert",
        "text_vars": ["intervention"],
        "use_features": False,
        "label_var": "label",
        "num_workers": 12,
        "prefetch_factor": 4,
        "pin_memory": True,
    },
)

In [6]:
trainer = pl.Trainer(
    accelerator="gpu",
    # profiler="simple",
    max_epochs=100,
    default_root_dir="../../../",
    # fast_dev_run=True,
    # overfit_batches=1,
    # auto_scale_batch_size="binsearch",
    # auto_lr_find=True,
    callbacks=[
        callbacks.EarlyStopping(monitor="val_loss", mode="min", check_finite=True, patience=10),
        callbacks.ModelSummary(max_depth=-1),
        callbacks.Timer(duration="00:03:00:00", interval="epoch"),  # Max three hours
    ],
)

Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [7]:
trainer.fit(lit_model)

Missing logger folder: ../../../lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type         | Params
-------------------------------------------------
0 | classifier      | SimplestBert | 788 K 
1 | classifier.fc   | Sequential   | 788 K 
2 | classifier.fc.0 | Linear       | 393 K 
3 | classifier.fc.1 | LeakyReLU    | 0     
4 | classifier.fc.2 | Linear       | 262 K 
5 | classifier.fc.3 | LeakyReLU    | 0     
6 | classifier.fc.4 | Linear       | 131 K 
7 | classifier.fc.5 | LeakyReLU    | 0     
8 | classifier.fc.6 | Linear       | 771   
-------------------------------------------------
788 K     Trainable params
0         Non-trainable params
788 K     Total params
3.154     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]