In [1]:
import numpy as np
import time
from torchmetrics import AUROC

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.data.functional import to_map_style_dataset
from torchtext.vocab import build_vocab_from_iterator

In [2]:
def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)


def collate_batch(batch, device):
    label_list, text_list, offsets = [], [], [0]
    for (_label, _text) in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.shape[0])
    label_list = torch.tensor(label_list, dtype=torch.int64)
    # Call cumsum() to get absolute offset (instead of relative offset)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)
    return label_list.to(device), text_list.to(device), offsets.to(device)


def train(dataloader, model, optimizers, criterion, auroc):
    model.train()
    total_cor, total_count = 0.0, 0.0
    log_interval = 500
    start_time = time.time()
    decay_acc = 0.0
    decay_auc = 0.0
    
    for idx, (labels, text, offsets) in enumerate(dataloader):
        for optimizer in optimizers:
            optimizer.zero_grad()
        pred_labels = model(text, offsets)
        loss = criterion(pred_labels, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        for optimizer in optimizers:
            optimizer.step()
        total_cor += (pred_labels.argmax(1) == labels).sum().item()
        total_count += labels.size(0)
        acc = total_cor / total_count
        auc = auroc(F.softmax(pred_labels, dim=1), labels).to("cpu").item()
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print((
                f"| epoch {epoch} | {idx} / {len(dataloader)} batches | train acc {np.round(acc, 3)} "
                f"| train auc {np.round(auc, 3)} |"
            ))
        # Is this the first batch?
        if decay_acc > 0.0:
            decay_acc = 0.9 * decay_acc + 0.1 * acc
            decay_auc = 0.9 * decay_auc + 0.1 * auc
        else:
            decay_acc = acc
            decay_auc = auc
        acc = 0.0
        total_cor, total_count = 0.0, 0.0
        start_time = time.time()
    return decay_acc, decay_auc
        

def evaluate(dataloader, model, criterion, auroc):
    model.eval()
    total_acc, total_count = 0.0, 0.0
    aucs = []
    losses = []
    
    with torch.no_grad():
        for idx, (labels, text, offsets) in enumerate(dataloader):
            pred_labels = model(text, offsets)
            loss = criterion(pred_labels, labels)
            total_acc += (pred_labels.argmax(1) == labels).sum().item()
            total_count += labels.size(0)
            auc = auroc(F.softmax(pred_labels, dim=1), labels).to("cpu").item()
            aucs.append(auc)
            losses.append(loss.to("cpu").item())
            
        return total_acc / total_count, np.mean(aucs), loss

In [3]:
class TextClassifier(nn.Module):
    
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_class):
        super().__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_class)
        self.init_weights()
        
    def init_weights(self):
        # Use Leaky ReLU since behavior is similar to ELU 
        nn.init.kaiming_uniform_(self.fc1.weight, nonlinearity="leaky_relu")
        nn.init.zeros_(self.fc1.bias)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.zeros_(self.fc2.bias)
        
    def forward(self, text, offsets):
        embedded = self.embedding.forward(text, offsets)
        x = self.fc1(embedded)
        x = F.elu(x)
        x = self.fc2(x)
        # x = F.softmax(x)
        return x

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
chkpt_path = "../checkpoints/news_class.pth"

In [5]:
train_iter = AG_NEWS(split="train")
tokenizer = get_tokenizer("basic_english")
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1

In [6]:
num_class = len(set([label for label, text in train_iter]))
vocab_size = len(vocab)
emsize = 64
hdsize = emsize
model = TextClassifier(vocab_size, emsize, hdsize, num_class).to(device)

In [7]:
model

TextClassifier(
  (embedding): EmbeddingBag(95811, 64, mode='mean')
  (fc1): Linear(in_features=64, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=4, bias=True)
)

In [8]:
for i, (_, txt) in enumerate(train_iter):
    print(f">>>\t{txt}")
    print(f">>>\t{tokenizer(txt)}")
    print("")
    if i == 0:
        break

>>>	Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.
>>>	['wall', 'st', '.', 'bears', 'claw', 'back', 'into', 'the', 'black', '(', 'reuters', ')', 'reuters', '-', 'short-sellers', ',', 'wall', 'street', "'", 's', 'dwindling\\band', 'of', 'ultra-cynics', ',', 'are', 'seeing', 'green', 'again', '.']



