In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import math
import json
import torch
import pickle
import random
from tqdm import tqdm
from torch import nn, optim
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, util
from transformers import GPT2LMHeadModel

import load_data
from load_data import GenderDataset, gender_data_collate_fn
from models.encoder_t5 import EncoderT5
from models.classifier_bert import ClassifierBERT
from models.perplexity_gpt import PerplexityGPT2
from models.similarity_sent_enc import encode_for_similarities

In [3]:
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')

print(device)

cuda


In [4]:
# Hyper parameters
num_epoch = 100
batch_size = 20
lr = 3e-4
wd = 5e-6
print_every = 500

max_norm = 2.0
random.seed(42)

In [5]:
with open(os.path.join(os.curdir, "data", "blog.json"), "r") as file:
    json_data = json.load(file)
docs = json_data['docs'][1:] # I don't want to see the first document

In [6]:
num_docs = len(docs)
num_train_docs = int(num_docs * 0.7)
num_val_docs = int(num_docs * 0.15)
num_test_docs = num_docs - num_train_docs - num_val_docs
print(num_train_docs, num_val_docs, num_test_docs)

13773 2951 2952


In [7]:
train_docs = docs[:num_train_docs]
val_docs = docs[num_train_docs:num_train_docs+num_val_docs]
test_docs = docs[num_train_docs+num_val_docs:]

In [8]:
train_dataset, val_dataset, test_dataset = None, None, None
load_from_pickled = False

if os.path.exists(os.path.join(os.curdir, "data", "train.pickle")):
    load_from_pickled = True
    with open(os.path.join(os.curdir, "data", "train.pickle"), "rb") as f:
        train_dataset = pickle.load(f)
    with open(os.path.join(os.curdir, "data", "val.pickle"), "rb") as f:
        val_dataset = pickle.load(f)
    with open(os.path.join(os.curdir, "data", "test.pickle"), "rb") as f:
        test_dataset = pickle.load(f)
else:
    train_dataset = GenderDataset(train_docs)
    val_dataset = GenderDataset(val_docs)
    test_dataset = GenderDataset(test_docs)

if not load_from_pickled:
    with open(os.path.join(os.curdir, "data", "train.pickle"), "wb") as f:
        pickle.dump(train_dataset, f)
    with open(os.path.join(os.curdir, "data", "val.pickle"), "wb") as f:
        pickle.dump(val_dataset, f)
    with open(os.path.join(os.curdir, "data", "test.pickle"), "wb") as f:
        pickle.dump(test_dataset, f)

print(load_from_pickled)   

True


In [9]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=gender_data_collate_fn
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=gender_data_collate_fn
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=gender_data_collate_fn
)

In [10]:
next(enumerate(test_dataloader))

(0,
 (tensor([[ 2013,  1009,  1059,  ...,  2013,  1009,  1059],
          [ 1054,  1009,  1060,  ...,  2125,  3280,  1015],
          [ 1014, 12440,  1016,  ...,  2070,  1041, 14688],
          ...,
          [ 2771,  5203,  2762,  ...,  1016,  1016,  1016],
          [ 1016,  2013,  2005,  ..., 18852,  5296,  2103],
          [ 2767,  1022,  1065,  ...,  2009,  4907,  2144]], dtype=torch.int32),
  tensor([128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
          128, 128, 128, 128, 128, 128], dtype=torch.int32),
  tensor([1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0])))

In [11]:
ppl_model = PerplexityGPT2(
    vocab_size=load_data.tokenizer.vocab_size,
    bos_token_id=load_data.tokenizer.bos_token_id,
    eos_token_id=load_data.tokenizer.eos_token_id,
).to(device)

In [12]:
print(ppl_model.config.vocab_size)
print(ppl_model.GPT2.lm_head.weight.size())

30527
torch.Size([30527, 768])


In [13]:
optimizer = optim.AdamW(ppl_model.parameters(), lr=lr, weight_decay=wd)

In [14]:
criterion = nn.CrossEntropyLoss()

