In [1]:
# This file will implement the main training loop for a model
from time import time
import argparse
import os

from torch import device
import torch
from torch import optim
import numpy as np

from disentanglement_final.data_prep import NLIGenData2, OntoGenData, HuggingYelp2, ParaNMTCuratedData
from disentanglement_final.models import DisentanglementTransformerVAE, LaggingDisentanglementTransformerVAE
from disentanglement_final.h_params import DefaultTransformerHParams as HParams
from disentanglement_final.graphs import *
from components.criteria import *
parser = argparse.ArgumentParser()
from torch.nn import MultiheadAttention
# Training and Optimization
k, kz, klstm = 2, 4, 2
parser.add_argument("--test_name", default='unnamed', type=str)
parser.add_argument("--data", default='nli', choices=["nli", "ontonotes", "yelp", 'paranmt'], type=str)
parser.add_argument("--csv_out", default='disentqkv.csv', type=str)
parser.add_argument("--max_len", default=17, type=int)
parser.add_argument("--batch_size", default=128, type=int)
parser.add_argument("--grad_accu", default=1, type=int)
parser.add_argument("--n_epochs", default=20, type=int)
parser.add_argument("--test_freq", default=32, type=int)
parser.add_argument("--complete_test_freq", default=160, type=int)
parser.add_argument("--generation_weight", default=1, type=float)
parser.add_argument("--device", default='cuda:0', choices=["cuda:0", "cuda:1", "cuda:2", "cpu"], type=str)
parser.add_argument("--embedding_dim", default=128, type=int)#################"
parser.add_argument("--pretrained_embeddings", default=False, type=bool)#################"
parser.add_argument("--z_size", default=96*kz, type=int)#################"
parser.add_argument("--z_emb_dim", default=192*k, type=int)#################"
parser.add_argument("--n_keys", default=4, type=int)#################"
parser.add_argument("--n_latents", default=[16, 16, 16], nargs='+', type=int)#################"
parser.add_argument("--text_rep_l", default=3, type=int)
parser.add_argument("--text_rep_h", default=192*k, type=int)
parser.add_argument("--encoder_h", default=192*k, type=int)#################"
parser.add_argument("--encoder_l", default=2, type=int)#################"
parser.add_argument("--decoder_h", default=int(192*k), type=int)
parser.add_argument("--decoder_l", default=2, type=int)#################"
parser.add_argument("--highway", default=False, type=bool)
parser.add_argument("--markovian", default=True, type=bool)
parser.add_argument('--minimal_enc', dest='minimal_enc', action='store_true')
parser.add_argument('--no-minimal_enc', dest='minimal_enc', action='store_false')
parser.set_defaults(minimal_enc=False)
parser.add_argument("--losses", default='VAE', choices=["VAE", "IWAE" "LagVAE"], type=str)
parser.add_argument("--graph", default='Normal', choices=["Vanilla", "IndepInfer", "QKV", "HQKV"], type=str)
parser.add_argument("--training_iw_samples", default=1, type=int)
parser.add_argument("--testing_iw_samples", default=5, type=int)
parser.add_argument("--test_prior_samples", default=10, type=int)
parser.add_argument("--anneal_kl0", default=3000, type=int)
parser.add_argument("--anneal_kl1", default=6000, type=int)
parser.add_argument("--anneal_kl_type", default="linear", choices=["linear", "sigmoid"], type=str)
parser.add_argument("--grad_clip", default=5., type=float)
parser.add_argument("--kl_th", default=0/(768*k/2), type=float or None)
parser.add_argument("--max_elbo1", default=6.0, type=float)
parser.add_argument("--max_elbo2", default=4.0, type=float)
parser.add_argument("--max_elbo_choice", default=10, type=int)
parser.add_argument("--kl_beta", default=0.5, type=float)
parser.add_argument("--kl_beta_zs", default=0.1, type=float)
parser.add_argument("--kl_beta_zg", default=0.5/8, type=float)
parser.add_argument("--dropout", default=0.3, type=float)
parser.add_argument("--word_dropout", default=0.1, type=float)
parser.add_argument("--l2_reg", default=0, type=float)
parser.add_argument("--lr", default=2e-4, type=float)
parser.add_argument("--lr_reduction", default=4., type=float)
parser.add_argument("--wait_epochs", default=1, type=float)
parser.add_argument("--save_all", default=True, type=bool)

