In [8]:
# This file will implement the main training loop for a model
from time import time
import sys
import os
import argparse
print(os.getcwd())
sys.path.append('../..')
os.chdir('../..')

from torch import device
import torch
from torch import optim
import numpy as np

from data_prep import NLIGenData2 as Data
from disentanglement_transformer.models import DisentanglementTransformerVAE as Model
from disentanglement_transformer.h_params import DefaultTransformerHParams as HParams
from disentanglement_transformer.graphs import *
from components.criteria import *
parser = argparse.ArgumentParser()

# Training and Optimization
parser.add_argument("--test_name", default='nliLM/StructuredAutoreg5', type=str)
parser.add_argument("--max_len", default=17, type=int)
parser.add_argument("--batch_size", default=128, type=int)
parser.add_argument("--grad_accu", default=1, type=int)
parser.add_argument("--n_epochs", default=10000, type=int)
parser.add_argument("--test_freq", default=32, type=int)
parser.add_argument("--complete_test_freq", default=160, type=int)
parser.add_argument("--generation_weight", default=1, type=float)
parser.add_argument("--device", default='cuda:0', choices=["cuda:0", "cuda:1", "cuda:2", "cpu"], type=str)
parser.add_argument("--embedding_dim", default=128, type=int)#################"
parser.add_argument("--z_size", default=768, type=int)#################"
parser.add_argument("--n_latents", default=16, type=int)#################"
parser.add_argument("--text_rep_l", default=2, type=int)
parser.add_argument("--text_rep_h", default=768, type=int)
parser.add_argument("--encoder_h", default=768, type=int)#################"
parser.add_argument("--encoder_l", default=2, type=int)#################"
parser.add_argument("--decoder_h", default=768, type=int)
parser.add_argument("--decoder_l", default=3, type=int)#################"
parser.add_argument("--highway", default=False, type=bool)
parser.add_argument("--markovian", default=True, type=bool)
parser.add_argument("--losses", default='VAE', choices=["VAE", "IWAE"], type=str)
parser.add_argument("--training_iw_samples", default=5, type=int)
parser.add_argument("--testing_iw_samples", default=20, type=int)
parser.add_argument("--test_prior_samples", default=10, type=int)
parser.add_argument("--anneal_kl0", default=2000, type=int)
parser.add_argument("--anneal_kl1", default=4000, type=int)
parser.add_argument("--grad_clip", default=100., type=float)
parser.add_argument("--kl_th", default=0/(768*3), type=float or None)
parser.add_argument("--dropout", default=0.0, type=float)
parser.add_argument("--word_dropout", default=.0, type=float)
parser.add_argument("--l2_reg", default=0, type=float)
parser.add_argument("--lr", default=2e-4, type=float)
parser.add_argument("--lr_reduction", default=4., type=float)
parser.add_argument("--wait_epochs", default=3, type=float)
parser.add_argument("--save_all", default=True, type=bool)

flags, _ = parser.parse_known_args()

# torch.autograd.set_detect_anomaly(True)
MAX_LEN = flags.max_len
BATCH_SIZE = flags.batch_size
GRAD_ACCU = flags.grad_accu
N_EPOCHS = flags.n_epochs
TEST_FREQ = flags.test_freq
COMPLETE_TEST_FREQ = flags.complete_test_freq
DEVICE = device(flags.device)
# This prevents illegal memory access on multigpu machines (unresolved issue on torch's github)
if flags.device.startswith('cuda'):
    torch.cuda.set_device(int(flags.device[-1]))
LOSSES = {'IWAE': [IWLBo],
          'VAE': [ELBo]}[flags.losses]
#  LOSSES = [IWLBo]
ANNEAL_KL = [flags.anneal_kl0*flags.grad_accu, flags.anneal_kl1*flags.grad_accu]
LOSS_PARAMS = [1]
if flags.grad_accu > 1:
    LOSS_PARAMS = [w/flags.grad_accu for w in LOSS_PARAMS]

