In [1]:
# This file will implement the main training loop for a model
# Model 1
from time import time
import sys
import os
import argparse
print(os.getcwd())
sys.path.append('../..')
# os.chdir('../..')

from torch import device
import torch
from torch import optim
import numpy as np

from data_prep import NLIGenData2 as Data
from disentanglement_transformer.models import DisentanglementTransformerVAE as Model
from disentanglement_transformer.h_params import DefaultTransformerHParams as HParams
from disentanglement_transformer.graphs import *
from disentanglement_transformer.graphs import get_structured_auto_regressive_disentanglement_graph
from components.criteria import *
parser = argparse.ArgumentParser()

# Training and Optimization
parser.add_argument("--test_name", default='nliLM/StructuredAutoreg5', type=str)
parser.add_argument("--max_len", default=17, type=int)
parser.add_argument("--batch_size", default=128, type=int)
parser.add_argument("--grad_accu", default=1, type=int)
parser.add_argument("--n_epochs", default=10000, type=int)
parser.add_argument("--test_freq", default=32, type=int)
parser.add_argument("--complete_test_freq", default=160, type=int)
parser.add_argument("--generation_weight", default=1, type=float)
parser.add_argument("--device", default='cuda:0', choices=["cuda:0", "cuda:1", "cuda:2", "cpu"], type=str)
parser.add_argument("--embedding_dim", default=128, type=int)#################"
parser.add_argument("--z_size", default=768, type=int)#################"
parser.add_argument("--n_latents", default=[16, 16, 16], type=int)#################"
parser.add_argument("--text_rep_l", default=2, type=int)
parser.add_argument("--text_rep_h", default=768, type=int)
parser.add_argument("--encoder_h", default=768, type=int)#################"
parser.add_argument("--encoder_l", default=2, type=int)#################"
parser.add_argument("--decoder_h", default=768, type=int)
parser.add_argument("--decoder_l", default=3, type=int)#################"
parser.add_argument("--highway", default=False, type=bool)
parser.add_argument("--markovian", default=True, type=bool)
parser.add_argument("--losses", default='VAE', choices=["VAE", "IWAE"], type=str)
parser.add_argument("--training_iw_samples", default=5, type=int)
parser.add_argument("--testing_iw_samples", default=20, type=int)
parser.add_argument("--test_prior_samples", default=10, type=int)
parser.add_argument("--anneal_kl0", default=2000, type=int)
parser.add_argument("--anneal_kl1", default=4000, type=int)
parser.add_argument("--grad_clip", default=100., type=float)
parser.add_argument("--kl_th", default=0/(768*3), type=float or None)
parser.add_argument("--dropout", default=0.0, type=float)
parser.add_argument("--word_dropout", default=.0, type=float)
parser.add_argument("--l2_reg", default=0, type=float)
parser.add_argument("--lr", default=2e-4, type=float)
parser.add_argument("--lr_reduction", default=4., type=float)
parser.add_argument("--wait_epochs", default=3, type=float)
parser.add_argument("--save_all", default=True, type=bool)

flags, _ = parser.parse_known_args()

# torch.autograd.set_detect_anomaly(True)
MAX_LEN = flags.max_len
BATCH_SIZE = flags.batch_size
GRAD_ACCU = flags.grad_accu
N_EPOCHS = flags.n_epochs
TEST_FREQ = flags.test_freq
COMPLETE_TEST_FREQ = flags.complete_test_freq
DEVICE = device(flags.device)
# This prevents illegal memory access on multigpu machines (unresolved issue on torch's github)
if flags.device.startswith('cuda'):
    torch.cuda.set_device(int(flags.device[-1]))
LOSSES = {'IWAE': [IWLBo],
          'VAE': [ELBo]}[flags.losses]
#  LOSSES = [IWLBo]
ANNEAL_KL = [flags.anneal_kl0*flags.grad_accu, flags.anneal_kl1*flags.grad_accu]
LOSS_PARAMS = [1]
if flags.grad_accu > 1:
    LOSS_PARAMS = [w/flags.grad_accu for w in LOSS_PARAMS]

data = Data(MAX_LEN, BATCH_SIZE, N_EPOCHS, DEVICE, False)
h_params = HParams(len(data.vocab.itos), len(data.tags.itos), MAX_LEN, BATCH_SIZE, N_EPOCHS,
                   device=DEVICE, vocab_ignore_index=data.vocab.stoi['<pad>'], decoder_h=flags.decoder_h,
                   decoder_l=flags.decoder_l, encoder_h=flags.encoder_h, encoder_l=flags.encoder_l,
                   text_rep_h=flags.text_rep_h, text_rep_l=flags.text_rep_l,
                   test_name=flags.test_name, grad_accumulation_steps=GRAD_ACCU,
                   optimizer_kwargs={'lr': flags.lr, #'weight_decay': flags.l2_reg, 't0':100, 'lambd':0.},
                                     'weight_decay': flags.l2_reg, 'betas': (0.9, 0.85)},
                   is_weighted=[], graph_generator=get_structured_auto_regressive_disentanglement_graph,
                   z_size=flags.z_size, embedding_dim=flags.embedding_dim, anneal_kl=ANNEAL_KL,
                   grad_clip=flags.grad_clip*flags.grad_accu, kl_th=flags.kl_th, highway=flags.highway,
                   losses=LOSSES, dropout=flags.dropout, training_iw_samples=flags.training_iw_samples,
                   testing_iw_samples=flags.testing_iw_samples, loss_params=LOSS_PARAMS, optimizer=optim.AdamW,
                   markovian=flags.markovian, word_dropout=flags.word_dropout, contiguous_lm=False,
                   test_prior_samples=flags.test_prior_samples, n_latents=flags.n_latents)
val_iterator = iter(data.val_iter)
print("Words: ", len(data.vocab.itos), ", On device: ", DEVICE.type)
print("Loss Type: ", flags.losses)
model = Model(data.vocab, data.tags, h_params, wvs=data.wvs)
if DEVICE.type == 'cuda':
    model.cuda(DEVICE)

total_unsupervised_train_samples = len(data.train_iter)*BATCH_SIZE
print("Unsupervised training examples: ", total_unsupervised_train_samples)
current_time = time()
#print(model)
number_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of parameters: ", "{0:05.2f} M".format(number_parameters/1e6))
number_parameters = sum(p.numel() for p in model.infer_bn.parameters() if p.requires_grad)
print("Inference parameters: ", "{0:05.2f} M".format(number_parameters/1e6))
number_parameters = sum(p.numel() for p in model.gen_bn.parameters() if p.requires_grad)
print("Generation parameters: ", "{0:05.2f} M".format(number_parameters/1e6))
number_parameters = sum(p.numel() for p in model.word_embeddings.parameters() if p.requires_grad)
print("Embedding parameters: ", "{0:05.2f} M".format(number_parameters/1e6))


E:\Experiments\GLUE_BENCH


Mean length:  8.900745464443235  Quantiles .25, 0.5, 0.7, and 0.9 : [ 7.  8. 10. 13. 14. 15.]


Mean length:  8.920423872838148  Quantiles .25, 0.5, 0.7, and 0.9 : [ 7.  9. 10. 13. 14. 15.]
Mean length:  8.920423872838148  Quantiles .25, 0.5, 0.7, and 0.9 : [ 7.  9. 10. 13. 14. 15.]


Words:  11895 , On device:  cuda
Loss Type:  VAE


Loaded model at step 25936
Unsupervised training examples:  90112
Number of parameters:  03.77 M
Inference parameters:  02.21 M
Generation parameters:  03.08 M
Embedding parameters:  01.52 M


In [9]:
# Model 2
import sys
import os
import argparse
print(os.getcwd())
#sys.path.append('../..')
os.chdir('../..')

from torch import device
import torch
from torch import optim
import numpy as np

from data_prep import NLIGenData2 as Data
from disentanglement_transformer.models import DisentanglementTransformerVAE as Model
from disentanglement_transformer.h_params import DefaultTransformerHParams as HParams
from disentanglement_transformer.graphs import *
from components.criteria import *
parser = argparse.ArgumentParser()