In [15]:
def train(train_dataloader, val_dataloader, model, criterion, optimizer, num_epoch):
    train_id = random.randint(0,1000)
    log = open(f'./save/ppl_{train_id}.txt','w')
    for epoch in range(num_epoch):
        print(f"Epoch {epoch}, total {len(train_dataloader)} batches\n")
        log.write(f"Epoch {epoch}, total {len(train_dataloader)} batches\n")
        log.flush()

        for batch, (src_ids, src_len, tgt) in enumerate(train_dataloader):
            torch.cuda.empty_cache()
            model.train()
            optimizer.zero_grad()

            src_ids = src_ids.to(device)
            tgt = tgt.to(device)

            src_logits = torch.zeros(batch_size, src_ids.size(1), load_data.tokenizer.vocab_size).to(device)
            for i in range(batch_size):
                for j, label in enumerate(src_ids[i]):
                    src_logits[i][j][label] = 1.0

            tgt_logits = model(src_logits)
            
            shift_tgt = tgt_logits[..., :-1, :].contiguous()
            shift_src = src_logits[..., 1:, :].contiguous()
            loss = criterion(shift_tgt.view(-1, shift_tgt.size(-1)), shift_src.view(-1, shift_src.size(-1)))
            
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm)
            optimizer.step()
            torch.cuda.empty_cache()
            
            if batch % print_every == 0:
                print(f"Epoch Step: {batch} Loss: {loss}")
                log.write(f"Epoch Step: {batch} Loss: {loss}\n")
                log.flush()
                print(f"\nBegin Evaluation")
                model.eval()
                total_loss = 0
                limit=len(val_dataloader)
                if limit > 75: 
                    limit = 75
                with torch.no_grad():
                    for batch, (src_ids, src_len, tgt) in tqdm(enumerate(val_dataloader),total=limit):

                        src_ids = src_ids.to(device)
                        tgt = tgt.to(device)

                        src_logits = torch.zeros(batch_size, src_ids.size(1), load_data.tokenizer.vocab_size).to(device)
                        for i in range(batch_size):
                            for j, label in enumerate(src_ids[i]):
                                src_logits[i][j][label] = 1.0

                        tgt_logits = model(src_logits)
                        
                        shift_tgt = tgt_logits[..., :-1, :].contiguous()
                        shift_src = src_logits[..., 1:, :].contiguous()
                        loss = criterion(shift_tgt.view(-1, shift_tgt.size(-1)), shift_src.view(-1, shift_src.size(-1)))

                        total_loss += loss
                        if batch>=limit: break

                ppl = math.exp(total_loss / limit)
                print(f"Validation Perplexity: {ppl}\n")
                log.write(f"Validation Perplexity: {ppl}\n")
                log.flush()
                torch.save(model.state_dict(), f'./save/ppl_model_{train_id}_{batch_size}_epoch_{epoch}.file')

            



In [16]:
train(train_dataloader, val_dataloader, ppl_model, criterion, optimizer, num_epoch)

Epoch 0, total 22466 batches

Epoch Step: 0 Loss: 10.47973918914795

Begin Evaluation


100%|██████████| 75/75 [00:33<00:00,  2.27it/s]


Validation Perplexity: 10669.429778537775

Epoch Step: 500 Loss: 5.05801248550415

Begin Evaluation


100%|██████████| 75/75 [00:33<00:00,  2.23it/s]


Validation Perplexity: 184.53466530783209

Epoch Step: 1000 Loss: 5.047084808349609

Begin Evaluation


100%|██████████| 75/75 [00:34<00:00,  2.18it/s]


Validation Perplexity: 148.69537041239164

Epoch Step: 1500 Loss: 4.5915093421936035

Begin Evaluation


100%|██████████| 75/75 [00:33<00:00,  2.21it/s]


Validation Perplexity: 132.086648797218

Epoch Step: 2000 Loss: 4.713454246520996

Begin Evaluation


100%|██████████| 75/75 [00:34<00:00,  2.20it/s]


Validation Perplexity: 110.56062599745074

Epoch Step: 2500 Loss: 4.643686294555664

Begin Evaluation


100%|██████████| 75/75 [00:34<00:00,  2.20it/s]


Validation Perplexity: 102.70954763639354

Epoch Step: 3000 Loss: 4.7584733963012695

Begin Evaluation


100%|██████████| 75/75 [00:34<00:00,  2.15it/s]


Validation Perplexity: 94.07126194591892

Epoch Step: 3500 Loss: 4.4227447509765625

Begin Evaluation


100%|██████████| 75/75 [00:35<00:00,  2.12it/s]


Validation Perplexity: 87.50574803479874

Epoch Step: 4000 Loss: 4.493597507476807

Begin Evaluation


100%|██████████| 75/75 [00:35<00:00,  2.09it/s]


Validation Perplexity: 82.73543474107036