data = Data(MAX_LEN, BATCH_SIZE, N_EPOCHS, DEVICE)
h_params = HParams(len(data.vocab.itos), len(data.tags.itos), MAX_LEN, BATCH_SIZE, N_EPOCHS,
                   device=DEVICE, vocab_ignore_index=data.vocab.stoi['<pad>'], decoder_h=flags.decoder_h,
                   decoder_l=flags.decoder_l, encoder_h=flags.encoder_h, encoder_l=flags.encoder_l,
                   text_rep_h=flags.text_rep_h, text_rep_l=flags.text_rep_l,
                   test_name=flags.test_name, grad_accumulation_steps=GRAD_ACCU,
                   optimizer_kwargs={'lr': flags.lr, #'weight_decay': flags.l2_reg, 't0':100, 'lambd':0.},
                                     'weight_decay': flags.l2_reg, 'betas': (0.9, 0.85)},
                   is_weighted=[], graph_generator=get_structured_auto_regressive_disentanglement_graph,
                   z_size=flags.z_size, embedding_dim=flags.embedding_dim, anneal_kl=ANNEAL_KL,
                   grad_clip=flags.grad_clip*flags.grad_accu, kl_th=flags.kl_th, highway=flags.highway,
                   losses=LOSSES, dropout=flags.dropout, training_iw_samples=flags.training_iw_samples,
                   testing_iw_samples=flags.testing_iw_samples, loss_params=LOSS_PARAMS, optimizer=optim.AdamW,
                   markovian=flags.markovian, word_dropout=flags.word_dropout, contiguous_lm=False,
                   test_prior_samples=flags.test_prior_samples, n_latents=flags.n_latents)
val_iterator = iter(data.val_iter)
print("Words: ", len(data.vocab.itos), ", On device: ", DEVICE.type)
print("Loss Type: ", flags.losses)
model = Model(data.vocab, data.tags, h_params, wvs=data.wvs)
if DEVICE.type == 'cuda':
    model.cuda(DEVICE)

total_unsupervised_train_samples = len(data.train_iter)*BATCH_SIZE
print("Unsupervised training examples: ", total_unsupervised_train_samples)
current_time = time()
#print(model)
number_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of parameters: ", "{0:05.2f} M".format(number_parameters/1e6))
number_parameters = sum(p.numel() for p in model.infer_bn.parameters() if p.requires_grad)
print("Inference parameters: ", "{0:05.2f} M".format(number_parameters/1e6))
number_parameters = sum(p.numel() for p in model.gen_bn.parameters() if p.requires_grad)
print("Generation parameters: ", "{0:05.2f} M".format(number_parameters/1e6))
number_parameters = sum(p.numel() for p in model.word_embeddings.parameters() if p.requires_grad)
print("Embedding parameters: ", "{0:05.2f} M".format(number_parameters/1e6))


E:\Experiments\GLUE_BENCH\tb_logs\nliLM




Mean length: 

 

8.900745464443235

 

 Quantiles .25, 0.5, 0.7, and 0.9 :

 

[ 7.  8. 10. 13. 14. 15.]




Mean length: 

 

8.920423872838148

 

 Quantiles .25, 0.5, 0.7, and 0.9 :

 

[ 7.  9. 10. 13. 14. 15.]




Mean length: 

 

8.920423872838148

 

 Quantiles .25, 0.5, 0.7, and 0.9 :

 

[ 7.  9. 10. 13. 14. 15.]




Words: 

 

11895

 

, On device: 

 

cuda




Loss Type: 

 

VAE




128

 

48




Loaded model at step

 

25936




Unsupervised training examples: 

 

90112




Number of parameters: 

 

03.77 M




Inference parameters: 

 

02.21 M




Generation parameters: 

 

03.08 M




Embedding parameters: 

 

01.52 M




In [106]:

def decode_to_text(x_hat_params, vocab_size, vocab_index):
    # It is assumed that this function is used at test time for display purposes
    # Getting the argmax from the one hot if it's not done
    while x_hat_params.shape[-1] == vocab_size and x_hat_params.ndim > 3:
        x_hat_params = x_hat_params.mean(0)
    while x_hat_params.ndim > 2 and x_hat_params.shape[-1] != self.h_params.vocab_size:
        x_hat_params = x_hat_params[0]
    if x_hat_params.shape[-1] == vocab_size:
        x_hat_params = torch.argmax(x_hat_params, dim=-1)
    assert x_hat_params.ndim == 2, "Mis-shaped generated sequence: {}".format(x_hat_params.shape)
    
    samples = [' '.join([vocab_index.itos[w]
                         for w in sen]).split('<eos>')[0].replace('<go>', '').replace('</go>', '')
               .replace('<pad>', '_').replace('_unk', '<?>')
               for sen in x_hat_params]

    return samples


