<a href="https://colab.research.google.com/github/maxmatical/ml-cheatsheet/blob/master/Pytorch_BERT_Huggingface_w_SAM_%2B_EMA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Inspirations



In [1]:
!nvidia-smi

Sun Dec 26 01:49:21 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   57C    P8    32W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
%%capture
!pip install transformers
!pip install torchmetrics
!pip install torch-ema
!pip install koila

In [3]:
import math
import time
from typing import Optional, Tuple
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import OneCycleLR, LambdaLR, CosineAnnealingLR, ReduceLROnPlateau
from torch.optim import AdamW

from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
  
import torchmetrics
from torchmetrics.classification import F1, Accuracy
from torchmetrics.functional import accuracy, f1, auroc

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, multilabel_confusion_matrix

from torch_ema import ExponentialMovingAverage

In [4]:
# pretrained model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
model = AutoModel.from_pretrained("distilroberta-base")

Downloading:   0%|          | 0.00/480 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/316M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilroberta-base were not used when initializing RobertaModel: ['lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# Getting Data

In [5]:
!gdown --id 1VuQ-U7TtggShMeuRSA_hzC8qGDl2LRkr

Downloading...
From: https://drive.google.com/uc?id=1VuQ-U7TtggShMeuRSA_hzC8qGDl2LRkr
To: /content/toxic_comments.csv
100% 68.8M/68.8M [00:00<00:00, 87.6MB/s]


In [6]:
df = pd.read_csv("toxic_comments.csv")

df.head()

Unnamed: 0,id,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate
0,0000997932d777bf,Explanation\nWhy the edits made under my usern...,0,0,0,0,0,0
1,000103f0d9cfb60f,D'aww! He matches this background colour I'm s...,0,0,0,0,0,0
2,000113f07ec002fd,"Hey man, I'm really not trying to edit war. It...",0,0,0,0,0,0
3,0001b41b1c6bb37e,"""\nMore\nI can't make any real suggestions on ...",0,0,0,0,0,0
4,0001d958c54c6e35,"You, sir, are my hero. Any chance you remember...",0,0,0,0,0,0


In [7]:
train_df, val_df = train_test_split(df, test_size=0.15)

In [8]:
# subsample clean comments
LABEL_COLUMNS = df.columns.tolist()[2:]

train_toxic = train_df[train_df[LABEL_COLUMNS].sum(axis=1) > 0]
train_clean = train_df[train_df[LABEL_COLUMNS].sum(axis=1) == 0]

train_df = pd.concat([
  train_toxic,
  train_clean.sample(15_000)
])

train_df.shape, val_df.shape

((28818, 8), (23936, 8))

In [9]:
# take only a subsample of each train_df and val_df for faster iterations
train_df = train_df.sample(1000)
val_df = val_df.sample(1000)

train_df.shape, val_df.shape

((1000, 8), (1000, 8))

# Creating Dataset

In [10]:
# set batch size max seq_len
bs = 12
seq_len = 256

In [11]:
class ToxicCommentsDataset(Dataset):

  def __init__(
    self,
    data: pd.DataFrame,
    tokenizer: AutoTokenizer,
    max_token_len: int = 128

  ):
    self.tokenizer = tokenizer
    self.data = data
    self.max_token_len = max_token_len

  def __len__(self):
    return len(self.data)

  def __getitem__(self, index: int):
    data_row = self.data.iloc[index]
    comment_text = data_row.comment_text
    labels = data_row[LABEL_COLUMNS]

    encoding = self.tokenizer.encode_plus(
      comment_text,
      add_special_tokens=True,
      max_length=self.max_token_len,
      return_token_type_ids=False,
      padding="max_length",
      truncation=True,
      return_attention_mask=True,
      return_tensors='pt',

    )

    return dict(
      # comment_text=comment_text, # don't put text here
      input_ids=encoding["input_ids"].flatten(),
      attention_mask=encoding["attention_mask"].flatten(),
      labels = torch.IntTensor(labels)
    #   labels=torch.FloatTensor(labels)
    )

validations

In [12]:
# test
train_dataset = ToxicCommentsDataset(
  train_df,
  tokenizer,
  max_token_len=seq_len
)

sample_item = train_dataset[0]
sample_item.keys()

dict_keys(['input_ids', 'attention_mask', 'labels'])

In [13]:
print(sample_item["input_ids"], sample_item["labels"])
print(sample_item["input_ids"].shape)

tensor([    0, 43854,  1912,  1437, 50118, 50118, 31414,  6121,     4, 29221,
            4,   318,    47,   535,    19,     5,  1081,  1912,     6,    24,
         3374, 15655,  1402,    14,    47,    40,    28,  4953,     4,   370,
          581,    67,  2217,   143, 10796,    14,    47,   189,   202,    33,
           15,   110,  1383,   743,     6,   187,    82,    40,  6929,    47,
           25,    10, 17561,  4474,     4,   318,    47,    33,   143,   205,
        11304,   259,     6,     8,   888,   236,     7,  1477, 28274,     6,
           38,  8745,    47,     7,  6327,   159,     4,     2,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1, 

In [14]:
# data sets
train_dataset = ToxicCommentsDataset(
    train_df,
    tokenizer,
    max_token_len=seq_len
)

val_dataset = ToxicCommentsDataset(
    val_df,
    tokenizer,
    max_token_len=seq_len
)



# Creating Dataloaders

Using `koila` to determine max batch size that fits in gpu mem

In [15]:
train_dl = DataLoader(
    train_dataset,
    batch_size=bs,
    shuffle=True,
    num_workers=1
)

val_dl = DataLoader(
    val_dataset,
    batch_size=bs,
    shuffle=False,
    num_workers=1
)

print(len(train_dl), len(val_dl))

84 84


# Model

In [17]:
class BertModel(nn.Module):
  def __init__(self, n_classes: int):
    super().__init__()
    self.model = model
    self.classifier = nn.Linear(self.model.config.hidden_size, n_classes)

  def forward(self, input_ids, attention_mask): 
    out = self.model(input_ids, attention_mask=attention_mask)
    out = self.classifier(out.pooler_output)
    return out

In [18]:
bert_model = BertModel(len(LABEL_COLUMNS))



# Validate model with dataloader

In [19]:
sample_batch = next(iter(DataLoader(train_dataset, batch_size=bs, num_workers=1)))
sample_batch["input_ids"].shape, sample_batch["attention_mask"].shape

(torch.Size([12, 256]), torch.Size([12, 256]))

In [20]:
bert_model(sample_batch["input_ids"], sample_batch["attention_mask"]).shape  # should be bs x 6

torch.Size([12, 6])

# Training enhancements

## SAM optimizer

In [21]:
class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                self.state[p]["old_p"] = p.data.clone()
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.data = self.state[p]["old_p"]  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups

## model tracking (loss and metrics)

In [22]:
# keep running average of loss values
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

## LR scheduler

In [23]:
# flat_cos scheduler
def d(x): 
  """
  dummy function
  """
  return 1
    
class ConcatLR(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, scheduler1, scheduler2, total_steps, pct_start=0.5, last_epoch=-1):
        self.scheduler1 = scheduler1
        self.scheduler2 = scheduler2
        self.step_start = float(pct_start * total_steps) - 1
        super(ConcatLR, self).__init__(optimizer, last_epoch)
    
    def step(self):
        if self.last_epoch <= self.step_start:
            self.scheduler1.step()
        else:
            self.scheduler2.step()
        super().step()
        
    def get_lr(self):
        if self.last_epoch <= self.step_start:
            return self.scheduler1.get_lr()
        else:
            return self.scheduler2.get_lr()

## Helper functions for configuring optimizers and lr scheduler

In [24]:
# helper funcs

def configure_optimizer(
    model: nn.Module, 
    lr: float = 2e-5, 
    betas: Tuple[float, float] = (0.9, 0.999),
    eps: float = 1e-8,
    wd: float = 0.01,
    use_sam: bool = False, 
    rho: float = 0.05, 
    asam: bool = False
):
  if not use_sam:
    optimizer = AdamW(model.parameters(), lr=lr, betas=betas, eps=eps, weight_decay=wd)
  else:
    base_optimizer = AdamW
    optimizer = SAM(
        model.parameters(), 
        base_optimizer=base_optimizer, 
        lr=lr, 
        betas=betas,
        weight_decay=wd,
        eps=eps,
        rho=rho,
        adaptive=asam
    )
  return optimizer 

def configure_scheduler(
    fit_func: str, 
    lr: float, 
    total_steps: int, 
    optimizer, 
    pct_start: float = 0.3, 
    use_sam: bool = False
):
  assert fit_func in ["one_cycle", "flat_cos"], f"fit function {fit_func} not found"
  # if using sam, lr scheduler is on base optimizer
  opt = optimizer.base_optimizer if use_sam else optimizer
  if fit_func == "one_cycle":
    scheduler = OneCycleLR(
        optimizer=opt,
        max_lr=lr,
        pct_start=pct_start,
        total_steps=total_steps
      )
  elif fit_func == "flat_cos":
    dummy = LambdaLR(opt, d)
    cosine = CosineAnnealingLR(opt, total_steps*(1-pct_start))
    scheduler = ConcatLR(opt, dummy, cosine, total_steps, pct_start)

  else:
    raise ValueError(f"fit_func {fit_func} not found")

  return scheduler

## Model saver

In [25]:
class ModelSaver:
  def __init__(self, save_path: str, mode: str = "max"):
    """
    class used for saving models during training
    """
    self.save_path = save_path
    self.mode = mode
    assert self.mode in ["min", "max"], f"mode {mode} not found"
    # self.best_value = torch.tensor(float("inf")) if mode == "min" else torch.tensor(float("-inf"))
    self.best_value = float("inf") if mode == "min" else float("-inf")

    def save_model(self, epoch: int, model: nn.Module, current_value: float):
      """
      compares current_value with self.best_value
      if current_value is better then
      1. save model weights to self.save_path
      2. update self.best_value with current_value
      """
      if (
        (self.mode == "min" and current_value <= self.best_value)
        or (self.mode == "max" and current_value >= self.best_value)
      ):
        torch.save(model.state_dict(), saved_model_pth)
        print(f"better model found at epoch {epoch} with value {current_value}")
        print(f"model weights saved to {saved_model_pth}, to load model weights, create new model and use new_model.load_state_dict(torch.load({saved_model_pth}))")
        self.best_value = current_value





# Training Model

FP16 with SAM:

discussion thread: https://github.com/davda54/sam/issues/7

should be something like:
```
# first pass
with torch.cuda.amp.autocast():
    out = model(input_ids, attention_mask)
    loss = criterion(out, labels.to(dtype=torch.float32))

scaler.scale(loss).backward()
scaler.unscale_(optimizer)
optimizer.first_step(zero_grad=True)
scaler.update()

# 2nd pass
with torch.cuda.amp.autocast():
    out_2 = model(input_ids, attention_mask)
    loss_2 = criterion(out_2, labels.to(dtype=torch.float32))

scaler.scale(loss_2).backward()
scaler.unscale_(optimizer)
optimizer.second_step(zero_grad=True)
scaler.update()
```

Gradient accumulation with SAM: https://github.com/davda54/sam/issues/3

4 different types of training modes:
1. FP32
2. SAM
3. FP16
4. FP16 + SAM

TODO:
- [x] Saving best model (with/without EMA) based on measured metric
- Early stopping? based on metric
- [x] `ReduceLROnPlateau` configurable
- [x] terminate training on NaN

In [31]:
def train(
    model: nn.Module,
    train_dl: DataLoader,
    val_dl: DataLoader,
    n_epochs: int,
    lr: float = 2e-5,
    b1: float = 0.9,
    b2: float = 0.999,
    eps: float = 1e-7,
    wd: float = 0.01,
    use_sam: bool = False,
    rho: float = 0.05,
    asam: bool = False,
    fit_func: str = "one_cycle",
    pct_start: float = 0.3,
    fp16: bool = False,
    is_ddp: bool = False,
    n_gpus: int = 1,
    use_ema: bool = False,
    ema_decay: float = 0.995,
    grad_acc_batches: int = 1,
    model_saver: Optional[ModelSaver] = None,
    reduce_lr_on_plateau: bool = False,
    reduce_lr_on_plateau_mode: str = "min",
    reduce_lr_on_plateau_patience: int = 10,
    reduce_lr_on_plateau_factor: float = 0.1,
    terminate_on_nan: bool = False
):
  """
  args:
      model: nn.Module the model to train
      train_dl: DataLoader, train dataloader
      val_dl: DataLoader, validation dataloader
      n_epochs: int, number of epochs to train
      lr: float = 2e-5, learning rate
      b1: float = 0.9, beta1 for optimizers like adam, ranger etc.
      b2: float = 0.999, beta2 for optimizers like adam, ranger etc.
      eps: float = 1e-7, eps for optimizers (set to >1e-7 for fp16)
      wd: float = 0.01, weight decay regularization
      use_sam: bool = False, whether to use SAM with optimizer as base optimizer
      rho: float = 0.05, neighborhood size for SAM (set to 10x larger for ASAM)
      asam: bool = False, whether to use ASAM variant of SAM
      fit_func: str = "one_cycle", what type of training, one_cycle or flat_cos
      pct_start: float = 0.3, pct to start cosine annealing for fit func
      fp16: bool = False, whether to use mixed precision training with AMP
      is_ddp: bool = False, whether to use DDP (when n_gpus >1)
      n_gpus: int = 1, number of gpus used for training
      use_ema: bool = False, whether to use EMA to average model weights
      ema_decay: float = 0.995, decay factor for EMA
      grad_acc_batches: int = 1, number of batches to accumulate for gradient accumulation
      model_saver: Optional[ModelSaver] = None, saving models
      reduce_lr_on_plateau: bool = False, whether to use ReduceLROnPleateau scheduler
      reduce_lr_on_plateau_patience: int = 10,
      reduce_lr_on_plateau_factor: float = 0.1, factor to reduce lr by
    
  """
  # set device
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  # set model to device if haven't already done so
  model.to(device)

  # set loss
  criterion = nn.BCEWithLogitsLoss()

  # configure optimizer 
  optimizer = configure_optimizer(model, lr=lr, betas=(b1, b2), eps=eps, wd=wd, use_sam=use_sam, rho=rho, asam=asam)

  # configure lr scheduler
  total_steps = int(n_epochs * len(train_dl) / n_gpus / grad_acc_batches)
  lr_schedule = configure_scheduler(fit_func, lr, total_steps, optimizer, pct_start, use_sam)

  # configure ema
  if use_ema:
    ema = ExponentialMovingAverage(model.parameters(), decay=ema_decay)

  # configure scaler for fp16
  scaler = torch.cuda.amp.GradScaler() if fp16 else None

  # used for SAM
  input_list, attn_mask_list, labels_list = [], [], []

  # reduce lr on plateau
  reduce_lr_on_plateau_scheduler = ReduceLROnPlateau(
      optimizer,
      mode=reduce_lr_on_plateau_mode,
      patience=reduce_lr_on_plateau_patience,
      factor=reduce_lr_on_plateau_factor,
      verbose=True
  ) if reduce_lr_on_plateau else None
  # for testing purposes
  # lrs = []

  # start training
  for epoch in range(n_epochs):
    #############################################
    # training step
    #############################################
    model.train()
    # initialize tracking variables
    start = time.time()
    losses = AverageMeter('Loss', ':.4e')
    f1_metrics = AverageMeter("F1", ':6.2f')

    for batch_idx, batch in enumerate(train_dl):
      input_ids = batch["input_ids"].to(device)
      attention_mask = batch["attention_mask"].to(device)
      labels = batch["labels"].to(device)

      # with fp16
      if fp16 and not use_sam:
        with torch.cuda.amp.autocast():
          out = model(input_ids, attention_mask)
          loss = criterion(out, labels.to(dtype=torch.float32))
          # scale the loss by gradient accumulation batches
          loss = loss / grad_acc_batches
        # backwards step  
        scaler.scale(loss).backward()
        # optimizer
        if (batch_idx + 1) % grad_acc_batches == 0:
          scaler.step(optimizer)
          scaler.update()
          optimizer.zero_grad()

      # with fp16 + sam
      elif fp16 and use_sam:
        with torch.cuda.amp.autocast():
          out = model(input_ids, attention_mask)

          # save input and output for 2nd step
          input_list.append(input_ids)
          attn_mask_list.append(attention_mask)
          labels_list.append(labels)

          loss = criterion(out, labels.to(dtype=torch.float32))
          # scale the loss by gradient accumulation batches
          loss = loss / grad_acc_batches
        # backwards step  
        scaler.scale(loss).backward()
        # optimizer step
        if (batch_idx + 1) % grad_acc_batches == 0:
          scaler.unscale_(optimizer)
          optimizer.first_step(zero_grad=True)
          scaler.update()

          # 2nd forward pass with saved input_list and labels_list
          # to get the accumulated gradients again
          for (input_ids, attention_mask, labels) in list(zip(input_list, attn_mask_list, labels_list)):
            with torch.cuda.amp.autocast():
              out_2 = model(input_ids, attention_mask)
              loss_2 = criterion(out_2, labels.to(dtype=torch.float32))
              loss_2 = loss_2 / grad_acc_batches
            # 2nd backwards step  
            scaler.scale(loss_2).backward()

          # 2nd optimizer step (outside the for loop)
          scaler.unscale_(optimizer)
          optimizer.second_step(zero_grad=True)
          scaler.update()

          # clear saved lists
          input_list, attn_mask_list, labels_list = [], [], []
      
      # with fp32 + sam
      elif not fp16 and use_sam:
        # save input and output for 2nd step
        input_list.append(input_ids)
        attn_mask_list.append(attention_mask)
        labels_list.append(labels)

        out = model(input_ids, attention_mask)
        loss = criterion(out, labels.to(dtype=torch.float32))
        # scale the loss by gradient accumulation batches
        loss = loss / grad_acc_batches
        loss.backward()

        if (batch_idx + 1) % grad_acc_batches == 0:
          optimizer.first_step(zero_grad=True)

          # 2nd step
          for (input_ids, attention_mask, labels) in list(zip(input_list, attn_mask_list, labels_list)):
            out_2 = model(input_ids, attention_mask)
            loss_2 = criterion(out_2, labels.to(dtype=torch.float32))
            loss_2 = loss_2 / grad_acc_batches
            loss_2.backward()
          optimizer.second_step(zero_grad=True)
          input_list, attn_mask_list, labels_list = [], [], []


      # without fp16 or sam
      else:
        out = model(input_ids, attention_mask)
        loss = criterion(out, labels.to(dtype=torch.float32))
        # scale the loss by gradient accumulation batches
        loss = loss / grad_acc_batches
        # backwards step
        loss.backward()
        if (batch_idx + 1) % grad_acc_batches == 0:
          optimizer.step()
          optimizer.zero_grad()

      # terminate training if nan encountered
      if terminate_on_nan and torch.isnan(loss).item():
        print("NaN encountered in training loss, terminating training.")
        break
      
      # log loss and metrics 
      losses.update(loss.item() * grad_acc_batches)  # scale loss back up by grad_acc_batches
      f1_metrics.update(f1(torch.sigmoid(out), labels))

      # update EMA and lr schedule
      if (batch_idx + 1) % grad_acc_batches == 0:
        # update ema
        if use_ema:
          ema.update()
        lr_schedule.step()
        # for tracking lr schedule
        # lrs.append(optimizer.param_groups[0]["lr"])

    #############################################
    # validation step
    #############################################
    model.eval()
    # initialize val variables
    val_losses = AverageMeter('Loss', ':.4e')
    val_f1_metrics = AverageMeter("F1", ':6.2f')

    with torch.no_grad():
      for batch_idx, batch in enumerate(val_dl):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        # eval with fp16
        if fp16:
          with torch.cuda.amp.autocast():
            # use EMA
            if use_ema:
              with ema.average_parameters():
                out = model(input_ids, attention_mask)
            else:
              out = model(input_ids, attention_mask)
          
        else:
          # use EMA
          if use_ema:
            with ema.average_parameters():
              out = model(input_ids, attention_mask)
          else:
            out = model(input_ids, attention_mask)

        # calculate val loss and val_f1
        loss = criterion(out, labels.to(dtype=torch.float32))
        val_losses.update(loss.item())
        val_f1_metrics.update(f1(torch.sigmoid(out), labels))

    #############################################
    # end of epoch
    # logging training stats
    # reduce_lr_on_plateau_scheduler.step() if is not None
    # TODO add model saving step here
    #############################################
    """
    ema logic should be
    if model_saver:
      if use_ema:
        ema.store() # store model current weights
        ema.copy_to() # update model with ema weights
        model_saver.save_model(epoch+1, model, val_f1_metrics.avg)
        ema.restore() # restore current weights to model

      or it could be
      if use_ema:
        with ema.average_parameters():
          model_saver.save_model(epoch+1, model, val_f1_metrics.avg)
      else:
        model_saver.save_model(epoch+1, model, val_f1_metrics.avg)
    """
    if reduce_lr_on_plateau_scheduler:
      reduce_lr_on_plateau_scheduler.step(val_f1_metrics.avg)
    
    end = time.time()
    elapsed = end - start

    # log relevant metrics at the end of epoch
    print(f"Epoch {epoch+1}: train loss: {losses.avg}, val loss: {val_losses.avg}, training f1: {f1_metrics.avg}, val f1: {val_f1_metrics.avg}, time: {elapsed}")

  # saving model weights
  saved_model_pth = "saved_model.pth"

  # if using ema weights, copy those weights to model before saving
  if use_ema:
    ema.copy_to(model.parameters())
  torch.save(model.state_dict(), saved_model_pth)
  print(f"model weights saved to {saved_model_pth}, to load model weights, create new model and use new_model.load_state_dict(torch.load({saved_model_pth}))")
  # print(lrs)



In [32]:
del bert_model
bert_model = BertModel(len(LABEL_COLUMNS))

In [33]:
use_sam = False
fp16 = False
use_ema = False
grad_acc_batches = 10


train(
    bert_model,
    train_dl,
    val_dl,
    n_epochs=1,
    fp16=fp16,
    use_sam=use_sam,
    use_ema=use_ema,
    grad_acc_batches=grad_acc_batches
)

Epoch 1: train loss: 0.5984141291784388, val loss: 0.4876699823708761, training f1: 0.32552164793014526, val f1: 0.3752131462097168, time: 59.73013091087341
model weights saved to saved_model.pth, to load model weights, create new model and use new_model.load_state_dict(torch.load(saved_model.pth))


### Some results
trained for 1 epoch only, gradient accumulation 12 batches (144 bs)


baseline (no fp16, sam, ema):
```
Epoch 1: train loss: 0.6239880334053721, val loss: 0.5045745234404292, training f1: 0.4051700830459595, val f1: 0.3108885884284973, time: 55.422086238861084
```

fp16:
```
Epoch 1: train loss: 0.5987142457493714, val loss: 0.47237819823480787, training f1: 0.27484646439552307, val f1: 0.261325478553772, time: 110.72927689552307
```
sam:
```
Epoch 1: train loss: 0.46754554366426815, val loss: 0.3434885683513823, training f1: 0.4873872697353363, val f1: 0.3218401372432709, time: 93.9434859752655
```

ema:
```
Epoch 1: train loss: 0.5785652372453894, val loss: 0.47438419610261917, training f1: 0.5454299449920654, val f1: 0.41533538699150085, time: 56.51772093772888
```


sam + fp16

```

```

fp16 + ema

```

```


sam+ema
```
Epoch 1: train loss: 0.6115093609052045, val loss: 0.516309067606926, training f1: 0.3987903594970703, val f1: 0.2628635764122009, time: 94.63389563560486
```

fp16 + sam + ema
```
Epoch 1: train loss: 0.5853038478110517, val loss: 0.47891736775636673, training f1: 0.39333027601242065, val f1: 0.1278771311044693, time: 192.29889297485352
```


In [None]:
device = "cuda"
sample_batch = next(iter(DataLoader(val_dataset, batch_size=bs, num_workers=1)))
input_ids = sample_batch["input_ids"].to(device)
attention_mask = sample_batch["attention_mask"].to(device)
labels = sample_batch["labels"].to(device)

out = bert_model(input_ids, attention_mask)

print(f1(torch.sigmoid(out), labels))
print(torch.sigmoid(out))
print(labels)

tensor(0., device='cuda:0')
tensor([[0.3604, 0.0749, 0.1837, 0.0692, 0.1929, 0.0806],
        [0.3803, 0.0752, 0.1921, 0.0698, 0.2088, 0.0787],
        [0.3515, 0.0763, 0.1836, 0.0712, 0.1901, 0.0817],
        [0.3820, 0.0757, 0.1940, 0.0667, 0.1922, 0.0804],
        [0.4409, 0.0791, 0.2198, 0.0593, 0.2180, 0.0834],
        [0.3872, 0.0754, 0.1918, 0.0665, 0.1922, 0.0803],
        [0.3724, 0.0750, 0.1772, 0.0717, 0.1873, 0.0847],
        [0.3428, 0.0758, 0.1795, 0.0735, 0.1838, 0.0833],
        [0.3711, 0.0760, 0.1883, 0.0677, 0.1894, 0.0823],
        [0.3813, 0.0759, 0.1940, 0.0691, 0.1950, 0.0782],
        [0.4030, 0.0771, 0.2016, 0.0631, 0.1977, 0.0796],
        [0.3810, 0.0765, 0.1917, 0.0652, 0.1971, 0.0807]], device='cuda:0',
       grad_fn=<SigmoidBackward0>)
tensor([[0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],

In [None]:
print(f1(F.sigmoid(out), labels))

tensor(0.4444, device='cuda:0')




# Saving and Exporting model

In [None]:
# if using ema weights, copy those weights to model before saving
bert1 = BertModel(len(LABEL_COLUMNS))
print("bert1")
print(bert1.classifier.weight)
ema = ExponentialMovingAverage(bert1.parameters(), decay=0.995)
bert2 = BertModel(len(LABEL_COLUMNS))
print("bert2")
print(bert2.classifier.weight)


if use_ema:
  ema.copy_to(bert2.parameters())

print("bert2 after ema copy weight")
print(bert2.classifier.weight)


# torch.save(bert_model.state_dict(), saved_model_pth)

In [None]:
new_bert_model = BertModel(6)

new_bert_model.load_state_dict(torch.load(saved_model_pth))

In [None]:
# check values are the same

inf_preds = new_bert_model(sample_batch["input_ids"], sample_batch["attention_mask"])
print(inf_preds.shape) # should be 8 x 6
inf_preds