flags, _ = parser.parse_known_args()

# Manual Settings, Deactivate before pushing
if True:
    flags.batch_size = 128
    flags.grad_accu = 1
    flags.max_len = 20
    flags.test_name = "nliLM/HQKVParanmtMini5"
    flags.data = "paranmt"
    flags.n_latents = [16]
    flags.n_keys = 16
    flags.graph ="QKV"  # "Vanilla"
    # flags.losses = "LagVAE"
    flags.kl_beta = 0.4
    flags.kl_beta_zg = 0.1
    flags.kl_beta_zs = 0.1
    # flags.anneal_kl0, flags.anneal_kl1 = 3900, 6900
    flags.word_dropout = 0.4
    flags.anneal_kl_type = "sigmoid"

    # flags.anneal_kl0 = 0
    flags.max_elbo_choice = 6
    # flags.z_size = 16
    # flags.encoder_h = 256
    # flags.decoder_h = 256


# torch.autograd.set_detect_anomaly(True)
GRAPH = {"Vanilla": get_vanilla_graph,
         "IndepInfer": get_structured_auto_regressive_indep_graph,
         "QKV": get_qkv_graph,
         "HQKV": get_hqkv_graph}[flags.graph]
if flags.graph == "NormalLSTM":
    flags.encoder_h = int(flags.encoder_h/k*klstm)
if flags.graph == "Vanilla":
    flags.n_latents = [flags.z_size]
if flags.losses == "LagVAE":
    flags.anneal_kl0 = 0
    flags.anneal_kl1 = 0
Data = {"nli": NLIGenData2, "ontonotes": OntoGenData, "yelp": HuggingYelp2, "paranmt": ParaNMTCuratedData}[flags.data]
MAX_LEN = flags.max_len
BATCH_SIZE = flags.batch_size
GRAD_ACCU = flags.grad_accu
N_EPOCHS = flags.n_epochs
TEST_FREQ = flags.test_freq
COMPLETE_TEST_FREQ = flags.complete_test_freq
DEVICE = device(flags.device)
# This prevents illegal memory access on multigpu machines (unresolved issue on torch's github)
if flags.device.startswith('cuda'):
    torch.cuda.set_device(int(flags.device[-1]))
LOSSES = {'IWAE': [IWLBo],
          'VAE': [ELBo],
          'LagVAE': [ELBo]}[flags.losses]

ANNEAL_KL = [flags.anneal_kl0*flags.grad_accu, flags.anneal_kl1*flags.grad_accu]
LOSS_PARAMS = [1]
if flags.grad_accu > 1:
    LOSS_PARAMS = [w/flags.grad_accu for w in LOSS_PARAMS]


data = Data(MAX_LEN, BATCH_SIZE, N_EPOCHS, DEVICE, pretrained=flags.pretrained_embeddings)
h_params = HParams(len(data.vocab.itos), len(data.tags.itos) if flags.data == 'yelp' else None, MAX_LEN, BATCH_SIZE, N_EPOCHS,
                   device=DEVICE, vocab_ignore_index=data.vocab.stoi['<pad>'], decoder_h=flags.decoder_h,
                   decoder_l=flags.decoder_l, encoder_h=flags.encoder_h, encoder_l=flags.encoder_l,
                   text_rep_h=flags.text_rep_h, text_rep_l=flags.text_rep_l,
                   test_name=flags.test_name, grad_accumulation_steps=GRAD_ACCU,
                   optimizer_kwargs={'lr': flags.lr, #'weight_decay': flags.l2_reg, 't0':100, 'lambd':0.},
                                     'weight_decay': flags.l2_reg, 'betas': (0.9, 0.99)},
                   is_weighted=[], graph_generator=GRAPH,
                   z_size=flags.z_size, embedding_dim=flags.embedding_dim, anneal_kl=ANNEAL_KL,
                   grad_clip=flags.grad_clip*flags.grad_accu, kl_th=flags.kl_th, highway=flags.highway,
                   losses=LOSSES, dropout=flags.dropout, training_iw_samples=flags.training_iw_samples,
                   testing_iw_samples=flags.testing_iw_samples, loss_params=LOSS_PARAMS, optimizer=optim.AdamW,
                   markovian=flags.markovian, word_dropout=flags.word_dropout, contiguous_lm=False,
                   test_prior_samples=flags.test_prior_samples, n_latents=flags.n_latents, n_keys=flags.n_keys,
                   max_elbo=[flags.max_elbo_choice, flags.max_elbo1],  # max_elbo is paper's beta
                   z_emb_dim=flags.z_emb_dim, minimal_enc=flags.minimal_enc, kl_beta=flags.kl_beta,
                   kl_beta_zs=flags.kl_beta_zs, kl_beta_zg=flags.kl_beta_zg, anneal_kl_type=flags.anneal_kl_type)
