In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import json
import torch
import numpy as np
from tqdm import tqdm
from torch import nn, optim
from torch.utils.data import DataLoader

from load_data import Tokenizer, GenderDataset, gender_data_collate_fn
from models.classifier_lstm import ClassifierLSTM

In [3]:
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')

print(device)

cuda


In [4]:
classifier_embedding_size = 512
classifier_hidden_size = 512
classifier_num_layers = 2
classifier_is_bidirectional = True
classifier_dropout = 0.6

classifier_lr = 3e-4
classifier_weight_decay = 1e-3
classifier_num_epoch = 100
classifier_batch_size = 128
classifier_max_norm = 2

print_every = 200

In [5]:
with open(os.path.join(os.curdir, "data", "blog.json"), "r") as file:
    json_data = json.load(file)
docs = json_data['docs'][1:] # I don't want to see the first document

In [6]:
tokenizer = Tokenizer(docs)

Cutting documents into paragraphs of length 128...


100%|██████████| 19676/19676 [00:52<00:00, 377.61it/s]


Number of documents: 559126
Counting freqeuncies of words...


100%|██████████| 559126/559126 [00:26<00:00, 21398.39it/s]


Number of documents with lengths <= 128: 554016
Number of unique words before converting to <UNK>:  505954
Converting words with frequencies less than 20 to <UNK>...


554016it [00:17, 30898.24it/s]


Number of unique words after converting <UNK>:  40741
Known occurrences rate 98.26%


In [7]:
num_docs = len(docs)
num_train_docs = int(num_docs * 0.7)
num_val_docs = int(num_docs * 0.15)
num_test_docs = num_docs - num_train_docs - num_val_docs
print(num_train_docs, num_val_docs, num_test_docs)

13773 2951 2952


In [8]:
train_docs = docs[:num_train_docs]
val_docs = docs[num_train_docs:num_train_docs+num_val_docs]
test_docs = docs[num_train_docs+num_val_docs:]

In [10]:
train_dataset = GenderDataset(train_docs, tokenizer)
val_dataset = GenderDataset(val_docs, tokenizer)
test_dataset = GenderDataset(test_docs, tokenizer)

Cutting documents into paragraphs of length 128...


100%|██████████| 2952/2952 [00:07<00:00, 405.77it/s]


Number of documents: 65739
Counting freqeuncies of words...


65739it [00:02, 25922.30it/s]


Number of documents with lengths <= 128: 65047


In [11]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=classifier_batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=gender_data_collate_fn
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=classifier_batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=gender_data_collate_fn
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=classifier_batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=gender_data_collate_fn
)

In [12]:
classifier_model = ClassifierLSTM(
    tokenizer.vocab_size(), 
    classifier_embedding_size, 
    classifier_hidden_size, 
    classifier_num_layers, 
    classifier_is_bidirectional
).to(device)



In [13]:
optimizer = optim.AdamW(classifier_model.parameters(), lr = classifier_lr, weight_decay=5e-6)

In [14]:
criterion = nn.CrossEntropyLoss()

In [15]:
def train(train_dataloader, val_dataloader, model, criterion, optimizer, num_epoch):

    for epoch in range(num_epoch):
        print(f"Epoch {epoch}, total {len(train_dataloader)} batches\n")
        model.train()

        for batch, (src_ids, src_len, tgt) in enumerate(train_dataloader):
            optimizer.zero_grad()

            src_ids = src_ids.to(device)
            tgt = tgt.to(device)

            logits = model(src_ids, src_len)
            loss = criterion(logits, tgt)
            if batch % print_every == 0:
                print(f"Epoch Step: {batch} Loss: {loss} Acc: {(logits.argmax(1) == tgt).sum().item() / tgt.size(0)}")

            loss.backward()
            # nn.utils.clip_grad_norm_(model.parameters(), classifier_max_norm)
            optimizer.step()

        print("\nBegin Evaluation")
        model.eval()
        total_acc = 0
        with torch.no_grad():
            for batch, (src_ids, src_len, tgt) in enumerate(val_dataloader):
                src_ids = src_ids.to(device)
                tgt = tgt.to(device)
                logits = model(src_ids, src_len)
                total_acc += (logits.argmax(1) == tgt).sum().item()
        
        acc = total_acc / len(val_dataloader.dataset)
        print(f"Validation Accuracy: {acc}, model saved\n")

        torch.save(model.state_dict(), f'./save/classifier_model_clip_{classifier_hidden_size}_{classifier_batch_size}_{classifier_num_layers}_{classifier_is_bidirectional}_epoch_{epoch}.file')