# Training and Optimization
k, kz =1, 10
parser.add_argument("--test_name", default='unnamed', type=str)
parser.add_argument("--max_len", default=20, type=int)
parser.add_argument("--batch_size", default=512, type=int)
parser.add_argument("--grad_accu", default=1, type=int)
parser.add_argument("--n_epochs", default=10000, type=int)
parser.add_argument("--test_freq", default=32, type=int)
parser.add_argument("--complete_test_freq", default=32, type=int)
parser.add_argument("--generation_weight", default=1, type=float)
parser.add_argument("--device", default='cuda:0', choices=["cuda:0", "cuda:1", "cuda:2", "cpu"], type=str)
parser.add_argument("--embedding_dim", default=300, type=int)#################"
parser.add_argument("--pretrained_embeddings", default=True, type=bool)#################"
parser.add_argument("--z_size", default=768*kz, type=int)#################"
parser.add_argument("--z_emb_dim", default=768*k, type=int)#################"
parser.add_argument("--n_latents", default=[16, 16, 16], type=list)#################"
parser.add_argument("--text_rep_l", default=2, type=int)
parser.add_argument("--text_rep_h", default=768*k, type=int)
parser.add_argument("--encoder_h", default=768*k, type=int)#################"
parser.add_argument("--encoder_l", default=2, type=int)#################"
parser.add_argument("--decoder_h", default=768*k, type=int)
parser.add_argument("--decoder_l", default=2, type=int)#################"
parser.add_argument("--highway", default=False, type=bool)
parser.add_argument("--markovian", default=True, type=bool)
parser.add_argument("--losses", default='VAE', choices=["VAE", "IWAE"], type=str)
parser.add_argument("--graph", default='Discrete', choices=["Discrete", "Normal"], type=str)
parser.add_argument("--training_iw_samples", default=5, type=int)
parser.add_argument("--testing_iw_samples", default=20, type=int)
parser.add_argument("--test_prior_samples", default=10, type=int)
parser.add_argument("--anneal_kl0", default=3000, type=int)
parser.add_argument("--anneal_kl1", default=6000, type=int)
parser.add_argument("--grad_clip", default=10., type=float)
parser.add_argument("--kl_th", default=0*12/(1536*28/16), type=float or None)
parser.add_argument("--dropout", default=0.0, type=float)
parser.add_argument("--word_dropout", default=.0, type=float)
parser.add_argument("--l2_reg", default=0, type=float)
parser.add_argument("--lr", default=2e-4, type=float)
parser.add_argument("--lr_reduction", default=4., type=float)
parser.add_argument("--wait_epochs", default=3, type=float)
parser.add_argument("--save_all", default=True, type=bool)

flags, _ = parser.parse_known_args()

# Manual Settings, Deactivate before pushing
if True:
    flags.losses = 'VAE'
    flags.batch_size = 64
    flags.grad_accu = 1
    flags.max_len = 17
    flags.test_name = "nliLM/Discrete3"

# torch.autograd.set_detect_anomaly(True)
GRAPH = {"Discrete": get_discrete_auto_regressive_disentanglement_graph,
         "Normal": get_structured_auto_regressive_disentanglement_graph}[flags.graph]
MAX_LEN = flags.max_len
BATCH_SIZE = flags.batch_size
MAS_ELBO = 5
GRAD_ACCU = flags.grad_accu
N_EPOCHS = flags.n_epochs
TEST_FREQ = flags.test_freq
COMPLETE_TEST_FREQ = flags.complete_test_freq
DEVICE = device(flags.device)
# This prevents illegal memory access on multigpu machines (unresolved issue on torch's github)
if flags.device.startswith('cuda'):
    torch.cuda.set_device(int(flags.device[-1]))
LOSSES = {'IWAE': [IWLBo],
          'VAE': [ELBo]}[flags.losses]
#  LOSSES = [IWLBo]
ANNEAL_KL = [flags.anneal_kl0*flags.grad_accu, flags.anneal_kl1*flags.grad_accu]
LOSS_PARAMS = [1]
if flags.grad_accu > 1:
    LOSS_PARAMS = [w/flags.grad_accu for w in LOSS_PARAMS]

data = Data(MAX_LEN, BATCH_SIZE, N_EPOCHS, DEVICE, flags.pretrained_embeddings)
h_params = HParams(len(data.vocab.itos), len(data.tags.itos), MAX_LEN, BATCH_SIZE, N_EPOCHS,
                   device=DEVICE, vocab_ignore_index=data.vocab.stoi['<pad>'], decoder_h=flags.decoder_h,
                   decoder_l=flags.decoder_l, encoder_h=flags.encoder_h, encoder_l=flags.encoder_l,
                   text_rep_h=flags.text_rep_h, text_rep_l=flags.text_rep_l,
                   test_name=flags.test_name, grad_accumulation_steps=GRAD_ACCU,
                   optimizer_kwargs={'lr': flags.lr, #'weight_decay': flags.l2_reg, 't0':100, 'lambd':0.},
                                     'weight_decay': flags.l2_reg, 'betas': (0.9, 0.85)},
                   is_weighted=[], graph_generator=GRAPH,
                   z_size=flags.z_size, embedding_dim=flags.embedding_dim, anneal_kl=ANNEAL_KL,
                   grad_clip=flags.grad_clip*flags.grad_accu, kl_th=flags.kl_th, highway=flags.highway,
                   losses=LOSSES, dropout=flags.dropout, training_iw_samples=flags.training_iw_samples,
                   testing_iw_samples=flags.testing_iw_samples, loss_params=LOSS_PARAMS, optimizer=optim.AdamW,
                   markovian=flags.markovian, word_dropout=flags.word_dropout, contiguous_lm=False,
                   test_prior_samples=flags.test_prior_samples, n_latents=flags.n_latents, max_elbo=5,
                   z_emb_dim=flags.z_emb_dim)
val_iterator = iter(data.val_iter)
print("Words: ", len(data.vocab.itos), ", On device: ", DEVICE.type)
print("Loss Type: ", flags.losses)
model = Model(data.vocab, data.tags, h_params, wvs=data.wvs)
if DEVICE.type == 'cuda':
    model.cuda(DEVICE)

total_unsupervised_train_samples = len(data.train_iter)*BATCH_SIZE
print("Unsupervised training examples: ", total_unsupervised_train_samples)
current_time = time()
#print(model)
number_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of parameters: ", "{0:05.2f} M".format(number_parameters/1e6))
number_parameters = sum(p.numel() for p in model.infer_bn.parameters() if p.requires_grad)
print("Inference parameters: ", "{0:05.2f} M".format(number_parameters/1e6))
number_parameters = sum(p.numel() for p in model.gen_bn.parameters() if p.requires_grad)
print("Generation parameters: ", "{0:05.2f} M".format(number_parameters/1e6))
number_parameters = sum(p.numel() for p in model.word_embeddings.parameters() if p.requires_grad)
print("Embedding parameters: ", "{0:05.2f} M".format(number_parameters/1e6))

E:\Experiments\GLUE_BENCH\tb_logs\nliLM


Mean length:  8.900745464443235  Quantiles .25, 0.5, 0.7, and 0.9 : [ 7.  8. 10. 13. 14. 15.]
Mean length:  8.920423872838148  Quantiles .25, 0.5, 0.7, and 0.9 : [ 7.  9. 10. 13. 14. 15.]


Mean length:  8.920423872838148  Quantiles .25, 0.5, 0.7, and 0.9 : [ 7.  9. 10. 13. 14. 15.]


Words:  11895 , On device:  cuda
Loss Type:  VAE


Loaded model at step 39396
Unsupervised training examples:  90048
Number of parameters:  07.75 M
Inference parameters:  05.48 M
Generation parameters:  05.84 M
Embedding parameters:  03.57 M


In [7]:
model.eval()
def decode_to_text(x_hat_params, vocab_size, vocab_index):
    # It is assumed that this function is used at test time for display purposes
    # Getting the argmax from the one hot if it's not done
    while x_hat_params.shape[-1] == vocab_size and x_hat_params.ndim > 3:
        x_hat_params = x_hat_params.mean(0)
    while x_hat_params.ndim > 2 and x_hat_params.shape[-1] != self.h_params.vocab_size:
        x_hat_params = x_hat_params[0]
    if x_hat_params.shape[-1] == vocab_size:
        x_hat_params = torch.argmax(x_hat_params, dim=-1)
    assert x_hat_params.ndim == 2, "Mis-shaped generated sequence: {}".format(x_hat_params.shape)
    
    samples = [' '.join([vocab_index.itos[w]
                         for w in sen]).split('<eos>')[0].replace('<go>', '').replace('</go>', '')
               .replace('<pad>', '_').replace('_unk', '<?>')
               for sen in x_hat_params]

    return samples