val_iterator = iter(data.val_iter)
print("Words: ", len(data.vocab.itos), ", On device: ", DEVICE.type)
print("Loss Type: ", flags.losses)
if flags.losses == 'LagVAE':
    model = LaggingDisentanglementTransformerVAE(data.vocab, data.tags, h_params, wvs=data.wvs, dataset=flags.data,
                                                 enc_iter=data.enc_train_iter)
else:
    model = DisentanglementTransformerVAE(data.vocab, data.tags, h_params, wvs=data.wvs, dataset=flags.data)
if DEVICE.type == 'cuda':
    model.cuda(DEVICE)

total_unsupervised_train_samples = len(data.train_iter)*BATCH_SIZE
total_unsupervised_val_samples = len(data.val_iter)*(BATCH_SIZE/data.divide_bs)
print("Unsupervised training examples: ", total_unsupervised_train_samples)
print("Unsupervised val examples: ", total_unsupervised_val_samples)
current_time = time()
#print(model)
number_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of parameters: ", "{0:05.2f} M".format(number_parameters/1e6))
number_parameters = sum(p.numel() for p in model.infer_bn.parameters() if p.requires_grad)
print("Inference parameters: ", "{0:05.2f} M".format(number_parameters/1e6))
number_parameters = sum(p.numel() for p in model.gen_bn.parameters() if p.requires_grad)
print("Generation parameters: ", "{0:05.2f} M".format(number_parameters/1e6))
number_parameters = sum(p.numel() for p in model.word_embeddings.parameters() if p.requires_grad)
print("Embedding parameters: ", "{0:05.2f} M".format(number_parameters/1e6))
model.eval()

Using custom data configuration default-b97e0f2e99bae8b9


Reusing dataset csv (C:\Users\ghazy\.cache\huggingface\datasets\csv\default-b97e0f2e99bae8b9\0.0.0\2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0)


Words:  10000 , On device:  cuda
Loss Type:  VAE


reconstruction net size: 09.71 M
prior net sizes:


Loaded model at step 84766
Unsupervised training examples:  493056
Unsupervised val examples:  384.0
Number of parameters:  28.47 M
Inference parameters:  20.04 M
Generation parameters:  09.71 M
Embedding parameters:  01.28 M


In [71]:

from datasets import load_metric
from tqdm import tqdm

bleu_score = load_metric("bleu").compute

