In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import json
import torch
import pickle
import random
import math
from tqdm import tqdm
from torch import nn, optim
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, util

import load_data
from load_data import GenderDataset, gender_data_collate_fn
from models.encoder_t5 import EncoderT5
from models.classifier_bert import ClassifierBERT
from models.perplexity_gpt import PerplexityGPT2
from models.similarity_sent_enc import encode_for_similarities

In [3]:
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda:3')

print(device)

cuda:3


In [4]:
# Hyper parameters
num_epoch = 100
batch_size = 32
lr = 1e-4
wd = 5e-6
print_every = 200

In [5]:
with open(os.path.join(os.curdir, "data", "blog.json"), "r") as file:
    json_data = json.load(file)
docs = json_data['docs'][1:] # I don't want to see the first document

In [6]:
num_docs = len(docs)
num_train_docs = int(num_docs * 0.7)
num_val_docs = int(num_docs * 0.15)
num_test_docs = num_docs - num_train_docs - num_val_docs
print(num_train_docs, num_val_docs, num_test_docs)

13773 2951 2952


In [7]:
train_docs = docs[:num_train_docs]
val_docs = docs[num_train_docs:num_train_docs+num_val_docs]
test_docs = docs[num_train_docs+num_val_docs:]

In [8]:
train_dataset, val_dataset, test_dataset = None, None, None
load_from_pickled = False

if os.path.exists(os.path.join(os.curdir, "data", "train.pickle")):
    load_from_pickled = True
    with open(os.path.join(os.curdir, "data", "train.pickle"), "rb") as f:
        train_dataset = pickle.load(f)
    with open(os.path.join(os.curdir, "data", "val.pickle"), "rb") as f:
        val_dataset = pickle.load(f)
    with open(os.path.join(os.curdir, "data", "test.pickle"), "rb") as f:
        test_dataset = pickle.load(f)
else:
    train_dataset = GenderDataset(train_docs)
    val_dataset = GenderDataset(val_docs)
    test_dataset = GenderDataset(test_docs)

if not load_from_pickled:
    with open(os.path.join(os.curdir, "data", "train.pickle"), "wb") as f:
        pickle.dump(train_dataset, f)
    with open(os.path.join(os.curdir, "data", "val.pickle"), "wb") as f:
        pickle.dump(val_dataset, f)
    with open(os.path.join(os.curdir, "data", "test.pickle"), "wb") as f:
        pickle.dump(test_dataset, f)

print(load_from_pickled)   

True


In [9]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8,
    collate_fn=gender_data_collate_fn
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8,
    collate_fn=gender_data_collate_fn
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8,
    collate_fn=gender_data_collate_fn
)

In [10]:
next(enumerate(test_dataloader))

(0,
 (tensor([[ 2002,  2427,  3827,  ..., 12673,  4823,  2002],
          [ 2009,  2030,  2175,  ...,  2021,  9381,  7169],
          [ 2232,  1049,  1009,  ...,  7230,  2021, 27137],
          ...,
          [ 1049,  2035,  6870,  ...,  1015,  1015,  1015],
          [ 1049,  1009,  2314,  ...,  1016,  2047,  1049],
          [ 1049,  1009,  1053,  ...,  1049,  5227,  2000]], dtype=torch.int32),
  tensor([128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
          128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
          128, 128, 128, 128], dtype=torch.int32),
  tensor([0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1,
          0, 0, 0, 0, 0, 1, 0, 0])))

In [11]:
encoder_model = EncoderT5(
    vocab_size=load_data.tokenizer.vocab_size
).to(device)

classifier_model = ClassifierBERT(
    vocab_size=load_data.tokenizer.vocab_size
).to(device)

ppl_model = PerplexityGPT2(
    vocab_size=load_data.tokenizer.vocab_size,
    bos_token_id=load_data.tokenizer.bos_token_id,
    eos_token_id=load_data.tokenizer.eos_token_id,
).to(device)