In [9]:
print(
    vocab(["my", "dog", "'", "s", "name", "is", "fido"])
)

print(
    text_pipeline("My dog's name is Fido")
)

print(
    label_pipeline("9")
)

[1300, 5383, 16, 9, 951, 21, 45387]
[1300, 5383, 16, 9, 951, 21, 45387]
8


In [10]:
epochs = 15
lr = 0.1
batch_size = 64

In [11]:
auroc = AUROC(task="multiclass", num_classes=num_class)

In [12]:
criterion = torch.nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
optimizer1 = torch.optim.AdamW(
    nn.ParameterList(
        list(model.fc1.parameters()) +\
        list(model.fc2.parameters())
    ),
    lr=lr,
    amsgrad=True
)
optimizer2 = torch.optim.Adamax(model.embedding.parameters(), lr=lr)
scheduler = None
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1.0, gamma=0.1)

In [13]:
train_iter, test_iter = AG_NEWS()
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)
num_train = int(len(train_dataset) * 0.95)
split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])

train_dataloader = DataLoader(split_train_, batch_size=batch_size, shuffle=True, collate_fn=lambda x: collate_batch(x, device))
valid_dataloader = DataLoader(split_valid_, batch_size=batch_size, shuffle=True, collate_fn=lambda x: collate_batch(x, device))
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: collate_batch(x, device))

In [14]:
best_acc = 0.0
for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    acc_train, auc_train = train(train_dataloader, model, (optimizer1, optimizer2), criterion, auroc)
    acc_val, auc_val, loss_val = evaluate(valid_dataloader, model, criterion, auroc)
    if acc_val > best_acc:
        print("Best performance so far. Saving model checkpoint.")
        best_acc = acc_val
        checkpoint = {
            "model_state_dict": model.state_dict(),
            "optimizer1_state_dict": optimizer1.state_dict(),
            "optimizer2_state_dict": optimizer2.state_dict(),
            "criterion": criterion
        }
        torch.save(checkpoint, chkpt_path)
        
    print("-" * 59)
    print((
        f"| end of epoch {epoch} | "
        f"time: {np.round(time.time() - epoch_start_time, 3)} | "
        f"valid acc {np.round(acc_val, 3)} | "
        f"valid avg. auc {np.round(auc_val, 3)} | "
        f"decay train acc {np.round(acc_train, 3)} | "
        f"decay train auc {np.round(auc_train, 3)}"
    ))
    print("-" * 59)

| epoch 1 | 500 / 1782 batches | train acc 0.781 | train auc 0.961 |
| epoch 1 | 1000 / 1782 batches | train acc 0.766 | train auc 0.958 |
| epoch 1 | 1500 / 1782 batches | train acc 0.844 | train auc 0.98 |
Best performance so far. Saving model checkpoint.
-----------------------------------------------------------
| end of epoch 1 | time: 26.103 | valid acc 0.871 | valid avg. auc 0.977 | decay train acc 0.887 | decay train auc 0.982
-----------------------------------------------------------
| epoch 2 | 500 / 1782 batches | train acc 0.906 | train auc 0.987 |
| epoch 2 | 1000 / 1782 batches | train acc 0.859 | train auc 0.983 |
| epoch 2 | 1500 / 1782 batches | train acc 0.906 | train auc 0.983 |
-----------------------------------------------------------
| end of epoch 2 | time: 26.248 | valid acc 0.865 | valid avg. auc 0.979 | decay train acc 0.885 | decay train auc 0.986
-----------------------------------------------------------
| epoch 3 | 500 / 1782 batches | train acc 0.891 | 

In [15]:
checkpoint = torch.load(chkpt_path)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer1.load_state_dict(checkpoint["optimizer1_state_dict"])
optimizer2.load_state_dict(checkpoint["optimizer2_state_dict"])
criterion = checkpoint["criterion"]

In [16]:
print("Test data performance")
acc_test, auc_test, _ = evaluate(test_dataloader, model, criterion, auroc)
print((
        f"time: {np.round(time.time() - epoch_start_time, 3)} | "
        f"test acc {np.round(acc_test, 3)} | "
        f"test avg. auc {np.round(auc_test, 3)} | "
    ))

Test data performance
time: 28.778 | test acc 0.9 | test avg. auc 0.98 | 