def get_paraphrase_bleu(model, iterator):
    with torch.no_grad():
        orig, para, orig_mod, para_mod, rec = [], [], [], [], []
        zs_infer, z_infer, x_gen = model.infer_bn.name_to_v['zs'], \
                                      {'z{}'.format(i+1):model.infer_bn.name_to_v['z{}'.format(i+1)]
                                       for i in range(len(model.h_params.n_latents))}, \
                                   model.gen_bn.name_to_v['x']

        go_symbol = torch.ones((1, 1)).long() * model.index[model.generated_v].stoi['<go>']
        go_symbol = go_symbol.to(model.h_params.device)
        temp = 1.
        for i, batch in enumerate(tqdm(iterator, desc="Getting Model Paraphrase Bleu stats")):
            if batch.text.shape[1] < 2: continue
            # if i > 1: break

            # get source and target sentence latent variable values
            model.infer_bn({'x': batch.text[..., 1:]})
            orig_zs, orig_z = zs_infer.post_params['loc'][..., 0, :], \
                                       {k: v.post_params['loc'][..., 0, :] for k, v in z_infer.items()}
            model.infer_bn({'x': batch.para[..., 1:]})
            para_zs, para_z = zs_infer.post_params['loc'][..., 0, :],\
                                       {k: v.post_params['loc'][..., 0, :] for k, v in z_infer.items()}

            # generate source and target reconstructions with the latent variable swap
            # Inputs: 1) original sentence to be reconstructed,
            #         2) original sentence with the paraphrase's structure
            #         3) paraphrase with the original sentence's content
            z_input = {'zs': torch.cat([orig_zs, orig_zs, para_zs]).unsqueeze(1),
                       **{k: torch.cat([orig_z[k], para_z[k], orig_z[k]]).unsqueeze(1) for k in para_z.keys()}}
            x_prev = go_symbol.repeat((para_zs.shape[0]*3, 1))
            for i in range(model.h_params.max_len):
                model.gen_bn({'x_prev': x_prev, **{k: v.expand(v.shape[0], i + 1, v.shape[-1])
                                                  for k, v in z_input.items()}}, target=x_gen)
                samples_i = model.generated_v.post_params['logits']
                x_prev = torch.cat([x_prev, torch.argmax(samples_i, dim=-1)[..., -1].unsqueeze(-1)],
                                   dim=-1)

            # store original sentences, the 2 resulting "paraphrases", and the reconstruction of the original
            text = model.decode_to_text2(x_prev, model.h_params.vocab_size,
                                         model.index[model.generated_v])
            rec_i, para_mod_i, orig_mod_i = text[:int(len(text)/3)], text[int(len(text)/3):int(len(text)*2/3)], \
                                            text[int(len(text)*2/3):]
            orig_i = model.decode_to_text2(batch.text[..., 1:], model.h_params.vocab_size, 
                                          model.index[model.generated_v])
            para_i = model.decode_to_text2(batch.para[..., 1:], model.h_params.vocab_size,
                                          model.index[model.generated_v])
            orig.extend([[o.split()] for o in orig_i])
            para.extend([[p.split()] for p in para_i])
            orig_mod.extend([o.split() for o in orig_mod_i])
            para_mod.extend([p.split() for p in para_mod_i])
            rec.extend([r.split() for r in rec_i])
        # for o, r, pm, om in zip(orig, rec, para_mod, orig_mod):
        #     print([' '.join(o[0]), '|||',  ' '.join(r), '|||',  ' '.join(pm), '|||',  ' '.join(om)])
        # Calculate the 3 bleu scores
        orig_mod_bleu = bleu_score(predictions=orig_mod, references=para)['bleu']
        para_mod_bleu = bleu_score(predictions=para_mod, references=orig)['bleu']
        rec_bleu = bleu_score(predictions=rec, references=orig)['bleu']

        return orig_mod_bleu, para_mod_bleu, rec_bleu
    
