In [1]:
%load_ext autoreload
%autoreload 2

In [22]:
import os
import json
import torch
import pickle
import random
from tqdm import tqdm
from torch import nn, optim
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, util

import load_data
from load_data import GenderDataset, gender_data_collate_fn
from models.encoder_t5 import EncoderT5
from models.classifier_bert import ClassifierBERT
from models.similarity_sent_enc import encode_for_similarities

In [3]:
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')

print(device)

cuda


In [20]:
# Hyper parameters
batch_size = 8
lr = 1e-4
wd = 5e-6
print_every = 200

In [5]:
with open(os.path.join(os.curdir, "data", "blog.json"), "r") as file:
    json_data = json.load(file)
docs = json_data['docs'][1:] # I don't want to see the first document

In [6]:
num_docs = len(docs)
num_train_docs = int(num_docs * 0.7)
num_val_docs = int(num_docs * 0.15)
num_test_docs = num_docs - num_train_docs - num_val_docs
print(num_train_docs, num_val_docs, num_test_docs)

13773 2951 2952


In [7]:
train_docs = docs[:num_train_docs]
val_docs = docs[num_train_docs:num_train_docs+num_val_docs]
test_docs = docs[num_train_docs+num_val_docs:]

In [11]:
train_dataset, val_dataset, test_dataset = None, None, None
load_from_pickled = False

if os.path.exists(os.path.join(os.curdir, "data", "train.pickle")):
    load_from_pickled = True
    with open(os.path.join(os.curdir, "data", "train.pickle"), "rb") as f:
        train_dataset = pickle.load(f)
    with open(os.path.join(os.curdir, "data", "val.pickle"), "rb") as f:
        val_dataset = pickle.load(f)
    with open(os.path.join(os.curdir, "data", "test.pickle"), "rb") as f:
        test_dataset = pickle.load(f)
else:
    train_dataset = GenderDataset(train_docs)
    val_dataset = GenderDataset(val_docs)
    test_dataset = GenderDataset(test_docs)

if not load_from_pickled:
    with open(os.path.join(os.curdir, "data", "train.pickle"), "wb") as f:
        pickle.dump(train_dataset, f)
    with open(os.path.join(os.curdir, "data", "val.pickle"), "wb") as f:
        pickle.dump(val_dataset, f)
    with open(os.path.join(os.curdir, "data", "test.pickle"), "wb") as f:
        pickle.dump(test_dataset, f)

print(load_from_pickled)   

Cutting documents into paragraphs of length 128...


  0%|          | 0/13773 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1154 > 512). Running this sequence through the model will result in indexing errors
100%|██████████| 13773/13773 [02:48<00:00, 81.55it/s] 


Number of documents: 449320
Discarded ratio (due to MIN_LENGTH): 0.165
Number of unique words before converting to [UNK]:  27221
Converting words with frequencies less than 10 to [UNK]...


100%|██████████| 449320/449320 [00:39<00:00, 11304.71it/s]


Number of unique words after converting [UNK]:  25086
Known occurrences rate 99.98%
Cutting documents into paragraphs of length 128...


100%|██████████| 2951/2951 [00:33<00:00, 89.10it/s] 


Number of documents: 97159
Discarded ratio (due to MIN_LENGTH): 0.151
Number of unique words before converting to [UNK]:  26020
Converting words with frequencies less than 10 to [UNK]...


100%|██████████| 97159/97159 [00:07<00:00, 12872.15it/s]


Number of unique words after converting [UNK]:  19062
Known occurrences rate 99.72%
Cutting documents into paragraphs of length 128...


100%|██████████| 2952/2952 [00:24<00:00, 120.52it/s]


Number of documents: 71898
Discarded ratio (due to MIN_LENGTH): 0.164
Number of unique words before converting to [UNK]:  25551
Converting words with frequencies less than 10 to [UNK]...


100%|██████████| 71898/71898 [00:06<00:00, 11405.14it/s]


Number of unique words after converting [UNK]:  17366
Known occurrences rate 99.57%
False


In [12]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=gender_data_collate_fn
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=gender_data_collate_fn
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=gender_data_collate_fn
)

In [10]:
next(enumerate(test_dataloader))

