In [1]:
import pickle
from pathlib import Path

import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import seaborn as sns
import torch
from pytorch_lightning import callbacks
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from tqdm.autonotebook import tqdm

from nlp_assemblee.simple_datasets import get_dataloader, get_single_dataloader

In [2]:
sns.set_context("paper")
sns.set_palette("deep")
sns.set_style("white")

colors = sns.color_palette("deep")

seed_everything(42, workers=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
device

Global seed set to 42


'cuda'

In [3]:
class LitModel(pl.LightningModule):
    def __init__(
        self,
        classifier,
        optimizer_type="Adam",
        learning_rate=1e-3,
        optimizer_kwargs={},
        criterion_type="CrossEntropyLoss",
        batch_size=256,
        loader_kwargs={
            "root": "../../../data/",
            "bert_type": "camembert",
            "text_vars": ["intervention"],
            "use_features": False,
            "label_var": "label",
            "num_workers": 12,
            "prefetch_factor": 4,
            "pin_memory": True,
        },
    ):
        super().__init__()

        self.optimizer_type = optimizer_type
        self.learning_rate = learning_rate
        self.optimizer_kwargs = optimizer_kwargs
        self.criterion_type = criterion_type
        self.batch_size = batch_size
        self.loader_kwargs = loader_kwargs
        self.save_hyperparameters(ignore=["classifier", "criterion"])

        if criterion_type == "CrossEntropyLoss":
            self.criterion = nn.functional.cross_entropy

        self.classifier = classifier

    def forward(self, x):
        return self.classifier(x)

    def configure_optimizers(self):
        if self.optimizer_type == "SGD":
            optimizer = optim.SGD(
                self.classifier.parameters(), lr=self.learning_rate, **self.optimizer_kwargs
            )
        elif self.optimizer_type == "Adam":
            optimizer = optim.Adam(
                self.classifier.parameters(), lr=self.learning_rate, **self.optimizer_kwargs
            )
        return optimizer

    def get_loss(self, batch, model_type="train"):
        x, y = batch
        z = self.classifier(x)
        loss = nn.CrossEntropyLoss()(z, y)
        self.log(f"{model_type}_loss", loss, prog_bar=(model_type == "val"))

        _, predicted = z.max(1)
        accuracy = predicted.eq(y).sum().item() / y.size(0)
        self.log(f"{model_type}_accuracy", accuracy, prog_bar=(model_type == "val"))

        return loss

    def training_step(self, batch, batch_idx):
        tain_loss = self.get_loss(batch, model_type="train")
        return tain_loss

    def validation_step(self, val_batch, batch_idx):
        val_loss = self.get_loss(val_batch, model_type="val")
        return val_loss

    def test_step(self, test_batch, batch_idx):
        test_loss = self.get_loss(test_batch, model_type="test")
        return test_loss

    def train_dataloader(self):
        _, loader, _ = get_single_dataloader(
            phase="train", batch_size=self.batch_size, **self.loader_kwargs
        )
        return loader

    def val_dataloader(self):
        _, loader, _ = get_single_dataloader(
            phase="val", batch_size=self.batch_size, **self.loader_kwargs
        )
        return loader

    def test_dataloader(self):
        _, loader, _ = get_single_dataloader(
            phase="test", batch_size=self.batch_size, **self.loader_kwargs
        )
        return loader

In [44]:
class ClassicBert(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(768, 256),
            nn.LeakyReLU(0.01),
            nn.Dropout(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.01),
            nn.Dropout(0.2),
        )
        self.fc2 = nn.Sequential(
            nn.Linear(768, 256),
            nn.LeakyReLU(0.01),
            nn.Dropout(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.01),
            nn.Dropout(0.2),
        )
        self.fc3 = nn.Sequential(
            nn.Linear(768, 256),
            nn.LeakyReLU(0.01),
            nn.Dropout(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.01),
            nn.Dropout(0.2),
        )
        self.fc4 = nn.Sequential(
            nn.Linear(2, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.01),
        )
        self.fc5 = nn.Sequential(
            nn.Linear(256, 512),
            nn.LeakyReLU(0.01),
            nn.Dropout(0.2),
            nn.Linear(512, 3),
            nn.LeakyReLU(0.01),
        )

    def forward(self, x):
        x_ = x["text"]["intervention"].to(device)
        y_ = x["text"]["titre_complet"].to(device)
        z_ = x["text"]["profession"].to(device)
        f_ = x["features"].to(device).float()

        x_ = self.fc1(x_)
        y_ = self.fc2(y_)
        z_ = self.fc3(z_)
        f_ = self.fc4(f_)

        p_ = x_ + y_ + z_ + f_

        p_ = self.fc5(p_)

        return p_

In [45]:
net = ClassicBert()

lit_model = LitModel(
    net,
    optimizer_type="Adam",
    learning_rate=5e-3,
    optimizer_kwargs={},
    criterion_type="CrossEntropyLoss",
    batch_size=256,
    loader_kwargs={
        "root": "../../../data/",
        "bert_type": "bert",
        "text_vars": ["intervention", "titre_complet", "profession"],
        "use_features": True,
        "label_var": "label",
        "num_workers": 12,
        "prefetch_factor": 4,
        "pin_memory": True,
    },
)

In [46]:
trainer = pl.Trainer(
    accelerator="gpu",
    # profiler="simple",
    max_epochs=100,
    default_root_dir="../../../",
    # fast_dev_run=True,
    # overfit_batches=1,
    # auto_scale_batch_size="binsearch",
    # auto_lr_find=True,
    callbacks=[
        callbacks.EarlyStopping(monitor="val_loss", mode="min", check_finite=True, patience=10),
        callbacks.ModelSummary(max_depth=-1),
        # Max three hours
        callbacks.Timer(duration="00:03:00:00", interval="epoch"),
    ],
)

Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [47]:
trainer.fit(lit_model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name             | Type        | Params
--------------------------------------------------
0  | classifier       | ClassicBert | 987 K 
1  | classifier.fc1   | Sequential  | 262 K 
2  | classifier.fc1.0 | Linear      | 196 K 
3  | classifier.fc1.1 | LeakyReLU   | 0     
4  | classifier.fc1.2 | Dropout     | 0     
5  | classifier.fc1.3 | Linear      | 65.8 K
6  | classifier.fc1.4 | LeakyReLU   | 0     
7  | classifier.fc1.5 | Dropout     | 0     
8  | classifier.fc2   | Sequential  | 262 K 
9  | classifier.fc2.0 | Linear      | 196 K 
10 | classifier.fc2.1 | LeakyReLU   | 0     
11 | classifier.fc2.2 | Dropout     | 0     
12 | classifier.fc2.3 | Linear      | 65.8 K
13 | classifier.fc2.4 | LeakyReLU   | 0     
14 | classifier.fc2.5 | Dropout     | 0     
15 | classifier.fc3   | Sequential  | 262 K 
16 | classifier.fc3.0 | Linear      | 196 K 
17 | classifier.fc3.1 | LeakyReLU   | 0     
18 | classifier.fc3.2 | Dropout     | 0     
19 | c

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]