st_model = SentenceTransformer('all-mpnet-base-v2')
st_model[0].auto_model=st_model[0].auto_model.to(device)

In [12]:
classifier_model.load_state_dict(torch.load('./save/cls_model_271_128_epoch_0.file'))
ppl_model.load_state_dict(torch.load('./save/ppl_model_223_128_epoch_1.file'))

<All keys matched successfully>

In [13]:
optimizer = optim.AdamW(encoder_model.parameters(), lr=lr, weight_decay=wd)

In [14]:
criterion_similarity = nn.MSELoss()
criterion_classification = nn.CrossEntropyLoss()
criterion_perplexity = nn.CrossEntropyLoss() # Placeholder

In [15]:
def train(train_dataloader, val_dataloader, models, criterions, optimizer, num_epoch):
    
    train_id = random.randint(0, 1000)
    log = open(f'./save/enc_{train_id}.txt', 'w')
    models['st'].eval()
    models['ppl'].eval()
    models['cls'].eval()

    for epoch in range(num_epoch):
        
        print(f"Epoch {epoch}, total {len(train_dataloader)} batches\n")
        log.write(f"Epoch {epoch}, total {len(train_dataloader)} batches\n")
        log.flush()

        models['enc'].train()

        for batch, (src_ids, src_len, tgt) in enumerate(train_dataloader):

            torch.cuda.empty_cache()
            optimizer.zero_grad()
            src_ids = src_ids.to(device)
            tgt = torch.full((batch_size, 2), 0.5).to(device)

            # Calculate the obfuscated logits via the encoder & softmax
            obf_logits = models['enc'](src_ids)
            obf_logits = nn.functional.softmax(obf_logits, dim=-1)
            
            # Calculate similarity loss
            result = encode_for_similarities(models['st'], device, src_ids, obf_logits)
            loss_sim = torch.tensor(1.0) - torch.mean(result)

            # Calculate gender classification loss aiming for 50/50
            gender_logits = models['cls'](obf_logits)
            loss_cls = criterions['cls'](gender_logits, tgt)
            
            # Calculate perplexity loss
            ppl_logits = models['ppl'](obf_logits)
            shift_ppl = ppl_logits[..., :-1, :].contiguous()
            shift_obf = obf_logits[..., 1:, :].contiguous()
            loss_ppl = criterions['ppl'](shift_ppl.view(-1, shift_ppl.size(-1)), shift_obf.view(-1, shift_obf.size(-1)))

            loss = loss_sim * 25 + loss_cls * 25 + loss_ppl * 1
            loss.backward()
            optimizer.step()
            torch.cuda.empty_cache()
            
            if batch % print_every == 0:
                print(f"Epoch Step: {batch} Sim Loss: {loss_sim} Class Loss: {loss_cls} PPL Loss: {loss_ppl}")
                log.write(f"Epoch Step: {batch} Sim Loss: {loss_sim} Class Loss: {loss_cls} PPL Loss: {loss_ppl}\n")

                torch.save(models['enc'].state_dict(), f'./save/enc_model_{train_id}_{batch_size}_epoch_{epoch}.file')

            if batch % (print_every * 5) == 0:
                print(f"\nBegin Evaluation")
                models['enc'].eval()
                total_acc = 0
                total_ppl = 0
                total_sim = 0
                limit=len(val_dataloader)
                if limit>100: 
                    limit=100
                with torch.no_grad():
                    for batch, (src_ids, src_len, tgt) in tqdm(enumerate(val_dataloader),total=limit):
                        src_ids = src_ids.to(device)
                        tgt = tgt.to(device)

                        obf_logits = models['enc'](src_ids)
                        obf_logits = nn.functional.softmax(obf_logits, dim=-1)
                        
                        result = encode_for_similarities(models['st'], device, src_ids, obf_logits)
                        gender_logits = models['cls'](obf_logits)

                        ppl_logits = models['ppl'](obf_logits)
                        shift_ppl = ppl_logits[..., :-1, :].contiguous()
                        shift_obf = obf_logits[..., 1:, :].contiguous()
                        loss_ppl = criterions['ppl'](shift_ppl.view(-1, shift_ppl.size(-1)), shift_obf.view(-1, shift_obf.size(-1)))

                        total_acc += (gender_logits.argmax(1) == tgt).sum().item()
                        total_ppl += loss_ppl
                        total_sim += torch.mean(result)

                        if batch >= limit:
                            break

                acc = total_acc / limit / batch_size
                ppl = math.exp(total_ppl / limit)
                sim = total_sim / limit

                print(f"Validation Accuracy: {acc}\nValidation Perplexity: {ppl}\nValidation Semantic Similarity: {sim}\n")
                log.write(f"Validation Accuracy: {acc}\nValidation Perplexity: {ppl}\nValidation Semantic Similarity: {sim}\n")
                log.flush()


