In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import wandb
from datetime import datetime
import time
from tqdm import tqdm
import os
import os.path as osp

from transformers import AutoTokenizer, AutoModel, DataCollatorWithPadding, get_linear_schedule_with_warmup
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [3]:
MODEL = "bert-base-uncased"
BATCH_SIZE = 16
NUM_EPOCHS = 30
LEARNING_RATE = 1e-5
RANDOM_SEED = 12345
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pdate = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") # for naming
print(f"device: {DEVICE}")

In [22]:
df = pd.read_pickle("train/seq2reply_regression_data.pickle")
df = df.loc[df.clean_text.notnull()]
NUM_CLASSES = max(df.num_replies.unique()) + 1
print(f"num_classes: {NUM_CLASSES}")

num_classes: 345


In [5]:
class OrdinalUfoDataset(torch.utils.data.Dataset):
    def __init__(self, X, y, tokenizer, num_classes):
        self.encodings = tokenizer(X.tolist(), truncation=True)
        self.labels = y.tolist()
        self.num_classes = num_classes

    def __getitem__(self, index):
        item = {key: torch.tensor(val[index]) for key, val in self.encodings.items()}
        label = self.labels[index]
        item["labels"] = torch.tensor(label)
        levels = [1] * label + [0] * (self.num_classes - 1 - label)
        item["levels"] = torch.tensor(levels, dtype=torch.float32)
        return item

    def __len__(self):
        return len(self.labels)

In [6]:
# split the data
X_train, X_test, y_train, y_test = train_test_split(
    df.clean_text,
    df.num_replies.astype("int16"),
    test_size=0.2,
    random_state=RANDOM_SEED,
)
X_train, X_val, y_train, y_val = train_test_split(
    X_train, y_train, test_size=0.2, random_state=RANDOM_SEED
)

# initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL)

# initialize datasets
train_dataset = OrdinalUfoDataset(X_train, y_train, tokenizer, NUM_CLASSES)
val_dataset = OrdinalUfoDataset(X_val, y_val, tokenizer, NUM_CLASSES)
test_dataset = OrdinalUfoDataset(X_test, y_test, tokenizer, NUM_CLASSES)

# initialize data collator
data_collator = DataCollatorWithPadding(
    tokenizer=tokenizer, padding="longest", return_tensors="pt"
)

# initialize dataloaders
train_loader = DataLoader(
    dataset=train_dataset, collate_fn=data_collator, batch_size=BATCH_SIZE, shuffle=True
)
val_loader = DataLoader(
    dataset=val_dataset, collate_fn=data_collator, batch_size=BATCH_SIZE, shuffle=True
)
test_loader = DataLoader(
    dataset=test_dataset, collate_fn=data_collator, batch_size=BATCH_SIZE, shuffle=True
)

In [7]:
class OrdinalRegressionBERT(nn.Module):
    def __init__(self, bert, num_classes, dropout_rate=0.1):
        super(OrdinalRegressionBERT, self).__init__()
        self.num_classes = num_classes

        self.bert = AutoModel.from_pretrained(bert)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.fc = nn.Linear(768, 1, bias=False)
        self.linear_1_bias = nn.Parameter(torch.zeros(self.num_classes - 1).float())

    def forward(self, input_ids, attention_mask):
        hidden = self.bert(input_ids=input_ids, attention_mask=attention_mask)[0]
        x = hidden[:, 0]
        x = self.dropout(x)
        logits = self.fc(x)
        logits = logits + self.linear_1_bias
        probas = torch.sigmoid(logits)
        return logits, probas

In [8]:
def cost_fn(logits, levels, device):
    imp = torch.ones(NUM_CLASSES-1, dtype=torch.float32).to(device) # task importance weights
    val = (-torch.sum((F.logsigmoid(logits)*levels
                      + (F.logsigmoid(logits) - logits)*(1-levels))*imp,
           dim=1))
    return torch.mean(val)

In [None]:
def eval_metrics(model, data_loader, device):
    mae, mse, num_examples = 0, 0, 0
    for i, batch in enumerate(data_loader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        logits, probas = model(input_ids, attention_mask)
        
        predicted_labels = torch.sum(probas > 0.5, dim=1)
        num_examples += labels.size(0)
        
        mae += torch.sum(torch.abs(predicted_labels - labels))
        mse += torch.sum((predicted_labels - labels)**2)
    mae = mae.float() / num_examples
    mse = mse.float() / num_examples
    return mae, mse

In [None]:
model = OrdinalRegressionBERT(MODEL, NUM_CLASSES)
model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
training_steps = int(len(train_loader) * NUM_EPOCHS)
warmup_steps = int(0.1 * training_steps)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=training_steps)

In [None]:
project_name = f"BERT-FT-Ordinal-Regression-lr_warmup-{warmup_steps}-epochs-{NUM_EPOCHS}-{pdate}"
SAVE_PATH = osp.join("/output", project_name) # for the docker volume
os.makedirs(SAVE_PATH, exist_ok=True)
wandb.init(project=project_name, entity="jisoo")
wandb.watch(model)

In [21]:
start = time.time()
best_mae, best_rmse, best_epoch = 999, 999, -1
for epoch in range(NUM_EPOCHS):
    model.train()
    with tqdm(enumerate(train_loader), unit=" batch", total=len(train_loader)) as tepoch:
        for i, batch in tepoch:
            tepoch.set_description(f"Epoch {epoch}")
            
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            levels = batch["levels"].to(DEVICE)

            # forward
            logits, probas = model(input_ids, attention_mask)
            cost = cost_fn(logits, levels, DEVICE)
            optimizer.zero_grad()

            # backward
            cost.backward()
            wandb.log({"cost": cost.item(), "epoch": epoch})

            # update
            optimizer.step()
            scheduler.step()

            # log
            tepoch.set_postfix(cost=round(cost.item(),3))

    model.eval()
    with torch.set_grad_enabled(False):
        val_mae, val_mse = eval_metrics(model, val_loader, device=DEVICE)
        wandb.log({"validation mae": val_mae, "validation mse": val_mse, "epoch": epoch})
        
        if val_mae < best_mae:
            best_mae, best_rmse, best_epoch = val_mae, torch.sqrt(val_mse), epoch
            print(f"best_mae: {val_mae}, best_rmse: {best_rmse}, epoch: {epoch}")
            ##### save model #####
            ckpt_files = sorted([x for x in os.listdir(SAVE_PATH) if x.endswith(".pt")])
            num_ckpt = len(ckpt_files)
            if num_ckpt >= 3:
                print(f"Removing ckpt at {osp.join(SAVE_PATH, ckpt_files[0])}")
                os.remove(osp.join(SAVE_PATH, ckpt_files[0]))
            print(f"Saving ckpt at osp.join(SAVE_PATH, f"model_checkpoint_{epoch}.pt"))
            torch.save(model.state_dict(), osp.join(SAVE_PATH, f"model_checkpoint_{epoch}.pt"))            

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'labels', 'levels'])
dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'labels', 'levels'])


In [None]:
print(f"Time elapsed: {round(time.time()-start,2)}s")