def get_sentences(mdl, n_samples, gen_len=16, sample_w=False, vary_z=True, complete=None):
    
            go_symbol = torch.ones([n_samples]).long() * \
                        mdl.index[mdl.generated_v].stoi['<go>']
            go_symbol = go_symbol.to(mdl.h_params.device).unsqueeze(-1)
            x_prev = go_symbol
            if complete is not None:
                for token in complete.split(' '):
                    x_prev = torch.cat([x_prev, torch.ones([n_samples, 1]).long().to(mdl.h_params.device) * \
                        mdl.index[mdl.generated_v].stoi[token]], dim=1)
                gen_len = gen_len - len(complete.split(' '))
            temp = 1.
            z_gen = mdl.gen_bn.name_to_v['z1']
            if vary_z:
                z_sample = z_gen.prior_sample((n_samples,))[0]
            else:
                z_sample = z_gen.prior_sample((1,))[0]
                z_sample = z_sample.repeat(n_samples, 1)
            z_input = {'z1': z_sample.unsqueeze(1)}
            # Structured Z case
            z1, z2 = mdl.gen_bn.name_to_v['z2'], mdl.gen_bn.name_to_v['z3']
            if vary_z:
                mdl.gen_bn({'z1': z_sample.unsqueeze(1),
                             'x_prev':torch.zeros((n_samples, 1, mdl.generated_v.size)).to(mdl.h_params.device)})
                z1_sample, z2_sample = z1.post_samples.squeeze(1), z2.post_samples.squeeze(1)
                z1_params, z2_params = z1.post_params, z2.post_params
            else:
                mdl.gen_bn({'z1': z_sample[0].unsqueeze(0).unsqueeze(1),
                             'x_prev':torch.zeros((1, 1, mdl.generated_v.size)).to(mdl.h_params.device)})
                z1_sample, z2_sample = z1.post_samples.squeeze(1).repeat(n_samples, 1), z2.post_samples.squeeze(1).repeat(n_samples, 1)
                z1_params, z2_params = {k: v.squeeze(1).repeat(n_samples, 1) for k, v in z1.post_params.items()}, \
                                       {k: v.squeeze(1).repeat(n_samples, 1) for k, v in z2.post_params.items()}
            z_input['z2'] = z1_sample.unsqueeze(1)
            z_input['z3'] = z2_sample.unsqueeze(1)
            
            # Normal Autoregressive generation
            for i in range(gen_len):
                mdl.gen_bn({'x_prev': x_prev, **{k: v.expand(v.shape[0], i+1, v.shape[-1])
                                                  for k, v in z_input.items()}})
                if not sample_w:
                    samples_i = mdl.generated_v.post_params['logits']
                else:
                    samples_i = mdl.generated_v.posterior(logits=mdl.generated_v.post_params['logits']/temp,
                                                           temperature=1).rsample()
                x_prev = torch.cat([x_prev, torch.argmax(samples_i,     dim=-1)[..., -1].unsqueeze(-1)],
                                   dim=-1)

            text = decode_to_text(x_prev, mdl.h_params.vocab_size, mdl.index[mdl.generated_v])
            return text, {'z1':z_sample, 'z2':z1_sample, 'z3':z2_sample} , {'z2':z1_params, 'z3':z2_params}
text, samples, params = get_sentences(model, 5, 16, sample_w=False, vary_z=True, complete=None)
print(text)

[' a man is playing with a ball .. ', ' a boy is riding a horse .. ', ' a man is sitting in a park .. ', ' a group of people are gathered by a building .. ', ' a group of people are walking in a competition .. ']


In [9]:
def get_alternative_sentences(mdl, prev_latent_vals, params, var_z_ids, n_samples, gen_len, complete=None):
            n_orig_sentences = prev_latent_vals['z1'].shape[0]
            go_symbol = torch.ones([n_samples * n_orig_sentences]).long() * \
                        mdl.index[mdl.generated_v].stoi['<go>']
            go_symbol = go_symbol.to(mdl.h_params.device).unsqueeze(-1)
            x_prev = go_symbol
            if complete is not None:
                for token in complete.split(' '):
                    x_prev = torch.cat([x_prev, torch.ones([n_samples * n_orig_sentences, 1]).long().to(mdl.h_params.device) * \
                        mdl.index[mdl.generated_v].stoi[token]], dim=1)
                gen_len = gen_len - len(complete.split(' '))
            temp = 1.
            orig_z = prev_latent_vals['z1'].repeat(n_samples, 1)
            print(orig_z.shape)
            orig_z1 = prev_latent_vals['z2'].repeat(n_samples, 1)
            orig_z2 = prev_latent_vals['z3'].repeat(n_samples, 1)
            z_gen, z1, z2 = mdl.gen_bn.name_to_v['z1'], mdl.gen_bn.name_to_v['z2'], mdl.gen_bn.name_to_v['z3']
            
            mdl.gen_bn({'z1': orig_z.unsqueeze(1), 'z2':orig_z1.unsqueeze(1),
                        'z3': orig_z2.unsqueeze(1), 
                        'x_prev': torch.zeros((n_samples * n_orig_sentences, 1, mdl.generated_v.size)).to(mdl.h_params.device)})
            z_sample = z_gen.prior_sample((n_samples * n_orig_sentences,))[0]
            z1_sample, z2_sample = z1.post_samples.squeeze(1), z2.post_samples.squeeze(1)
            z1_params, z2_params = z1.post_params, z2.post_params
            for id in var_z_ids:
                z_number = sum([id> sum(h_params.n_latents[:i+1]) for i in range(len(h_params.n_latents))])
                z_index = id - sum(h_params.n_latents[:z_number])
                start, end = int(h_params.z_size/max(h_params.n_latents)*z_index), int(h_params.z_size/max(h_params.n_latents)*(z_index+1))
                source, destination = [z_sample, z1_sample, z2_sample][z_number], [orig_z, orig_z1, orig_z2][z_number]
                destination[:, start:end] = source[:, start:end]
            
            z_input = {'z1': orig_z.unsqueeze(1), 'z2': orig_z1.unsqueeze(1), 'z3': orig_z2.unsqueeze(1)}
            
            # Normal Autoregressive generation
            for i in range(gen_len):
                mdl.gen_bn({'x_prev': x_prev, **{k: v.expand(v.shape[0], i+1, v.shape[-1])
                                                  for k, v in z_input.items()}})
                samples_i = mdl.generated_v.post_params['logits']
                
                x_prev = torch.cat([x_prev, torch.argmax(samples_i,     dim=-1)[..., -1].unsqueeze(-1)],
                                   dim=-1)
            
            text = decode_to_text(x_prev, mdl.h_params.vocab_size, mdl.index[mdl.generated_v])
            return text, {'z1': z_sample.tolist(), 'z2': z1_sample.tolist(), 'z3': z2_sample} , None#{'z1': z1_params, 'z2': z2_params}



# for i in range(48):
#     alt_text, alt_samples, alt_params = get_alternative_sentences(model, {k:v[:5] for k, v in samples.items()},
#                                                                   None,# [i for i in range(36,37)],
#                                                                   [i],
#                                                                   5, 16, complete=None)
#     print(i, alt_text[2::5])

alt_text, alt_samples, alt_params = get_alternative_sentences(model, {k:v[:5] for k, v in samples.items()},
                                                              None,# [i for i in range(36,37)],
                                                              [30],
                                                              10, 16, complete=None)
print(alt_text[0::5])
print(alt_text[1::5])
print(alt_text[2::5])
print(alt_text[3::5])
print(alt_text[4::5])

    

torch.Size([50, 768])


