BERT


In [None]:
!pip install datasets scikit-learn

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW

import os
import numpy as np
from tqdm.auto import tqdm

import random
from torch.utils.data import DataLoader

from datasets import load_dataset
from transformers import (
    BertTokenizerFast,
    BertForSequenceClassification,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)

RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# parameters
MODEL_NAME   = "bert-base-uncased"
BATCH_SIZE   = 16
EPOCHS       = 5
LR           = 3e-5
WARMUP_RATIO = 0.1
THRESHOLD    = 0.4
LOG_STEP     = 100


In [None]:
ds = load_dataset("go_emotions", "simplified")  # 28 emotions
label_names = ds["train"].features["labels"].feature.names

NUM_LABELS = 28


tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)
def tokenize_and_encode(examples):
    encodings = tokenizer(examples["text"], truncation=True)

    # shape = (batch_in_map, 28)
    multi_hot = np.zeros((len(examples["labels"]), NUM_LABELS), dtype=np.int8)
    for i, label_list in enumerate(examples["labels"]):
        multi_hot[i, label_list] = 1

    encodings["labels"] = multi_hot.tolist()
    return encodings

ds = ds.map(tokenize_and_encode, batched=True, remove_columns=["text", "id"])
collator = DataCollatorWithPadding(tokenizer, return_tensors="pt")

# DataLoader
train_loader = DataLoader(ds["train"], batch_size=BATCH_SIZE,
                          shuffle=True, collate_fn=collator)
val_loader = DataLoader(ds["validation"], batch_size=BATCH_SIZE,
                          shuffle=False, collate_fn=collator)
print(ds["train"].column_names)
print(len(ds["validation"]))


In [None]:
# define model
model = BertForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=NUM_LABELS,
    problem_type="multi_label_classification",
).to(device)

In [None]:
optimizer = AdamW(model.parameters(), lr=LR)

total_steps   = len(train_loader) * EPOCHS
warmup_steps  = int(total_steps * WARMUP_RATIO)
scheduler     = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
)

criterion = nn.BCEWithLogitsLoss()

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score

#Evaluate
def evaluate(model, loader):
    model.eval()
    all_logits, all_labels = [], []
    with torch.no_grad():
        for batch in loader:
            labels = batch["labels"].clone().detach().to(device).float()

            inputs = {k: v.to(device) for k, v in batch.items() if k != "labels"}
            logits = model(**inputs).logits.cpu()

            all_logits.append(logits)
            all_labels.append(labels)

    logits = torch.cat(all_logits)
    labels = torch.cat(all_labels)

    preds  = (torch.sigmoid(logits) > THRESHOLD).int().numpy()
    labels = labels.int().cpu().numpy()

    # accuracy, precision, recall
    acc = accuracy_score(labels, preds)
    precision = precision_score(labels, preds, average="micro", zero_division=0)
    recall  = recall_score(labels, preds, average="micro", zero_division=0)

    #f1
    f1 = 2 * precision * recall / (precision + recall + 1e-8)
    # stats
    print("average 1's in data", labels.sum(axis=1).mean())
    print("average 1's in prediction:", preds.sum(axis=1).mean())

    return acc, precision, recall, f1


In [None]:
# train
global_step = 0
for epoch in range(1, EPOCHS+1):
    model.train()
    epoch_loss = 0.0

    prog_bar = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False)
    for step, batch in enumerate(prog_bar, 1):
        batch_size = len(batch["labels"])
        labels = batch["labels"].clone().detach().to(device).float()

        inputs = {k: v.to(device) for k, v in batch.items() if k != "labels"}

        # zero grad
        optimizer.zero_grad()

        # forward
        outputs = model(**inputs)
        loss = criterion(outputs.logits, labels)
        loss.backward()

        optimizer.step()
        scheduler.step()

        epoch_loss += loss.item()
        global_step += 1

        if global_step % LOG_STEP == 0:
            prog_bar.set_postfix(loss=f"{loss.item():.4f}")

    avg_loss = epoch_loss / len(train_loader)
    acc, prec, recall, f1 = evaluate(model, val_loader)

    print(
        f"Epoch {epoch} |"
        f"loss {avg_loss:.4f} | "
        f"accuracy {acc:.4f} | "
        f"Precision {prec:.4f} | "
        f"Recall {recall:.4f} | "
        f"F1 {f1:.4f}")