(0,
 (tensor([[3305, 2023, 1012,  ..., 1998, 6389, 2126],
          [2061, 3376, 1998,  ..., 1996, 2048, 2060],
          [2041, 3782, 7632,  ..., 2000, 2131, 2070],
          ...,
          [2043, 2016, 2001,  ..., 2033, 1012, 5206],
          [4532, 1045, 6293,  ..., 2061, 2919, 1010],
          [2009, 2357, 2041,  ..., 4123, 2052, 2566]], dtype=torch.int32),
  tensor([128, 128, 128, 128, 128, 128, 128, 128], dtype=torch.int32),
  tensor([0, 1, 0, 0, 0, 1, 1, 1])))

In [None]:
encoder_model = EncoderT5(
    vocab_size=load_data.tokenizer.vocab_size
).to(device)

classifier_model = ClassifierBERT(
    vocab_size=load_data.tokenizer.vocab_size
)

st_model = SentenceTransformer('all-mpnet-base-v2')
st_model[0].auto_model=st_model[0].auto_model.to(device)
st_model[0].auto_model.eval()

In [17]:
optimizer = optim.AdamW(list(encoder_model.parameters()) + list(classifier_model.parameters()), lr=lr, weight_decay=wd)

In [18]:
criterion_similarity = nn.MSELoss()
criterion_classification = nn.CrossEntropyLoss()
criterion_perplexity = nn.L1Loss() # Placeholder

In [None]:
def train(train_dataloader, val_dataloader, models, criterions, optimizer, num_epoch):
    
    train_id = random.randint(0, 1000)
    log = open(f'./save/{train_id}.txt', 'w')
    models['st'].eval()
    models['ppl'].eval()

    for epoch in range(num_epoch):
        
        print(f"Epoch {epoch}, total {len(train_dataloader)} batches\n")
        log.write(f"Epoch {epoch}, total {len(train_dataloader)} batches\n")
        log.flush()

        models['enc'].train()
        models['cls'].train()

        for batch, (src_ids, src_len, tgt) in enumerate(train_dataloader):

            torch.cuda.empty_cache()
            optimizer.zero_grad()
            src_ids = src_ids.to(device)
            tgt = tgt.to(device)

            # Calculate the obfuscated logits via the encoder
            obf_logits = models['enc'](src_ids)
            
            src_encode, obf_encode = encode_for_similarities(models['st'], src_ids, obf_logits)
            loss_sim = criterions['sim'](src_encode, obf_encode)

            gender_logits = models['cls'](src_ids)
            loss_cls = criterions['cls'](gender_logits, tgt)

            loss_ppl = criterions['ppl'](gender_logits, tgt) # Placeholder

            loss = loss_sim + loss_cls + loss_ppl
            loss.backward()
            optimizer.step()
            torch.cuda.empty_cache()
        
        print(f"\nBegin Evaluation")
        models['enc'].eval()
        models['cls'].eval()
        total_acc = 0
        total_ppl = 0
        total_sim = 0
        limit=len(val_dataloader)
        #if limit>100 and batch: limit=100
        with torch.no_grad():
            for batch, (src_ids, src_len, tgt) in tqdm(enumerate(val_dataloader),total=limit):
                src_ids = src_ids.to(device)
                tgt = tgt.to(device)

                obf_logits = models['enc'](src_ids)
                src_encode, obf_encode = encode_for_similarities(models['st'], src_ids, obf_logits)
                gender_logits = models['cls'](src_ids)

                total_acc += (gender_logits.argmax(1) == tgt).sum.item()
                total_ppl += 1 # placeholder
                total_sim += torch.sum(torch.mean(src_encode * obf_encode, dim=1))

                if batch >= limit:
                    break
        
        acc = total_acc / limit / batch_size
        ppl = total_ppl / limit / batch_size
        sim = total_sim / limit / batch_size

        print(f"Validation Accuracy: {acc}\nValidation Perplexity: {ppl}\nValidation Semantic Similarity: {sim}\n")
        log.write(f"Validation Accuracy: {acc}\nValidation Perplexity: {ppl}\nValidation Semantic Similarity: {sim}\n")
        log.flush()

        torch.save(models['cls'].state_dict(), f'./save/classifier_model_{train_id}_{batch_size}_epoch_{epoch}.file')
        torch.save(models['enc'].state_dict(), f'./save/encoder_model_{train_id}_{batch_size}_epoch_{epoch}.file')
