In [1]:
def memory_usage():
    process = psutil.Process()
    mem_info = process.memory_info()
    return mem_info.rss / (1024 ** 2)  # Memory usage in MB

def available_memory():
    mem = psutil.virtual_memory()
    return mem.available / (1024 ** 2)  # Available memory in MB

In [2]:
import psutil
print(f"Available memory before loading: {available_memory()} MB")
print(f"Current memory usage before loading: {memory_usage()} MB")

import sys
import os

import argparse
import json
import random
import shutil
import copy
from tokenizers import ByteLevelBPETokenizer

import torch
from torch import cuda
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter

import torch.nn.functional as F
import numpy as np
import time
import logging
from data import Dataset
from tokenized_models import RNNG
from utils import *


print(f"Available memory after loading: {available_memory()} MB")
print(f"Current memory usage after loading: {memory_usage()} MB")

Available memory before loading: 1004754.15234375 MB
Current memory usage before loading: 57.9296875 MB
Available memory after loading: 1004668.5 MB
Current memory usage after loading: 228.2578125 MB


In [3]:
def eval(data, model, samples = 0, count_eos_ppl = 0):
    model.eval()
    num_sents = 0
    num_words = 0
    total_nll_recon = 0.
    total_kl = 0.
    total_nll_iwae = 0.
    corpus_f1 = [0., 0., 0.]
    sent_f1 = [] 
    with torch.no_grad():
        for i in list(reversed(range(len(data)))):
            sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = data[i] 
            if length == 1: # length 1 sents are ignored since URNNG needs at least length 2 sents
                continue
            if count_eos_ppl == 1:
                tree_length = length
                length += 1 
            else:
                sents = sents[:, :-1] 
                tree_length = length
            sents = sents.cuda()
            print("Passing through the EVAL forward method")
            ll_word_all, ll_action_p_all, ll_action_q_all, actions_all, q_entropy = model(sents, 
                        samples = samples, has_eos = count_eos_ppl == 1)
            print("Made it through the EVAL forward method!")
            print()
            ll_word, ll_action_p, ll_action_q = ll_word_all.mean(1), ll_action_p_all.mean(1), ll_action_q_all.mean(1)
            kl = ll_action_q - ll_action_p
            _, binary_matrix, argmax_spans = model.q_crf._viterbi(model.scores)
            actions = []
            for b in range(batch_size):    
                tree = get_tree_from_binary_matrix(binary_matrix[b], tree_length)
                actions.append(get_actions(tree))
            actions = torch.Tensor(actions).long()
            total_nll_recon += -ll_word.sum().item()
            total_kl += kl.sum().item()
            num_sents += batch_size
            num_words += batch_size * length
            if samples > 0:
                #PPL estimate based on IWAE
                sample_ll = torch.zeros(batch_size, samples)
                for j in range(samples):
                    ll_word_j, ll_action_p_j, ll_action_q_j = ll_word_all[:, j], ll_action_p_all[:, j], ll_action_q_all[:, j]
                    sample_ll[:, j].copy_(ll_word_j + ll_action_p_j - ll_action_q_j)
                ll_iwae = model.logsumexp(sample_ll, 1) - np.log(samples)
                total_nll_iwae -= ll_iwae.sum().item()      
            for b in range(batch_size):
                action = list(actions[b].numpy())
                span_b = get_spans(action)
                span_b = argmax_spans[b]
                span_b_set = set(span_b[:-1])        
                gold_b_set = set(gold_spans[b][:-1])
                tp, fp, fn = get_stats(span_b_set, gold_b_set) 
                corpus_f1[0] += tp
                corpus_f1[1] += fp
                corpus_f1[2] += fn

                # sent-level F1 is based on L83-89 from https://github.com/yikangshen/PRPN/test_phrase_grammar.py
                model_out = span_b_set
                std_out = gold_b_set
                overlap = model_out.intersection(std_out)
                prec = float(len(overlap)) / (len(model_out) + 1e-8)
                reca = float(len(overlap)) / (len(std_out) + 1e-8)
                if len(std_out) == 0:
                    reca = 1. 
                    if len(model_out) == 0:
                        prec = 1.
                f1 = 2 * prec * reca / (prec + reca + 1e-8)
                sent_f1.append(f1)
    tp, fp, fn = corpus_f1  
    prec = tp / (tp + fp)
    recall = tp / (tp + fn)
    corpus_f1 = 2*prec*recall/(prec+recall)*100 if prec+recall > 0 else 0.
    sent_f1 = np.mean(np.array(sent_f1))*100

    elbo_ppl = np.exp((total_nll_recon + total_kl) / num_words)
    recon_ppl = np.exp(total_nll_recon / num_words)
    iwae_ppl = np.exp(total_nll_iwae /num_words)
    kl = total_kl / num_sents  
    print('ElboPPL: %.2f, ReconPPL: %.2f, KL: %.4f, IwaePPL: %.2f, CorpusF1: %.2f, SentAvgF1: %.2f' % 
          (elbo_ppl, recon_ppl, kl, iwae_ppl, corpus_f1, sent_f1))
    #note that corpus F1 printed here is different from what you should get from
    #evalb since we do not ignore any tags (e.g. punctuation), while evalb ignores it
    model.train()
    return iwae_ppl, corpus_f1