def get_reconstructions(model, sens, beam_size=1, mask_unk=False):
    with torch.no_grad():
        zs_infer, z_infer, x_gen = model.infer_bn.name_to_v['zs'], \
                                      {'z{}'.format(i+1):model.infer_bn.name_to_v['z{}'.format(i+1)]
                                       for i in range(len(model.h_params.n_latents))}, \
                                   model.gen_bn.name_to_v['x']

        go_symbol = torch.ones((1, 1)).long() * model.index[model.generated_v].stoi['<go>']
        go_symbol = go_symbol.to(model.h_params.device)
        
        model.infer_bn({'x': sens})
        orig_zs, orig_z = zs_infer.post_params['loc'][..., 0, :], \
                                   {k: v.post_params['loc'][..., 0, :] for k, v in z_infer.items()}
        
        z_input = {'zs': torch.cat([orig_zs]).unsqueeze(1),
                   **{k: torch.cat([orig_z[k]]).unsqueeze(1) for k in orig_z.keys()}}
        x_prev = go_symbol.repeat((orig_zs.shape[0], 1))
        # for i in range(model.h_params.max_len):
        #     model.gen_bn({'x_prev': x_prev, **{k: v.expand(v.shape[0], i + 1, v.shape[-1])
        #                                       for k, v in z_input.items()}}, target=x_gen)
        #     samples_i = model.generated_v.post_params['logits']
        #     x_prev = torch.cat([x_prev, torch.argmax(samples_i, dim=-1)[..., -1].unsqueeze(-1)],
        #                        dim=-1)
        x_prev = generate_from_z(model, z_input, x_prev,
                                 only_z_sampling=True, temp=1.0, mask_unk=mask_unk,
                                 beam_size=beam_size)
        if beam_size>1:
            x_prev = x_prev[0:int(x_prev.shape[0]/beam_size)]

        rec = model.decode_to_text2(x_prev, model.h_params.vocab_size,
                                     model.index[model.generated_v])
        orig = model.decode_to_text2(sens, model.h_params.vocab_size, 
                                          model.index[model.generated_v])

        return rec, orig
    

def generate_from_z(model, z_input, x_prev, gen_len=None, only_z_sampling=True, temp=1.0, 
                    mask_unk=True):
    unk_mask = torch.ones(x_prev.shape[0], 1, 
                          model.h_params.vocab_size).long().to(model.h_params.device)
    if mask_unk:
        unk_mask[..., model.index[model.generated_v].stoi['<unk>']] = 0

    for i in range(gen_len or model.h_params.max_len):
        model.gen_bn({'x_prev': x_prev, **{k: v.expand(v.shape[0], i + 1, v.shape[-1])
                                          for k, v in z_input.items()}})
        unk_mask_i = unk_mask.expand(unk_mask.shape[0], i + 1, unk_mask.shape[-1])
        if only_z_sampling:
            samples_i = model.generated_v.post_params['logits']
        else:
            samples_i = model.generated_v.posterior(logits=model.generated_v.post_params['logits'],
                                                   temperature=temp).rsample()
        x_prev = torch.cat([x_prev, 
                            torch.argmax(samples_i*unk_mask_i, dim=-1)[..., -1].unsqueeze(-1)],
                           dim=-1)
    return x_prev


In [3]:
sens = next(iter(data.val_iter))
print(sens.text.shape)



torch.Size([128, 20])


In [85]:


