In [1]:
import os.path as op

from datasets import load_dataset

import lightning as L
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import ModelCheckpoint

import numpy as np
import pandas as pd
import torch

from sklearn.feature_extraction.text import CountVectorizer

from local_dataset_utilities import download_dataset, load_dataset_into_to_dataframe, partition_dataset
from local_dataset_utilities import IMDBDataset

In [2]:
imdb_dataset = load_dataset(
    "csv",
    data_files={
        "train": "train.csv",
        "validation": "val.csv",
        "test": "test.csv",
    },
)

print(imdb_dataset)

Found cached dataset csv (/Users/czq/.cache/huggingface/datasets/csv/default-2654090031c6f84b/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['index', 'text', 'label'],
        num_rows: 35000
    })
    validation: Dataset({
        features: ['index', 'text', 'label'],
        num_rows: 5000
    })
    test: Dataset({
        features: ['index', 'text', 'label'],
        num_rows: 10000
    })
})


In [3]:
# 2. Tokenization

In [4]:
from transformers import AutoTokenizer

In [5]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

In [6]:
print("Tokenizer max input length:", tokenizer.model_max_length)
print("Tokenizer vocabulary size:", tokenizer.vocab_size)

Tokenizer max input length: 512
Tokenizer vocabulary size: 30522


In [7]:
def tokenize_text(batch):
    return tokenizer(batch["text"], truncation=True, padding=True)

In [8]:
imdb_tokenized = imdb_dataset.map(tokenize_text, batched=True, batch_size=None)

Loading cached processed dataset at /Users/czq/.cache/huggingface/datasets/csv/default-2654090031c6f84b/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-eea593fa664c340e.arrow


Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

Loading cached processed dataset at /Users/czq/.cache/huggingface/datasets/csv/default-2654090031c6f84b/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-f5a415e473a98384.arrow


In [9]:
del imdb_dataset

In [10]:
imdb_tokenized.set_format("torch", columns=["input_ids", "attention_mask", "label"])

In [11]:
# 3. Set up DataLoaders

In [12]:
from torch.utils.data import DataLoader, Dataset

In [15]:
class IMDBDataset(Dataset):
    def __init__(self, dataset_dict, partition_key="train"):
        self.partition = dataset_dict[partition_key]
    
    def __getitem__(self, index):
        return self.partition[index]
    
    def __len__(self):
        return self.partition.num_rows

In [17]:
type(imdb_tokenized)

datasets.dataset_dict.DatasetDict

In [18]:
train_dataset = IMDBDataset(imdb_tokenized, partition_key="train")
val_dataset = IMDBDataset(imdb_tokenized, partition_key="validation")
test_dataset = IMDBDataset(imdb_tokenized, partition_key="test")

In [19]:
train_loader = DataLoader(dataset=train_dataset, batch_size=16, shuffle=True, num_workers=4)
val_loader = DataLoader(dataset=val_dataset, batch_size=16, shuffle=True, num_workers=4)
test_loader = DataLoader(dataset=test_dataset, batch_size=16, shuffle=True, num_workers=4)

In [20]:
# 4. Initialize DistilBert

In [24]:
from transformers import AutoModelForSequenceClassification

In [25]:
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_projector.bias', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.weight', 'pre_classifier.bias', 'classifier.

In [26]:
# Freeze all layers

for params in model.parameters():
    params.requires_grad = False

In [27]:
# Unfreeze last layers

In [28]:
model

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
 

In [29]:
# Unfreeze the last two classification layers

for param in model.pre_classifier.parameters():
    param.requires_grad = True
    
for param in model.classifier.parameters():
    param.requires_grad = True

In [30]:
# 5. Fine tuning

In [31]:
import lightning
import torch
import torchmetrics

In [None]:
class CustomLightningModule(lightning.LightningModule):
    def __init__(self, model, learning_rate=5e-5):
        super().__init__()
        
        self.learning_rate = learning_rate
        self.model = model
        
        self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2)
        self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2)
    
    def forward(self, input_ids, attention_mask, labels):
        return self.model(input_ids, attention_mask=attention_mask, labels=labels)
    
    def training_step(self, batch, batch_idx):
        outputs = self(
            batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["label"]
        )
        
        self.log("train_loss", outputs["loss"])
        
        return outputs["loss"]
    
    def 