## Dataset Preperation

### Tokenization

In [1]:
from transformers import BertTokenizer
from datasets import load_dataset

from functools import partial

In [2]:
dataset = load_dataset('stanfordnlp/imdb')

bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def tokenize_function(examples, tokenizer):
    return tokenizer(examples['text'], padding='max_length', truncation=True)

bert_tokenize = partial(tokenize_function, tokenizer=bert_tokenizer)

tokenized_train = dataset['train'].map(bert_tokenize, batched=True)



### DataLoaders

In [3]:
from torch.utils.data import DataLoader

In [9]:
split = tokenized_train.train_test_split()
train, validation = split['train'], split['test']

train_loader = DataLoader(train.with_format('torch'), batch_size=8)
validation_loader = DataLoader(validation.with_format('torch'), batch_size=8)

len(train_loader), len(validation_loader)

(2344, 782)

## Training

### Model

In [None]:
!pip install lightning -q

In [5]:
# pytorch
import torch
from torch import nn
# lightning
import lightning as L
# transformers
from transformers import DistilBertModel

2024-06-08 19:49:40.303775: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


#### Example DistilBert model usage

In [6]:
distilbert = DistilBertModel.from_pretrained('distilbert-base-uncased')

ex_input = bert_tokenizer('this text will get tokenized and passed to the model', return_tensors='pt')
del ex_input['token_type_ids'] # DistilBert doesn't expect 'token_type_ids'
print(ex_input)

ex_output = distilbert(**ex_input)
print(ex_output, ex_output.last_hidden_state, ex_output.last_hidden_state.shape, sep='\n\n') # the .last_hidden_state is what we really care about



{'input_ids': tensor([[  101,  2023,  3793,  2097,  2131, 19204,  3550,  1998,  2979,  2000,
          1996,  2944,   102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
BaseModelOutput(last_hidden_state=tensor([[[-0.3926, -0.4085,  0.1309,  ..., -0.2571,  0.0637,  0.5803],
         [-0.4795, -0.3352,  0.1271,  ..., -0.3095,  0.2035,  0.3989],
         [ 0.0195,  0.0232,  0.2583,  ..., -0.4280, -0.3019,  0.3976],
         ...,
         [-0.7386, -0.5224,  0.2975,  ..., -0.3621, -0.0776,  0.3723],
         [-0.4253, -0.3041, -0.0472,  ..., -0.4011, -0.0266,  0.1344],
         [ 0.8935,  0.0227, -0.4569,  ...,  0.1036, -0.7282, -0.3009]]],
       grad_fn=<NativeLayerNormBackward0>), hidden_states=None, attentions=None)

tensor([[[-0.3926, -0.4085,  0.1309,  ..., -0.2571,  0.0637,  0.5803],
         [-0.4795, -0.3352,  0.1271,  ..., -0.3095,  0.2035,  0.3989],
         [ 0.0195,  0.0232,  0.2583,  ..., -0.4280, -0.3019,  0.3976],
         ...,
         [-0.7386, -

#### Bert with classification head

In [23]:
class BertBinaryClassifier(L.LightningModule):
    def __init__(self, bert:nn.Module):
        super().__init__()
        # freeze the bert parameters
        for param in bert.parameters():
            param.requires_grad = False

        self.bert = bert 

        bert_output_dim = bert.config.dim
        self.classify = nn.Linear(bert_output_dim, 1)

        self.loss = nn.CrossEntropyLoss()


    def forward(self, x):
        outputs = self.bert(x)
        cls_tokens = outputs.last_hidden_state[:, 0] # the [CLS] token is at index 0
        classifications = self.classify(cls_tokens)
        return classifications


    def training_step(self, batch, batch_idx):
        x = batch['input_ids']
        y = batch['label']
        pred = self.forward(x)
        loss = self.loss(pred, y)

        print(f'{type(loss) = }')

        self.log('train_loss', loss)
        return loss

    # def validation_step(self, )

    def configure_optimizers(self):
        return torch.optim.Adam(lr=.01, params=self.parameters())


def make_model(bert: nn.Module) -> nn.Module:
    return BertBinaryClassifier(bert)

In [24]:
distilbert = DistilBertModel.from_pretrained('distilbert-base-uncased')
model = make_model(distilbert)

# Initialize the PyTorch Lightning trainer
trainer = L.Trainer(max_epochs=1, log_every_n_steps=1)

# Train the mode
trainer.fit(model, train_dataloaders=train_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name     | Type             | Params
----------------------------------------------
0 | bert     | DistilBertModel  | 66.4 M
1 | classify | Linear           | 769   
2 | loss     | CrossEntropyLoss | 0     
----------------------------------------------
769       Trainable params
66.4 M    Non-trainable params
66.4 M    Total params
265.455   Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

: 

: 

In [17]:
x = torch.randn(5, 5)
x.squeeze(-1)

tensor([[-0.5683, -0.3293,  0.9554, -1.0273, -1.0586],
        [ 1.0877, -0.6478, -1.3211,  0.6739,  1.0002],
        [ 2.0045, -1.5789,  0.9868,  0.5266, -0.3147],
        [-1.3763,  1.1556,  1.0450, -0.0393,  0.3330],
        [-0.3574, -1.3078, -0.0788,  0.0850, -1.9643]])

In [35]:
class BertBinaryClassifier(L.LightningModule):

# class BertBinaryClassifier(nn.Module):
#     def __init__(self, bert: nn.Module):
#         super(BertBinaryClassifier, self).__init__()

#         # freeze the bert parameters
#         for param in bert.parameters():
#             param.requires_grad = False

#         self.bert = bert 

#         bert_output_dim = bert.config.dim
#         self.classify = nn.Linear(bert_output_dim, 1)

#     def forward(self, x):
#         x = self.bert(x)
#         x = self.classify(x)
#         return x


# def make_model(bert: nn.Module) -> nn.Module:
#     return BertBinaryClassifier(bert)


# bert = DistilBertModel.from_pretrained('distilbert-base-uncased')

# make_model(bert)



BertBinaryClassifier(
  (bert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Li