[' a man and two men are playing a game .. ', ' a man is playing with a ball .. ', ' a little girl is getting ready to go to a train .. ', ' a person is playing with a ball .. ', ' a young girl is getting ready to go on a couch .. ', ' a young man is getting ready to go on a couch .. ', ' a man is playing with a ball .. ', ' there is a <?> outside .. ', ' a child is playing with a toy .. ', ' a man is playing with a ball .. ']
[' a man is looking up .. ', ' a boy is riding a horse .. ', ' a man is looking up .. ', ' a man is riding a horse .. ', ' a small child is looking at a crowd .. ', ' a man is riding a horse .. ', ' a man is singing .. ', ' a little girl is looking at a boy .. ', ' a man is looking up .. ', ' a man is looking up .. ']
[' a man is sitting in a park .. ', ' a man is sitting in a park .. ', ' a man is sitting in a park .. ', ' a band is playing in the sand .. ', ' a group of people are walking in front of a building .. ', ' a man is sitting in a park .. ', ' a man i

In [11]:
def swap_latents(mdl, prev_latent_vals, var_z_ids, gen_len, complete=None):
            n_orig_sentences = prev_latent_vals['z1'].shape[0]
            n_samples = n_orig_sentences
            go_symbol = torch.ones([n_samples * n_orig_sentences]).long() * \
                        mdl.index[mdl.generated_v].stoi['<go>']
            go_symbol = go_symbol.to(mdl.h_params.device).unsqueeze(-1)
            x_prev = go_symbol
            if complete is not None:
                for token in complete.split(' '):
                    x_prev = torch.cat([x_prev, torch.ones([n_samples * n_orig_sentences, 1]).long().to(mdl.h_params.device) * \
                        mdl.index[mdl.generated_v].stoi[token]], dim=1)
                gen_len = gen_len - len(complete.split(' '))
            temp = 1.
            orig_z = prev_latent_vals['z1'].unsqueeze(1).repeat(1, n_samples, 1)
            orig_z1 = prev_latent_vals['z2'].unsqueeze(1).repeat(1, n_samples, 1)
            orig_z2 = prev_latent_vals['z3'].unsqueeze(1).repeat(1, n_samples, 1)
            z_sample, z1_sample, z2_sample = orig_z.reshape(n_samples*n_orig_sentences, -1), orig_z1.reshape(n_samples*n_orig_sentences, -1), orig_z2.reshape(n_samples*n_orig_sentences, -1)
            orig_z, orig_z1, orig_z2 = orig_z.transpose(0, 1).reshape(n_samples*n_orig_sentences, -1), orig_z1.transpose(0, 1).reshape(n_samples*n_orig_sentences, -1), \
                                             orig_z2.transpose(0, 1).reshape(n_samples*n_orig_sentences, -1)
            

            for id in var_z_ids:
                z_number = sum([id> sum(h_params.n_latents[:i+1]) for i in range(len(h_params.n_latents))])
                z_index = id - sum(h_params.n_latents[:z_number])
                start, end = int(h_params.z_size/max(h_params.n_latents)*z_index), int(h_params.z_size/max(h_params.n_latents)*(z_index+1))
                source, destination = [z_sample, z1_sample, z2_sample][z_number], [orig_z, orig_z1, orig_z2][z_number]
                destination[:, start:end] = source[:, start:end]
            
            z_input = {'z1': orig_z.unsqueeze(1), 'z2': orig_z1.unsqueeze(1), 'z3': orig_z2.unsqueeze(1)}
            
            # Normal Autoregressive generation
            for i in range(gen_len):
                mdl.gen_bn({'x_prev': x_prev, **{k: v.expand(v.shape[0], i+1, v.shape[-1])
                                                  for k, v in z_input.items()}})
                samples_i = mdl.generated_v.post_params['logits']
                
                x_prev = torch.cat([x_prev, torch.argmax(samples_i,     dim=-1)[..., -1].unsqueeze(-1)],
                                   dim=-1)
            
            text = decode_to_text(x_prev, mdl.h_params.vocab_size, mdl.index[mdl.generated_v])
            return text, {'z1': z_sample.tolist(), 'z2': z1_sample.tolist(), 'z3': z2_sample} , None#{'z1': z1_params, 'z2': z2_params}
 
alt_text, alt_samples, alt_params = swap_latents(model, {k:v[:5] for k, v in samples.items()},
                                                              [30], 16, complete=None)

print(text)
print(alt_text[0::5])
print(alt_text[1::5])
print(alt_text[2::5])
print(alt_text[3::5])
print(alt_text[4::5])

[' a man is playing with a ball .. ', ' a boy is riding a horse .. ', ' a man is sitting in a park .. ', ' a group of people are gathered by a building .. ', ' a group of people are walking in a competition .. ']
[' a man is playing with a ball .. ', ' a group of people are outside .. ', ' a man is playing with a ball .. ', ' a group of people are outside .. ', ' a couple of people are outside .. ']
[' a man is riding a horse .. ', ' a boy is riding a horse .. ', ' a man is riding a horse .. ', ' a group of people are outside .. ', ' a couple of people are outside in a restaurant .. ']
[' a man is sitting in a park .. ', ' a group of people are walking in front of a building .. ', ' a man is sitting in a park .. ', ' a group of people are walking in front of a building .. ', ' a couple of people are walking in front of a building .. ']
[' a man is jumping over a piece of tree .. ', ' a group of people are gathered by a building .. ', ' a man is jumping over a piece of tree .. ', ' a gr

In [37]:
import spacy

nlp = spacy.load("en_core_web_sm")



In [124]:
from itertools import tee
def get_depth(root, toks, tree, depth=0):
    root_tree = list([tok for tok in tree[root]])
    if len(root_tree)>0:
        child_ids = [i for i, tok in enumerate(toks) if tok in root_tree]
        return 1+max([get_depth(child_id, toks, tree) for child_id in child_ids])
    else: return depth

def get_sentence_statistics(orig, sen):
    # Orig properties
    orig_doc = nlp(orig)
    orig_pos = [token.pos_ for token in orig_doc]
    orig_toks = [token.text for token in orig_doc]
    orig_length = len(orig_pos)
    orig_dep_label = [token.dep_ for token in orig_doc]
    orig_dep_tree = [list(token.children) for token in orig_doc]
    orig_depth = get_depth(orig_dep_label.index('ROOT'), orig_doc, orig_dep_tree)
    orig_root_children_text = [tok.text for tok in orig_dep_tree[orig_dep_label.index('ROOT')]]
    orig_root_children_dep = [tok.dep_ for tok in orig_dep_tree[orig_dep_label.index('ROOT')]]
    n_orig_root_children = len([tok.text for tok in orig_dep_tree[orig_dep_label.index('ROOT')]])
    # Alt properties
    doc = nlp(sen)
    pos = [token.pos_ for token in doc]
    toks = [token.text for token in doc]
    length = len(pos)
    dep_label = [token.dep_ for token in doc]
    dep_tree = [list(token.children) for token in doc]
    depth = get_depth(dep_label.index('ROOT'), doc, dep_tree)
    root_children_text = [tok.text for tok in dep_tree[dep_label.index('ROOT')]]
    root_children_dep = [tok.dep_ for tok in dep_tree[dep_label.index('ROOT')]]
    n_root_children = len([tok.text for tok in dep_tree[dep_label.index('ROOT')]])
    # Differences 
    len_diff = np.abs(length - orig_length)
    depth_diff = np.abs(depth - orig_depth)
    n_root_children_diff = np.abs(n_root_children-n_orig_root_children)
    root_children_text_diff = np.union1d(np.setdiff1d(orig_root_children_text, root_children_text),
                                        np.setdiff1d(root_children_text, orig_root_children_text)).tolist()
    root_children_dep_diff = np.union1d(np.setdiff1d(orig_root_children_dep, root_children_dep),
                                        np.setdiff1d(root_children_dep, orig_root_children_dep)).tolist()
    new_deps = np.union1d(np.setdiff1d(orig_dep_label, dep_label), np.setdiff1d(dep_label, orig_dep_label)).tolist()
    if len_diff:
        word_diff = []
        diff_pos = []
        diff_dep = []
        n_word_diff = -1
    else:
        word_diff = [(orig_tok, tok) for orig_tok, tok in zip(orig_toks, toks) if orig_tok != tok]
        n_word_diff = len(word_diff)
        diff_pos = [(orig_tok.pos_, tok.pos_) for orig_tok, tok in zip(orig_doc, doc) if orig_tok.text != tok.text]
        diff_dep = [(orig_tok.dep_, tok.dep_) for orig_tok, tok in zip(orig_doc, doc) if orig_tok.text != tok.text]
    return len_diff, depth_diff, n_root_children_diff, n_word_diff, word_diff, diff_pos, diff_dep, new_deps, root_children_dep_diff, \
           root_children_text_diff
    
print(text[0], alt_text[2])
# print(get_sentence_statistics('a blond woman wearing a white shirt and white shirt .. ', 'a blond woman wearing a white shirt and white shirt .. '))
print(get_sentence_statistics('a blond woman wearing a shirt .. ', 'a blond woman sitting  .. '))

 a man is holding a child ..   a toddler is taking a break from his mother .. 
(1, 1, 0, -1, [], [], [], ['', 'dobj'], [], ['sitting', 'wearing'])


In [125]:
from tqdm import tqdm
header = ['original', 'altered', 'alteration_id', 'len_diff', 'depth_diff', 'n_root_children_diff', 
          'n_word_diff', 'word_diff', 'diff_pos', 'diff_dep', 'new_deps', 'root_children_dep_diff',
           'root_children_text_diff']
stats = []
n_samples, n_alterations, nlatents = 100, 10, h_params.n_latents
# Generating a hundred sentences
text, samples, params = get_sentences(model, n_samples=n_samples, gen_len=16, sample_w=False, vary_z=True, complete=None)
batch_size = 20
for i in range(int(n_samples/batch_size)):
    for j in tqdm(range(nlatents*3), desc="Processing sample {}".format(str(i))):
        # Altering the sentences
        alt_text, _, _ = get_alternative_sentences(model, prev_latent_vals={k:v[i*batch_size:(i+1)*batch_size] for k, v in samples.items()},
                                                   params=None, var_z_ids=[j], n_samples=n_alterations,
                                                   gen_len=16, complete=None)
        # Getting alteration statistics
        for k in range(n_alterations*batch_size):
            orig_text = text[(i*batch_size)+k%batch_size]
            try:
                len_diff, depth_diff, n_root_children_diff, n_word_diff, word_diff, diff_pos, diff_dep, new_deps, root_children_dep_diff, \
               root_children_text_diff = get_sentence_statistics(orig_text, alt_text[k])
            except RecursionError:
                print(orig_text, alt_text[k])
                continue
            stats.append([orig_text, alt_text[k], j, len_diff, depth_diff, n_root_children_diff, n_word_diff, 
                          word_diff, diff_pos, diff_dep, new_deps, root_children_dep_diff, 
                          root_children_text_diff])



Processing sample 0:   0%|                                                                      | 0/48 [00:00<?, ?it/s]

Processing sample 0:   2%|█▎                                                            | 1/48 [00:03<02:49,  3.61s/it]

Processing sample 0:   4%|██▌                                                           | 2/48 [00:07<02:46,  3.61s/it]

Processing sample 0:   6%|███▉                                                          | 3/48 [00:10<02:43,  3.62s/it]

Processing sample 0:   8%|█████▏                                                        | 4/48 [00:14<02:38,  3.61s/it]

Processing sample 0:  10%|██████▍                                                       | 5/48 [00:17<02:33,  3.57s/it]

Processing sample 0:  12%|███████▊                                                      | 6/48 [00:21<02:28,  3.53s/it]

Processing sample 0:  15%|█████████                                                     | 7/48 [00:24<02:24,  3.53s/it]

Processing sample 0:  17%|██████████▎                                                   | 8/48 [00:28<02:21,  3.54s/it]

Processing sample 0:  19%|███████████▋                                                  | 9/48 [00:31<02:16,  3.49s/it]

Processing sample 0:  21%|████████████▋                                                | 10/48 [00:35<02:10,  3.44s/it]

Processing sample 0:  23%|█████████████▉                                               | 11/48 [00:38<02:06,  3.41s/it]

Processing sample 0:  25%|███████████████▎                                             | 12/48 [00:41<02:02,  3.39s/it]

Processing sample 0:  27%|████████████████▌                                            | 13/48 [00:45<01:58,  3.38s/it]

Processing sample 0:  29%|█████████████████▊                                           | 14/48 [00:48<01:54,  3.37s/it]

Processing sample 0:  31%|███████████████████                                          | 15/48 [00:51<01:50,  3.36s/it]

Processing sample 0:  33%|████████████████████▎                                        | 16/48 [00:55<01:47,  3.37s/it]

Processing sample 0:  35%|█████████████████████▌                                       | 17/48 [00:58<01:44,  3.38s/it]

Processing sample 0:  38%|██████████████████████▉                                      | 18/48 [01:02<01:41,  3.37s/it]

Processing sample 0:  40%|████████████████████████▏                                    | 19/48 [01:05<01:37,  3.38s/it]

Processing sample 0:  42%|█████████████████████████▍                                   | 20/48 [01:08<01:34,  3.36s/it]

Processing sample 0:  44%|██████████████████████████▋                                  | 21/48 [01:12<01:30,  3.36s/it]

Processing sample 0:  46%|███████████████████████████▉                                 | 22/48 [01:15<01:27,  3.35s/it]

Processing sample 0:  48%|█████████████████████████████▏                               | 23/48 [01:19<01:27,  3.51s/it]

Processing sample 0:  50%|██████████████████████████████▌                              | 24/48 [01:22<01:23,  3.48s/it]

Processing sample 0:  52%|███████████████████████████████▊                             | 25/48 [01:26<01:19,  3.45s/it]

Processing sample 0:  54%|█████████████████████████████████                            | 26/48 [01:29<01:15,  3.44s/it]

Processing sample 0:  56%|██████████████████████████████████▎                          | 27/48 [01:32<01:12,  3.43s/it]

Processing sample 0:  58%|███████████████████████████████████▌                         | 28/48 [01:36<01:07,  3.40s/it]

Processing sample 0:  60%|████████████████████████████████████▊                        | 29/48 [01:39<01:04,  3.39s/it]

Processing sample 0:  62%|██████████████████████████████████████▏                      | 30/48 [01:42<01:00,  3.38s/it]

Processing sample 0:  65%|███████████████████████████████████████▍                     | 31/48 [01:46<00:57,  3.36s/it]

Processing sample 0:  67%|████████████████████████████████████████▋                    | 32/48 [01:49<00:53,  3.36s/it]

Processing sample 0:  69%|█████████████████████████████████████████▉                   | 33/48 [01:53<00:50,  3.37s/it]

Processing sample 0:  71%|███████████████████████████████████████████▏                 | 34/48 [01:56<00:47,  3.36s/it]

Processing sample 0:  73%|████████████████████████████████████████████▍                | 35/48 [01:59<00:43,  3.35s/it]

Processing sample 0:  75%|█████████████████████████████████████████████▊               | 36/48 [02:03<00:40,  3.34s/it]

Processing sample 0:  77%|███████████████████████████████████████████████              | 37/48 [02:06<00:37,  3.37s/it]

Processing sample 0:  79%|████████████████████████████████████████████████▎            | 38/48 [02:09<00:33,  3.38s/it]

Processing sample 0:  81%|█████████████████████████████████████████████████▌           | 39/48 [02:13<00:30,  3.38s/it]

Processing sample 0:  83%|██████████████████████████████████████████████████▊          | 40/48 [02:16<00:27,  3.38s/it]

Processing sample 0:  85%|████████████████████████████████████████████████████         | 41/48 [02:19<00:23,  3.37s/it]

Processing sample 0:  88%|█████████████████████████████████████████████████████▍       | 42/48 [02:23<00:20,  3.35s/it]

Processing sample 0:  90%|██████████████████████████████████████████████████████▋      | 43/48 [02:26<00:17,  3.41s/it]

Processing sample 0:  92%|███████████████████████████████████████████████████████▉     | 44/48 [02:30<00:13,  3.41s/it]

Processing sample 0:  94%|█████████████████████████████████████████████████████████▏   | 45/48 [02:33<00:10,  3.40s/it]

Processing sample 0:  96%|██████████████████████████████████████████████████████████▍  | 46/48 [02:36<00:06,  3.38s/it]

Processing sample 0:  98%|███████████████████████████████████████████████████████████▋ | 47/48 [02:40<00:03,  3.38s/it]

Processing sample 0: 100%|█████████████████████████████████████████████████████████████| 48/48 [02:43<00:00,  3.46s/it]

Processing sample 0: 100%|█████████████████████████████████████████████████████████████| 48/48 [02:43<00:00,  3.42s/it]


Processing sample 1:   0%|                                                                      | 0/48 [00:00<?, ?it/s]

Processing sample 1:   2%|█▎                                                            | 1/48 [00:03<02:35,  3.30s/it]

Processing sample 1:   4%|██▌                                                           | 2/48 [00:06<02:32,  3.31s/it]

Processing sample 1:   6%|███▉                                                          | 3/48 [00:10<02:30,  3.35s/it]

Processing sample 1:   8%|█████▏                                                        | 4/48 [00:13<02:30,  3.42s/it]

Processing sample 1:  10%|██████▍                                                       | 5/48 [00:17<02:27,  3.42s/it]

Processing sample 1:  12%|███████▊                                                      | 6/48 [00:20<02:21,  3.38s/it]

Processing sample 1:  15%|█████████                                                     | 7/48 [00:23<02:18,  3.37s/it]

Processing sample 1:  17%|██████████▎                                                   | 8/48 [00:26<02:13,  3.35s/it]

Processing sample 1:  19%|███████████▋                                                  | 9/48 [00:30<02:10,  3.33s/it]

Processing sample 1:  21%|████████████▋                                                | 10/48 [00:33<02:06,  3.34s/it]

Processing sample 1:  23%|█████████████▉                                               | 11/48 [00:37<02:06,  3.42s/it]

Processing sample 1:  25%|███████████████▎                                             | 12/48 [00:40<02:01,  3.39s/it]

Processing sample 1:  27%|████████████████▌                                            | 13/48 [00:44<01:59,  3.43s/it]

Processing sample 1:  29%|█████████████████▊                                           | 14/48 [00:47<01:56,  3.41s/it]

Processing sample 1:  31%|███████████████████                                          | 15/48 [00:50<01:51,  3.39s/it]

Processing sample 1:  33%|████████████████████▎                                        | 16/48 [00:54<01:48,  3.39s/it]

Processing sample 1:  35%|█████████████████████▌                                       | 17/48 [00:57<01:44,  3.39s/it]

Processing sample 1:  38%|██████████████████████▉                                      | 18/48 [01:00<01:40,  3.36s/it]

Processing sample 1:  40%|████████████████████████▏                                    | 19/48 [01:04<01:37,  3.36s/it]

Processing sample 1:  42%|█████████████████████████▍                                   | 20/48 [01:07<01:33,  3.34s/it]

Processing sample 1:  44%|██████████████████████████▋                                  | 21/48 [01:10<01:30,  3.36s/it]

Processing sample 1:  46%|███████████████████████████▉                                 | 22/48 [01:14<01:26,  3.35s/it]

Processing sample 1:  48%|█████████████████████████████▏                               | 23/48 [01:17<01:23,  3.33s/it]

Processing sample 1:  50%|██████████████████████████████▌                              | 24/48 [01:20<01:19,  3.32s/it]

Processing sample 1:  52%|███████████████████████████████▊                             | 25/48 [01:24<01:18,  3.40s/it]

Processing sample 1:  54%|█████████████████████████████████                            | 26/48 [01:27<01:15,  3.41s/it]

Processing sample 1:  56%|██████████████████████████████████▎                          | 27/48 [01:31<01:11,  3.40s/it]

Processing sample 1:  58%|███████████████████████████████████▌                         | 28/48 [01:34<01:08,  3.44s/it]

Processing sample 1:  60%|████████████████████████████████████▊                        | 29/48 [01:38<01:05,  3.42s/it]

Processing sample 1:  62%|██████████████████████████████████████▏                      | 30/48 [01:41<01:00,  3.39s/it]

Processing sample 1:  65%|███████████████████████████████████████▍                     | 31/48 [01:44<00:57,  3.40s/it]

Processing sample 1:  67%|████████████████████████████████████████▋                    | 32/48 [01:48<00:53,  3.37s/it]

Processing sample 1:  69%|█████████████████████████████████████████▉                   | 33/48 [01:51<00:51,  3.44s/it]

Processing sample 1:  71%|███████████████████████████████████████████▏                 | 34/48 [01:55<00:47,  3.40s/it]

Processing sample 1:  73%|████████████████████████████████████████████▍                | 35/48 [01:58<00:44,  3.39s/it]

Processing sample 1:  75%|█████████████████████████████████████████████▊               | 36/48 [02:01<00:40,  3.37s/it]

Processing sample 1:  77%|███████████████████████████████████████████████              | 37/48 [02:05<00:37,  3.37s/it]

Processing sample 1:  79%|████████████████████████████████████████████████▎            | 38/48 [02:08<00:33,  3.35s/it]

Processing sample 1:  81%|█████████████████████████████████████████████████▌           | 39/48 [02:11<00:30,  3.33s/it]

Processing sample 1:  83%|██████████████████████████████████████████████████▊          | 40/48 [02:15<00:26,  3.35s/it]

Processing sample 1:  85%|████████████████████████████████████████████████████         | 41/48 [02:18<00:23,  3.33s/it]

Processing sample 1:  88%|█████████████████████████████████████████████████████▍       | 42/48 [02:21<00:19,  3.33s/it]

Processing sample 1:  90%|██████████████████████████████████████████████████████▋      | 43/48 [02:25<00:16,  3.38s/it]

Processing sample 1:  92%|███████████████████████████████████████████████████████▉     | 44/48 [02:28<00:13,  3.36s/it]

Processing sample 1:  94%|█████████████████████████████████████████████████████████▏   | 45/48 [02:32<00:10,  3.40s/it]

Processing sample 1:  96%|██████████████████████████████████████████████████████████▍  | 46/48 [02:35<00:06,  3.36s/it]

Processing sample 1:  98%|███████████████████████████████████████████████████████████▋ | 47/48 [02:38<00:03,  3.41s/it]

Processing sample 1: 100%|█████████████████████████████████████████████████████████████| 48/48 [02:42<00:00,  3.37s/it]

Processing sample 1: 100%|█████████████████████████████████████████████████████████████| 48/48 [02:42<00:00,  3.38s/it]


Processing sample 2:   0%|                                                                      | 0/48 [00:00<?, ?it/s]

Processing sample 2:   2%|█▎                                                            | 1/48 [00:03<02:35,  3.32s/it]

Processing sample 2:   4%|██▌                                                           | 2/48 [00:06<02:33,  3.35s/it]

Processing sample 2:   6%|███▉                                                          | 3/48 [00:10<02:32,  3.40s/it]

Processing sample 2:   8%|█████▏                                                        | 4/48 [00:13<02:28,  3.38s/it]

Processing sample 2:  10%|██████▍                                                       | 5/48 [00:16<02:24,  3.35s/it]

Processing sample 2:  12%|███████▊                                                      | 6/48 [00:20<02:20,  3.33s/it]

Processing sample 2:  15%|█████████                                                     | 7/48 [00:23<02:16,  3.33s/it]

Processing sample 2:  17%|██████████▎                                                   | 8/48 [00:26<02:14,  3.35s/it]

Processing sample 2:  19%|███████████▋                                                  | 9/48 [00:30<02:09,  3.33s/it]

Processing sample 2:  21%|████████████▋                                                | 10/48 [00:33<02:06,  3.32s/it]

Processing sample 2:  23%|█████████████▉                                               | 11/48 [00:36<02:04,  3.35s/it]

Processing sample 2:  25%|███████████████▎                                             | 12/48 [00:40<02:00,  3.34s/it]

Processing sample 2:  27%|████████████████▌                                            | 13/48 [00:43<01:56,  3.33s/it]

Processing sample 2:  29%|█████████████████▊                                           | 14/48 [00:46<01:52,  3.32s/it]

Processing sample 2:  31%|███████████████████                                          | 15/48 [00:50<01:49,  3.32s/it]

Processing sample 2:  33%|████████████████████▎                                        | 16/48 [00:53<01:45,  3.31s/it]

Processing sample 2:  35%|█████████████████████▌                                       | 17/48 [00:56<01:43,  3.33s/it]

Processing sample 2:  38%|██████████████████████▉                                      | 18/48 [01:00<01:40,  3.35s/it]

Processing sample 2:  40%|████████████████████████▏                                    | 19/48 [01:03<01:36,  3.33s/it]

Processing sample 2:  42%|█████████████████████████▍                                   | 20/48 [01:06<01:34,  3.38s/it]

Processing sample 2:  44%|██████████████████████████▋                                  | 21/48 [01:10<01:32,  3.43s/it]

Processing sample 2:  46%|███████████████████████████▉                                 | 22/48 [01:13<01:29,  3.42s/it]

Processing sample 2:  48%|█████████████████████████████▏                               | 23/48 [01:17<01:25,  3.43s/it]

Processing sample 2:  50%|██████████████████████████████▌                              | 24/48 [01:20<01:22,  3.42s/it]

Processing sample 2:  52%|███████████████████████████████▊                             | 25/48 [01:24<01:17,  3.38s/it]

Processing sample 2:  54%|█████████████████████████████████                            | 26/48 [01:27<01:14,  3.37s/it]

Processing sample 2:  56%|██████████████████████████████████▎                          | 27/48 [01:30<01:10,  3.38s/it]

Processing sample 2:  58%|███████████████████████████████████▌                         | 28/48 [01:34<01:07,  3.35s/it]

Processing sample 2:  60%|████████████████████████████████████▊                        | 29/48 [01:37<01:03,  3.36s/it]

Processing sample 2:  62%|██████████████████████████████████████▏                      | 30/48 [01:40<01:00,  3.37s/it]

Processing sample 2:  65%|███████████████████████████████████████▍                     | 31/48 [01:44<00:56,  3.35s/it]

Processing sample 2:  67%|████████████████████████████████████████▋                    | 32/48 [01:47<00:53,  3.37s/it]

Processing sample 2:  69%|█████████████████████████████████████████▉                   | 33/48 [01:51<00:51,  3.41s/it]

Processing sample 2:  71%|███████████████████████████████████████████▏                 | 34/48 [01:54<00:47,  3.41s/it]

Processing sample 2:  73%|████████████████████████████████████████████▍                | 35/48 [01:57<00:44,  3.42s/it]

Processing sample 2:  75%|█████████████████████████████████████████████▊               | 36/48 [02:01<00:40,  3.40s/it]

Processing sample 2:  77%|███████████████████████████████████████████████              | 37/48 [02:04<00:37,  3.40s/it]

Processing sample 2:  79%|████████████████████████████████████████████████▎            | 38/48 [02:07<00:33,  3.38s/it]

Processing sample 2:  81%|█████████████████████████████████████████████████▌           | 39/48 [02:11<00:30,  3.43s/it]

Processing sample 2:  83%|██████████████████████████████████████████████████▊          | 40/48 [02:14<00:27,  3.40s/it]

Processing sample 2:  85%|████████████████████████████████████████████████████         | 41/48 [02:18<00:23,  3.37s/it]

Processing sample 2:  88%|█████████████████████████████████████████████████████▍       | 42/48 [02:21<00:20,  3.39s/it]

Processing sample 2:  90%|██████████████████████████████████████████████████████▋      | 43/48 [02:24<00:16,  3.38s/it]

Processing sample 2:  92%|███████████████████████████████████████████████████████▉     | 44/48 [02:28<00:13,  3.36s/it]

Processing sample 2:  94%|█████████████████████████████████████████████████████████▏   | 45/48 [02:31<00:10,  3.39s/it]

Processing sample 2:  96%|██████████████████████████████████████████████████████████▍  | 46/48 [02:35<00:06,  3.45s/it]

Processing sample 2:  98%|███████████████████████████████████████████████████████████▋ | 47/48 [02:38<00:03,  3.46s/it]

Processing sample 2: 100%|█████████████████████████████████████████████████████████████| 48/48 [02:42<00:00,  3.47s/it]

Processing sample 2: 100%|█████████████████████████████████████████████████████████████| 48/48 [02:42<00:00,  3.38s/it]


Processing sample 3:   0%|                                                                      | 0/48 [00:00<?, ?it/s]

Processing sample 3:   2%|█▎                                                            | 1/48 [00:03<02:38,  3.37s/it]

Processing sample 3:   4%|██▌                                                           | 2/48 [00:06<02:35,  3.38s/it]

Processing sample 3:   6%|███▉                                                          | 3/48 [00:10<02:34,  3.42s/it]

Processing sample 3:   8%|█████▏                                                        | 4/48 [00:13<02:30,  3.42s/it]

Processing sample 3:  10%|██████▍                                                       | 5/48 [00:16<02:25,  3.38s/it]

Processing sample 3:  12%|███████▊                                                      | 6/48 [00:20<02:22,  3.38s/it]

Processing sample 3:  15%|█████████                                                     | 7/48 [00:23<02:17,  3.36s/it]

Processing sample 3:  17%|██████████▎                                                   | 8/48 [00:26<02:13,  3.34s/it]

Processing sample 3:  19%|███████████▋                                                  | 9/48 [00:30<02:10,  3.34s/it]

Processing sample 3:  21%|████████████▋                                                | 10/48 [00:33<02:06,  3.33s/it]

Processing sample 3:  23%|█████████████▉                                               | 11/48 [00:36<02:02,  3.32s/it]

Processing sample 3:  25%|███████████████▎                                             | 12/48 [00:40<02:01,  3.37s/it]

Processing sample 3:  27%|████████████████▌                                            | 13/48 [00:43<01:58,  3.39s/it]

Processing sample 3:  29%|█████████████████▊                                           | 14/48 [00:47<01:57,  3.46s/it]

Processing sample 3:  31%|███████████████████                                          | 15/48 [00:50<01:52,  3.41s/it]

Processing sample 3:  33%|████████████████████▎                                        | 16/48 [00:54<01:48,  3.38s/it]

Processing sample 3:  35%|█████████████████████▌                                       | 17/48 [00:57<01:44,  3.36s/it]

Processing sample 3:  38%|██████████████████████▉                                      | 18/48 [01:00<01:40,  3.36s/it]

Processing sample 3:  40%|████████████████████████▏                                    | 19/48 [01:04<01:37,  3.38s/it]

Processing sample 3:  42%|█████████████████████████▍                                   | 20/48 [01:07<01:34,  3.36s/it]

Processing sample 3:  44%|██████████████████████████▋                                  | 21/48 [01:10<01:30,  3.34s/it]

Processing sample 3:  46%|███████████████████████████▉                                 | 22/48 [01:14<01:26,  3.33s/it]

Processing sample 3:  48%|█████████████████████████████▏                               | 23/48 [01:17<01:23,  3.34s/it]

Processing sample 3:  50%|██████████████████████████████▌                              | 24/48 [01:20<01:19,  3.32s/it]

Processing sample 3:  52%|███████████████████████████████▊                             | 25/48 [01:24<01:16,  3.34s/it]

Processing sample 3:  54%|█████████████████████████████████                            | 26/48 [01:27<01:14,  3.39s/it]

Processing sample 3:  56%|██████████████████████████████████▎                          | 27/48 [01:30<01:10,  3.37s/it]

Processing sample 3:  58%|███████████████████████████████████▌                         | 28/48 [01:34<01:06,  3.35s/it]

Processing sample 3:  60%|████████████████████████████████████▊                        | 29/48 [01:37<01:03,  3.34s/it]

Processing sample 3:  62%|██████████████████████████████████████▏                      | 30/48 [01:41<01:01,  3.40s/it]

Processing sample 3:  65%|███████████████████████████████████████▍                     | 31/48 [01:44<00:57,  3.40s/it]

Processing sample 3:  67%|████████████████████████████████████████▋                    | 32/48 [01:47<00:54,  3.38s/it]

Processing sample 3:  69%|█████████████████████████████████████████▉                   | 33/48 [01:51<00:50,  3.36s/it]

Processing sample 3:  71%|███████████████████████████████████████████▏                 | 34/48 [01:54<00:46,  3.34s/it]

Processing sample 3:  73%|████████████████████████████████████████████▍                | 35/48 [01:57<00:43,  3.33s/it]

Processing sample 3:  75%|█████████████████████████████████████████████▊               | 36/48 [02:01<00:40,  3.34s/it]

Processing sample 3:  77%|███████████████████████████████████████████████              | 37/48 [02:04<00:37,  3.37s/it]

Processing sample 3:  79%|████████████████████████████████████████████████▎            | 38/48 [02:07<00:33,  3.37s/it]

Processing sample 3:  81%|█████████████████████████████████████████████████▌           | 39/48 [02:11<00:30,  3.37s/it]

Processing sample 3:  83%|██████████████████████████████████████████████████▊          | 40/48 [02:14<00:26,  3.35s/it]

Processing sample 3:  85%|████████████████████████████████████████████████████         | 41/48 [02:17<00:23,  3.33s/it]

Processing sample 3:  88%|█████████████████████████████████████████████████████▍       | 42/48 [02:21<00:19,  3.33s/it]

Processing sample 3:  90%|██████████████████████████████████████████████████████▋      | 43/48 [02:24<00:16,  3.34s/it]

Processing sample 3:  92%|███████████████████████████████████████████████████████▉     | 44/48 [02:27<00:13,  3.33s/it]

Processing sample 3:  94%|█████████████████████████████████████████████████████████▏   | 45/48 [02:31<00:09,  3.33s/it]

Processing sample 3:  96%|██████████████████████████████████████████████████████████▍  | 46/48 [02:34<00:06,  3.36s/it]

Processing sample 3:  98%|███████████████████████████████████████████████████████████▋ | 47/48 [02:37<00:03,  3.34s/it]

Processing sample 3: 100%|█████████████████████████████████████████████████████████████| 48/48 [02:41<00:00,  3.38s/it]

Processing sample 3: 100%|█████████████████████████████████████████████████████████████| 48/48 [02:41<00:00,  3.36s/it]


Processing sample 4:   0%|                                                                      | 0/48 [00:00<?, ?it/s]

Processing sample 4:   2%|█▎                                                            | 1/48 [00:03<02:39,  3.40s/it]

Processing sample 4:   4%|██▌                                                           | 2/48 [00:06<02:35,  3.39s/it]

Processing sample 4:   6%|███▉                                                          | 3/48 [00:10<02:31,  3.36s/it]

Processing sample 4:   8%|█████▏                                                        | 4/48 [00:13<02:32,  3.46s/it]

Processing sample 4:  10%|██████▍                                                       | 5/48 [00:17<02:29,  3.48s/it]

Processing sample 4:  12%|███████▊                                                      | 6/48 [00:20<02:23,  3.43s/it]

Processing sample 4:  15%|█████████                                                     | 7/48 [00:23<02:19,  3.40s/it]

Processing sample 4:  17%|██████████▎                                                   | 8/48 [00:27<02:14,  3.37s/it]

Processing sample 4:  19%|███████████▋                                                  | 9/48 [00:30<02:10,  3.35s/it]

Processing sample 4:  21%|████████████▋                                                | 10/48 [00:34<02:10,  3.43s/it]

Processing sample 4:  23%|█████████████▉                                               | 11/48 [00:37<02:06,  3.41s/it]

Processing sample 4:  25%|███████████████▎                                             | 12/48 [00:40<02:01,  3.38s/it]

Processing sample 4:  27%|████████████████▌                                            | 13/48 [00:44<01:57,  3.36s/it]

Processing sample 4:  29%|█████████████████▊                                           | 14/48 [00:47<01:55,  3.41s/it]

Processing sample 4:  31%|███████████████████                                          | 15/48 [00:51<01:53,  3.43s/it]

Processing sample 4:  33%|████████████████████▎                                        | 16/48 [00:54<01:48,  3.39s/it]

Processing sample 4:  35%|█████████████████████▌                                       | 17/48 [00:57<01:44,  3.38s/it]

Processing sample 4:  38%|██████████████████████▉                                      | 18/48 [01:01<01:40,  3.35s/it]

Processing sample 4:  40%|████████████████████████▏                                    | 19/48 [01:04<01:38,  3.39s/it]

Processing sample 4:  42%|█████████████████████████▍                                   | 20/48 [01:07<01:34,  3.36s/it]

Processing sample 4:  44%|██████████████████████████▋                                  | 21/48 [01:11<01:30,  3.35s/it]

Processing sample 4:  46%|███████████████████████████▉                                 | 22/48 [01:14<01:26,  3.33s/it]

Processing sample 4:  48%|█████████████████████████████▏                               | 23/48 [01:17<01:23,  3.34s/it]

Processing sample 4:  50%|██████████████████████████████▌                              | 24/48 [01:21<01:21,  3.38s/it]

Processing sample 4:  52%|███████████████████████████████▊                             | 25/48 [01:24<01:17,  3.36s/it]

Processing sample 4:  54%|█████████████████████████████████                            | 26/48 [01:27<01:13,  3.34s/it]

Processing sample 4:  56%|██████████████████████████████████▎                          | 27/48 [01:31<01:09,  3.33s/it]

Processing sample 4:  58%|███████████████████████████████████▌                         | 28/48 [01:34<01:06,  3.32s/it]

Processing sample 4:  60%|████████████████████████████████████▊                        | 29/48 [01:37<01:03,  3.32s/it]

Processing sample 4:  62%|██████████████████████████████████████▏                      | 30/48 [01:41<00:59,  3.32s/it]

Processing sample 4:  65%|███████████████████████████████████████▍                     | 31/48 [01:44<00:56,  3.32s/it]

Processing sample 4:  67%|████████████████████████████████████████▋                    | 32/48 [01:47<00:52,  3.31s/it]

Processing sample 4:  69%|█████████████████████████████████████████▉                   | 33/48 [01:51<00:50,  3.36s/it]

Processing sample 4:  71%|███████████████████████████████████████████▏                 | 34/48 [01:54<00:46,  3.34s/it]

Processing sample 4:  73%|████████████████████████████████████████████▍                | 35/48 [01:57<00:43,  3.33s/it]

Processing sample 4:  75%|█████████████████████████████████████████████▊               | 36/48 [02:01<00:39,  3.32s/it]

Processing sample 4:  77%|███████████████████████████████████████████████              | 37/48 [02:04<00:36,  3.31s/it]

Processing sample 4:  79%|████████████████████████████████████████████████▎            | 38/48 [02:08<00:33,  3.40s/it]

Processing sample 4:  81%|█████████████████████████████████████████████████▌           | 39/48 [02:11<00:30,  3.41s/it]

Processing sample 4:  83%|██████████████████████████████████████████████████▊          | 40/48 [02:14<00:27,  3.43s/it]

Processing sample 4:  85%|████████████████████████████████████████████████████         | 41/48 [02:18<00:23,  3.40s/it]

Processing sample 4:  88%|█████████████████████████████████████████████████████▍       | 42/48 [02:21<00:20,  3.38s/it]

Processing sample 4:  90%|██████████████████████████████████████████████████████▋      | 43/48 [02:25<00:16,  3.38s/it]

Processing sample 4:  92%|███████████████████████████████████████████████████████▉     | 44/48 [02:28<00:13,  3.42s/it]

Processing sample 4:  94%|█████████████████████████████████████████████████████████▏   | 45/48 [02:32<00:10,  3.45s/it]

Processing sample 4:  96%|██████████████████████████████████████████████████████████▍  | 46/48 [02:35<00:06,  3.44s/it]

Processing sample 4:  98%|███████████████████████████████████████████████████████████▋ | 47/48 [02:38<00:03,  3.41s/it]

Processing sample 4: 100%|█████████████████████████████████████████████████████████████| 48/48 [02:42<00:00,  3.42s/it]

Processing sample 4: 100%|█████████████████████████████████████████████████████████████| 48/48 [02:42<00:00,  3.38s/it]




In [1]:
import pandas as pd
header = ['original', 'altered', 'alteration_id', 'len_diff', 'depth_diff', 'n_root_children_diff', 
          'n_word_diff', 'word_diff', 'diff_pos', 'diff_dep', 'new_deps', 'root_children_dep_diff',
           'root_children_text_diff']
#df = pd.DataFrame(stats, columns=header)
df = pd.read_csv('Autoreg5stats.csv')
print(df['alteration_id'].unique())
grouped = df.groupby('alteration_id')
diff_df = grouped.mean()[['len_diff', 'depth_diff', 'n_root_children_diff']]
diff_df['len_diff'] /= diff_df['len_diff'].max()
diff_df['depth_diff'] /= diff_df['depth_diff'].max()
diff_df['n_root_children_diff'] /= diff_df['n_root_children_diff'].max()
print(diff_df.sort_values('depth_diff', axis=0))

[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47]
               len_diff  depth_diff  n_root_children_diff
alteration_id                                            
18             0.012422    0.011952              0.033333
20             0.013199    0.035857              0.040741
31             0.010093    0.035857              0.033333
25             0.028727    0.035857              0.018519
26             0.020963    0.035857              0.014815
41             0.012422    0.039841              0.014815
21             0.014752    0.043825              0.022222
27             0.054348    0.047809              0.059259
19             0.058230    0.051793              0.048148
17             0.060559    0.051793              0.070370
32             0.016304    0.051793              0.029630
24             0.034161    0.055777              0.066667
28             0.027950    0.055777      

In [2]:
# print(np.unique(df['new_deps'].array, return_counts=True))#'root_children_diff',
from tqdm import tqdm
import numpy as np
header = ['original', 'altered', 'alteration_id', 'len_diff', 'depth_diff', 'n_root_children_diff', 
          'n_word_diff', 'word_diff', 'diff_pos', 'diff_dep', 'new_deps', 'root_children_dep_diff',
           'root_children_text_diff']
print(df['diff_dep'].array[116], )
def revert_to_l1(el):
    if len(el[1:-1]):
        el = el.replace('(', '').replace("'", '').replace(' ', '').replace('),', ')').replace(']', '').replace('[', '')
        output = [el_i.split(",") for el_i in el.split(')') if len(el_i)>3]
        if len(output)>1: 
            output = np.concatenate(output)
        return np.unique(output)
    else:
        return []
# df['diff_dep'] = df['diff_dep'].map(revert_to_l1)
# df['new_deps'] = df['new_deps'].map(revert_to_l1)
# df['root_children_dep_diff'] = df['root_children_dep_diff'].map(revert_to_l1)
d_dep_types = ['d_'+ty for ty in np.unique(np.concatenate(df['diff_dep'].array))]
for ty in tqdm(d_dep_types):
    concerned = []
    for deps in df['diff_dep'].array.astype(list):
        concerned.append(ty[2:] in deps)
    df[ty] = concerned

dep_types = np.unique(np.concatenate(df['new_deps'].array))
for ty in tqdm(dep_types):
    concerned = []
    for deps in df['new_deps'].array:
        concerned.append(ty in deps)
    df[ty] = concerned
c_dep_types = ['c_'+ty for ty in np.unique(np.concatenate(df['root_children_dep_diff'].array))]
for ty in tqdm(c_dep_types):
    concerned = []
    for deps in df['root_children_dep_diff'].array:
        concerned.append(ty[2:] in deps)
    df[ty] = concerned



[('amod', 'compound'), ('pobj', 'pobj')]


ValueError: zero-dimensional arrays cannot be concatenated

In [61]:


header = ['original', 'altered', 'alteration_id', 'len_diff', 'depth_diff', 'n_root_children_diff', 
          'n_word_diff', 'word_diff', 'diff_pos', 'diff_dep', 'new_deps', 'root_children_dep_diff',
           'root_children_text_diff']
pd.set_option("display.max_columns", 100)
pd.set_option('display.width', 150)
pd.options.display.max_rows = 10000


grouped = df.groupby('alteration_id')
# print(grouped.mean()[dep_types])
print('************ ANY ****************')
print('       *** Values ***')
print(grouped.mean()[dep_types].max())
print('       ***  IDX   ***')
print(grouped.mean()[dep_types].idxmax())
print('************ ROOT ****************')
print('       *** Values ***')
print(grouped.mean()[c_dep_types].max())
print('       ***  IDX   ***')
print(grouped.mean()[c_dep_types].idxmax())
print('************ Same length differences ****************')
print('       *** Values ***')
print(grouped.mean()[d_dep_types].max())
print('       ***  IDX   ***')
print(grouped.mean()[d_dep_types].idxmax())

************ ANY ****************
       *** Values ***
acl          0.010
acomp        0.078
advcl        0.013
advmod       0.068
agent        0.014
amod         0.326
attr         0.030
aux          0.111
auxpass      0.017
cc           0.029
compound     0.024
conj         0.029
dep          0.013
det          0.004
dobj         0.369
expl         0.020
mark         0.005
neg          0.005
npadvmod     0.001
nsubj        0.030
nsubjpass    0.017
nummod       0.047
pobj         0.343
poss         0.073
prep         0.348
prt          0.054
relcl        0.002
xcomp        0.090
dtype: float64
       ***  IDX   ***
acl          30
acomp        10
advcl        43
advmod       10
agent        10
amod         30
attr         30
aux          30
auxpass      10
cc           30
compound     43
conj         30
dep          10
det          10
dobj         10
expl         30
mark         43
neg          33
npadvmod     10
nsubj        10
nsubjpass    10
nummod       30
pobj         10
poss   

In [14]:

pd.set_option("display.max_columns", 100)
pd.set_option('display.width', 150)
pd.options.display.max_rows = 10000
for line in df[df['alteration_id'] == 39][['original', 'altered']].iterrows():    
    if line[1]['original'] != line[1]['altered']:
        print(line[1]['original'], line[1]['altered'])

 a man is standing in front of a street ..   a man with a black shirt is walking down a street .. 
 a young boy is walking with a red toy ..   a young boy is running outside .. 
 a man is sitting in a park ..   a man is sitting on a bench .. 
 a man with a pink hat is swimming ..   a man with a blue hat is swimming .. 
 a dog is sitting in a field with a blue shirt ..   a dog is sitting in a race with a blue shirt .. 
 a man in a green shirt is playing a rock outside ..   a man is laying on a couch .. 
 a boy is sitting on a couch ..   a boy is sitting in a chair .. 
 a little boy is preparing to play with a ball ..   a little boy is jumping off a mountain with a red shirt .. 
 a little boy is swimming in a pond ..   a little boy is jumping in a pond .. 
 a young boy is walking with a red toy ..   a young boy is sitting outside .. 
 a man is sitting in a park ..   a man is sitting outside .. 
 a man with a pink hat is swimming ..   a man with a hat is looking at a statue .. 
 a man in 

In [140]:
df.to_csv('Autoreg5stats.csv')

In [None]:
pd.read_csv('Autoreg5stats.csv')