In [None]:
train(
    train_dataloader,
    val_dataloader,
    {
        'enc': encoder_model,
        'cls': classifier_model,
        'st': st_model,
        'ppl': ppl_model,
    },
    {
        'sim': criterion_similarity,
        'cls': criterion_classification,
        'ppl': criterion_perplexity,
    },
    optimizer,
    num_epoch
)

Epoch 0, total 14042 batches

Epoch Step: 0 Sim Loss: 0.8519104719161987 Class Loss: 1.0731971263885498 PPL Loss: 15.138219833374023

Begin Evaluation


100%|█████████████████████████████████████████████████████████████████████████████████| 100/100 [00:27<00:00,  3.64it/s]


Validation Accuracy: 0.4653125
Validation Perplexity: 4227079.571668542
Validation Semantic Similarity: 0.150627002120018



In [None]:
def test(test_dataloader, models, criterions):
    
    models['st'].eval()
    models['ppl'].eval()
    models['cls'].eval()
    models['enc'].eval()
    
    print(f"\nBegin Testing")
    total_acc = 0
    total_ppl = 0
    total_sim = 0
    limit=len(test_dataloader)
    if limit>100: 
        limit=100
    with torch.no_grad():
        for batch, (src_ids, src_len, tgt) in tqdm(enumerate(test_dataloader),total=limit):
            src_ids = src_ids.to(device)
            tgt = tgt.to(device)

            obf_logits = models['enc'](src_ids)
            obf_logits = nn.functional.softmax(obf_logits, dim=-1)

            result = encode_for_similarities(models['st'], device, src_ids, obf_logits)
            gender_logits = models['cls'](obf_logits)

            ppl_logits = models['ppl'](obf_logits)
            shift_ppl = ppl_logits[..., :-1, :].contiguous()
            shift_obf = obf_logits[..., 1:, :].contiguous()
            loss_ppl = criterions['ppl'](shift_ppl.view(-1, shift_ppl.size(-1)), shift_obf.view(-1, shift_obf.size(-1)))

            print(test_dataset.idx2str(src_ids[0].cpu().detach()))
            print(test_dataset.idx2str(obf_logits[0].argmax(-1).cpu().detach()))

            total_acc += (gender_logits.argmax(1) == tgt).sum().item()
            total_ppl += loss_ppl
            total_sim += torch.mean(result)

            if batch >= limit:
                break

    acc = total_acc / limit / batch_size
    ppl = math.exp(total_ppl / limit)
    sim = total_sim / limit

    print(f"Test Accuracy: {acc}\nTest Perplexity: {ppl}\nTest Semantic Similarity: {sim}\n")

In [None]:
test(
    test_dataloader,
    {
        'enc': encoder_model,
        'cls': classifier_model,
        'st': st_model,
        'ppl': ppl_model,
    },
    {
        'sim': criterion_similarity,
        'cls': criterion_classification,
        'ppl': criterion_perplexity,
    }
)