def get_sentences(mdl, n_samples, gen_len=16, sample_w=False, vary_z=True, complete=None):
    
            go_symbol = torch.ones([n_samples]).long() * \
                        mdl.index[mdl.generated_v].stoi['<go>']
            go_symbol = go_symbol.to(mdl.h_params.device).unsqueeze(-1)
            x_prev = go_symbol
            if complete is not None:
                for token in complete.split(' '):
                    x_prev = torch.cat([x_prev, torch.ones([n_samples, 1]).long().to(mdl.h_params.device) * \
                        mdl.index[mdl.generated_v].stoi[token]], dim=1)
                gen_len = gen_len - len(complete.split(' '))
            temp = 1.
            z_gen = mdl.gen_bn.name_to_v['z']
            if vary_z:
                z_sample = z_gen.prior_sample((n_samples,))[0]
            else:
                z_sample = z_gen.prior_sample((1,))[0]
                z_sample = z_sample.repeat(n_samples, 1)
            z_input = {'z': z_sample.unsqueeze(1)}
            # Structured Z case
            z1, z2 = mdl.gen_bn.name_to_v['z1'], mdl.gen_bn.name_to_v['z2']
            if vary_z:
                mdl.gen_bn({'z': z_sample.unsqueeze(1),
                             'x_prev':torch.zeros((n_samples, 1, mdl.generated_v.size)).to(mdl.h_params.device)})
                z1_sample, z2_sample = z1.post_samples.squeeze(1), z2.post_samples.squeeze(1)
                z1_params, z2_params = z1.post_params, z2.post_params
            else:
                mdl.gen_bn({'z': z_sample[0].unsqueeze(0).unsqueeze(1),
                             'x_prev':torch.zeros((1, 1, mdl.generated_v.size)).to(mdl.h_params.device)})
                z1_sample, z2_sample = z1.post_samples.squeeze(1).repeat(n_samples, 1), z2.post_samples.squeeze(1).repeat(n_samples, 1)
                z1_params, z2_params = {k: v.squeeze(1).repeat(n_samples, 1) for k, v in z1.post_params.items()}, \
                                       {k: v.squeeze(1).repeat(n_samples, 1) for k, v in z2.post_params.items()}
            z_input['z1'] = z1_sample.unsqueeze(1)
            z_input['z2'] = z2_sample.unsqueeze(1)
            
            # Normal Autoregressive generation
            for i in range(gen_len):
                mdl.gen_bn({'x_prev': x_prev, **{k: v.expand(v.shape[0], i+1, v.shape[-1])
                                                  for k, v in z_input.items()}})
                if not sample_w:
                    samples_i = mdl.generated_v.post_params['logits']
                else:
                    samples_i = mdl.generated_v.posterior(logits=mdl.generated_v.post_params['logits']/temp,
                                                           temperature=1).rsample()
                x_prev = torch.cat([x_prev, torch.argmax(samples_i,     dim=-1)[..., -1].unsqueeze(-1)],
                                   dim=-1)

            text = decode_to_text(x_prev, mdl.h_params.vocab_size, mdl.index[mdl.generated_v])
            return text, {'z':z_sample, 'z1':z1_sample, 'z2':z2_sample} , {'z1':z1_params, 'z2':z2_params}
text, samples, params = get_sentences(model, 5, 16, sample_w=False, vary_z=True, complete=None)
print(text)

[' a man is walking down the beach .. ', ' a man is walking outside .. ', ' two men are sitting at a park .. ', ' a man is doing a guitar on a beach .. ', ' a boy is riding a bike .. ']