In [4]:
np.random.seed(3435)
torch.manual_seed(3435)

print(f"Available memory before loading: {available_memory()} MB")
print(f"Current memory usage before loading: {memory_usage()} MB")

# train_data = Dataset("babylm_data/tokenized/babylm_final_dataset-train.pkl")
# val_data = Dataset("babylm_data/tokenized/babylm_final_dataset-val.pkl")
# test_data = Dataset("babylm_data/tokenized/babylm_final_dataset-test.pkl")

train_data = Dataset("bllip_data/bllip-train.pkl")
val_data = Dataset("bllip_data/bllip-val.pkl")

# train_data = Dataset("data/tokenized_data/ptb-train.pkl")
# val_data = Dataset("data/tokenized_data/ptb-val.pkl")

print(f"Available memory after loading: {available_memory()} MB")
print(f"Current memory usage after loading: {memory_usage()} MB")

vocab_size = int(train_data.vocab_size) 
# vocab_size = int(og_data.vocab_size) # For comparison, comment out if using the tokenized data
print('Train: %d sents / %d batches, Val: %d sents / %d batches' % 
      (train_data.sents.size(0), len(train_data), val_data.sents.size(0), 
       len(val_data)))
print('Vocab size: %d' % vocab_size)
cuda.set_device(0)
count_eos_ppl = 0

Available memory before loading: 1004668.5078125 MB
Current memory usage before loading: 228.51171875 MB
Available memory after loading: 990897.98046875 MB
Current memory usage after loading: 13946.01171875 MB
Train: 1327870 sents / 83063 batches, Val: 165984 sents / 10446 batches
Vocab size: 12000


In [5]:
# test_data[74900][0].shape

In [12]:
val_data[10445][0].shape

torch.Size([1, 174])

In [18]:
# val_data[2644]

In [13]:
val_data[364][1]

153

In [9]:
model = RNNG(vocab = vocab_size,
             w_dim = 650, 
             h_dim = 650,
             dropout = 0.5,
             num_layers = 2,
             q_dim = 256)
if 0.1 > 0:
    for param in model.parameters():    
        param.data.uniform_(-0.1, 0.1)

In [10]:
print("model architecture")
print(model)
q_params = []
action_params = []
model_params = []
for name, param in model.named_parameters():    
    if 'action' in name:
        print(name)
        action_params.append(param)
    elif 'q_' in name:
        print(name)
        q_params.append(param)
    else:
        model_params.append(param)
q_lr = 0.0001
lr = 1
action_lr = 0.1
optimizer = torch.optim.SGD(model_params, lr=lr)
q_optimizer = torch.optim.Adam(q_params, lr=q_lr)
action_optimizer = torch.optim.SGD(action_params, lr=0.1)
model.train()
model.cuda()