In [16]:
train(train_dataloader, val_dataloader, classifier_model, criterion, optimizer, classifier_num_epoch)

Epoch 0, total 3144 batches

Epoch Step: 0 Loss: 0.6946387887001038 Acc: 0.515625
Epoch Step: 200 Loss: 0.6872751712799072 Acc: 0.578125
Epoch Step: 400 Loss: 0.6896743178367615 Acc: 0.5390625
Epoch Step: 600 Loss: 0.6924378871917725 Acc: 0.53125
Epoch Step: 800 Loss: 0.7062233686447144 Acc: 0.4609375
Epoch Step: 1000 Loss: 0.6824397444725037 Acc: 0.6015625
Epoch Step: 1200 Loss: 0.6917186975479126 Acc: 0.5390625
Epoch Step: 1400 Loss: 0.6932624578475952 Acc: 0.5234375
Epoch Step: 1600 Loss: 0.7041881084442139 Acc: 0.4453125
Epoch Step: 1800 Loss: 0.6872370839118958 Acc: 0.5546875
Epoch Step: 2000 Loss: 0.6757984161376953 Acc: 0.6328125
Epoch Step: 2200 Loss: 0.6879671216011047 Acc: 0.546875
Epoch Step: 2400 Loss: 0.6847577095031738 Acc: 0.578125
Epoch Step: 2600 Loss: 0.6857582330703735 Acc: 0.5546875
Epoch Step: 2800 Loss: 0.6962257623672485 Acc: 0.53125
Epoch Step: 3000 Loss: 0.6852021217346191 Acc: 0.5625

Begin Evaluation
Validation Accuracy: 0.5352959505152736, model saved

Epoch

In [28]:
for batch, (src_ids, src_len, tgt) in enumerate(train_dataloader):
    with torch.no_grad():
        src_ids = src_ids.to(device)
        tgt = tgt.to(device)
        logits = classifier_model(src_ids, src_len)
        print(f"Epoch Step: {batch} Acc: {(logits.argmax(1) == tgt).sum().item() / tgt.size(0)}")

Epoch Step: 0 Acc: 0.953125
Epoch Step: 1 Acc: 0.9609375
Epoch Step: 2 Acc: 0.9921875
Epoch Step: 3 Acc: 0.984375
Epoch Step: 4 Acc: 0.9453125
Epoch Step: 5 Acc: 0.9609375
Epoch Step: 6 Acc: 0.96875
Epoch Step: 7 Acc: 0.90625
Epoch Step: 8 Acc: 0.9609375
Epoch Step: 9 Acc: 0.9609375
Epoch Step: 10 Acc: 0.9609375
Epoch Step: 11 Acc: 0.953125
Epoch Step: 12 Acc: 0.96875
Epoch Step: 13 Acc: 0.96875
Epoch Step: 14 Acc: 0.96875
Epoch Step: 15 Acc: 1.0
Epoch Step: 16 Acc: 0.96875
Epoch Step: 17 Acc: 0.9921875
Epoch Step: 18 Acc: 0.96875
Epoch Step: 19 Acc: 0.9453125
Epoch Step: 20 Acc: 0.9765625
Epoch Step: 21 Acc: 0.9765625
Epoch Step: 22 Acc: 0.96875
Epoch Step: 23 Acc: 0.9609375
Epoch Step: 24 Acc: 0.96875
Epoch Step: 25 Acc: 0.96875
Epoch Step: 26 Acc: 0.9765625
Epoch Step: 27 Acc: 0.984375
Epoch Step: 28 Acc: 0.9609375
Epoch Step: 29 Acc: 0.953125
Epoch Step: 30 Acc: 0.953125
Epoch Step: 31 Acc: 0.953125
Epoch Step: 32 Acc: 0.9609375
Epoch Step: 33 Acc: 0.9609375
Epoch Step: 34 Acc: 0.9

KeyboardInterrupt: 