def generate_from_z(model, z_input, x_prev, gen_len=None, only_z_sampling=True, temp=1.0, 
                    mask_unk=True, beam_size=1):
    eos_idx = (model.index[model.generated_v].stoi["?"], 
               model.index[model.generated_v].stoi["!"], 
               model.index[model.generated_v].stoi["."],
               model.index[model.generated_v].stoi["<eos>"])
    unk_mask = torch.ones(x_prev.shape[0], 1, 
                          model.h_params.vocab_size).long().to(model.h_params.device)
    if mask_unk:
        unk_mask[..., model.index[model.generated_v].stoi['<unk>']] = 0
    ended = [False]*x_prev.shape[0]
    seq_scores = torch.tensor([[0.0]*x_prev.shape[0]]*beam_size).to(x_prev.device)
    if beam_size > 1:
        z_input = {k: v.unsqueeze(0).expand(beam_size, v.shape[0], 1, v.shape[-1])
                                              for k, v in z_input.items()}
        x_prev = x_prev.unsqueeze(0).expand(beam_size, *x_prev.shape)
        unk_mask = unk_mask.unsqueeze(0).expand(beam_size, *unk_mask.shape)

    for i in range(gen_len or model.h_params.max_len):
        if beam_size == 1:
            z_i = {k: v.expand(v.shape[0], i + 1, v.shape[-1])
                                              for k, v in z_input.items()}
        else:
            z_i = {k: v.expand(beam_size, v.shape[1], i + 1, v.shape[-1])
                                              for k, v in z_input.items()}
        model.gen_bn({'x_prev': x_prev, **z_i})
        unk_mask_i = unk_mask.expand(*unk_mask.shape[:-2], i + 1, unk_mask.shape[-1])
        if only_z_sampling:
            samples_i = model.generated_v.post_params['logits']
        else:
            samples_i = model.generated_v.posterior(logits=model.generated_v.post_params['logits'],
                                                   temperature=temp).rsample()
        if beam_size == 1:
            best_toks = torch.argmax(samples_i*unk_mask_i, dim=-1)
            x_prev = torch.cat([x_prev, 
                                best_toks[..., -1].unsqueeze(-1)],
                               dim=-1)
        else:
            next_xprev = torch.zeros((x_prev.shape[0], 
                                      x_prev.shape[1], x_prev.shape[2]+1)).long().to(x_prev.device)
            for j in range(x_prev.shape[1]):
                if any([idx in eos_idx for idx in x_prev[0, j]]) or ended[j]:
                    next_xprev[:, j] = torch.cat([x_prev[:, j], 
                                                  x_prev[:, j, -1:]*0+eos_idx[-1]], dim=-1)
                    ended[j] = True
                    continue
                if i==0:
                    sample_ij = samples_i[0, j, -1].reshape(-1)*unk_mask_i[0, j, -1].reshape(-1)
                else:
                    sample_ij = (samples_i[:, j, -1]+seq_scores[:, j].unsqueeze(-1)).reshape(-1)\
                                *unk_mask_i[:, j, -1].reshape(-1)
                                
                tk = torch.topk(sample_ij, 
                                k=beam_size, dim=-1)
                vocab_size = model.h_params.vocab_size
                b_idx, w_idx = tk.indices.floor_divide(vocab_size), tk.indices % vocab_size
                seq_scores[:, j] = seq_scores[b_idx, j]+tk.values
                # print(j, "-->", b_idx, w_idx)
                next_xprev[:, j] = torch.cat([x_prev[b_idx, j], w_idx.unsqueeze(-1)], dim=-1)
            x_prev = next_xprev
        # print(seq_scores)
                    
            # best_toks = torch.topk(samples_i*unk_mask_i, k=beam_size, dim=-1).indices
    if beam_size>1:
        x_prev = x_prev.view(x_prev.shape[0]*x_prev.shape[1], x_prev.shape[2])
    return x_prev
start, end, bs =40, 45, 5 
rec1, orig = get_reconstructions(model, sens.text[start:end], beam_size=1, mask_unk=False)
rec2, orig = get_reconstructions(model, sens.text[start:end], beam_size=5, mask_unk=False)
rec3, orig = get_reconstructions(model, sens.text[start:end], beam_size=100, mask_unk=False)
print("=========================")
for i in range(len(orig)):
    print(orig[i], '===>', '->'.join([rec1[i], rec2[i], rec3[i]]))
    # print("--------------------------")

 you ' re afraid you ' re gon na get caught  ===>  i ' m not gon ' re gon na get out -> i ' m gon na do n ' t get any chance -> i ' m gon na i ' m gon na get out 
 our theme today will be mental health  ===>  the <unk> <unk> , will be to <unk> -> the <unk> of <unk> will be to change -> the <unk> of <unk> will be to change 
 it is not gon na be very easy  ===>  it ' s not gon na be easy easy -> it ' s not gon na be easy easy -> it ' s not gon na be easy easy 
 looks like you got our <unk> on your religion  ===>  i ' m not going to your own your ass -> i ' m gon na make your own your ass -> i ' m gon na make your own your ass 
 did she open the door and leave no prints here  ===>  what ' s the <unk> of the <unk> to leave -> you ' ve got the <unk> on here to the doors -> you ' ve got the <unk> on here to the doors 


In [13]:

a = torch.Tensor([[[1, 2, 3], [4, 5, 6]],
                 [[0, 4, -5], [3, 9, 1]]])
print(a%3)
print(5/3, 5//3)

tensor([[[1., 2., 0.],
         [1., 2., 0.]],

        [[0., 1., 1.],
         [0., 0., 1.]]])
1.6666666666666667 1
