In [1]:
import os
from typing import Dict, List, Tuple, Union, Any
import sys

import lightning as L
import matplotlib.pyplot as plt
import numpy as np
import tensorboard
import torch
from lightning.pytorch.callbacks import (
    LearningRateMonitor,
    ModelCheckpoint,
    ModelSummary,
)
import copy
from rich import print
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import GPT2Tokenizer, AutoModelForSequenceClassification, GPT2Config
from sklearn.datasets import fetch_20newsgroups

%load_ext autoreload
%autoreload 2
%load_ext rich
%load_ext tensorboard

# Set random seed for reproducibility
seed = 42
L.seed_everything(seed)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# Set data directory
DATA_DIR = os.path.join(os.getcwd(), "data")

Seed set to 42


## Load and preprocess dataset

In [2]:
train = fetch_20newsgroups(data_home=DATA_DIR, subset="train")
test = fetch_20newsgroups(data_home=DATA_DIR, subset="test")

In [3]:
len(train.data), len(test.data)

[1m([0m[1;36m11314[0m, [1;36m7532[0m[1m)[0m

In [4]:
class TextDataset(Dataset):
    def __init__(self, data, target, tokenizer, max_length):
        self.data = data
        self.target = target
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]
        label = self.target[idx]

        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
            add_special_tokens=True,
        ).to(device)

        input_ids = encoding["input_ids"].squeeze()
        attention_mask = encoding["attention_mask"].squeeze()

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": torch.tensor(label, dtype=torch.long).to(device),
        }

In [5]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
tokenizer


[1;35mGPT2Tokenizer[0m[1m([0m[33mname_or_path[0m=[32m'gpt2'[0m, [33mvocab_size[0m=[1;36m50257[0m, [33mmodel_max_length[0m=[1;36m1024[0m, [33mis_fast[0m=[3;91mFalse[0m, [33mpadding_side[0m=[32m'left'[0m, [33mtruncation_side[0m=[32m'right'[0m, [33mspecial_tokens[0m=[1m{[0m[32m'bos_token'[0m: [32m'[0m[32m<[0m[32m|endoftext|[0m[32m>'[0m[39m, [0m[32m'eos_token'[0m[39m: [0m[32m'<|endoftext|>'[0m[39m, [0m[32m'unk_token'[0m[39m: [0m[32m'<|endoftext|>'[0m[39m, [0m[32m'pad_token'[0m[39m: [0m[32m'<|endoftext|>'[0m[1;39m}[0m[39m, [0m[33mclean_up_tokenization_spaces[0m[39m=[0m[3;92mTrue[0m[1;39m)[0m[39m,  [0m[33madded_tokens_decoder[0m[39m=[0m[1;39m{[0m
[39m        [0m[1;36m50256[0m[39m: [0m[1;35mAddedToken[0m[1;39m([0m[32m"<|endoftext|[0m[32m>[0m[32m"[0m, [33mrstrip[0m=[3;91mFalse[0m, [33mlstrip[0m=[3;91mFalse[0m, [33msingle_word[0m=[3;91mFalse[0m, [33mnormalized[0m=[3;92mTrue[0m,

In [6]:
train_dataset = TextDataset(train.data, train.target, tokenizer, max_length=1024)
test_dataset = TextDataset(test.data, test.target, tokenizer, max_length=1024)

train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)


## Load pre-trained model

In [22]:
cfg = GPT2Config.from_pretrained("gpt2", num_labels=20)
model = AutoModelForSequenceClassification.from_pretrained("gpt2", config=cfg)

model.resize_token_embeddings(len(tokenizer))
model.config.pad_token_id = model.config.eos_token_id
model.to(device)

for param in model.transformer.parameters():
    param.requires_grad = False


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [24]:
model


[1;35mGPT2ForSequenceClassification[0m[1m([0m
  [1m([0mtransformer[1m)[0m: [1;35mGPT2Model[0m[1m([0m
    [1m([0mwte[1m)[0m: [1;35mEmbedding[0m[1m([0m[1;36m50257[0m, [1;36m768[0m[1m)[0m
    [1m([0mwpe[1m)[0m: [1;35mEmbedding[0m[1m([0m[1;36m1024[0m, [1;36m768[0m[1m)[0m
    [1m([0mdrop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.1[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
    [1m([0mh[1m)[0m: [1;35mModuleList[0m[1m([0m
      [1m([0m[1;36m0[0m-[1;36m11[0m[1m)[0m: [1;36m12[0m x [1;35mGPT2Block[0m[1m([0m
        [1m([0mln_1[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m768[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mattn[1m)[0m: [1;35mGPT2Attention[0m[1m([0m
          [1m([0mc_attn[1m)[0m: [1;35mConv1D[0m[1m([0m[1m)[0m
          [1m([0mc_proj[1m)[0m: [1;35mConv1D[0m[1m([0m[1m)[0m


## Train on dataset

In [25]:
class GPT2Classifier(L.LightningModule):
    def __init__(self, model):
        super(GPT2Classifier, self).__init__()
        self.model = model

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.model(input_ids, attention_mask=attention_mask, labels=labels)
        return outputs

    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]

        outputs = self(input_ids, attention_mask, labels)
        loss = outputs.loss

        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]

        outputs = self(input_ids, attention_mask, labels)
        loss = outputs.loss

        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=2e-5)
        return optimizer

In [26]:
classifier = GPT2Classifier(model)

In [27]:
trainer = L.Trainer(
    accelerator="auto",
    logger=False,
    max_epochs=2,
    callbacks=[
        ModelCheckpoint(monitor="val_loss"),
    ],
)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


# Building GPT-2 from scratch

### Conv1D

This is essentially a linear transformation applied using a 1D convolutional layer to provide greater flexibility in the model as the flattening happens automatically.

In [7]:
class Conv1D(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()

        w = torch.empty(input_dim, output_dim)
        nn.init.normal_(w, std=0.02)

        self.weight = nn.Parameter(w)
        self.bias = nn.Parameter(torch.zeros(output_dim))

    def forward(self, x):
        """
        x: (batch_size, seq_len, d_model)

        Returns:
        x: (batch_size, seq_len, output_dim)
        """

        size_out = x.size()[:-1] + (self.weight.size(1),)
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)

        x = x.view(*size_out)

        return x

In [9]:
Conv1D(768, 20)(torch.randn(8, 1024, 768)).shape

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m8[0m, [1;36m1024[0m, [1;36m20[0m[1m][0m[1m)[0m

### Feed-forward network

This feed-forward network takes in the pre-activation values from the attention block, projects them to a higher-dimensional space, applies a non-linearity, and then projects them back down to the original dimension.

In [10]:
class FFN(nn.Module):
    def __init__(self, input_dim, output_dim, dropout=0.1):
        super().__init__()

        self.c_fc = Conv1D(input_dim, output_dim)
        self.c_proj = Conv1D(output_dim, input_dim)

        self.act = F.gelu
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        """
        x: (batch_size, seq_len, d_model)

        Returns:
        x: (batch_size, seq_len, d_model)
        """

        x = self.act(self.c_fc(x))
        x = self.drop(x)
        x = self.c_proj(x)

        return x


In [11]:
FFN(768, 768 * 4)(torch.randn(8, 1024, 768)).shape

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m8[0m, [1;36m1024[0m, [1;36m768[0m[1m][0m[1m)[0m

### Self-attention



In [55]:
class Attention(nn.Module):
    def __init__(self, d_model, n_head, n_ctx, d_head, bias=True):
        super().__init__()

        self.d_model = d_model

        self.n_head = n_head
        self.d_head = d_head

        self.c_attn = Conv1D(d_model, d_model * 3)
        self.softmax = nn.Softmax(dim=-1)

        self.register_buffer(
            "tril", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)
        )

        self.dropout = nn.Dropout(0.1)
        self.c_proj = Conv1D(d_model, d_model)

    def split_heads(self, x):
        """
        x: (batch_size, seq_len, d_model)

        Returns:
        x: (batch_size, n_head, seq_len, d_head)
        """

        new_shape = x.size()[:-1] + (
            self.n_head,
            self.d_head,
        )  # (batch_size, seq_len, n_head, d_head)
        x = x.view(*new_shape)

        return x.permute(0, 2, 1, 3)  # (batch_size, n_head, seq_len, d_head)

    def _attn(self, q, k, v, mask=None):
        """
        q: (batch_size, n_head, seq_len, d_head)
        k: (batch_size, n_head, seq_len, d_head)
        v: (batch_size, n_head, seq_len, d_head)

        Returns:
        out: (batch_size, n_head, seq_len, d_head)
        """

        scores = torch.matmul(q, k.transpose(-2, -1))
        scores = scores / np.sqrt(v.size(-1))  # (batch_size, n_head, seq_len, seq_len)

        if mask is None:
            mask = (
                torch.tril(torch.ones(scores.size(-2), scores.size(-1)))
                .view(1, 1, *scores.size()[-2:])
                .to(scores.device)
            )

        scores = scores.masked_fill(mask == 0, float("-inf"))

        attn = self.softmax(scores)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)

        return out

    def merge_heads(self, x):
        """
        x: (batch_size, n_head, seq_len, d_head)

        Returns:
        x: (batch_size, seq_len, d_model)
        """

        x = x.permute(0, 2, 1, 3).contiguous()
        new_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        x = x.view(*new_shape)

        return x

    def forward(self, x, mask=None):
        """
        x: (batch_size, seq_len, d_model)

        Returns:
        x: (batch_size, seq_len, d_model)
        """

        x = self.c_attn(x)

        q, k, v = x.split(self.d_model, dim=-1)
        q, k, v = self.split_heads(q), self.split_heads(k), self.split_heads(v)

        out = self._attn(q, k, v, mask)
        out = self.merge_heads(out)

        out = self.c_proj(out)

        return out

In [56]:
Attention(768, 8, 1024, 768 // 8)(torch.randn(8, 1024, 768)).shape

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m8[0m, [1;36m1024[0m, [1;36m768[0m[1m][0m[1m)[0m

### Transformer block

In [15]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_head, dropout=0.1):
        super().__init__()

        self.attn = Attention(
            d_model, n_head, n_ctx=1024, d_head=d_model // n_head, bias=True
        )
        self.mlp = FFN(d_model, d_model * 4, dropout=dropout)

        self.ln_1 = nn.LayerNorm(d_model)
        self.ln_2 = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        """
        x: (batch_size, seq_len, d_model)

        Returns:
        x: (batch_size, seq_len, d_model)
        """

        x = x + self.attn(self.ln_1(x), mask=mask)
        x = x + self.mlp(self.ln_2(x))

        return x

In [16]:
TransformerBlock(768, 12)(torch.randn(8, 1024, 768)).shape

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m8[0m, [1;36m1024[0m, [1;36m768[0m[1m][0m[1m)[0m

### GPT-2 architecture

In [17]:
class GPT2(nn.Module):
    def __init__(
        self, n_layers=12, n_head=12, n_ctx=1024, d_model=768, n_labels=20, vcb_sz=50257
    ):
        super().__init__()

        self.n_layers = n_layers
        self.d_model = d_model

        self.wte = nn.Embedding(vcb_sz, d_model)
        self.wpe = nn.Embedding(n_ctx, d_model)

        self.drop = nn.Dropout(0.1)

        block = TransformerBlock(d_model=d_model, n_head=n_head, dropout=0.1)
        self.h = self._get_clones(block, n_layers)
        self.ln_f = nn.LayerNorm(d_model)

        self.out = nn.Linear(d_model, n_labels, bias=False)

        self._init_weights(self)

    def _get_clones(self, module, n):
        return nn.ModuleList([copy.deepcopy(module) for i in range(n)])

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, src, labels=None, mask=None, pos_ids=None):
        """
        src: (batch_size, seq_len)

        Returns:

        logits: (batch_size, n_labels)
        """
        if pos_ids is None:
            pos_ids = torch.arange(0, src.size(-1)).unsqueeze(0).to(src.device)

        with torch.no_grad():
            inp = self.drop((self.wte(src) + self.wpe(pos_ids)))

            for i in range(self.n_layers):
                inp = self.h[i](inp, mask=mask)

            inp = self.ln_f(inp)

        logits = self.out(inp[:, -1])

        return logits

In [18]:
GPT2(2, 12, 1024, 768, 20, 1000)(torch.randint(0, 1000, (8, 1024))).shape

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m8[0m, [1;36m20[0m[1m][0m[1m)[0m

In [19]:
custom_gpt2 = GPT2(3, 12, 1024, 768, 20, 50257).to(device)

custom_gpt2


[1;35mGPT2[0m[1m([0m
  [1m([0mwte[1m)[0m: [1;35mEmbedding[0m[1m([0m[1;36m50257[0m, [1;36m768[0m[1m)[0m
  [1m([0mwpe[1m)[0m: [1;35mEmbedding[0m[1m([0m[1;36m1024[0m, [1;36m768[0m[1m)[0m
  [1m([0mdrop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.1[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m([0mh[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m-[1;36m2[0m[1m)[0m: [1;36m3[0m x [1;35mTransformerBlock[0m[1m([0m
      [1m([0mattn[1m)[0m: [1;35mAttention[0m[1m([0m
        [1m([0mc_attn[1m)[0m: [1;35mConv1D[0m[1m([0m[1m)[0m
        [1m([0msoftmax[1m)[0m: [1;35mSoftmax[0m[1m([0m[33mdim[0m=[1;36m-1[0m[1m)[0m
        [1m([0mdropout[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.1[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
        [1m([0mc_proj[1m)[0m: [1;35mConv1D[0m[1m([0m[1m)[0m
      [1m)[0m
      [1m([0mmlp[1m)[0m: [1;35mF

In [20]:
cfg = GPT2Config.from_pretrained("gpt2", num_labels=20, n_layer=3)
model = AutoModelForSequenceClassification.from_pretrained("gpt2", config=cfg)

model.resize_token_embeddings(len(tokenizer))
model.config.pad_token_id = model.config.eos_token_id
model.to(device)

Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



[1;35mGPT2ForSequenceClassification[0m[1m([0m
  [1m([0mtransformer[1m)[0m: [1;35mGPT2Model[0m[1m([0m
    [1m([0mwte[1m)[0m: [1;35mEmbedding[0m[1m([0m[1;36m50257[0m, [1;36m768[0m[1m)[0m
    [1m([0mwpe[1m)[0m: [1;35mEmbedding[0m[1m([0m[1;36m1024[0m, [1;36m768[0m[1m)[0m
    [1m([0mdrop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.1[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
    [1m([0mh[1m)[0m: [1;35mModuleList[0m[1m([0m
      [1m([0m[1;36m0[0m-[1;36m2[0m[1m)[0m: [1;36m3[0m x [1;35mGPT2Block[0m[1m([0m
        [1m([0mln_1[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m768[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mattn[1m)[0m: [1;35mGPT2Attention[0m[1m([0m
          [1m([0mc_attn[1m)[0m: [1;35mConv1D[0m[1m([0m[1m)[0m
          [1m([0mc_proj[1m)[0m: [1;35mConv1D[0m[1m([0m[1m)[0m
  

In [21]:
state_dict = model.state_dict()
custom_gpt2_state_dict = custom_gpt2.state_dict()

In [22]:
old_keys = []
new_keys = []

for key in state_dict.keys():
    if "transformer" in key:
        old_keys.append(key)
        new_keys.append(key.replace("transformer.", ""))

for old_key, new_key in zip(old_keys, new_keys):
    state_dict[new_key] = state_dict.pop(old_key)

pretrained_dict = {k: v for k, v in state_dict.items() if k in custom_gpt2_state_dict}

custom_gpt2_state_dict.update(pretrained_dict)

custom_gpt2.load_state_dict(custom_gpt2_state_dict)
custom_gpt2.eval()



[1;35mGPT2[0m[1m([0m
  [1m([0mwte[1m)[0m: [1;35mEmbedding[0m[1m([0m[1;36m50257[0m, [1;36m768[0m[1m)[0m
  [1m([0mwpe[1m)[0m: [1;35mEmbedding[0m[1m([0m[1;36m1024[0m, [1;36m768[0m[1m)[0m
  [1m([0mdrop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.1[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m([0mh[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m-[1;36m2[0m[1m)[0m: [1;36m3[0m x [1;35mTransformerBlock[0m[1m([0m
      [1m([0mattn[1m)[0m: [1;35mAttention[0m[1m([0m
        [1m([0mc_attn[1m)[0m: [1;35mConv1D[0m[1m([0m[1m)[0m
        [1m([0msoftmax[1m)[0m: [1;35mSoftmax[0m[1m([0m[33mdim[0m=[1;36m-1[0m[1m)[0m
        [1m([0mdropout[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.1[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
        [1m([0mc_proj[1m)[0m: [1;35mConv1D[0m[1m([0m[1m)[0m
      [1m)[0m
      [1m([0mmlp[1m)[0m: [1;35mF

### Testing on data

In [23]:
batch = next(iter(train_dataloader))

input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]

In [24]:
input_ids


[1;35mtensor[0m[1m([0m[1m[[0m[1m[[0m[1;36m50256[0m, [1;36m50256[0m, [1;36m50256[0m,  [33m...[0m,  [1;36m6930[0m,    [1;36m13[0m,   [1;36m628[0m[1m][0m,
        [1m[[0m[1;36m50256[0m, [1;36m50256[0m, [1;36m50256[0m,  [33m...[0m,    [1;36m13[0m, [1;36m15532[0m,   [1;36m198[0m[1m][0m,
        [1m[[0m[1;36m50256[0m, [1;36m50256[0m, [1;36m50256[0m,  [33m...[0m,   [1;36m628[0m,   [1;36m628[0m,   [1;36m198[0m[1m][0m,
        [33m...[0m,
        [1m[[0m[1;36m50256[0m, [1;36m50256[0m, [1;36m50256[0m,  [33m...[0m, [1;36m19618[0m,    [1;36m29[0m,   [1;36m198[0m[1m][0m,
        [1m[[0m[1;36m50256[0m, [1;36m50256[0m, [1;36m50256[0m,  [33m...[0m,   [1;36m198[0m, [1;36m14679[0m,   [1;36m198[0m[1m][0m,
        [1m[[0m[1;36m50256[0m, [1;36m50256[0m, [1;36m50256[0m,  [33m...[0m,  [1;36m4992[0m, [1;36m23500[0m,   [1;36m198[0m[1m][0m[1m][0m, [33mdevice[0m=[32m'cuda:0'[0m[1m)[0m

In [60]:
custom_gpt2.train()
logits = custom_gpt2(input_ids)

In [61]:
logits


[1;35mtensor[0m[1m([0m[1m[[0m[1m[[0m[1;36m-2.0713e+00[0m,  [1;36m1.2661e+00[0m,  [1;36m8.9001e-01[0m, [1;36m-1.8227e-01[0m, [1;36m-3.4259e-01[0m,
          [1;36m1.6165e+00[0m, [1;36m-1.3114e-01[0m,  [1;36m1.6069e-01[0m,  [1;36m1.1070e+00[0m,  [1;36m2.8233e-01[0m,
          [1;36m5.8022e-02[0m, [1;36m-9.4528e-01[0m,  [1;36m4.9517e-01[0m, [1;36m-8.8699e-01[0m, [1;36m-9.2799e-01[0m,
          [1;36m8.8554e-01[0m, [1;36m-5.4683e-01[0m, [1;36m-9.4678e-01[0m, [1;36m-5.8822e-02[0m, [1;36m-7.6902e-02[0m[1m][0m,
        [1m[[0m[1;36m-2.2478e+00[0m,  [1;36m1.6846e+00[0m,  [1;36m1.6972e+00[0m,  [1;36m5.4963e-01[0m,  [1;36m1.8086e-01[0m,
          [1;36m1.3459e+00[0m, [1;36m-2.2058e-01[0m, [1;36m-8.4645e-01[0m,  [1;36m1.3589e+00[0m, [1;36m-1.1583e+00[0m,
          [1;36m8.9867e-02[0m,  [1;36m9.2684e-02[0m,  [1;36m1.6044e-03[0m, [1;36m-1.4605e+00[0m, [1;36m-1.1566e+00[0m,
          [1;36m3.0560e-02[0m, [1;36m-1.1

In [62]:
probs = F.softmax(logits, dim=-1)
preds = torch.argmax(probs, dim=-1)

preds


[1;35mtensor[0m[1m([0m[1m[[0m[1;36m5[0m, [1;36m2[0m, [1;36m1[0m, [1;36m3[0m, [1;36m8[0m, [1;36m2[0m, [1;36m2[0m, [1;36m2[0m[1m][0m, [33mdevice[0m=[32m'cuda:0'[0m[1m)[0m

## Training implementation

In [57]:
class GPT2Classifier_v2(L.LightningModule):
    def __init__(
        self, n_layers=12, n_head=12, n_ctx=1024, d_model=768, n_labels=20, vcb_sz=50257
    ):
        super().__init__()

        self.model = GPT2(n_layers, n_head, n_ctx, d_model, n_labels, vcb_sz)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, src, labels=None, mask=None, pos_ids=None):
        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(2)

        return self.model(src, labels, mask, pos_ids)

    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, labels = (
            batch["input_ids"],
            batch["attention_mask"],
            batch["labels"],
        )

        logits = self(input_ids, labels=labels, mask=attention_mask)

        loss = self.loss_fn(logits, labels)

        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, labels = (
            batch["input_ids"],
            batch["attention_mask"],
            batch["labels"],
        )

        logits = self(input_ids, labels=labels, mask=attention_mask)

        loss = self.loss_fn(logits, labels)

        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True)

        if batch_idx == 0:
            probs = F.softmax(logits, dim=-1)
            preds = torch.argmax(probs, dim=-1)

            print(tokenizer.decode(input_ids[0], skip_special_tokens=True))
            print(f"True: {labels[0]}")

            print(f"Pred: {preds[0]}")

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=2e-5)
        return optimizer


In [58]:
classifier_v2 = GPT2Classifier_v2(
    n_layers=3, n_head=12, n_ctx=1024, d_model=768, n_labels=20, vcb_sz=50257
)

trainer = L.Trainer(
    accelerator="auto",
    logger=False,
    max_epochs=2,
    callbacks=[
        ModelCheckpoint(monitor="val_loss"),
    ],
)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [59]:
trainer.fit(classifier_v2, train_dataloader, test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params
---------------------------------------------
0 | model   | GPT2             | 60.7 M
1 | loss_fn | CrossEntropyLoss | 0     
---------------------------------------------
60.7 M    Trainable params
0         Non-trainable params
60.7 M    Total params
242.657   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\Lakshya Agarwal\AppData\Local\Programs\Python\Python310\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


c:\Users\Lakshya Agarwal\AppData\Local\Programs\Python\Python310\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.


In [60]:
trainer.validate(classifier_v2, test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\Lakshya Agarwal\AppData\Local\Programs\Python\Python310\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Validation: |          | 0/? [00:00<?, ?it/s]

[1m[[0m[1m{[0m[32m'val_loss_epoch'[0m: [1;36m2.979785442352295[0m[1m}[0m[1m][0m