model architecture
RNNG(
  (emb): Embedding(12000, 650)
  (dropout): Dropout(p=0.5, inplace=False)
  (stack_rnn): SeqLSTM(
    (linears): ModuleList(
      (0): Linear(in_features=1300, out_features=2600, bias=True)
      (1): Linear(in_features=1300, out_features=2600, bias=True)
    )
    (dropout_layer): Dropout(p=0.5, inplace=False)
  )
  (tree_rnn): TreeLSTM(
    (linear): Linear(in_features=1300, out_features=3250, bias=True)
  )
  (vocab_mlp): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=650, out_features=12000, bias=True)
  )
  (q_binary): Sequential(
    (0): Linear(in_features=512, out_features=512, bias=True)
    (1): ReLU()
    (2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=512, out_features=1, bias=True)
  )
  (action_mlp_p): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=650, out_features=1, bias=True)
  )
  (q_leaf_rnn): LSTM(65

RNNG(
  (emb): Embedding(12000, 650)
  (dropout): Dropout(p=0.5, inplace=False)
  (stack_rnn): SeqLSTM(
    (linears): ModuleList(
      (0): Linear(in_features=1300, out_features=2600, bias=True)
      (1): Linear(in_features=1300, out_features=2600, bias=True)
    )
    (dropout_layer): Dropout(p=0.5, inplace=False)
  )
  (tree_rnn): TreeLSTM(
    (linear): Linear(in_features=1300, out_features=3250, bias=True)
  )
  (vocab_mlp): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=650, out_features=12000, bias=True)
  )
  (q_binary): Sequential(
    (0): Linear(in_features=512, out_features=512, bias=True)
    (1): ReLU()
    (2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=512, out_features=1, bias=True)
  )
  (action_mlp_p): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=650, out_features=1, bias=True)
  )
  (q_leaf_rnn): LSTM(650, 256, batch_first

In [11]:
epoch = 0
decay= 0
kl_warmup = 2
if kl_warmup > 0:
    kl_pen = 0.
    kl_warmup_batch = 1./(kl_warmup * len(train_data))
else:
    kl_pen = 1.
    
best_val_ppl = 5e5
best_val_f1 = 0
samples = 8
mc_samples = 5
mode = "supervised"
num_epochs = 18
train_q_epochs = 2
kl_warmup = 2
max_grad_norm = 5
q_max_grad_norm = 1
print_every = 500
mc_samples = 5
min_epochs = 8
save_path = "debug_train.pt"

In [12]:
# best_val_ppl, best_val_f1 = eval(val_data, model, samples = mc_samples, 
#                                    count_eos_ppl = count_eos_ppl)

In [23]:
def pad_actions_lists(sequences):
    max_seq = max(len(l) for l in sequences)
    padded_sequences = []
    masks = []
    for seq in sequences:
        padded_sequence = seq + [-1]*(max_seq - len(seq))
        mask = [1]*len(seq) + [0]*(max_seq - len(seq))
        padded_sequences.append(padded_sequence)
        masks.append(mask)
    return torch.tensor(padded_sequences), torch.tensor(masks)

In [11]:
tokenizer = ByteLevelBPETokenizer("tokenizers/rnng/vocab.json", "tokenizers/rnng/merges.txt")

In [11]:
trained_rnng = torch.load("saved_models/tokenized_rnng.pt")

In [85]:
trained_rnng = trained_rnng['model'].cuda()

In [109]:
sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = train_data[100]

In [12]:
def wordify(sent):
    solo_tokens = ["<s>", "</s", "<unk>", "<pad>", "<mask>", "'s", "'d",
                   "'re", "$", "%", "*", ":", "'", "--", "`", ";", "&"]
    sentence = []
    new_word = ""
    for i in range(len(sent)): 
        token = train_data.idx2word[sent[i].item()]
#         print(token)
        if token == "<s>":
            sentence.append(token)
            continue
        elif token == "</s>":
            sentence.append(new_word)
            sentence.append(token)
            break

        if token[0] == "Ġ":
            if new_word != "":
                sentence.append(new_word)
            new_word = token
        elif token not in solo_tokens:
            new_word = new_word + token
        elif token in solo_tokens:
            if new_word != "":
                sentence.append(new_word)
            new_word = token
    return sentence

In [13]:
all_wrongs = []
for j in range(len(train_data)):
    wrong = []
    sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = train_data[j]
    for i in range(len(sents)):
        sent = sents[i]
#         print(i, ": Length of sentence: ", len(sent), "\tLength of actions: ", len(gold_binary_trees[i]))
        if (((len(sent)-2))*2) - 1 != len(gold_binary_trees[i]):
            wrong.append(i)
    all_wrongs.append(wrong)

In [13]:
all_stats = [[0., 0., 0.]] #true pos, false pos, false neg for f1 calc

while epoch < num_epochs:
    start_time = time.time()
    epoch += 1  
    if epoch > train_q_epochs:
        #stop training q after this many epochs
        q_lr = 0.
        for param_group in q_optimizer.param_groups:
            param_group['lr'] = q_lr
    print('Starting epoch %d' % epoch)
    train_nll_recon = 0.
    train_nll_iwae = 0.
    train_kl = 0.
    train_q_entropy = 0.
    num_sents = 0.
    num_words = 0.
    b = 0
    
    for i in np.random.permutation(len(train_data)):
#     for i in list(reversed(range(len(train_data)))):
        if kl_warmup > 0:
            kl_pen = min(1., kl_pen + kl_warmup_batch)
        sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = train_data[i]

        # This is the data processed at the word level, not the token level. Using it for comparison
        
        if length == 1:
            continue
        sents = sents.cuda()
        b += 1
        q_optimizer.zero_grad()
        optimizer.zero_grad()
        action_optimizer.zero_grad()
        if mode == 'unsupervised':
            ll_word, ll_action_p, ll_action_q, all_actions, q_entropy = model(sents, samples=samples, 
                                                                              has_eos = True)
            print(f"Batch: {b}\n")
            log_f = ll_word + kl_pen*ll_action_p
            iwae_ll = log_f.mean(1).detach() + kl_pen*q_entropy.detach()
            obj = log_f.mean(1)
            if epoch < train_q_epochs:
                obj += kl_pen*q_entropy
                baseline = torch.zeros_like(log_f)
                baseline_k = torch.zeros_like(log_f)
                for k in range(samples):
                    baseline_k.copy_(log_f)
                    baseline_k[:, k].fill_(0)
                    baseline[:, k] =  baseline_k.detach().sum(1) / (samples - 1)        
                obj += ((log_f.detach() - baseline.detach())*ll_action_q).mean(1)
            kl = (ll_action_q - ll_action_p).mean(1).detach()
            ll_word = ll_word.mean(1)
            train_q_entropy += q_entropy.sum().item()
        else:
            gold_actions = gold_binary_trees
            
            # Testing against the data that works
#             sents = comparison[0]
#             sents = sents.cuda()
#             gold_actions = comparison[5] # This is og data gold_binary_tress
            
            ll_action_q = model.forward_tree(sents, gold_actions, has_eos=True)        
            ll_word, ll_action_p, all_actions = model.forward_actions(sents, gold_actions)
            print("Made it through both forward methods!!")
            
            obj = ll_word + ll_action_p + ll_action_q
            kl = -ll_action_q
            iwae_ll = ll_word + ll_action_p
            print("Batch: ", b)
        
        train_nll_iwae += -iwae_ll.sum().item()
        actions = all_actions[:, 0].long().cpu()
        train_nll_recon += -ll_word.sum().item()
        train_kl += kl.sum().item()
        (-obj.mean()).backward()
        
        if max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model_params + action_params, max_grad_norm)        
        if q_max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(q_params, q_max_grad_norm)        
        q_optimizer.step()
        optimizer.step()
        action_optimizer.step()
        num_sents += batch_size
        num_words += batch_size * length
        for bb in range(batch_size):
            action = list(actions[bb].numpy())
            span_b = get_spans(action)
            span_b_set = set(span_b[:-1]) #ignore the sentence-level trivial span
            update_stats(span_b_set, [set(gold_spans[bb][:-1])], all_stats)
        if b % print_every == 0:
            all_f1 = get_f1(all_stats)
            param_norm = sum([p.norm()**2 for p in model.parameters()]).item()**0.5
            log_str = 'Epoch: %d, Batch: %d/%d, LR: %.4f, qLR: %.5f, qEnt: %.4f, TrainVAEPPL: %.2f, ' + \
                  'TrainReconPPL: %.2f, TrainKL: %.2f, TrainIWAEPPL: %.2f, ' + \
                  '|Param|: %.2f, BestValPerf: %.2f, BestValF1: %.2f, KLPen: %.4f, ' + \
                  'GoldTreeF1: %.2f, Throughput: %.2f examples/sec'
            print(log_str %
                  (epoch, b, len(train_data), lr, q_lr, train_q_entropy / num_sents, 
                   np.exp((train_nll_recon + train_kl)/ num_words),
                   np.exp(train_nll_recon/num_words), train_kl / num_sents, 
                   np.exp(train_nll_iwae/num_words),
                   param_norm, best_val_ppl, best_val_f1, kl_pen, 
                   all_f1[0], num_sents / (time.time() - start_time)))
            sent_str = [train_data.idx2word[word_idx] for word_idx in list(sents[-1][1:-1].cpu().numpy())]
            print("PRED:", get_tree(action[:-2], sent_str))
            print("GOLD:", get_tree(gold_binary_trees[-1], sent_str))
        
    print('--------------------------------')
    print('Checking validation perf...')    
    val_ppl, val_f1 = eval(val_data, model, 
                           samples = mc_samples, count_eos_ppl = count_eos_ppl)
