## Dataset Preperation

### Tokenization

In [5]:
from transformers import BertTokenizer
from datasets import load_dataset

from functools import partial

In [9]:
dataset = load_dataset('stanfordnlp/imdb')

bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def tokenize_function(examples, tokenizer):
    return tokenizer(examples['text'], padding='max_length', truncation=True)

bert_tokenize = partial(tokenize_function, tokenizer=bert_tokenizer)

tokenized_train = dataset['train'].map(bert_tokenize, batched=True)



### DataLoaders

In [11]:
from torch.utils.data import DataLoader

In [12]:
split = tokenized_train.train_test_split()
train, validation = split['train'], split['test']

train_loader = DataLoader(train.with_format('torch'))
validation_loader = DataLoader(validation.with_format('torch'))

len(train_loader), len(validation_loader)

(18750, 6250)

## Training

### Model

In [13]:
from transformers import DistilBertModel
from torch import nn

#### Example DistilBert model usage

In [36]:
bert = DistilBertModel.from_pretrained('distilbert-base-uncased')

ex_input = bert_tokenizer('this text will get tokenized and passed to the model', return_tensors='pt')
del ex_input['token_type_ids'] # DistilBert doesn't expect 'token_type_ids'
print(ex_input)

ex_output = distilbert(**ex_input)
print(ex_output) # the .last_hidden_state is what we really care about
print(ex_output.last_hidden_state) # the .last_hidden_state is what we really care about
print(ex_output.last_hidden_state.shape) 

: 

: 

#### Bert with classification head

In [35]:
class BertBinaryClassifier(nn.Module):
    def __init__(self, bert: nn.Module):
        super(BertBinaryClassifier, self).__init__()

        # freeze the bert parameters
        for param in bert.parameters():
            param.requires_grad = False

        self.bert = bert 

        bert_output_dim = bert.config.dim
        self.classify = nn.Linear(bert_output_dim, 1)

    def forward(self, x):
        x = self.bert(x)
        x = self.classify(x)
        return x


def make_model(bert: nn.Module) -> nn.Module:
    return BertBinaryClassifier(bert)


bert = DistilBertModel.from_pretrained('distilbert-base-uncased')

make_model(bert)



BertBinaryClassifier(
  (bert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Li