<a href="https://colab.research.google.com/github/maxmatical/ml-cheatsheet/blob/master/Pytorch_Lightning_BERT_Huggingface_w_SAM_%2B_EMA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Inspirations
https://curiousily.com/posts/multi-label-text-classification-with-bert-and-pytorch-lightning/

https://github.com/mgrankin/over9000/blob/master/train.py


https://medium.com/pytorch/getting-started-with-ray-lightning-easy-multi-node-pytorch-lightning-training-e639031aff8b


In [None]:
!nvidia-smi

Wed Dec  1 22:59:46 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   72C    P0    73W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
%%capture
!pip install transformers
!pip install pytorch_lightning
!pip install torchmetrics
!pip install torch-ema

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import OneCycleLR # , ReduceLROnPlateau
from torch.optim import AdamW

from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
  
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor, StochasticWeightAveraging
from pytorch_lightning.loggers import TensorBoardLogger
import torchmetrics
from torchmetrics.functional import accuracy, f1, auroc

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, multilabel_confusion_matrix

from torch_ema import ExponentialMovingAverage

In [None]:
# pretrained model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
model = AutoModel.from_pretrained("distilroberta-base")

Some weights of the model checkpoint at distilroberta-base were not used when initializing RobertaModel: ['lm_head.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# Data

In [None]:
!gdown --id 1VuQ-U7TtggShMeuRSA_hzC8qGDl2LRkr

Downloading...
From: https://drive.google.com/uc?id=1VuQ-U7TtggShMeuRSA_hzC8qGDl2LRkr
To: /content/toxic_comments.csv
100% 68.8M/68.8M [00:00<00:00, 166MB/s]


In [None]:
df = pd.read_csv("toxic_comments.csv")

df.head()

Unnamed: 0,id,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate
0,0000997932d777bf,Explanation\nWhy the edits made under my usern...,0,0,0,0,0,0
1,000103f0d9cfb60f,D'aww! He matches this background colour I'm s...,0,0,0,0,0,0
2,000113f07ec002fd,"Hey man, I'm really not trying to edit war. It...",0,0,0,0,0,0
3,0001b41b1c6bb37e,"""\nMore\nI can't make any real suggestions on ...",0,0,0,0,0,0
4,0001d958c54c6e35,"You, sir, are my hero. Any chance you remember...",0,0,0,0,0,0


In [None]:
train_df, val_df = train_test_split(df, test_size=0.15)

In [None]:
# subsample clean comments
LABEL_COLUMNS = df.columns.tolist()[2:]

train_toxic = train_df[train_df[LABEL_COLUMNS].sum(axis=1) > 0]
train_clean = train_df[train_df[LABEL_COLUMNS].sum(axis=1) == 0]

train_df = pd.concat([
  train_toxic,
  train_clean.sample(15_000)
])

train_df.shape, val_df.shape

((28772, 8), (23936, 8))

In [None]:
# take only a subsample of each train_df and val_df for faster iterations
train_df = train_df.sample(1000)
val_df = val_df.sample(1000)

train_df.shape, val_df.shape

((1000, 8), (1000, 8))

## Creating Dataset and Lightning Data Module

In [None]:
# set batch size and max seq_len
bs = 12
seq_len = 256

In [None]:
class ToxicCommentsDataset(Dataset):

  def __init__(
    self,
    data: pd.DataFrame,
    tokenizer: AutoTokenizer,
    max_token_len: int = 128

  ):
    self.tokenizer = tokenizer
    self.data = data
    self.max_token_len = max_token_len

  def __len__(self):
    return len(self.data)

  def __getitem__(self, index: int):
    data_row = self.data.iloc[index]
    comment_text = data_row.comment_text
    labels = data_row[LABEL_COLUMNS]

    encoding = self.tokenizer.encode_plus(
      comment_text,
      add_special_tokens=True,
      max_length=self.max_token_len,
      return_token_type_ids=False,
      padding="max_length",
      truncation=True,
      return_attention_mask=True,
      return_tensors='pt',

    )

    return dict(
      # comment_text=comment_text, # don't put text here
      input_ids=encoding["input_ids"].flatten(),
      attention_mask=encoding["attention_mask"].flatten(),
      labels = torch.IntTensor(labels)
    #   labels=torch.FloatTensor(labels)
    )

In [None]:
# test
train_dataset = ToxicCommentsDataset(
  train_df,
  tokenizer,
  max_token_len=seq_len
)

sample_item = train_dataset[0]
sample_item.keys()

dict_keys(['input_ids', 'attention_mask', 'labels'])

In [None]:
print(sample_item["input_ids"], sample_item["labels"])
print(sample_item["input_ids"].shape)

tensor([    0,   100,   465,    24, 16437,     7,  2813,   951,     7,   213,
            7,    39, 22519,    25,  1010,    25,   678,    54, 13195,  7240,
         6566, 35561,   160,  8845,    47,     4,   407,   109,    24,  2540,
            4,   370,    40,   146,     5,   232,    10,   828,   357,   396,
          110, 22092,     4,     2,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1, 

In [None]:
class ToxicCommentsDataModule(pl.LightningDataModule):
  def __init__(self, train_df, test_df, tokenizer, batch_size=8, max_token_len=512):
    super().__init__()
    self.train_df, self.test_df = train_df, test_df
    self.tokenizer = tokenizer
    self.batch_size = batch_size
    self.max_token_len = max_token_len

  def setup(self, stage=None):
    self.train_dataset = ToxicCommentsDataset(
        self.train_df,
        self.tokenizer,
        self.max_token_len
    )

    self.test_dataset = ToxicCommentsDataset(
        self.test_df,
        self.tokenizer,
        self.max_token_len
    )

  def train_dataloader(self):
    return DataLoader(
        self.train_dataset,
        batch_size = self.batch_size,
        shuffle = True,
        num_workers=1
    )

  def val_dataloader(self):
    return DataLoader(
        self.test_dataset,
        batch_size = self.batch_size,
        shuffle = False,
        num_workers=1
    )

  def test_dataloader(self):
    return DataLoader(
        self.test_dataset,
        batch_size = self.batch_size,
        shuffle = False,
        num_workers=1
    )

In [None]:
data_module = ToxicCommentsDataModule(
    train_df,
    val_df,
    tokenizer,
    batch_size = bs,
    max_token_len = seq_len
)


In [None]:
data_module.setup() # call this before getting len of dataloader

In [None]:
len(data_module.train_dataloader())

84

# SAM

In [None]:
class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                self.state[p]["old_p"] = p.data.clone()
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.data = self.state[p]["old_p"]  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups

# Model

In [None]:
sample_batch = next(iter(DataLoader(train_dataset, batch_size=bs, num_workers=1)))
sample_batch["input_ids"].shape, sample_batch["attention_mask"].shape

(torch.Size([12, 256]), torch.Size([12, 256]))

In [None]:
# test lr schedule with https://www.kaggle.com/isbhargav/guide-to-pytorch-learning-rate-scheduling
# test wtih flat cos lr

In [None]:
class BertModel(nn.Module):
  def __init__(self, n_classes: int):
    super().__init__()
    self.model = model
    self.classifier = nn.Linear(self.model.config.hidden_size, n_classes)

  def forward(self, input_ids, attention_mask): 
    out = self.model(input_ids, attention_mask=attention_mask)
    out = self.classifier(out.pooler_output)
    return out

In [None]:
bert_model = BertModel(len(LABEL_COLUMNS)).to("cuda")



In [None]:
# bert_model(sample_batch["input_ids"], sample_batch["attention_mask"]).shape  # should be bs x 6

In [None]:
class ToxicCommentClassifier(pl.LightningModule):
  def __init__(
      self, 
      pytorch_model: nn.Module, 
      total_steps: int, 
      lr: float = 2e-5, 
      rho: float = 0.05,
      asam: bool = False,
      fit_func: str = "one_cycle", 
      is_ddp: bool = False,
      use_ema: bool = False,
      accumulate_grad_batches: int = 1
               
  ):
    super().__init__()
    self.lr = lr
    self.rho = rho  # neighborhood size for SAM
    self.asam = asam  # whether to use adaptive sam
    self.total_steps = total_steps # for lr schedule
    self.fit_func = fit_func
    self.is_ddp = is_ddp 
    self.use_ema = use_ema
    self.ema = None
    self.grad_acc_batches = accumulate_grad_batches
    
    self.pytorch_model = pytorch_model
    self.criterion = nn.BCEWithLogitsLoss() # bce for multi-label

    if self.use_ema:
      # self.pytorch_model should be moved to device
      self.ema = ExponentialMovingAverage(self.pytorch_model.parameters(), decay=0.995)

    # metrics 
    self.accuracy = torchmetrics.Accuracy()

    # manually define opt step
    self.automatic_optimization = False

    # track lr schedule
    self.lr_schedule = []

    # keep track of inputs for sam 2nd update with gradient accumulation
    self.input_list, self.attn_mask_list, self.labels_list = [], [], []

  def forward(self, input_ids, attention_mask):
    """
    forward step for lightning module
    """
    out = self.pytorch_model(input_ids, attention_mask=attention_mask)
    return out

  def training_step(self, batch, batch_idx):
    """
    training step
    since using manual optimization: only precision and accelerator logic is handled by Lightning

    Must manually specify everything else, such as:
    - manual optimization logic ie 
      - opt = self.optimizers()
      - opt.zero_grad()
      - self.manual_backward(loss)
      - and opt.step()
    - learning rate scheduler step
    - gradient accumulation logic
    - etc
    """
    # batch comes from dataset
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]

    # save input and output for 2nd step
    self.input_list.append(input_ids)
    self.attn_mask_list.append(attention_mask)
    self.labels_list.append(labels)

    opt = self.optimizers()
    self.lr_schedule.append(opt.param_groups[0]["lr"]) # check that lrs are the same (is true)

    # opt, opt_adamw = self.optimizers()
    # print(opt.param_groups[0]["lr"], opt_adamw.param_groups[0]["lr"]) # check that lrs are the same (is true)
    out = self(input_ids, attention_mask)


    loss = self.criterion(out, labels.to(dtype=torch.float32)) # cast to float because labels is in Int
    # scale the loss by gradient accumulation batches
    loss = loss / self.grad_acc_batches
    acc = self.accuracy(out, labels)

    # loss backward
    if self.is_ddp:
      # maybe should be `with self.pytorch_model.no_sync()`?
      # (test with both on ddp)
      with self.trainer.model.no_sync():
        self.manual_backward(loss)
    else:
      self.manual_backward(loss)

    # optimizer step every grad_acc_batches steps
    if (batch_idx + 1) % self.grad_acc_batches == 0:
      opt.first_step(zero_grad=True)

      # 2nd forward pass with saved input_list and labels_list
      # to get the accumulated gradients again
      for (input_ids, attention_mask, labels) in list(zip(self.input_list, self.attn_mask_list, self.labels_list)):
        out_2 = self(input_ids, attention_mask)
        loss_2 = self.criterion(out_2, labels.to(dtype=torch.float32))
        loss_2 = loss_2 / self.grad_acc_batches
        self.manual_backward(loss_2)

      # optimizer step with accumulated gradients
      opt.second_step(zero_grad=True)

      # clear saved lists
      self.input_list, self.attn_mask_list, self.labels_list = [], [], []

      # update ema (after every optimizer step)
      if self.use_ema:
        self.ema.update() # note: this might be self.ema.update(self.trainer.model.parameters())

      # lr schedule (every optimizer step)
      lr_sch = self.lr_schedulers()
      lr_sch.step()

    # log loss and metrics each k steps and each epoch
    loss = loss * self.grad_acc_batches  # scale loss back wrt gradient accumulation
    log_values = {"loss": loss, "train_acc": acc}
    self.log_dict(log_values, sync_dist=True, prog_bar=True, on_step=True, on_epoch=False)
    return {"loss": loss, "train_accuracy": acc}

  def validation_step(self, batch, batch_idx):
    """
    validation step
    get validation loss and accuracy metrics
    """
    # batch comes from dataset
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]

    if self.use_ema:
      with self.ema.average_parameters():
        out = self(input_ids, attention_mask)
    else:
      out = self(input_ids, attention_mask)

    # compute loss + metrics
    loss = self.criterion(out, labels.to(dtype=torch.float32)) # cast to float because labels is in Int
    acc = self.accuracy(out, labels)
    # log loss and metrics each k steps and each epoch
    log_values = {"val_loss": loss, "val_accuracy": acc}
    self.log_dict(log_values, sync_dist=True, prog_bar=True, on_step=True, on_epoch=False)
    return {"val_loss": loss, "val_accuracy": acc}

  def training_epoch_end(self, outputs):
    """
    after every training epoch, log train metrics

    note: this might not be needed if logging is done for each step and epoch
    """
    # log validation loss and metrics
    avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
    avg_acc = torch.stack([x["train_accuracy"] for x in outputs]).mean()

    self.log("avg_train_loss", avg_loss, sync_dist=True, prog_bar=True)
    self.log("avg_train_accuracy", avg_acc, sync_dist=True, prog_bar=True)

  def validation_epoch_end(self, outputs):
    """
    after every training epoch, log val metrics

    note: this might not be needed if logging is done for each step and epoch

    note: to save model with
    """
    # log validation loss and metrics
    avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
    avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
    self.log("avg_val_loss", avg_loss, sync_dist=True, prog_bar=True)
    self.log("avg_val_accuracy", avg_acc, sync_dist=True, prog_bar=True)

  def configure_optimizers(self):
    # TODO: 
    # 1. automatically find steps for onecycleLR
    # 2. add flat cos
    # 3. make lr scheduler configurable(one_cycle vs flat_cos)

    # trying SAM
    base_optimizer = AdamW
    optimizer = SAM(
        self.pytorch_model.parameters(), 
        base_optimizer=base_optimizer, 
        lr=self.lr, 
        betas=(0.9, 0.99),
        rho=self.rho,
        adaptive=self.asam
    )
    # lr schedule with sam (seems to work fine with 2 step?)
    scheduler = OneCycleLR(
      optimizer=optimizer,
      max_lr=self.lr,
      pct_start=0.3,
      total_steps=self.total_steps
    )

    # lr schedule with base optimizer
    # scheduler = OneCycleLR(
    #   optimizer=optimizer.base_optimizer,
    #   max_lr=self.lr,
    #   pct_start=0.3,
    #   total_steps=self.total_steps
    # )
    # return [optimizer, optimizer.base_optimizer], [scheduler]
    
    return [optimizer], [scheduler]

In [None]:
# training epoch related hyperparams
steps_per_epoch = len(data_module.train_dataloader())
n_epochs = 2
accumulate_grad_batches = 4

total_steps = steps_per_epoch * n_epochs
print(f"total_steps: {total_steps}")

total_steps: 168


Note: if using multiple gpus, total steps calculation is something like
```
steps_per_epoch = len(data_module.train_dataloader())

# for gpu or tpu cores
num_devices = 4
if tpu_cores:
    num_devices = max(num_devices, tpu_cores)

accumulate_grad_batches = 1 # for no gradient accumulation

effective_accum = accumulate_grad_batches * num_devices
total_steps (steps_per_epoch // effective_accum) * n_epochs

```

In [None]:
toxic_comment_model = ToxicCommentClassifier(bert_model, total_steps = total_steps, use_ema=True, lr=0.01)

In [None]:
# sam, adamw = toxic_comment_model.configure_optimizers()[0][0], toxic_comment_model.configure_optimizers()[0][1]

In [None]:
# sam.param_groups[0]["lr"], adamw.param_groups[0]["lr"]

In [None]:
sam = toxic_comment_model.configure_optimizers()[0][0]
sam.param_groups[0]["lr"]

0.0003999999999999993

In [None]:
# toxic_comment_model(sample_batch["input_ids"], sample_batch["attention_mask"])[1].shape  # should be bs x 6

# Trainer

defining the trainer and running `trainer.fit` 

In [None]:
# callbacks
lr_monitor_cb = LearningRateMonitor(logging_interval='step')

logger = TensorBoardLogger("lightning_logs", name="toxic-comments")

callbacks = [lr_monitor_cb]

In [None]:
trainer = pl.Trainer(
    logger=logger,
    callbacks = callbacks,
    precision=16,
    max_epochs = n_epochs,
    accelerator="gpu",
    devices=1,
)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
# optional: currently not working?
# lr_finder = trainer.tuner.lr_find(model=toxic_comment_model, datamodule=data_module)

In [None]:
trainer.fit(toxic_comment_model, data_module)

  f"DataModule.{name} has already been called, so it will not be called again. "
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type              | Params
----------------------------------------------------
0 | pytorch_model | BertModel         | 82.1 M
1 | criterion     | BCEWithLogitsLoss | 0     
2 | accuracy      | Accuracy          | 0     
----------------------------------------------------
82.1 M    Trainable params
0         Non-trainable params
82.1 M    Total params
164.246   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

  f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [None]:
trainer.logged_metrics

{'avg_train_accuracy': 0.8191137909889221,
 'avg_train_loss': 0.4239652454853058,
 'avg_val_accuracy': 0.9619709253311157,
 'avg_val_loss': 0.2339302897453308,
 'train_acc_epoch': 0.8193333745002747,
 'train_acc_step': tensor(0.7917),
 'train_loss_epoch': 0.42350253462791443,
 'train_loss_step': tensor(0.4818),
 'val_accuracy_epoch': 0.9623332619667053,
 'val_accuracy_step': 0.9166666865348816,
 'val_loss_epoch': 0.2337791621685028,
 'val_loss_step': 0.2528218924999237}

In [None]:
# %reload_ext tensorboard
# %tensorboard --logdir lightning_logs/


In [None]:
# original weights before ema
toxic_comment_model.pytorch_model.classifier.weight



Parameter containing:
tensor([[-0.0099, -0.0108,  0.0002,  ..., -0.0067,  0.0195,  0.0189],
        [-0.0098, -0.0243, -0.0092,  ..., -0.0401, -0.0533, -0.0349],
        [-0.0247, -0.0069, -0.0084,  ..., -0.0209, -0.0066,  0.0006],
        [-0.0666, -0.0643, -0.0458,  ..., -0.0863, -0.0143, -0.0616],
        [ 0.0505,  0.0348,  0.0083,  ...,  0.0122, -0.0357,  0.0014],
        [-0.0413, -0.0203,  0.0004,  ..., -0.0118, -0.0084,  0.0015]],
       requires_grad=True)

In [None]:
bert_model2 = BertModel(len(LABEL_COLUMNS))
print("weights before ema \n\n", bert_model2.classifier.weight)
toxic_comment_model.ema.copy_to(bert_model2.parameters())

bert_model2.classifier.weight

weights before ema 

 Parameter containing:
tensor([[ 1.0467e-02, -6.6118e-03, -1.8938e-02,  ...,  2.3005e-02,
         -8.5002e-03,  2.3651e-02],
        [-2.6170e-02, -1.7333e-02, -1.4113e-02,  ..., -2.7001e-02,
          2.4300e-05,  2.5617e-02],
        [-4.2284e-03, -3.5021e-02,  1.5847e-03,  ...,  1.5148e-02,
         -2.3250e-02, -2.8635e-02],
        [ 2.9981e-02,  2.0158e-02, -2.7264e-02,  ...,  2.6321e-02,
         -1.7058e-02,  3.0486e-02],
        [-4.8996e-03,  1.2535e-03, -4.2920e-03,  ..., -2.0642e-02,
         -1.9914e-02,  3.1620e-02],
        [-2.0354e-02, -2.5691e-03, -2.8449e-02,  ..., -2.5412e-02,
         -1.4536e-03,  2.2498e-02]], requires_grad=True)


Parameter containing:
tensor([[-0.0100, -0.0109,  0.0003,  ..., -0.0068,  0.0196,  0.0188],
        [-0.0094, -0.0239, -0.0096,  ..., -0.0396, -0.0537, -0.0345],
        [-0.0248, -0.0070, -0.0083,  ..., -0.0211, -0.0065,  0.0005],
        [-0.0666, -0.0643, -0.0458,  ..., -0.0864, -0.0143, -0.0616],
        [ 0.0504,  0.0347,  0.0085,  ...,  0.0120, -0.0355,  0.0012],
        [-0.0416, -0.0206,  0.0007,  ..., -0.0121, -0.0081,  0.0012]],
       requires_grad=True)

In [None]:
toxic_comment_model.lr_schedule

NameError: ignored

# Saving only the pytorch model

In [None]:
# saving model weights
saved_model_pth = "saved_model.pth"

# if using ema weights, copy those weights to model before saving
if use_ema:
  ema.copy_to(toxic_comment_model.pytorch_model)

torch.save(toxic_comment_model.pytorch_model.state_dict(), saved_model_pth)

NameError: ignored

# Exporting model as pure pytorch 

In [None]:
new_bert_model = BertModel(6)

new_bert_model.load_state_dict(torch.load(saved_model_pth))

In [None]:
# check values are the same

inf_preds = new_bert_model(sample_batch["input_ids"], sample_batch["attention_mask"])
print(inf_preds.shape) # should be 8 x 6
inf_preds

In [None]:
toxic_comment_model(sample_batch["input_ids"], sample_batch["attention_mask"])[1]