#         log_progress(epoch=epoch, criterion=val_ppl, lr=lr)
    print('--------------------------------')
    if val_ppl < best_val_ppl:
        best_val_ppl = val_ppl
        best_val_f1 = val_f1
        checkpoint = {
#                 'args': args.__dict__,
            'model': model.cpu(),
            'word2idx': train_data.word2idx,
            'idx2word': train_data.idx2word
        }
        print('Saving checkpoint to %s' % save_path)
        torch.save(checkpoint, save_path)
        model.cuda()
    else:
        if epoch > min_epochs:
            decay = 1
    if decay == 1:
        lr = decay*lr
        q_lr = decay*q_lr
        action_lr = decay*action_lr
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        for param_group in q_optimizer.param_groups:
            param_group['lr'] = q_lr
        for param_group in action_optimizer.param_groups:
            param_group['lr'] = action_lr
    if lr < 0.03:
        break
print("Finished training!")

Starting epoch 1
Made it through both forward methods!!
Batch:  1
Made it through both forward methods!!
Batch:  2
Made it through both forward methods!!
Batch:  3
Made it through both forward methods!!
Batch:  4
Made it through both forward methods!!
Batch:  5
Made it through both forward methods!!
Batch:  6
Made it through both forward methods!!
Batch:  7


KeyboardInterrupt: 