In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import math
import json
import torch
import pickle
import random
from tqdm import tqdm
from torch import nn, optim
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, util
from transformers import GPT2LMHeadModel

import load_data
from load_data import GenderDataset, gender_data_collate_fn
from models.classifier_bert import ClassifierBERT

In [3]:
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda:1')

print(device)

cuda:1


In [4]:
# Hyper parameters
num_epoch = 100
batch_size = 192
lr = 3e-4
wd = 5e-6
print_every = 300

max_norm = 2.0

In [5]:
with open(os.path.join(os.curdir, "data", "blog.json"), "r") as file:
    json_data = json.load(file)
docs = json_data['docs'][1:] # I don't want to see the first document

In [6]:
num_docs = len(docs)
num_train_docs = int(num_docs * 0.7)
num_val_docs = int(num_docs * 0.15)
num_test_docs = num_docs - num_train_docs - num_val_docs
print(num_train_docs, num_val_docs, num_test_docs)

13773 2951 2952


In [7]:
train_docs = docs[:num_train_docs]
val_docs = docs[num_train_docs:num_train_docs+num_val_docs]
test_docs = docs[num_train_docs+num_val_docs:]

In [8]:
train_dataset, val_dataset, test_dataset = None, None, None
load_from_pickled = False

if os.path.exists(os.path.join(os.curdir, "data", "train.pickle")):
    load_from_pickled = True
    with open(os.path.join(os.curdir, "data", "train.pickle"), "rb") as f:
        train_dataset = pickle.load(f)
    with open(os.path.join(os.curdir, "data", "val.pickle"), "rb") as f:
        val_dataset = pickle.load(f)
    with open(os.path.join(os.curdir, "data", "test.pickle"), "rb") as f:
        test_dataset = pickle.load(f)
else:
    train_dataset = GenderDataset(train_docs)
    val_dataset = GenderDataset(val_docs)
    test_dataset = GenderDataset(test_docs)

if not load_from_pickled:
    with open(os.path.join(os.curdir, "data", "train.pickle"), "wb") as f:
        pickle.dump(train_dataset, f)
    with open(os.path.join(os.curdir, "data", "val.pickle"), "wb") as f:
        pickle.dump(val_dataset, f)
    with open(os.path.join(os.curdir, "data", "test.pickle"), "wb") as f:
        pickle.dump(test_dataset, f)

print(load_from_pickled)   

True


In [9]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8,
    collate_fn=gender_data_collate_fn
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8,
    collate_fn=gender_data_collate_fn
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8,
    collate_fn=gender_data_collate_fn
)

In [10]:
next(enumerate(test_dataloader))

(0,
 (tensor([[ 1049,  2035,  2468,  ...,  1015,  4016,  4506],
          [ 2133,  2061,  2028,  ...,  2070,  2013,  1016],
          [ 9008,  1014,  2058,  ...,  2106,  2432,  1004],
          ...,
          [ 2066,  1016,  1016,  ..., 15669,  1014, 11785],
          [19786,  2079,  1003,  ...,  1016,  1016,  1016],
          [ 4658,  1016,  4597,  ...,  1014,  2006, 26354]], dtype=torch.int32),
  tensor([128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
          128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
          128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
          128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
          128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
          128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
          128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
          128, 128, 128, 128, 128, 128, 

In [11]:
classifier_model = ClassifierBERT(
    vocab_size=load_data.tokenizer.vocab_size,
).to(device)

In [12]:
classifier_model.load_state_dict(torch.load('./save/cls_model_768_192_epoch_0.file'))

<All keys matched successfully>

In [13]:
optimizer = optim.AdamW(classifier_model.parameters(), lr=lr, weight_decay=wd)

In [14]:
criterion = nn.CrossEntropyLoss()

In [15]:
def train(train_dataloader, val_dataloader, model, criterion, optimizer, num_epoch):
    train_id = random.randint(0,1000)
    log = open(f'./save/cls_{train_id}.txt','w')
    for epoch in range(num_epoch):
        print(f"Epoch {epoch}, total {len(train_dataloader)} batches\n")
        log.write(f"Epoch {epoch}, total {len(train_dataloader)} batches\n")
        log.flush()

        for batch, (src_ids, src_len, tgt) in enumerate(train_dataloader):
            torch.cuda.empty_cache()
            model.train()
            optimizer.zero_grad()

            src_ids = src_ids.to(device)
            tgt = tgt.to(device)
            
            src_logits = torch.zeros(src_ids.size(0), src_ids.size(1), load_data.tokenizer.vocab_size).to(device)
            for i in range(src_ids.size(0)):
                for j, label in enumerate(src_ids[i]):
                    src_logits[i][j][label] = 1.0

            logits = model(src_logits)
            loss = criterion(logits, tgt)
            
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm)
            optimizer.step()
            torch.cuda.empty_cache()
            
            if batch % print_every == 0:
                print(f"Epoch Step: {batch} Loss: {loss} Acc: {(logits.argmax(1) == tgt).sum().item() / tgt.size(0)}")
                log.write(f"Epoch Step: {batch} Loss: {loss} Acc: {(logits.argmax(1) == tgt).sum().item() / tgt.size(0)}\n")

                torch.save(model.state_dict(), f'./save/cls_model_{train_id}_{batch_size}_epoch_{epoch}.file')
            
            if batch % (print_every * 5) == 0:
                print(f"\nBegin Evaluation")
                model.eval()
                total_acc = 0
                limit=len(val_dataloader)
                if limit > 75: 
                    limit=75

                with torch.no_grad():
                    for batch, (src_ids, src_len, tgt) in tqdm(enumerate(val_dataloader),total=limit):
                        src_ids = src_ids.to(device)
                        tgt = tgt.to(device)
                        
                        src_logits = torch.zeros(src_ids.size(0), src_ids.size(1), load_data.tokenizer.vocab_size).to(device)
                        for i in range(src_ids.size(0)):
                            for j, label in enumerate(src_ids[i]):
                                src_logits[i][j][label] = 1.0
                        
                        logits = model(src_logits)
                        total_acc += (logits.argmax(1) == tgt).sum().item()
                        if batch>=limit: break

                acc = total_acc / limit  / batch_size #len(val_dataloader.dataset)
                print(f"Validation Accuracy: {acc}\n")
                log.write(f"Validation Accuracy: {acc}\n")
                log.flush()

            



In [None]:
train(train_dataloader, val_dataloader, classifier_model, criterion, optimizer, num_epoch)

Epoch 0, total 2341 batches

Epoch Step: 0 Loss: 0.6973282694816589 Acc: 0.5052083333333334

Begin Evaluation


100%|███████████████████████████████████████████████████████████████████████████████████| 75/75 [02:13<00:00,  1.78s/it]


Validation Accuracy: 0.5530555555555555

