In [12]:
import torch
from torchtext.datasets import IMDB, AG_NEWS, YahooAnswers
from torchtext.vocab import GloVe
from torchtext.data import to_map_style_dataset
from torchtext.data.utils import get_tokenizer
from torch.nn import LSTM, GRU, Linear, Softmax, CrossEntropyLoss
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, random_split, Dataset
from torch.optim import Adam
from tqdm import tqdm

In [2]:
DATASET = 'AG_NEWS'
MODEL = 'LSTM'
VALIDATION_SPLIT = 0.5 # of test data
BATCH_SIZE = 64
SHUFFLE = True
NUM_EPOCHS = 10

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
from google.colab import drive
drive.mount('/content/drive')
PATH = '/content/drive/MyDrive/Checkpoints/model'
!mkdir '/content/drive/MyDrive/Checkpoints/model'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
mkdir: cannot create directory ‘/content/drive/MyDrive/Checkpoints/model’: File exists


In [4]:
class BidirectionalLSTMClassifier(torch.nn.Module):
    def __init__(self, num_classes, hidden_size, num_layers):
        super().__init__()
        self.num_layers = num_layers
        self.LSTM = LSTM(50, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=True)
        self.linear = Linear(2 * hidden_size, num_classes)
        self.softmax = Softmax(dim=1)
        
    def forward(self, x):
        _, (h_n, _) = self.LSTM(x)
        h_forward = h_n[2 * self.num_layers - 2]
        h_backward = h_n[2 * self.num_layers - 1]
        y = self.linear(torch.cat((h_forward, h_backward), 1))
        return self.softmax(y)
    
    
class BidirectionalGRUClassifier(torch.nn.Module):
    def __init__(self, num_classes, hidden_size, num_layers):
        super().__init__()
        self.num_layers = num_layers
        self.GRU = GRU(50, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=True)
        self.linear = Linear(2 * hidden_size, num_classes)
        self.softmax = Softmax(dim=1)
        
    def forward(self, x):
        _, h_n = self.GRU(x)
        h_forward = h_n[2 * self.num_layers - 2]
        h_backward = h_n[2 * self.num_layers - 1]
        y = self.linear(torch.cat((h_forward, h_backward), 1))
        return self.softmax(y)

In [5]:
class ClassificationDataset(Dataset):
    def __init__(self, dataset, num_classes, tokenizer):
        self.num_classes = num_classes
        self.dataset = dataset
        self.tokenizer = tokenizer
    
    def __len__(self):
        return self.dataset.__len__()
    
    def __getitem__(self, idx):
        label, text = self.dataset.__getitem__(idx)
        return int(label) - 1, self.tokenizer(text)

In [6]:
if DATASET == 'IMDB':
    train_set = IMDB(split='train')
    test_set = IMDB(split='test')
    num_classes = 2
elif DATASET == 'AG_NEWS':
    train_set = AG_NEWS(split='train')
    test_set = AG_NEWS(split='test')
    num_classes = 4
elif DATASET == 'YahooAnswers':
    train_set = YahooAnswers(split='train')
    test_set = YahooAnswers(split='test')
    num_classes = 10
else:
    raise ValueError()

tokenizer = get_tokenizer('basic_english')
embedding = GloVe(name='6B', dim=50)

train_set = to_map_style_dataset(train_set)
test_set = to_map_style_dataset(test_set)

train_set = ClassificationDataset(train_set, num_classes, tokenizer)
test_set = ClassificationDataset(test_set, num_classes, tokenizer)
test_set, val_set = random_split(test_set, [test_set.__len__() - int(VALIDATION_SPLIT * test_set.__len__()), int(VALIDATION_SPLIT * test_set.__len__())], generator=torch.Generator().manual_seed(42))

In [7]:
def collate_batch(batch):
    label_list, text_list = [], []
    for (_label, _tokens) in batch:
        label_list.append(_label)
        embed = embedding.get_vecs_by_tokens(_tokens)
        text_list.append(embed)
    label_list = torch.tensor(label_list, dtype=torch.int64)
    text_list = pad_sequence(text_list, batch_first=True)
    return label_list.to(device), text_list.to(device)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, collate_fn=collate_batch, shuffle=SHUFFLE)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, collate_fn=collate_batch, shuffle=SHUFFLE)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, collate_fn=collate_batch, shuffle=SHUFFLE)

In [8]:
def evaluate(model, data_loader, loss=CrossEntropyLoss()):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (labels, text) in enumerate(data_loader):
            predicted_label = model(text)
            loss_ = loss(predicted_label, labels)
            total_acc += (predicted_label.argmax(1) == labels).sum().item()
            total_count += labels.size(0)
    return total_acc / total_count


def train(model, optimizer, train_loader, loss=CrossEntropyLoss(), log_interval=50):
    model.train()
    total_acc, total_count = 0, 0
    pbar = tqdm(total=len(train_loader), desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]')
    for idx, (labels, text) in enumerate(train_loader):
        output = model(text)
        loss_ = loss(output, labels)
        optimizer.zero_grad()
        loss_.backward()
        optimizer.step()
        total_acc += (output.argmax(1) == labels).sum().item()
        total_count += labels.size(0)
        pbar.update()
        if idx % log_interval == 0 and idx > 0:
            pbar.set_postfix(loss=loss_, accuracy=total_acc / total_count)
            total_acc, total_count = 0, 0
    
    pbar.close()

In [10]:
model = BidirectionalGRUClassifier(num_classes, 64, 1).to(device)
optim = Adam(model.parameters())

for epoch in range(NUM_EPOCHS):
    train(model, optim, train_loader)
    val_accuracy = evaluate(model, val_loader)

    torch.save({
        'epoch' : epoch,
        'model_state_dict' : model.state_dict(),
        'optimizer_state_dict': optim.state_dict(),
        'val_accuracy' : val_accuracy
    }, PATH + '_' + str(epoch) + '.pt' )

    #How to load a model
    #checkpoint = torch.load(PATH)
    #model.load_state_dict(checkpoint['model_state_dict'])
    #optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    #epoch = checkpoint['epoch']
    #val_accuracy = checkpoint['val_accuracy']

Epoch [0/10]: 100%|██████████| 1875/1875 [01:09<00:00, 27.15it/s, accuracy=0.906]
Epoch [1/10]: 100%|██████████| 1875/1875 [01:09<00:00, 27.14it/s, accuracy=0.893]
Epoch [2/10]: 100%|██████████| 1875/1875 [01:09<00:00, 27.06it/s, accuracy=0.91]
Epoch [3/10]: 100%|██████████| 1875/1875 [01:09<00:00, 27.04it/s, accuracy=0.911]
Epoch [4/10]: 100%|██████████| 1875/1875 [01:09<00:00, 27.12it/s, accuracy=0.918]
Epoch [5/10]: 100%|██████████| 1875/1875 [01:09<00:00, 27.10it/s, accuracy=0.911]
Epoch [6/10]: 100%|██████████| 1875/1875 [01:09<00:00, 26.96it/s, accuracy=0.917]
Epoch [7/10]: 100%|██████████| 1875/1875 [01:08<00:00, 27.23it/s, accuracy=0.919]
Epoch [8/10]: 100%|██████████| 1875/1875 [01:09<00:00, 27.14it/s, accuracy=0.93]
Epoch [9/10]: 100%|██████████| 1875/1875 [01:08<00:00, 27.18it/s, accuracy=0.927]


In [11]:
test_accuracy = evaluate(model, test_loader)
print(f'Test accuracy: {test_accuracy}')

Test accuracy: 0.9142105263157895
