In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import json
import torch
import numpy as np
from tqdm import tqdm
from torch import nn, optim
from torch.utils.data import DataLoader

from load_data import Tokenizer, GenderDataset, gender_data_collate_fn
from models.classifier_lstm import ClassifierLSTM

In [3]:
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')

print(device)

cuda


In [4]:
classifier_embedding_size = 512
classifier_hidden_size = 512
classifier_num_layers = 2
classifier_is_bidirectional = True

classifier_lr = 3e-4
classifier_num_epoch = 100
classifier_batch_size = 128
classifier_max_norm = 2

print_every = 200

In [5]:
with open(os.path.join(os.curdir, "data", "blog.json"), "r") as file:
    json_data = json.load(file)
docs = json_data['docs'][1:] # I don't want to see the first document

In [6]:
tokenizer = Tokenizer(docs)

Cutting documents into paragraphs of length 128...


100%|██████████| 19676/19676 [00:34<00:00, 576.52it/s]


Number of documents: 559126
Counting freqeuncies of words...


100%|██████████| 559126/559126 [00:16<00:00, 32907.19it/s]


Number of documents with lengths <= 128: 554016
Number of unique words before converting to <UNK>:  505954
Converting words with frequencies less than 10 to <UNK>...


554016it [00:12, 45847.09it/s]


Number of unique words after converting <UNK>:  59178
Known occurrences rate 98.69%


In [7]:
num_docs = len(docs)
num_train_docs = int(num_docs * 0.7)
num_val_docs = int(num_docs * 0.15)
num_test_docs = num_docs - num_train_docs - num_val_docs
print(num_train_docs, num_val_docs, num_test_docs)

13773 2951 2952


In [8]:
train_docs = docs[:num_train_docs]
val_docs = docs[num_train_docs:num_train_docs+num_val_docs]
test_docs = docs[num_train_docs+num_val_docs:]

In [9]:
train_dataset = GenderDataset(train_docs, tokenizer)
val_dataset = GenderDataset(val_docs, tokenizer)
test_dataset = GenderDataset(test_docs, tokenizer)

Cutting documents into paragraphs of length 128...


100%|██████████| 13773/13773 [00:24<00:00, 563.29it/s]


Number of documents: 405965
Counting freqeuncies of words...


100%|██████████| 405965/405965 [00:11<00:00, 35529.94it/s]


Number of documents with lengths <= 128: 402316
Cutting documents into paragraphs of length 128...


100%|██████████| 2951/2951 [00:05<00:00, 566.19it/s]


Number of documents: 87422
Counting freqeuncies of words...


100%|██████████| 87422/87422 [00:02<00:00, 34650.03it/s]


Number of documents with lengths <= 128: 86653
Cutting documents into paragraphs of length 128...


100%|██████████| 2952/2952 [00:03<00:00, 764.37it/s]


Number of documents: 65739
Counting freqeuncies of words...


100%|██████████| 65739/65739 [00:02<00:00, 32005.36it/s]


Number of documents with lengths <= 128: 65047


In [10]:
print(train_dataset[1], val_dataset[1], test_dataset[1])

