A rough copy of https://towardsdatascience.com/text-classification-with-bert-in-pytorch-887965e5820f

In [1]:
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from torch import nn
from transformers import BertTokenizer, BertModel
from helpers import get_gpu
device = get_gpu()

In [2]:
data_dir = Path("..").resolve(strict=True) / "data"
data = pd.read_csv(data_dir / "bbc-text.csv")
labels = {"business": 0, "entertainment": 1, "sport": 2, "tech": 3, "politics": 4}
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

In [3]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.labels = [labels[category] for category in data["category"]]
        self.texts = [
            tokenizer(
                text,
                padding="max_length",
                max_length=512,
                truncation=True,
                return_tensors="pt",
            )
            for text in data["text"]
        ]
    
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.texts[idx], np.array(self.labels[idx])

    
class BERTClassifier(nn.Module):
    def __init__(self, dropout_p=0.5):
        super().__init__()
        self.bert = BertModel.from_pretrained("bert-base-cased")
        self.dropout = nn.Dropout(dropout_p)
        self.fc = nn.Linear(768, 5)
    
    def __call__(self, input_ids, mask):
        _, output = self.bert(input_ids=input_ids, attention_mask=mask, return_dict=False)
        output = self.fc(self.dropout(output)).relu()
        return output

In [4]:
train_df, valid_df, test_df = np.split(
    data.sample(frac=1, random_state=1337), [int(0.8*len(data)), int(0.9*len(data))]
)
train_dataloader = torch.utils.data.DataLoader(Dataset(train_df), batch_size=2, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(Dataset(valid_df), batch_size=2)
criterion = nn.CrossEntropyLoss()
model = BERTClassifier().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
for epoch in range(5):
    total_train_loss = 0
    total_train_correct = 0
    for train_input, train_label in tqdm(train_dataloader):
        mask = train_input["attention_mask"].to(device)
        input_id = train_input["input_ids"].squeeze(1).to(device)
        output = model(input_id, mask)
        train_label = train_label.to(device)
        loss = criterion(output, train_label.long())
        total_train_loss += loss.item()
        total_train_correct += (output.argmax(dim=1) == train_label).sum().item()
        model.zero_grad()
        loss.backward()
        optimizer.step()
    total_valid_loss = 0
    total_valid_correct = 0
    with torch.no_grad():
        for valid_input, valid_label in valid_dataloader:
            mask = valid_input["attention_mask"].to(device)
            input_ids = valid_input["input_ids"].squeeze(1).to(device)
            output = model(input_ids, mask)
            valid_label = valid_label.to(device)
            loss = criterion(output, valid_label.long())
            total_valid_loss += loss.item()
            total_valid_correct += (output.argmax(dim=1) == valid_label).sum().item()
    print(
        f"epoch {epoch+1} | "
        f"train loss {total_train_loss / len(train_df):.3f} | "
        f"train accuracy {total_train_correct / len(train_df):.3f} | "
        f"valid loss {total_valid_loss / len(valid_df):.3f} | "
        f"valid accuracy {total_valid_correct / len(valid_df):.3f}"
    )

100%|████████████████████████████████████████████████████████████████████| 890/890 [04:37<00:00,  3.20it/s]


epoch 1 | train loss 0.750 | train accuracy 0.345 | valid loss 0.567 | valid accuracy 0.649


100%|████████████████████████████████████████████████████████████████████| 890/890 [04:44<00:00,  3.13it/s]


epoch 2 | train loss 0.387 | train accuracy 0.778 | valid loss 0.314 | valid accuracy 0.784


100%|████████████████████████████████████████████████████████████████████| 890/890 [04:41<00:00,  3.16it/s]


epoch 3 | train loss 0.176 | train accuracy 0.946 | valid loss 0.113 | valid accuracy 0.991


100%|████████████████████████████████████████████████████████████████████| 890/890 [04:42<00:00,  3.15it/s]


epoch 4 | train loss 0.074 | train accuracy 0.990 | valid loss 0.062 | valid accuracy 0.982


100%|████████████████████████████████████████████████████████████████████| 890/890 [04:44<00:00,  3.13it/s]


epoch 5 | train loss 0.042 | train accuracy 0.996 | valid loss 0.040 | valid accuracy 0.986


In [6]:
test_dataloader = torch.utils.data.DataLoader(Dataset(test_df), batch_size=2)
total_test_correct = 0
with torch.no_grad():
    for test_input, test_label in test_dataloader:
        mask = test_input["attention_mask"].to(device)
        input_ids = test_input["input_ids"].squeeze(1).to(device)
        output = model(input_ids, mask)
        test_label = test_label.to(device)
        total_test_correct += (output.argmax(dim=1) == test_label).sum().item()
print(f"test accuracy {total_test_correct / len(test_df):.3f}")

test accuracy 0.987