In [111]:
def get_alternative_sentences(mdl, prev_latent_vals, params, var_z_ids, n_samples, gen_len, complete=None):
            go_symbol = torch.ones([n_samples]).long() * \
                        mdl.index[mdl.generated_v].stoi['<go>']
            go_symbol = go_symbol.to(mdl.h_params.device).unsqueeze(-1)
            x_prev = go_symbol
            if complete is not None:
                for token in complete.split(' '):
                    x_prev = torch.cat([x_prev, torch.ones([n_samples, 1]).long().to(mdl.h_params.device) * \
                        mdl.index[mdl.generated_v].stoi[token]], dim=1)
                gen_len = gen_len - len(complete.split(' '))
            temp = 1.
            orig_z = prev_latent_vals['z'].repeat(n_samples, 1)
            orig_z1 = prev_latent_vals['z1'].repeat(n_samples, 1)
            orig_z2 = prev_latent_vals['z2'].repeat(n_samples, 1)
            z_gen, z1, z2 = mdl.gen_bn.name_to_v['z'], mdl.gen_bn.name_to_v['z1'], mdl.gen_bn.name_to_v['z2']
            
            mdl.gen_bn({'z': orig_z.unsqueeze(1), 'z1':orig_z1.unsqueeze(1),
                        'z2': orig_z2.unsqueeze(1), 
                        'x_prev': torch.zeros((n_samples, 1, mdl.generated_v.size)).to(mdl.h_params.device)})
            z_sample = z_gen.prior_sample((n_samples,))[0]
            z1_sample, z2_sample = z1.post_samples.squeeze(1), z2.post_samples.squeeze(1)
            z1_params, z2_params = z1.post_params, z2.post_params
            for id in var_z_ids:
                z_number = int(id/h_params.n_latents)
                z_index = id % h_params.n_latents
                start, end = int(h_params.z_size/h_params.n_latents*z_index), int(h_params.z_size/h_params.n_latents*(z_index+1))
                source, destination = [z_sample, z1_sample, z2_sample][z_number], [orig_z, orig_z1, orig_z2][z_number]
                destination[:, start:end] = source[:, start:end]
            
            z_input = {'z': orig_z.unsqueeze(1), 'z1': orig_z1.unsqueeze(1), 'z2': orig_z2.unsqueeze(1)}
            
            # Normal Autoregressive generation
            for i in range(gen_len):
                mdl.gen_bn({'x_prev': x_prev, **{k: v.expand(v.shape[0], i+1, v.shape[-1])
                                                  for k, v in z_input.items()}})
                samples_i = mdl.generated_v.post_params['logits']
                
                x_prev = torch.cat([x_prev, torch.argmax(samples_i,     dim=-1)[..., -1].unsqueeze(-1)],
                                   dim=-1)

            text = decode_to_text(x_prev, mdl.h_params.vocab_size, mdl.index[mdl.generated_v])
            return text, {'z': z_sample, 'z1': z1_sample, 'z2': z2_sample} , {'z1': z1_params, 'z2': z2_params}


alt_text, alt_samples, alt_params = get_alternative_sentences(model, {k:v[3] for k, v in samples.items()},
                                                              None, [1], 20, 16, complete=None)
print(alt_text)
    

[' a man is doing a guitar on a beach .. ', ' a man is doing a guitar on a beach .. ', ' a man is doing a guitar on a beach .. ', ' a man is doing a guitar on a beach .. ', ' a man is doing a guitar on a beach .. ', ' a man is doing a toy .. ', ' a man is doing a toy .. ', ' a man is doing a guitar on a beach .. ', ' a man is doing a guitar on a beach .. ', ' a man is doing a guitar on a beach .. ', ' a man is doing a toy .. ', ' a man is doing a toy .. ', ' a man is doing a guitar on a beach .. ', ' a man is doing a guitar on a beach .. ', ' a man is doing a guitar on a beach .. ', ' a man is doing a guitar on a beach .. ', ' a man is doing a guitar on a beach .. ', ' a man is doing a toy .. ', ' a man is doing a guitar on a beach .. ', ' a man is doing a toy .. ']




In [112]:
from supar import Parser

parser = Parser.load('biaffine-dep-en')
dataset = parser.predict([['She', 'enjoys', 'playing', 'tennis', '.']], prob=True, verbose=False)

Downloading: "https://github.com/yzhangcs/supar/releases/download/v0.1.0/ptb.biaffine.dependency.char.zip" to C:\Users\ghazy/.cache\torch\checkpoints\ptb.biaffine.dependency.char.zip


HBox(children=(IntProgress(value=0, max=346795161), HTML(value='')))




AttributeError: module 'torch.distributed' has no attribute 'is_initialized'