([30275, 52036, 14412, 2662, 49226, 10062, 11762, 20664, 6772, 52364, 48343, 10062, 49309, 33403, 38804, 52036, 53265, 10062, 24796, 49226, 37486, 38804, 37486, 56870, 7998, 12358, 37486, 21323, 30832, 24785, 37486, 1530, 30556, 38804, 20065, 38804, 28927, 52036, 4514, 56936, 42360, 14412, 42607, 42719, 10062, 50735, 49226, 10062, 56936, 51328, 38804, 52036, 56632, 24820, 56802, 14412, 10062, 33186, 19301, 42719, 1078, 16301, 38804, 24785, 18994, 14412, 11371, 1530, 27207, 1530, 466, 44896, 52036, 1596, 24785, 19945, 24820, 28853, 466, 49411, 38804, 48343, 51790, 21323, 10062, 25057, 32466, 48343, 51170, 33597, 34179, 45849, 57746, 52747, 57074, 34356, 24785, 49144, 38804, 7362], 0) ([30275, 10349, 41147, 46695, 5492, 9836, 41756, 44196, 21323, 1280, 44120, 47077, 20325, 4688, 43985, 42719, 57023, 34415, 56452, 21323, 24785, 10349, 57673, 57074, 19169, 23810, 52036, 42719, 41869, 16639, 38804, 55987, 10349, 39996, 9343, 48353, 48343, 10062, 46818, 51617, 49116, 50723, 55151, 1280, 2587

In [11]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=classifier_batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=gender_data_collate_fn
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=classifier_batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=gender_data_collate_fn
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=classifier_batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=gender_data_collate_fn
)

In [52]:
classifier_model = ClassifierLSTM(
    tokenizer.vocab_size(), 
    classifier_embedding_size, 
    classifier_hidden_size, 
    classifier_num_layers, 
    classifier_is_bidirectional
).to(device)

In [53]:
optimizer = optim.Adam(classifier_model.parameters(), lr = classifier_lr)

In [54]:
criterion = nn.CrossEntropyLoss()

In [55]:
def train(train_dataloader, val_dataloader, model, criterion, optimizer, num_epoch):

    for epoch in range(num_epoch):
        print(f"Epoch {epoch}, total {len(train_dataloader)} batches\n")
        model.train()

        for batch, (src_ids, src_len, tgt) in enumerate(train_dataloader):
            optimizer.zero_grad()

            src_ids = src_ids.to(device)
            tgt = tgt.to(device)

            logits = model(src_ids, src_len)
            loss = criterion(logits, tgt)
            if batch % print_every == 0:
                print(f"Epoch Step: {batch} Loss: {loss} Acc: {(logits.argmax(1) == tgt).sum().item() / tgt.size(0)}")

            loss.backward()
            # nn.utils.clip_grad_norm_(model.parameters(), classifier_max_norm)
            optimizer.step()

        print("\nBegin Evaluation")
        model.eval()
        total_acc = 0
        with torch.no_grad():
            for batch, (src_ids, src_len, tgt) in enumerate(val_dataloader):
                src_ids = src_ids.to(device)
                tgt = tgt.to(device)
                logits = model(src_ids, src_len)
                total_acc += (logits.argmax(1) == tgt).sum().item()
        
        acc = total_acc / len(val_dataloader.dataset)
        print(f"Validation Accuracy: {acc}, model saved\n")

        torch.save(model.state_dict(), f'./save/classifier_model_{classifier_hidden_size}_{classifier_batch_size}_{classifier_num_layers}_{classifier_is_bidirectional}_epoch_{epoch}.file')


In [57]:
train(train_dataloader, val_dataloader, classifier_model, criterion, optimizer, classifier_num_epoch)

Epoch 0, total 3144 batches

Epoch Step: 0 Loss: 0.6901217699050903 Acc: 0.515625
Epoch Step: 200 Loss: 0.6968944668769836 Acc: 0.515625
Epoch Step: 400 Loss: 0.6965958476066589 Acc: 0.515625
Epoch Step: 600 Loss: 0.691948413848877 Acc: 0.5390625
Epoch Step: 800 Loss: 0.6927976012229919 Acc: 0.5234375
Epoch Step: 1000 Loss: 0.7107475399971008 Acc: 0.4375
Epoch Step: 1200 Loss: 0.6919336318969727 Acc: 0.53125
Epoch Step: 1400 Loss: 0.6899346113204956 Acc: 0.53125
Epoch Step: 1600 Loss: 0.68748539686203 Acc: 0.546875
Epoch Step: 1800 Loss: 0.6952205896377563 Acc: 0.515625
Epoch Step: 2000 Loss: 0.6928135752677917 Acc: 0.53125
Epoch Step: 2200 Loss: 0.6911121606826782 Acc: 0.546875
Epoch Step: 2400 Loss: 0.6961835622787476 Acc: 0.4921875
Epoch Step: 2600 Loss: 0.6776772737503052 Acc: 0.578125
Epoch Step: 2800 Loss: 0.6839342713356018 Acc: 0.5703125
Epoch Step: 3000 Loss: 0.6882370114326477 Acc: 0.53125

Begin Evaluation
Validation Accuracy: 0.5340726806919552, model saved

Epoch 1, total 

KeyboardInterrupt: 

In [12]:
trained_model = ClassifierLSTM(
    tokenizer.vocab_size(), 
    classifier_embedding_size, 
    classifier_hidden_size, 
    classifier_num_layers, 
    classifier_is_bidirectional
).to(device)
trained_model.load_state_dict(torch.load("./save/classifier_model_512_128_2_True_epoch_7.file"))

<All keys matched successfully>

In [13]:
for batch, (src_ids, src_len, tgt) in enumerate(test_dataloader):
    src_ids = src_ids.to(device)
    tgt = tgt.to(device)
    logits = trained_model(src_ids, src_len)
    print(f"Epoch Step: {batch} Acc: {(logits.argmax(1) == tgt).sum().item() / tgt.size(0)}")

Epoch Step: 0 Acc: 0.5625
Epoch Step: 1 Acc: 0.453125
Epoch Step: 2 Acc: 0.4453125
Epoch Step: 3 Acc: 0.5078125
Epoch Step: 4 Acc: 0.5
Epoch Step: 5 Acc: 0.4765625
Epoch Step: 6 Acc: 0.484375
Epoch Step: 7 Acc: 0.453125
Epoch Step: 8 Acc: 0.4765625
Epoch Step: 9 Acc: 0.4453125
Epoch Step: 10 Acc: 0.40625
Epoch Step: 11 Acc: 0.4296875
Epoch Step: 12 Acc: 0.421875
Epoch Step: 13 Acc: 0.4375
Epoch Step: 14 Acc: 0.5078125
Epoch Step: 15 Acc: 0.4765625
Epoch Step: 16 Acc: 0.4921875
Epoch Step: 17 Acc: 0.53125
Epoch Step: 18 Acc: 0.4453125
Epoch Step: 19 Acc: 0.5390625
Epoch Step: 20 Acc: 0.46875
Epoch Step: 21 Acc: 0.5234375
Epoch Step: 22 Acc: 0.4765625
Epoch Step: 23 Acc: 0.5234375
Epoch Step: 24 Acc: 0.5859375
Epoch Step: 25 Acc: 0.4765625
Epoch Step: 26 Acc: 0.4921875
Epoch Step: 27 Acc: 0.453125
Epoch Step: 28 Acc: 0.5
Epoch Step: 29 Acc: 0.4140625
Epoch Step: 30 Acc: 0.40625
Epoch Step: 31 Acc: 0.5078125
Epoch Step: 32 Acc: 0.453125
Epoch Step: 33 Acc: 0.4609375
Epoch Step: 34 Acc: 0.

In [14]:
for batch, (src_ids, src_len, tgt) in enumerate(train_dataloader):
    src_ids = src_ids.to(device)
    tgt = tgt.to(device)
    logits = trained_model(src_ids, src_len)
    print(f"Epoch Step: {batch} Acc: {(logits.argmax(1) == tgt).sum().item() / tgt.size(0)}")

Epoch Step: 0 Acc: 0.4609375
Epoch Step: 1 Acc: 0.453125
Epoch Step: 2 Acc: 0.4609375
Epoch Step: 3 Acc: 0.484375
Epoch Step: 4 Acc: 0.5703125
Epoch Step: 5 Acc: 0.46875
Epoch Step: 6 Acc: 0.5078125
Epoch Step: 7 Acc: 0.46875
Epoch Step: 8 Acc: 0.453125
Epoch Step: 9 Acc: 0.390625
Epoch Step: 10 Acc: 0.4296875
Epoch Step: 11 Acc: 0.484375
Epoch Step: 12 Acc: 0.453125
Epoch Step: 13 Acc: 0.5
Epoch Step: 14 Acc: 0.46875
Epoch Step: 15 Acc: 0.5625
Epoch Step: 16 Acc: 0.46875
Epoch Step: 17 Acc: 0.46875
Epoch Step: 18 Acc: 0.5
Epoch Step: 19 Acc: 0.484375
Epoch Step: 20 Acc: 0.4609375
Epoch Step: 21 Acc: 0.5703125
Epoch Step: 22 Acc: 0.453125
Epoch Step: 23 Acc: 0.4453125
Epoch Step: 24 Acc: 0.4921875
Epoch Step: 25 Acc: 0.5078125
Epoch Step: 26 Acc: 0.4140625
Epoch Step: 27 Acc: 0.4609375
Epoch Step: 28 Acc: 0.5078125
Epoch Step: 29 Acc: 0.5390625
Epoch Step: 30 Acc: 0.453125
Epoch Step: 31 Acc: 0.46875
Epoch Step: 32 Acc: 0.5078125
Epoch Step: 33 Acc: 0.5390625
Epoch Step: 34 Acc: 0.4375

KeyboardInterrupt: 