# 02a - Fine-tune DistilBERT for Sequence Classification

In [1]:
import numpy as np
import pandas as pd

import torch
from datasets import load_dataset

from src import data, models, metrics

DATA_DIR = 'data/'
OUTPUT_DIR = 'output/distilbert/'
MODEL_NAME = 'distilbert_monitors_3epoch'
CLASSES = ['Monitor', 'Tv', 'Noise']


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

Device: cuda


## Create DistilBERT Model with Tokenizer

In [2]:
# create model and load pre-trained checkpoint
net = models.DistilBERT(pretrained_checkpoint='distilbert-base-uncased', classes=CLASSES)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.bias', 'pre_classi

In [3]:
print(f'Number of trainable parameters: {net.num_trainable_params():,}')

Number of trainable parameters: 66,955,779


## Example of Classification

Classify one example to verify that the code executes without errors. The output distribution is uniform as the pre-trained model was not fine-tuned on the given task yet.

In [4]:
x = '32 inch curved screen 144hz monitor 1k 2k 4k fhd ips curved lcd pc hd-mi power vga cable'
net.predict_sample(x, return_dict=True)

[{'Monitor': 0.33961165, 'Tv': 0.29880092, 'Noise': 0.36158744}]

## Load the Data

In [5]:
# load datasets
datasets = load_dataset('csv', data_files={
    'train': DATA_DIR + 'monitors_classification_202107_train.csv',
    'validation': DATA_DIR + 'monitors_classification_202107_val.csv',
    # 'test': DATA_DIR + 'monitors_classification_202107_test.csv'
})

# tokenize datasets
tokenized_datasets = net.tokenize_dataset(datasets)

datasets

Using custom data configuration default-3d875e4c7602e5a9
Reusing dataset csv (/home/ec2-user/.cache/huggingface/datasets/csv/default-3d875e4c7602e5a9/0.0.0/e138af468cb14e747fb46a19c787ffcfa5170c821476d20d5304287ce12bbc23)


  0%|          | 0/120 [00:00<?, ?ba/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

DatasetDict({
    train: Dataset({
        features: ['inp', 'trg', 'metadata'],
        num_rows: 120000
    })
    validation: Dataset({
        features: ['inp', 'trg', 'metadata'],
        num_rows: 12000
    })
})

## Fine-tune the Model

In [6]:
#idx = np.random.choice(len(tokenized_datasets['train']), size=36_000, replace=False)
#np.save('idx.npy', idx)
idx = np.load('idx.npy')
traindataset_sample = tokenized_datasets['train'].select(idx)
traindataset_sample

Dataset({
    features: ['attention_mask', 'inp', 'input_ids', 'labels', 'metadata', 'trg'],
    num_rows: 36000
})

In [7]:
# create trainer instance
trainer = net.get_trainer(
    output_dir=OUTPUT_DIR,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    no_epochs=3,
    bs=32,
    gradient_accumulation_steps=2,
    lr=0.0001,
    wd=0.01,
    lr_scheduler_type='linear',
    fp16=False,
    compute_metrics_cb=metrics.ClassificationMetricsCallback(),
    log_level='error')

In [8]:
# train the network
training_output = trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,F1
1,0.0875,0.077077,0.978083,0.971114
2,0.0511,0.059294,0.98425,0.979348
3,0.0267,0.059236,0.985,0.98031


In [9]:
# save fine-tuned checkpoint
net.save_pretrained(OUTPUT_DIR + MODEL_NAME)

DistilBERT(distilbert-base-uncased)