# Instanciating the model


In [1]:
from time import time
import argparse
import os

from torch import device
import torch
from torch import optim
import numpy as np

from data_prep import NLIGenData2, OntoGenData, HuggingYelp2
from disentanglement_qkv.models import DisentanglementTransformerVAE, LaggingDisentanglementTransformerVAE
from disentanglement_qkv.h_params import DefaultTransformerHParams as HParams
from disentanglement_qkv.graphs import *
from components.criteria import *
parser = argparse.ArgumentParser()
from torch.nn import MultiheadAttention
# Training and Optimization
k, kz, klstm = 1, 8, 2
# k, kz, klstm = 2, 4, 2
parser.add_argument("--test_name", default='unnamed', type=str)
parser.add_argument("--data", default='nli', choices=["nli", "ontonotes", "yelp"], type=str)
parser.add_argument("--csv_out", default='disentqkv.csv', type=str)
parser.add_argument("--max_len", default=17, type=int)
parser.add_argument("--batch_size", default=128, type=int)
parser.add_argument("--grad_accu", default=1, type=int)
parser.add_argument("--n_epochs", default=20, type=int)
parser.add_argument("--test_freq", default=32, type=int)
parser.add_argument("--complete_test_freq", default=160, type=int)
parser.add_argument("--generation_weight", default=1, type=float)
parser.add_argument("--device", default='cuda:0', choices=["cuda:0", "cuda:1", "cuda:2", "cpu"], type=str)
parser.add_argument("--embedding_dim", default=128, type=int)#################"
parser.add_argument("--pretrained_embeddings", default=False, type=bool)#################"
parser.add_argument("--z_size", default=96*kz, type=int)#################"
parser.add_argument("--z_emb_dim", default=192*k, type=int)#################"
parser.add_argument("--n_keys", default=4, type=int)#################"
parser.add_argument("--n_latents", default=[16, 16, 16], nargs='+', type=int)#################"
parser.add_argument("--text_rep_l", default=3, type=int)
parser.add_argument("--text_rep_h", default=192*k, type=int)
parser.add_argument("--encoder_h", default=192*k, type=int)#################"
parser.add_argument("--encoder_l", default=2, type=int)#################"
parser.add_argument("--decoder_h", default=192*k, type=int)
parser.add_argument("--decoder_l", default=2, type=int)#################"
parser.add_argument("--highway", default=False, type=bool)
parser.add_argument("--markovian", default=True, type=bool)
parser.add_argument('--minimal_enc', dest='minimal_enc', action='store_true')
parser.add_argument('--no-minimal_enc', dest='minimal_enc', action='store_false')
parser.set_defaults(minimal_enc=False)
parser.add_argument("--losses", default='VAE', choices=["VAE", "IWAE" "LagVAE"], type=str)
parser.add_argument("--graph", default='Normal', choices=["Vanilla", "IndepInfer", "QKV", "HQKV"], type=str)
parser.add_argument("--training_iw_samples", default=1, type=int)
parser.add_argument("--testing_iw_samples", default=5, type=int)
parser.add_argument("--test_prior_samples", default=10, type=int)
parser.add_argument("--anneal_kl0", default=3000, type=int)
parser.add_argument("--anneal_kl1", default=6000, type=int)
parser.add_argument("--grad_clip", default=5., type=float)
parser.add_argument("--kl_th", default=0/(768*k/2), type=float or None)
parser.add_argument("--max_elbo1", default=6.0, type=float)
parser.add_argument("--max_elbo2", default=4.0, type=float)
parser.add_argument("--max_elbo_choice", default=10, type=int)
parser.add_argument("--kl_beta", default=0.4, type=float)
parser.add_argument("--dropout", default=0.3, type=float)
parser.add_argument("--word_dropout", default=0.1, type=float)
parser.add_argument("--l2_reg", default=0, type=float)
parser.add_argument("--lr", default=2e-4, type=float)
parser.add_argument("--lr_reduction", default=4., type=float)
parser.add_argument("--wait_epochs", default=1, type=float)
parser.add_argument("--save_all", default=True, type=bool)

flags, _ = parser.parse_known_args()

# Manual Settings, Deactivate before pushing
if True:
    flags.batch_size = 128
    flags.grad_accu = 1
    flags.max_len = 17
    # flags.test_name = "nliLM/YelpQKV_beta0.5.0.3.1.16"#"nliLM/HQKVTest2"
    flags.test_name = "nliLM/HQKVTest2"
    flags.data = "yelp"
    flags.n_latents = [8]
    # flags.n_latents = [16]
    flags.graph = "HQKV"  
    # flags.graph ="QKV"
    # flags.losses = "LagVAE"
    flags.kl_beta = 0.5
    
    # flags.anneal_kl0 = 0
    flags.max_elbo_choice = 6
    # flags.z_size = 16
    # flags.encoder_h = 256
    # flags.decoder_h = 256
    


# torch.autograd.set_detect_anomaly(True)
GRAPH = {"Vanilla": get_vanilla_graph,
         "IndepInfer": get_structured_auto_regressive_indep_graph,
         "QKV": get_qkv_graph2,
         "HQKV": get_hqkv_graph_old}[flags.graph]
if flags.graph == "NormalLSTM":
    flags.encoder_h = int(flags.encoder_h/k*klstm)
if flags.graph == "Vanilla":
    flags.n_latents = [flags.z_size]
if flags.losses == "LagVAE":
    flags.anneal_kl0 = 0
    flags.anneal_kl1 = 0
Data = {"nli": NLIGenData2, "ontonotes": OntoGenData, "yelp": HuggingYelp2}[flags.data]
MAX_LEN = flags.max_len
BATCH_SIZE = flags.batch_size
GRAD_ACCU = flags.grad_accu
N_EPOCHS = flags.n_epochs
TEST_FREQ = flags.test_freq
COMPLETE_TEST_FREQ = flags.complete_test_freq
DEVICE = device(flags.device)
# This prevents illegal memory access on multigpu machines (unresolved issue on torch's github)
if flags.device.startswith('cuda'):
    torch.cuda.set_device(int(flags.device[-1]))
LOSSES = {'IWAE': [IWLBo],
          'VAE': [ELBo],
          'LagVAE': [ELBo]}[flags.losses]

ANNEAL_KL = [flags.anneal_kl0*flags.grad_accu, flags.anneal_kl1*flags.grad_accu]
LOSS_PARAMS = [1]
if flags.grad_accu > 1:
    LOSS_PARAMS = [w/flags.grad_accu for w in LOSS_PARAMS]


data = Data(MAX_LEN, BATCH_SIZE, N_EPOCHS, DEVICE, pretrained=flags.pretrained_embeddings)
h_params = HParams(len(data.vocab.itos), len(data.tags.itos) if flags.data == 'yelp' else None, MAX_LEN, BATCH_SIZE, N_EPOCHS,
                   device=DEVICE, vocab_ignore_index=data.vocab.stoi['<pad>'], decoder_h=flags.decoder_h,
                   decoder_l=flags.decoder_l, encoder_h=flags.encoder_h, encoder_l=flags.encoder_l,
                   text_rep_h=flags.text_rep_h, text_rep_l=flags.text_rep_l,
                   test_name=flags.test_name, grad_accumulation_steps=GRAD_ACCU,
                   optimizer_kwargs={'lr': flags.lr, #'weight_decay': flags.l2_reg, 't0':100, 'lambd':0.},
                                     'weight_decay': flags.l2_reg, 'betas': (0.9, 0.99)},
                   is_weighted=[], graph_generator=GRAPH,
                   z_size=flags.z_size, embedding_dim=flags.embedding_dim, anneal_kl=ANNEAL_KL,
                   grad_clip=flags.grad_clip*flags.grad_accu, kl_th=flags.kl_th, highway=flags.highway,
                   losses=LOSSES, dropout=flags.dropout, training_iw_samples=flags.training_iw_samples,
                   testing_iw_samples=flags.testing_iw_samples, loss_params=LOSS_PARAMS, optimizer=optim.AdamW,
                   markovian=flags.markovian, word_dropout=flags.word_dropout, contiguous_lm=False,
                   test_prior_samples=flags.test_prior_samples, n_latents=flags.n_latents, n_keys=flags.n_keys,
                   max_elbo=[flags.max_elbo_choice, flags.max_elbo1],  # max_elbo is paper's beta
                   z_emb_dim=flags.z_emb_dim, minimal_enc=flags.minimal_enc, kl_beta=flags.kl_beta)
val_iterator = iter(data.val_iter)
print("Words: ", len(data.vocab.itos), ", On device: ", DEVICE.type)
print("Loss Type: ", flags.losses)
if flags.losses == 'LagVAE':
    model = LaggingDisentanglementTransformerVAE(data.vocab, data.tags, h_params, wvs=data.wvs, dataset=flags.data,
                                                 enc_iter=data.enc_train_iter)
else:
    model = DisentanglementTransformerVAE(data.vocab, data.tags, h_params, wvs=data.wvs, dataset=flags.data)
if DEVICE.type == 'cuda':
    model.cuda(DEVICE)

total_unsupervised_train_samples = len(data.train_iter)*BATCH_SIZE
total_unsupervised_val_samples = len(data.val_iter)*BATCH_SIZE
print("Unsupervised training examples: ", total_unsupervised_train_samples)
print("Unsupervised val examples: ", total_unsupervised_val_samples)
current_time = time()
#print(model)
number_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of parameters: ", "{0:05.2f} M".format(number_parameters/1e6))
number_parameters = sum(p.numel() for p in model.infer_bn.parameters() if p.requires_grad)
print("Inference parameters: ", "{0:05.2f} M".format(number_parameters/1e6))
number_parameters = sum(p.numel() for p in model.gen_bn.parameters() if p.requires_grad)
print("Generation parameters: ", "{0:05.2f} M".format(number_parameters/1e6))
number_parameters = sum(p.numel() for p in model.word_embeddings.parameters() if p.requires_grad)
print("Embedding parameters: ", "{0:05.2f} M".format(number_parameters/1e6))
model.eval()

Dataset has 443259  examples. statistics:
 -words: 8.881732801815643+-3.6417630572571147(quantiles(0.5, 0.7, 0.9, 0.95, 0.99:9.0,11.0,14.0,15.0,15.0)
Dataset has 4000  examples. statistics:
 -words: 8.9255+-3.668371539252806(quantiles(0.5, 0.7, 0.9, 0.95, 0.99:9.0,11.0,14.0,15.0,15.0)
Dataset has 1000  examples. statistics:
 -words: 10.325+-2.8399603870476784(quantiles(0.5, 0.7, 0.9, 0.95, 0.99:10.0,12.0,14.0,15.0,15.0)
data loading took 5.219523668289185


Words:  9600 , On device:  cuda
Loss Type:  VAE
reconstruction net size: 25.54 M
prior net sizes:


Loaded model at step 27704
Unsupervised training examples:  443264
Unsupervised val examples:  42752
Number of parameters:  30.71 M
Inference parameters:  06.02 M
Generation parameters:  25.91 M
Embedding parameters:  01.23 M


In [2]:
text, samples, params = model.get_sentences(5, gen_len=16, sample_w=False, vary_z=True, complete=None, contains=None, max_tries=100)

print(text)


[" i 've been sorry ", ' the service was great and the food was great ', " i have been here for years and i 've had a better experience ", ' i had a great experience with this place ', ' the food is great ']


In [7]:
var_ids = [8]
alt_text, alt_params = model._get_alternative_sentences(samples, None, var_ids, 2, 16, complete=None,)
for i in range(len(text)):
    print(text[i], ':', alt_text[i::len(text)])


 i 've been sorry  : [' i had a great experience with my friend at this location ', " i 've been a regular for a year "]
 the service was great and the food was great  : [' service was great ', ' service was great ']
 i have been here for years and i 've had a better experience  : [" i 've been here for years , and love ", ' unfortunately , we were very disappointed ']
 i had a great experience with this place  : [" it 's pretty good ", " it 's a great place to go "]
 the food is great  : [' the food is a great place to go ', ' good food ']


In [8]:
def _get_alternative_sentences(mdl, prev_latent_vals, params, var_z_ids, n_samples, gen_len, complete=None):
        h_params = mdl.h_params
        has_struct = mdl.h_params.graph_generator in (get_qkv_graph, get_hqkv_graph)
        has_zg = mdl.h_params.graph_generator = get_hqkv_graph

        n_orig_sentences = prev_latent_vals['z1'].shape[0]
        go_symbol = torch.ones([n_samples * n_orig_sentences]).long() * \
                    mdl.index[mdl.generated_v].stoi['<go>']
        go_symbol = go_symbol.to(mdl.h_params.device).unsqueeze(-1)
        x_prev = go_symbol
        if complete is not None:
            for token in complete.split(' '):
                x_prev = torch.cat(
                    [x_prev, torch.ones([n_samples * n_orig_sentences, 1]).long().to(mdl.h_params.device) * \
                     mdl.index[mdl.generated_v].stoi[token]], dim=1)
            gen_len = gen_len - len(complete.split(' '))

        orig_zs = [prev_latent_vals['z{}'.format(i+1)].repeat(n_samples, 1) for i in range(len(h_params.n_latents))]
        zs = [mdl.gen_bn.name_to_v['z{}'.format(i+1)] for i in range(len(h_params.n_latents))]
        gen_input = {**{'z{}'.format(i+1): orig_zs[i].unsqueeze(1) for i in range(len(orig_zs))},
                     'x_prev': torch.zeros((n_samples * n_orig_sentences, 1, mdl.generated_v.size)).to(
                         mdl.h_params.device)}
        if has_struct:
            orig_zst = prev_latent_vals['zs'].repeat(n_samples, 1)
            zst = mdl.gen_bn.name_to_v['zs']
            gen_input['zs'] = orig_zst.unsqueeze(1)
        if has_zg:
            orig_zg = prev_latent_vals['zg'].repeat(n_samples, 1)
            zg = mdl.gen_bn.name_to_v['zg']
            # gen_input['zg'] = zg.prior_sample((n_samples * n_orig_sentences,))[0]
            gen_input['zg'] = orig_zg.unsqueeze(1)
        mdl.gen_bn(gen_input)
        if has_zg:
            z1_sample = zs[0].posterior_sample(mdl.gen_bn.name_to_v['z1'].post_params)[0].squeeze(1)
            if has_struct:
                zst_sample = zst.posterior_sample(mdl.gen_bn.name_to_v['zs'].post_params)[0].squeeze(1)
        else:
            z1_sample = zs[0].prior_sample((n_samples * n_orig_sentences,))[0]
            if has_struct:
                zst_sample = zst.prior_sample((n_samples * n_orig_sentences,))[0]
        zs_sample = [z1_sample] +\
                    [z.post_samples.squeeze(1) for z in zs[1:]]

        for id in var_z_ids:
            # id == sum(h_params.n_latents) means its zst
            if id == sum(h_params.n_latents) and has_struct:
                orig_zst = zst_sample
            else:
                assert id < sum(h_params.n_latents)
                z_number = sum([id > sum(h_params.n_latents[:i + 1]) for i in range(len(h_params.n_latents))])
                z_index = id - sum(h_params.n_latents[:z_number])
                start, end = int(h_params.z_size / max(h_params.n_latents) * z_index), int(
                    h_params.z_size / max(h_params.n_latents) * (z_index + 1))
                source, destination = zs_sample[z_number], orig_zs[z_number]
                destination[:, start:end] = source[:, start:end]

        z_input = {'z{}'.format(i+1): orig_zs[i].unsqueeze(1) for i in range(len(orig_zs))}
        if has_struct:
            z_input['zs'] = orig_zst.unsqueeze(1)
        if has_zg:
            z_input['zg'] = orig_zg.unsqueeze(1)

        # Normal Autoregressive generation
        for i in range(gen_len):
            mdl.gen_bn({'x_prev': x_prev, **{k: v.expand(v.shape[0], i + 1, v.shape[-1])
                                             for k, v in z_input.items()}})
            samples_i = mdl.generated_v.post_params['logits']

            x_prev = torch.cat([x_prev, torch.argmax(samples_i, dim=-1)[..., -1].unsqueeze(-1)],
                               dim=-1)

        text = mdl.decode_to_text2(x_prev, mdl.h_params.vocab_size, mdl.index[mdl.generated_v])
        samples = {'z{}'.format(i+1): zs_sample[i].tolist() for i in range(len(orig_zs))}
        if has_struct:
            samples['zs'] = zst_sample.tolist()
        if has_zg:
            samples['zg'] = orig_zg.tolist()
        return text, samples
print("==varying content==")
var_ids = list(range(8))
alt_text, alt_params = _get_alternative_sentences(model, samples, None, var_ids, 2, 16, complete=None)
for i in range(len(text)):
    print(text[i], ':', alt_text[i::len(text)])
print("==varying structure==")
var_ids = [8]
alt_text, alt_params = _get_alternative_sentences(model, samples, None, var_ids, 4, 16, complete=None)
for i in range(len(text)):
    print(text[i], ':', alt_text[i::len(text)])


==varying content==


 i 've been sorry  : [" i 'm great for the price ", ' i love this place ']
 the service was great and the food was great  : [' the service is great , and the food is great ', ' great place to go for a quick lunch ']
 i have been here for years and i 've had a better experience  : [' the food is great ', ' they have a good selection of food and the service is great ']
 i had a great experience with this place  : [' i enjoy their variety of beers ', " i 've tried this location "]
 the food is great  : [' the bar is great ', ' this is a great place ']
==varying structure==


 i 've been sorry  : [' great place ', " i 'm a fan of this year ", ' i love this place ', ' my favorite thing about this place is the best ']
 the service was great and the food was great  : [' service was very good ', ' service was great ', ' the service is great ', ' the food is great ']
 i have been here for years and i 've had a better experience  : [' we had a fantastic experience ', " i 've been here for years , and i love it ", " i 've been here for years , and again ", " i have been here for years and i 've ever had "]
 i had a great experience with this place  : [" it 's a great place to go ", ' i had a great experience with this place ', " i had a great experience with a friend and it 's a disappointment ", " it 's a great place to go "]
 the food is great  : [' the food is good , but the service is great ', ' the food is great and the price is great ', ' the food is good and a good job ', ' the food is good and good ']


In [29]:
var_ids = [3, 4, 5, 6, 7]
alt_text, alt_params = model._get_alternative_sentences(samples, None, var_ids, 2, 16, complete=None)
for i in range(len(text)):
    print(text[i], ':', alt_text[i::len(text)])

 no server , he was very friendly and helpful  : [' great service and great food ', ' great service ']
 the staff is friendly and helpful and friendly  : [' our food was very good and loved it and it disappeared ', ' the staff is friendly and helpful and so hard to pay her cafe ']
 the place is way more than the food is good  : [' the food was delicious and the service was great ', ' the food was delicious and the service was great ']
 amazing service , what a good time  : [' amazing people ', ' the food is always good , service is great ']
 i 've been here twice and we 're quite happy with my experience  : [' the food is always fresh and delicious ', ' i will definitely be back again ']


In [21]:
def decode_to_text(x_hat_params, vocab_size, vocab_index):
    # It is assumed that this function is used at test time for display purposes
    # Getting the argmax from the one hot if it's not done
    while x_hat_params.shape[-1] == vocab_size and x_hat_params.ndim > 3:
        x_hat_params = x_hat_params.mean(0)
    while x_hat_params.ndim > 2 and x_hat_params.shape[-1] != self.h_params.vocab_size:
        x_hat_params = x_hat_params[0]
    if x_hat_params.shape[-1] == vocab_size:
        x_hat_params = torch.argmax(x_hat_params, dim=-1)
    assert x_hat_params.ndim == 2, "Mis-shaped generated sequence: {}".format(x_hat_params.shape)
    
    samples = [' '.join([vocab_index.itos[w]
                         for w in sen]).split('<eos>')[0].split(' !')[0].split(' .')[0].replace('<go>', '').replace('</go>', '')
               .replace('<pad>', '_').replace('_unk', '<?>')
               for sen in x_hat_params]

    return samples
print(model.gen_bn.name_to_v.keys())

def swap_latents(mdl, prev_latent_vals, var_z_ids, gen_len, complete=None, no_unk=True):
            has_struct = 'zs' in mdl.gen_bn.name_to_v
            has_zg = 'zg' in mdl.gen_bn.name_to_v
            
            
            n_orig_sentences = prev_latent_vals['z1'].shape[0]
            n_samples = n_orig_sentences
            go_symbol = torch.ones([n_samples * n_orig_sentences]).long() * \
                        mdl.index[mdl.generated_v].stoi['<go>']
            go_symbol = go_symbol.to(mdl.h_params.device).unsqueeze(-1)
            x_prev = go_symbol
            if complete is not None:
                for token in complete.split(' '):
                    x_prev = torch.cat([x_prev, torch.ones([n_samples * n_orig_sentences, 1]).long().to(mdl.h_params.device) * \
                        mdl.index[mdl.generated_v].stoi[token]], dim=1)
                gen_len = gen_len - len(complete.split(' '))
            temp = 1.
            orig_z = prev_latent_vals['z1'].unsqueeze(1).repeat(1, n_samples, 1)
            z_sample = orig_z.reshape(n_samples*n_orig_sentences, -1)
            orig_z = orig_z.transpose(0, 1).reshape(n_samples*n_orig_sentences, -1)
            if has_struct:
                orig_zst = prev_latent_vals['zs'].unsqueeze(1).repeat(1, n_samples, 1)
                zst_sample = orig_zst.reshape(n_samples*n_orig_sentences, -1)
                orig_zst = orig_zst.transpose(0, 1).reshape(n_samples*n_orig_sentences, -1)
            if has_zg:
                orig_zg = prev_latent_vals['zg'].unsqueeze(1).repeat(1, n_samples, 1)
                orig_zg = orig_zg.transpose(0, 1).reshape(n_samples*n_orig_sentences, -1)
            

            for id in var_z_ids:
                if id < sum(h_params.n_latents):
                    z_number = sum([id> sum(h_params.n_latents[:i+1]) for i in range(len(h_params.n_latents))])
                    z_index = id - sum(h_params.n_latents[:z_number])
                    start, end = int(h_params.z_size/max(h_params.n_latents)*z_index),\
                                 int(h_params.z_size/max(h_params.n_latents)*(z_index+1))
                    source, destination = [z_sample][z_number], [orig_z][z_number]
                    destination[:, start:end] = source[:, start:end]
                elif id == sum(h_params.n_latents) and has_struct:
                    orig_zst = zst_sample
                else:
                    raise IndexError("You gave a too high z_id for swapping with this model")
                    
            z_input = {'z1': orig_z.unsqueeze(1), **({'zs':orig_zst.unsqueeze(1)} if has_struct else {}), 
                       **({'zg':orig_zg.unsqueeze(1)} if has_zg else {})}
            
            # Normal Autoregressive generation
            for i in range(gen_len):
                mdl.gen_bn({'x_prev': x_prev, **{k: v.expand(v.shape[0], i+1, v.shape[-1])
                                                  for k, v in z_input.items()}})
                samples_i = mdl.generated_v.post_params['logits']
                if no_unk:
                    annul_vector = 1-F.one_hot(torch.tensor([data.vocab.stoi['<unk>']]).to(DEVICE), h_params.vocab_size)
                    samples_i *= annul_vector
                
                x_prev = torch.cat([x_prev, torch.argmax(samples_i,     dim=-1)[..., -1].unsqueeze(-1)],
                                   dim=-1)
            
            text = decode_to_text(x_prev, mdl.h_params.vocab_size, mdl.index[mdl.generated_v])
            return text, {'z1': orig_z}
sw_zs = [8]
sw_text, sw_samples = swap_latents(model, samples, sw_zs, 16, complete=None, no_unk=True)
print(text)
for i in range(len(text)):
    for j in range(len(text)):
        if i!=j:
            print("z_from: ", text[i], "|z_to: ", text[j], "|result: ", sw_text[len(text)*i+j])



dict_keys(['z1', 'x', 'x_prev', 'zs', 'zg'])


[" i 've been sorry ", ' the service was great and the food was great ', " i have been here for years and i 've had a better experience ", ' i had a great experience with this place ', ' the food is great ']
z_from:   i 've been sorry  |z_to:   the service was great and the food was great  |result:   service was very good
z_from:   i 've been sorry  |z_to:   i have been here for years and i 've had a better experience  |result:   we love this location
z_from:   i 've been sorry  |z_to:   i had a great experience with this place  |result:   it 's a disappointment
z_from:   i 've been sorry  |z_to:   the food is great  |result:   the food is great
z_from:   the service was great and the food was great  |z_to:   i 've been sorry  |result:   i 've been a lot of food and the service is great
z_from:   the service was great and the food was great  |z_to:   i have been here for years and i 've had a better experience  |result:   i 've been here for years and i love this place
z_from:   the se

dict_keys(['z1', 'x', 'x_prev', 'zs', 'zg'])


[" i 've been sorry ", ' the service was great and the food was great ', " i have been here for years and i 've had a better experience ", ' i had a great experience with this place ', ' the food is great ']
z_from:   i 've been sorry  |z_to:   the service was great and the food was great  |result:   service was very good
z_from:   i 've been sorry  |z_to:   i have been here for years and i 've had a better experience  |result:   we love this location
z_from:   i 've been sorry  |z_to:   i had a great experience with this place  |result:   it 's a disappointment
z_from:   i 've been sorry  |z_to:   the food is great  |result:   the food is great
z_from:   the service was great and the food was great  |z_to:   i 've been sorry  |result:   i 've been a lot of food and the service is great
z_from:   the service was great and the food was great  |z_to:   i have been here for years and i 've had a better experience  |result:   i 've been here for years and i love this place
z_from:   the se

In [18]:
text, samples, params = model.get_sentences(5, gen_len=16, sample_w=False, vary_z=True, complete=None, 
                                            contains=None, max_tries=100)
print(text)

[' i love this place ', ' i would not recommend this place to anyone ', " i was n't even run to the restaurant ", " i 've had a great experience ", ' i would not recommend this place to anyone who ']


In [9]:
alt_text, alt_samples = model._get_alternative_sentences(samples, params, [8], 4, 16, complete=None)
print("====== Changing Structure=======")
for i in range(len(text)):
    print("-->", text[i], '|<', '><'.join(alt_text[i::len(text)]), '>')
alt_text, alt_samples = model._get_alternative_sentences(samples, params, list(range(8)), 4, 16, complete=None)
print("====== Changing Content=======")
for i in range(len(text)):
    print("-->", text[i], '|<', '><'.join(alt_text[i::len(text)]), '>')

-->  i 've been sorry  |<  i had a great experience with my friend at this location >< my favorite thing about this place is the best >< i 'm a fan of my favorite joint >< i 've been a regular for years and love  >
-->  the service was great and the food was great  |<  service was very friendly and the food was great >< service was very friendly and the food was great >< the food is great >< the staff is always friendly  >
-->  i have been here for years and i 've had a better experience  |<  we had a fantastic experience with this >< we love this location >< i love this location >< unfortunately , i was n't impressed  >
-->  i had a great experience with this place  |<  i had a great experience tonight >< i had a great experience with a friend and it 's a great experience >< it 's a disappointment >< it 's a great place to eat  >
-->  the food is great  |<  i had a great experience here >< good food and a good price >< the food is good , but the service is great >< the food is good an

-->  i 've been sorry  |<  great place and fast >< the atmosphere was great >< i just went there >< i have always had  >
-->  the service was great and the food was great  |<  the staff is always friendly and helpful >< i 've never had a bad experience >< i have been here for years and it 's a great place >< i love this place  >
-->  i have been here for years and i 've had a better experience  |<  i 'm not sure why i do n't go to the bar >< i have been here for years and it 's a bad experience >< i had a great experience with the other day , i 'm disappointed >< i had a great experience with the staff and they were very helpful  >
-->  i had a great experience with this place  |<  i love this place >< i 'm not sure why i 'm not sure why i 'm going back >< i love this place >< i 'm not sure why i 'm not sure why  >
-->  the food is great  |<  the food was ok >< i love this place >< i love this place >< i love this store  >


In [None]:
from disentanglement_qkv.models import template_match, tqdm, pd
def get_generation_TMA(self, n_samples=2000, n_alterations=1, batch_size=100):
        stats = []
        has_struct = 'zs' in self.gen_bn.name_to_v
        assert has_struct
        alter_lvs = [list(range(sum(self.h_params.n_latents))), [sum(self.h_params.n_latents)]]
        n_lvs = sum(self.h_params.n_latents) + 1
        # Generating n_samples sentences
        text, samples, _ = self.get_sentences(n_samples=batch_size, gen_len=self.h_params.max_len - 1,
                                              sample_w=False, vary_z=True, complete=None)
        for _ in tqdm(range(int(n_samples / batch_size)), desc="Generating original sentences"):
            text_i, samples_i, _ = self.get_sentences(n_samples=batch_size, gen_len=self.h_params.max_len - 1,
                                                       sample_w=False, vary_z=True, complete=None)
            text.extend(text_i)
            for k in samples.keys():
                samples[k] = torch.cat([samples[k], samples_i[k]])
        for i in range(int(n_samples / batch_size)):
            for alvs in tqdm(alter_lvs, desc="Processing sample {}".format(str(i))):
                # Altering the sentences
                alt_text, _ = self._get_alternative_sentences(
                    prev_latent_vals={k: v[i * batch_size:(i + 1) * batch_size]
                                      for k, v in samples.items()},
                    params=None, var_z_ids=alvs, n_samples=n_alterations,
                    gen_len=self.h_params.max_len - 1, complete=None)
                # Getting alteration statistics
                orig_texts = [text[(i * batch_size) + k % batch_size] for k in range(n_alterations * batch_size)]
                tma2 = template_match(orig_texts, alt_text, 2)
                tma3 = template_match(orig_texts, alt_text, 3)
                altered_var = 'zc' if alvs[0]!=(n_lvs-1) else 'zs'
                for k in range(n_alterations * batch_size):
                    stats.append([orig_texts[k], alt_text[k], altered_var, tma2[k], tma3[k]])

        header = ['original', 'altered', 'alteration_id', 'tma2', 'tma3']
        df = pd.DataFrame(stats, columns=header)
        var_wise_scores = df.groupby('alteration_id').mean()[['tma2', 'tma3']]
        return var_wise_scores


Generating original sentences:   0%|          | 0/2 [00:00<?, ?it/s]

Generating original sentences:  50%|█████     | 1/2 [00:01<00:01,  1.68s/it]

Generating original sentences: 100%|██████████| 2/2 [00:02<00:00,  1.33s/it]

Generating original sentences: 100%|██████████| 2/2 [00:02<00:00,  1.09s/it]


Processing sample 0:   0%|          | 0/2 [00:00<?, ?it/s]

Processing sample 0:  50%|█████     | 1/2 [00:11<00:11, 11.55s/it]

Processing sample 0: 100%|██████████| 2/2 [00:22<00:00, 11.51s/it]

Processing sample 0: 100%|██████████| 2/2 [00:22<00:00, 11.48s/it]


Processing sample 1:   0%|          | 0/2 [00:00<?, ?it/s]

Processing sample 1:  50%|█████     | 1/2 [00:11<00:11, 11.73s/it]

Processing sample 1: 100%|██████████| 2/2 [00:22<00:00, 11.53s/it]

Processing sample 1: 100%|██████████| 2/2 [00:22<00:00, 11.41s/it]




In [22]:
tma_mat = get_generation_TMA(model, 200)

Generating original sentences:   0%|          | 0/2 [00:00<?, ?it/s]

Generating original sentences:  50%|█████     | 1/2 [00:01<00:01,  1.68s/it]

Generating original sentences: 100%|██████████| 2/2 [00:02<00:00,  1.33s/it]

Generating original sentences: 100%|██████████| 2/2 [00:02<00:00,  1.09s/it]


Processing sample 0:   0%|          | 0/2 [00:00<?, ?it/s]

Processing sample 0:  50%|█████     | 1/2 [00:11<00:11, 11.55s/it]

Processing sample 0: 100%|██████████| 2/2 [00:22<00:00, 11.51s/it]

Processing sample 0: 100%|██████████| 2/2 [00:22<00:00, 11.48s/it]


Processing sample 1:   0%|          | 0/2 [00:00<?, ?it/s]

Processing sample 1:  50%|█████     | 1/2 [00:11<00:11, 11.73s/it]

Processing sample 1: 100%|██████████| 2/2 [00:22<00:00, 11.53s/it]

Processing sample 1: 100%|██████████| 2/2 [00:22<00:00, 11.41s/it]




In [23]:
print(tma_mat)

                tma2   tma3
alteration_id              
zc             0.315  0.160
zs             0.275  0.105


In [2]:
from disentanglement_qkv.models import bleu_score, tqdm, template_match
import os
def get_swap_tma(self, n_samples=2000, batch_size=50, beam_size=2):
    with torch.no_grad():
        has_struct = 'zs' in self.gen_bn.name_to_v
        assert has_struct
        # Generating n_samples sentences
        text, samples, _ = self.get_sentences(n_samples=batch_size, gen_len=self.h_params.max_len - 1,
                                              sample_w=False, vary_z=True, complete=None)
        for _ in tqdm(range(int(n_samples / batch_size)-1), desc="Generating original sentences"):
            text_i, samples_i, _ = self.get_sentences(n_samples=batch_size, gen_len=self.h_params.max_len - 1,
                                                      sample_w=False, vary_z=True, complete=None)
            text.extend(text_i)
            for k in samples.keys():
                samples[k] = torch.cat([samples[k], samples_i[k]])
        source_sents, target_sents = text[:int(n_samples / 2)], text[int(n_samples / 2):]
        source_lvs, target_lvs = {k: v[:int(n_samples / 2)] for k, v in samples.items()}, \
                                 {k: v[int(n_samples / 2):] for k, v in samples.items()}
        result_sents = []
        inv_result_sents = []
        go_symbol = torch.ones((1, 1)).long() * self.index[self.generated_v].stoi['<go>']
        go_symbol = go_symbol.to(self.h_params.device)
        temp = 1.
        for i in tqdm(range(int(n_samples / (2 * batch_size))),
                      desc="Getting Model Swap TMA"):
            z_input = {'zs': source_lvs['zs'][i * batch_size:(i + 1) * batch_size].unsqueeze(1),
                       **{'z{}'.format(i + 1): target_lvs['z{}'.format(i + 1)][
                                               i * batch_size:(i + 1) * batch_size].unsqueeze(1)
                          for i in range(len(self.h_params.n_latents))}}
            inv_z_input = {'zs': target_lvs['zs'][i * batch_size:(i + 1) * batch_size].unsqueeze(1),
                           **{'z{}'.format(i + 1): source_lvs['z{}'.format(i + 1)][
                                                   i * batch_size:(i + 1) * batch_size].unsqueeze(1)
                              for i in range(len(self.h_params.n_latents))}}
            # for z_in, sents in zip([z_input, inv_z_input], [result_sents, inv_result_sents]):
            #     x_prev = go_symbol.repeat((batch_size, 1))
            #     x_prev = self.generate_from_z2(z_in, x_prev, mask_unk=False, beam_size=beam_size)
            #     if beam_size > 1:
            #         x_prev = x_prev[:int(x_prev.shape[0] / beam_size)]
            #     sents.extend(self.decode_to_text2(x_prev, self.h_params.vocab_size,
            #                                       self.index[self.generated_v]))
            x_prev = go_symbol.repeat((batch_size, 1))
            x_prev = self.generate_from_z2(z_input, x_prev, mask_unk=False, beam_size=beam_size)
            if beam_size > 1:
                x_prev = x_prev[:int(x_prev.shape[0] / beam_size)]
            result_sents.extend(self.decode_to_text2(x_prev, self.h_params.vocab_size,
                                              self.index[self.generated_v]))
            x_prev = go_symbol.repeat((batch_size, 1))
            x_prev = self.generate_from_z2(inv_z_input, x_prev, mask_unk=False, beam_size=beam_size)
            if beam_size > 1:
                x_prev = x_prev[:int(x_prev.shape[0] / beam_size)]
            inv_result_sents.extend(self.decode_to_text2(x_prev, self.h_params.vocab_size,
                                              self.index[self.generated_v]))
        test_name = self.h_params.test_name.split("\\")[-1].split("/")[-1]
        dump_location = os.path.join(".data", 
                    "{}_tempdump.tsv".format(test_name))
        with open(dump_location, 'w', encoding="UTF-8") as f:
            for s, t, r, i in zip(source_sents, target_sents, result_sents, 
                                  inv_result_sents):
                f.write('\t'.join([s, t, r, i])+'\n')
            
        print("Calculating zs tma...")
        zs_tma2, zs_tma3 = np.mean(template_match(source_sents, result_sents, 2))*100, \
                               np.mean(template_match(source_sents, result_sents, 3))*100
        print("Calculating zc tma...")
        zc_tma2, zc_tma3 = np.mean(template_match(source_sents, inv_result_sents, 2))*100, \
                                       np.mean(template_match(source_sents, inv_result_sents, 3))*100
        print("Calculating copy tma...")
        copy_tma2, copy_tma3 = np.mean(template_match(source_sents, target_sents, 2))*100, \
                               np.mean(template_match(source_sents, target_sents, 3))*100

        print("Calculating zs bleu...")
        zs_bleu = bleu_score(predictions=[s.split() for s in source_sents],
                                   references=[[s.split()] for s in result_sents])['bleu']*100
        print("Calculating zc bleu...")
        zc_bleu = bleu_score(predictions=[s.split() for s in source_sents],
                                   references=[[s.split()] for s in inv_result_sents])['bleu']*100
        print("Calculating copy bleu...")
        copy_bleu = bleu_score(predictions=[s.split() for s in source_sents],
                                   references=[[s.split()] for s in target_sents])['bleu']*100
        
        return zs_tma2, zs_tma3, zc_tma2, zc_tma3, copy_tma2, copy_tma3, zs_bleu, zc_bleu, copy_bleu


In [11]:
# zs_tma2, zs_tma3, zc_tma2, zc_tma3, copy_tma2, copy_tma3, zs_bleu, zc_bleu, copy_bleu\
#     = get_swap_tma(model, n_samples=200, batch_size=20)
tma_res = model.get_swap_tma(n_samples=2000, batch_size=20)

Generating original sentences:   0%|          | 0/99 [00:00<?, ?it/s]

Generating original sentences:   1%|          | 1/99 [00:00<00:27,  3.62it/s]

Generating original sentences:   2%|▏         | 2/99 [00:00<00:26,  3.62it/s]

Generating original sentences:   3%|▎         | 3/99 [00:00<00:26,  3.60it/s]

Generating original sentences:   4%|▍         | 4/99 [00:01<00:26,  3.60it/s]

Generating original sentences:   5%|▌         | 5/99 [00:01<00:25,  3.72it/s]

Generating original sentences:   6%|▌         | 6/99 [00:01<00:25,  3.67it/s]

Generating original sentences:   7%|▋         | 7/99 [00:01<00:24,  3.72it/s]

Generating original sentences:   8%|▊         | 8/99 [00:02<00:24,  3.76it/s]

Generating original sentences:   9%|▉         | 9/99 [00:02<00:23,  3.85it/s]

Generating original sentences:  10%|█         | 10/99 [00:02<00:23,  3.84it/s]

Generating original sentences:  11%|█         | 11/99 [00:02<00:22,  3.87it/s]

Generating original sentences:  12%|█▏        | 12/99 [00:03<00:22,  3.90it/s]

Generating original sentences:  13%|█▎        | 13/99 [00:03<00:22,  3.89it/s]

Generating original sentences:  14%|█▍        | 14/99 [00:03<00:22,  3.82it/s]

Generating original sentences:  15%|█▌        | 15/99 [00:03<00:22,  3.81it/s]

Generating original sentences:  16%|█▌        | 16/99 [00:04<00:21,  3.81it/s]

Generating original sentences:  17%|█▋        | 17/99 [00:04<00:21,  3.88it/s]

Generating original sentences:  18%|█▊        | 18/99 [00:04<00:20,  3.90it/s]

Generating original sentences:  19%|█▉        | 19/99 [00:04<00:20,  3.89it/s]

Generating original sentences:  20%|██        | 20/99 [00:05<00:20,  3.85it/s]

Generating original sentences:  21%|██        | 21/99 [00:05<00:20,  3.89it/s]

Generating original sentences:  22%|██▏       | 22/99 [00:05<00:19,  3.87it/s]

Generating original sentences:  23%|██▎       | 23/99 [00:06<00:20,  3.79it/s]

Generating original sentences:  24%|██▍       | 24/99 [00:06<00:19,  3.85it/s]

Generating original sentences:  25%|██▌       | 25/99 [00:06<00:19,  3.87it/s]

Generating original sentences:  26%|██▋       | 26/99 [00:06<00:18,  3.96it/s]

Generating original sentences:  27%|██▋       | 27/99 [00:07<00:18,  3.90it/s]

Generating original sentences:  28%|██▊       | 28/99 [00:07<00:18,  3.85it/s]

Generating original sentences:  29%|██▉       | 29/99 [00:07<00:18,  3.84it/s]

Generating original sentences:  30%|███       | 30/99 [00:07<00:18,  3.72it/s]

Generating original sentences:  31%|███▏      | 31/99 [00:08<00:19,  3.40it/s]

Generating original sentences:  32%|███▏      | 32/99 [00:08<00:19,  3.50it/s]

Generating original sentences:  33%|███▎      | 33/99 [00:08<00:18,  3.55it/s]

Generating original sentences:  34%|███▍      | 34/99 [00:09<00:19,  3.39it/s]

Generating original sentences:  35%|███▌      | 35/99 [00:09<00:19,  3.23it/s]

Generating original sentences:  36%|███▋      | 36/99 [00:09<00:20,  3.08it/s]

Generating original sentences:  37%|███▋      | 37/99 [00:10<00:19,  3.11it/s]

Generating original sentences:  38%|███▊      | 38/99 [00:10<00:19,  3.17it/s]

Generating original sentences:  39%|███▉      | 39/99 [00:10<00:18,  3.24it/s]

Generating original sentences:  40%|████      | 40/99 [00:10<00:17,  3.34it/s]

Generating original sentences:  41%|████▏     | 41/99 [00:11<00:17,  3.39it/s]

Generating original sentences:  42%|████▏     | 42/99 [00:11<00:16,  3.43it/s]

Generating original sentences:  43%|████▎     | 43/99 [00:11<00:16,  3.46it/s]

Generating original sentences:  44%|████▍     | 44/99 [00:12<00:15,  3.55it/s]

Generating original sentences:  45%|████▌     | 45/99 [00:12<00:15,  3.49it/s]

Generating original sentences:  46%|████▋     | 46/99 [00:12<00:15,  3.51it/s]

Generating original sentences:  47%|████▋     | 47/99 [00:12<00:14,  3.61it/s]

Generating original sentences:  48%|████▊     | 48/99 [00:13<00:13,  3.73it/s]

Generating original sentences:  49%|████▉     | 49/99 [00:13<00:12,  3.87it/s]

Generating original sentences:  51%|█████     | 50/99 [00:13<00:12,  3.99it/s]

Generating original sentences:  52%|█████▏    | 51/99 [00:13<00:11,  4.12it/s]

Generating original sentences:  53%|█████▎    | 52/99 [00:14<00:11,  4.24it/s]

Generating original sentences:  54%|█████▎    | 53/99 [00:14<00:10,  4.31it/s]

Generating original sentences:  55%|█████▍    | 54/99 [00:14<00:10,  4.25it/s]

Generating original sentences:  56%|█████▌    | 55/99 [00:14<00:10,  4.17it/s]

Generating original sentences:  57%|█████▋    | 56/99 [00:15<00:11,  3.86it/s]

Generating original sentences:  58%|█████▊    | 57/99 [00:15<00:12,  3.28it/s]

Generating original sentences:  59%|█████▊    | 58/99 [00:15<00:12,  3.23it/s]

Generating original sentences:  60%|█████▉    | 59/99 [00:16<00:12,  3.25it/s]

Generating original sentences:  61%|██████    | 60/99 [00:16<00:12,  3.21it/s]

Generating original sentences:  62%|██████▏   | 61/99 [00:16<00:11,  3.18it/s]

Generating original sentences:  63%|██████▎   | 62/99 [00:17<00:11,  3.11it/s]

Generating original sentences:  64%|██████▎   | 63/99 [00:17<00:11,  3.06it/s]

Generating original sentences:  65%|██████▍   | 64/99 [00:17<00:11,  3.10it/s]

Generating original sentences:  66%|██████▌   | 65/99 [00:18<00:10,  3.24it/s]

Generating original sentences:  67%|██████▋   | 66/99 [00:18<00:09,  3.31it/s]

Generating original sentences:  68%|██████▊   | 67/99 [00:18<00:09,  3.41it/s]

Generating original sentences:  69%|██████▊   | 68/99 [00:18<00:09,  3.27it/s]

Generating original sentences:  70%|██████▉   | 69/99 [00:19<00:09,  3.10it/s]

Generating original sentences:  71%|███████   | 70/99 [00:19<00:09,  3.08it/s]

Generating original sentences:  72%|███████▏  | 71/99 [00:19<00:09,  3.09it/s]

Generating original sentences:  73%|███████▎  | 72/99 [00:20<00:08,  3.03it/s]

Generating original sentences:  74%|███████▎  | 73/99 [00:20<00:08,  2.95it/s]

Generating original sentences:  75%|███████▍  | 74/99 [00:21<00:08,  2.91it/s]

Generating original sentences:  76%|███████▌  | 75/99 [00:21<00:08,  2.86it/s]

Generating original sentences:  77%|███████▋  | 76/99 [00:21<00:07,  2.99it/s]

Generating original sentences:  78%|███████▊  | 77/99 [00:21<00:07,  3.13it/s]

Generating original sentences:  79%|███████▉  | 78/99 [00:22<00:06,  3.23it/s]

Generating original sentences:  80%|███████▉  | 79/99 [00:22<00:06,  3.23it/s]

Generating original sentences:  81%|████████  | 80/99 [00:22<00:05,  3.31it/s]

Generating original sentences:  82%|████████▏ | 81/99 [00:23<00:05,  3.36it/s]

Generating original sentences:  83%|████████▎ | 82/99 [00:23<00:04,  3.42it/s]

Generating original sentences:  84%|████████▍ | 83/99 [00:23<00:04,  3.39it/s]

Generating original sentences:  85%|████████▍ | 84/99 [00:24<00:04,  3.42it/s]

Generating original sentences:  86%|████████▌ | 85/99 [00:24<00:04,  3.43it/s]

Generating original sentences:  87%|████████▋ | 86/99 [00:24<00:03,  3.45it/s]

Generating original sentences:  88%|████████▊ | 87/99 [00:24<00:03,  3.39it/s]

Generating original sentences:  89%|████████▉ | 88/99 [00:25<00:03,  3.36it/s]

Generating original sentences:  90%|████████▉ | 89/99 [00:25<00:02,  3.46it/s]

Generating original sentences:  91%|█████████ | 90/99 [00:25<00:02,  3.49it/s]

Generating original sentences:  92%|█████████▏| 91/99 [00:26<00:02,  3.48it/s]

Generating original sentences:  93%|█████████▎| 92/99 [00:26<00:01,  3.51it/s]

Generating original sentences:  94%|█████████▍| 93/99 [00:26<00:01,  3.50it/s]

Generating original sentences:  95%|█████████▍| 94/99 [00:26<00:01,  3.54it/s]

Generating original sentences:  96%|█████████▌| 95/99 [00:27<00:01,  3.54it/s]

Generating original sentences:  97%|█████████▋| 96/99 [00:27<00:00,  3.49it/s]

Generating original sentences:  98%|█████████▊| 97/99 [00:27<00:00,  3.46it/s]

Generating original sentences:  99%|█████████▉| 98/99 [00:28<00:00,  3.43it/s]

Generating original sentences: 100%|██████████| 99/99 [00:28<00:00,  3.43it/s]

Generating original sentences: 100%|██████████| 99/99 [00:28<00:00,  3.49it/s]


Getting Model Swap TMA:   0%|          | 0/50 [00:00<?, ?it/s]

Getting Model Swap TMA:   2%|▏         | 1/50 [00:03<02:31,  3.09s/it]

Getting Model Swap TMA:   4%|▍         | 2/50 [00:05<02:22,  2.97s/it]

Getting Model Swap TMA:   6%|▌         | 3/50 [00:08<02:15,  2.89s/it]

Getting Model Swap TMA:   8%|▊         | 4/50 [00:11<02:13,  2.91s/it]

Getting Model Swap TMA:  10%|█         | 5/50 [00:14<02:08,  2.85s/it]

Getting Model Swap TMA:  12%|█▏        | 6/50 [00:16<02:03,  2.81s/it]

Getting Model Swap TMA:  14%|█▍        | 7/50 [00:19<01:59,  2.77s/it]

Getting Model Swap TMA:  16%|█▌        | 8/50 [00:22<01:54,  2.72s/it]

Getting Model Swap TMA:  18%|█▊        | 9/50 [00:25<01:54,  2.78s/it]

Getting Model Swap TMA:  20%|██        | 10/50 [00:29<02:05,  3.13s/it]

Getting Model Swap TMA:  22%|██▏       | 11/50 [00:32<02:05,  3.22s/it]

Getting Model Swap TMA:  24%|██▍       | 12/50 [00:36<02:06,  3.34s/it]

Getting Model Swap TMA:  26%|██▌       | 13/50 [00:39<02:01,  3.28s/it]

Getting Model Swap TMA:  28%|██▊       | 14/50 [00:43<02:05,  3.47s/it]

Getting Model Swap TMA:  30%|███       | 15/50 [00:47<02:15,  3.87s/it]

Getting Model Swap TMA:  32%|███▏      | 16/50 [00:51<02:10,  3.85s/it]

Getting Model Swap TMA:  34%|███▍      | 17/50 [00:55<02:04,  3.78s/it]

Getting Model Swap TMA:  36%|███▌      | 18/50 [00:58<01:51,  3.47s/it]

Getting Model Swap TMA:  38%|███▊      | 19/50 [01:00<01:41,  3.26s/it]

Getting Model Swap TMA:  40%|████      | 20/50 [01:03<01:31,  3.06s/it]

Getting Model Swap TMA:  42%|████▏     | 21/50 [01:07<01:37,  3.36s/it]

Getting Model Swap TMA:  44%|████▍     | 22/50 [01:11<01:40,  3.59s/it]

Getting Model Swap TMA:  46%|████▌     | 23/50 [01:15<01:39,  3.69s/it]

Getting Model Swap TMA:  48%|████▊     | 24/50 [01:19<01:38,  3.81s/it]

Getting Model Swap TMA:  50%|█████     | 25/50 [01:23<01:37,  3.92s/it]

Getting Model Swap TMA:  52%|█████▏    | 26/50 [01:27<01:33,  3.88s/it]

Getting Model Swap TMA:  54%|█████▍    | 27/50 [01:30<01:21,  3.55s/it]

Getting Model Swap TMA:  56%|█████▌    | 28/50 [01:33<01:12,  3.28s/it]

Getting Model Swap TMA:  58%|█████▊    | 29/50 [01:35<01:05,  3.11s/it]

Getting Model Swap TMA:  60%|██████    | 30/50 [01:38<00:59,  2.99s/it]

Getting Model Swap TMA:  62%|██████▏   | 31/50 [01:41<00:55,  2.90s/it]

Getting Model Swap TMA:  64%|██████▍   | 32/50 [01:43<00:51,  2.86s/it]

Getting Model Swap TMA:  66%|██████▌   | 33/50 [01:46<00:48,  2.83s/it]

Getting Model Swap TMA:  68%|██████▊   | 34/50 [01:49<00:44,  2.80s/it]

Getting Model Swap TMA:  70%|███████   | 35/50 [01:52<00:42,  2.82s/it]

Getting Model Swap TMA:  72%|███████▏  | 36/50 [01:55<00:39,  2.85s/it]

Getting Model Swap TMA:  74%|███████▍  | 37/50 [01:57<00:36,  2.79s/it]

Getting Model Swap TMA:  76%|███████▌  | 38/50 [02:00<00:33,  2.79s/it]

Getting Model Swap TMA:  78%|███████▊  | 39/50 [02:03<00:30,  2.76s/it]

Getting Model Swap TMA:  80%|████████  | 40/50 [02:06<00:27,  2.77s/it]

Getting Model Swap TMA:  82%|████████▏ | 41/50 [02:08<00:25,  2.79s/it]

Getting Model Swap TMA:  84%|████████▍ | 42/50 [02:11<00:22,  2.79s/it]

Getting Model Swap TMA:  86%|████████▌ | 43/50 [02:14<00:19,  2.79s/it]

Getting Model Swap TMA:  88%|████████▊ | 44/50 [02:17<00:16,  2.79s/it]

Getting Model Swap TMA:  90%|█████████ | 45/50 [02:20<00:14,  2.80s/it]

Getting Model Swap TMA:  92%|█████████▏| 46/50 [02:22<00:11,  2.79s/it]

Getting Model Swap TMA:  94%|█████████▍| 47/50 [02:25<00:08,  2.81s/it]

Getting Model Swap TMA:  96%|█████████▌| 48/50 [02:28<00:05,  2.81s/it]

Getting Model Swap TMA:  98%|█████████▊| 49/50 [02:31<00:02,  2.82s/it]

Getting Model Swap TMA: 100%|██████████| 50/50 [02:34<00:00,  2.79s/it]

Getting Model Swap TMA: 100%|██████████| 50/50 [02:34<00:00,  3.08s/it]




Calculating zs tma...


Calculating zc tma...


Calculating copy tma...


Calculating bleu scores...


In [16]:
# print(tma_res)
res_enc = model.get_syn_disent_encoder(batch_size=20)
print(res_enc)

Paraphrase results 1 :  {'template': {'zs': 0.53272, 'zc': 0.51238}, 'paraphrase': {'zs': 0.55836, 'zc': 0.50162}}


Paraphrase detection: with zs 0.502, with zc 0.468
{'template': {'zs': 0.53272, 'zc': 0.51238}, 'paraphrase': {'zs': 0.55836, 'zc': 0.50162}, 'hard': {'zs': 0.468, 'zc': 0.468}}


In [4]:
print("        \t tma2 \t tma3 \t bleu")
print("copy\t", copy_tma2, "\t", copy_tma3, "\t", copy_bleu)
print("zc    \t", zc_tma2, "\t", zc_tma3, "\t", zc_bleu)
print("zs    \t", zs_tma2, "\t", zs_tma3, "\t", zs_bleu)
 

        	 tma2 	 tma3 	 bleu
copy	 34.599999999999994 	 6.3 	 3.2425404362197012
zc    	 39.0 	 7.000000000000001 	 3.211958605276792
zs    	 46.800000000000004 	 13.5 	 5.234059907070587


In [34]:
sents1 = ["This place is great .",
         "I will not be back .",
         "The food is delicious .",
         "The service is outstanding .",
         "This place is the best I 've ever been to .",
          "The chicken was good .",
          "The service was great, and the food was delicious .",
          "also way too much cheese .",
          "gross ."]

# sents1 = [
# "no more service from me .",
# "worst service i 've ever experienced .",
# "did n't even acknowledge us .",
# "do n't go here if you want any kind of service .",
# "horrible , horrible service they do n't deserve any stars .",
# ]
model.eval()
# print(data.tokenizer.encode(sents1))
def embed_sents(self, sents):
    zs_infer, z_infer, x_gen = self.infer_bn.name_to_v['zs'], \
                               {'z{}'.format(i + 1): self.infer_bn.name_to_v['z{}'.format(i + 1)]
                                for i in range(len(self.h_params.n_latents))}, self.gen_bn.name_to_v['x']

    bsz, max_len = len(sents), max([len(s) for s in sents])
    stoi = self.index[self.generated_v].stoi
    inputs = torch.zeros((bsz, max_len)).to(self.h_params.device).long()+stoi['<pad>']
    for i, sen in enumerate(sents):
        for j, tok in enumerate(sen.lower().split()):
            inputs[i, j] = stoi[tok] if tok in stoi else stoi['<unk>']

    self.infer_bn({'x': inputs})
    orig_zs, orig_z = zs_infer.rep(zs_infer.infer(zs_infer.post_params))[..., 0, :], \
                      torch.cat([v.post_params['loc'][..., 0, :] for k, v in z_infer.items()], dim=-1)

    return orig_zs, orig_z
ezs, ezc = embed_sents(model, sents1)
enc_samples = {"z1":ezc, "zs":ezs, "zg":torch.zeros_like(ezs)}
print(ezs.shape, ezc.shape)


this place is great  _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |||| i will not be back  _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |||| the food is delicious  _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |||| the service is outstanding  _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |||| this place is the best i 've ever been to  _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |||| the chicken was good  _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |||| the service was <unk> and the food was delicious  _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |||| also way too much cheese  _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |||| gross  _ _ _ _ _ 

In [35]:
sents = sents1
model.h_params.zs_anneal_kl, model.h_params.zg_anneal_kl = [7000, 10000], [7000, 10000]
bsz, max_len = len(sents), max([len(s) for s in sents])+1
stoi = model.index[model.generated_v].stoi
inputs = torch.zeros((bsz, max_len)).to(model.h_params.device).long()+stoi['<pad>']
for i, sen in enumerate(sents):
    inputs[i, 0] = stoi['<go>']
    for j, tok in enumerate(sen.lower().split()):
        inputs[i, j+1] = stoi[tok] if tok in stoi else stoi['<unk>']
    # inputs[i, j+1] = stoi['<eos>']
model({'x': inputs[..., 1:], 'x_prev': inputs[..., :-1]}, eval=True)
orig_text = model.decode_to_text(model.gen_bn.variables_star[model.generated_v])
dec_text = model.decode_to_text(model.generated_v.post_params['logits'])
print(orig_text.strip().replace('\n', '')) 
print(dec_text.strip().replace('\n', '')) 

this place is great  _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |||| i will not be back  _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |||| the food is delicious  _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |||| the service is outstanding  _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |||| this place is the best i 've ever been to  _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |||| the chicken was good  _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |||| the service was <unk> and the food was delicious  _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |||| also way too much cheese  _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ |||| gross  _ _ _ _ _ 

In [38]:
# sw_zs = [8]
# sw_zs = [14]
# sw_zs = [16]
# sw_zs = [6]
# sw_zs = [9]

sw_zs = [8]

# 8 ==> subject and 6
#14 ==> dobj
sw_text, sw_samples = swap_latents(model, enc_samples, sw_zs, 16, complete=None, no_unk=True)
print(sw_text)
for i in range(len(sents1)):
    for j in range(len(sents1)):
        if i!=j:
            print("z_from: ", sents1[i], "|z_to: ", sents1[j], "|result: ", sw_text[len(sents1)*i+j])
        

[' great place', ' not impressed', ' food is delicious', ' service is outstanding', ' love this place', ' good food', ' food was delicious', ' too bad', ' gross', ' great place', ' i will not', ' the food is delicious', ' the service is outstanding', ' love this place', ' the food was delicious', ' food was delicious', ' it was good', ' it was terrible', ' great place', ' not impressed', ' food is delicious', ' service is outstanding', ' love this place', ' food was good', ' food was delicious', ' it was good', ' love it', ' great place', ' not impressed', ' food is delicious', ' service is outstanding', ' love this place', ' food was good', ' food was delicious', ' it was good', ' love it', ' this place is great', ' i will not be back', ' the food is delicious', ' the service is outstanding', ' i love this place', ' the food was good', ' the food was delicious', ' the food was good', ' the food was terrible', ' great place', ' not impressed', ' food is delicious', ' service is outstan

In [95]:
def my_repeat(tens, n):
    return tens.unsqueeze(0).expand(n, *tens.shape).reshape(tens.shape[0]*n, *tens.shape[1:])


batch_size = 100
for file_names in [["E:\\Experiments\\GLUE_BENCH\\.data\\paranmt2\\dev_input.txt",
                   "E:\\Experiments\\GLUE_BENCH\\.data\\paranmt2\\test_input.txt"],
                   ["E:\\Experiments\\GLUE_BENCH\\.data\\paranmt2\\dev.txt",
                   "E:\\Experiments\\GLUE_BENCH\\.data\\paranmt2\\test.txt"],
                   ["E:\\Experiments\\GLUE_BENCH\\.data\\qqp\\pos_hard.tsv"],
                   ["E:\\Experiments\\GLUE_BENCH\\.data\\qqp\\pos.tsv"],
                   ["E:\\Experiments\\GLUE_BENCH\\.data\\qqp\\neg.tsv"]]:
    print("---------------------------------------------------")
    t1, t2 = [], []
    for file_name in file_names:
        with open(file_name, encoding="UTF-8") as f:
            for i, l in enumerate(f):
                if "\t" in l:
                    t1.append(l.split("\t")[0])
                    t2.append(l.split("\t")[1])
    
    f_names = "\""+', '.join([fn.split(os.sep)[-1]for fn in file_names])+"\""
    print("For files {} with {} samples:".format(f_names, len(t1)))
    ezs1, ezc1, ezs2, ezc2, ezs3, ezc3 = None, None, None, None, None, None
    for i in range(int(len(t1)/batch_size)):
        ezs1i, ezc1i = model.embed_sents(t1[i*batch_size:(i+1)*batch_size])
        ezs2i, ezc2i = model.embed_sents(t2[i*batch_size:(i+1)*batch_size])
        if ezs1 is None:
            ezs1, ezc1 = ezs1i, ezc1i
            ezs2, ezc2 = ezs2i, ezc2i
        else:     
            ezs1, ezc1 = torch.cat([ezs1, ezs1i]), torch.cat([ezc1, ezc1i])
            ezs2, ezc2 = torch.cat([ezs2, ezs2i]), torch.cat([ezc2, ezc2i])
    rep_n = 100
    perm_idx = torch.randperm(ezs1.shape[0]*rep_n)
    ezs1, ezc1 = my_repeat(ezs1, rep_n), my_repeat(ezc1, rep_n)
    ezs2, ezc2 = my_repeat(ezs2, rep_n), my_repeat(ezc2, rep_n)
    ezs3, ezc3 = ezs1[perm_idx], ezc1[perm_idx]
    
    s12sims, s13sims = torch.cosine_similarity(ezs1, ezs2), torch.cosine_similarity(ezs1, ezs3)
    c12sims, c13sims = torch.cosine_similarity(ezc1, ezc2), torch.cosine_similarity(ezc1, ezc3)
    print("expanded measure size to ", len(s12sims))
    print("syntactic accuracy", np.mean(s12sims.cpu().detach().numpy()>s13sims.cpu().detach().numpy()))
    print("semantic accuracy", np.mean(c12sims.cpu().detach().numpy()>c13sims.cpu().detach().numpy()))


---------------------------------------------------
For files "dev_input.txt, test_input.txt" with 1300 samples:


RuntimeError: CUDA out of memory. Tried to allocate 192.00 MiB (GPU 0; 6.00 GiB total capacity; 1.44 GiB already allocated; 2.69 GiB free; 1.64 GiB reserved in total by PyTorch)

In [6]:
def my_repeat(tens, n):
    return tens.unsqueeze(0).expand(n, *tens.shape).reshape(tens.shape[0]*n, *tens.shape[1:])

def my_sim(a, b):
    dist = (a-b).square().sum(-1).sqrt()
    sim = 1/(1+dist)
    return sim

sim = my_sim
# sim = torch.cosine_similarity
def _get_syn_disent_encoder_easy(self, split="valid", batch_size=100):
    template_file = {"valid": "E:\\Experiments\\GLUE_BENCH\\.data\\paranmt2\\dev_input.txt",
                       "test": "E:\\Experiments\\GLUE_BENCH\\.data\\paranmt2\\test_input.txt"}[split]
    paraphrase_file = {"valid": "E:\\Experiments\\GLUE_BENCH\\.data\\paranmt2\\dev.txt",
                       "test": "E:\\Experiments\\GLUE_BENCH\\.data\\paranmt2\\test.txt"}[split]
    file_names = {"template": template_file, "paraphrase": paraphrase_file}
    accuracies = {"template": {}, "paraphrase": {}}
    for task, file_n in file_names.items():
        t1, t2 = [], []
        with open(file_n, encoding="UTF-8") as f:
            for i, l in enumerate(f):
                if "\t" in l:
                    t1.append(l.split("\t")[0])
                    t2.append(l.split("\t")[1])
        
        ezs1, ezc1, ezs2, ezc2 = None, None, None, None
        for i in range(int(len(t1)/batch_size)):
            ezs1i, ezc1i = self.embed_sents(t1[i*batch_size:(i+1)*batch_size])
            ezs2i, ezc2i = self.embed_sents(t2[i*batch_size:(i+1)*batch_size])
            if ezs1 is None:
                ezs1, ezc1 = ezs1i, ezc1i
                ezs2, ezc2 = ezs2i, ezc2i
            else:     
                ezs1, ezc1 = torch.cat([ezs1, ezs1i]), torch.cat([ezc1, ezc1i])
                ezs2, ezc2 = torch.cat([ezs2, ezs2i]), torch.cat([ezc2, ezc2i])
        rep_n = 100
        perm_idx = torch.randperm(ezs1.shape[0]*rep_n)
        ezs1, ezc1 = my_repeat(ezs1, rep_n), my_repeat(ezc1, rep_n)
        ezs2, ezc2 = my_repeat(ezs2, rep_n), my_repeat(ezc2, rep_n)
        ezs3, ezc3 = ezs1[perm_idx], ezc1[perm_idx]
        
        s12sims, s13sims = sim(ezs1, ezs2), sim(ezs1, ezs3)
        c12sims, c13sims = sim(ezc1, ezc2), sim(ezc1, ezc3)
        syn_emb_sc = np.mean(s12sims.cpu().detach().numpy()>s13sims.cpu().detach().numpy())
        cont_emb_sc = np.mean(c12sims.cpu().detach().numpy()>c13sims.cpu().detach().numpy())
        accuracies[task] = {"syn_emb": syn_emb_sc, "cont_emb": cont_emb_sc}
    self.writer.add_scalar('test/zs_enc_para_acc', accuracies["paraphrase"]["syn_emb"], self.step)
    self.writer.add_scalar('test/zc_enc_para_acc', accuracies["paraphrase"]["cont_emb"], self.step)
    self.writer.add_scalar('test/zs_enc_temp_acc', accuracies["template"]["syn_emb"], self.step)
    self.writer.add_scalar('test/zc_enc_temp_acc', accuracies["template"]["cont_emb"], self.step)
    return accuracies
print(_get_syn_disent_encoder_easy(model, split="valid", batch_size=100))

{'template': {'syn_emb': 0.6106, 'cont_emb': 0.6307}, 'paraphrase': {'syn_emb': 0.66988, 'cont_emb': 0.90536}}


In [7]:
def _get_syn_disent_encoder_hard(self, split="valid", batch_size=100):
    pair_fn = {"valid": ".data\\paranmt2\\dev_input.txt", 
               "test": ".data\\paranmt2\\test_input.txt"}[split]
    ref_fn = {"valid": ".data\\paranmt2\\dev_ref.txt", 
              "test": ".data\\paranmt2\\test_ref.txt"}[split]
    t1, t2, t3 = [], [], []
    with open(pair_fn, encoding="UTF-8") as f:
        for i, l in enumerate(f):
            if "\t" in l:
                t1.append(l.split("\t")[0])
                t2.append(l.split("\t")[1][:-1])
    with open(ref_fn, encoding="UTF-8") as f:
        for i, l in enumerate(f):
            if len(l):
                t3.append(l[:-1])    
        
    ezs1, ezc1, ezs2, ezc2, ezs3, ezc3 = None, None, None, None, None, None
    for i in range(int(len(t1)/batch_size)):
        ezs1i, ezc1i = self.embed_sents(t1[i*batch_size:(i+1)*batch_size])
        ezs2i, ezc2i = self.embed_sents(t2[i*batch_size:(i+1)*batch_size])
        ezs3i, ezc3i = self.embed_sents(t3[i*batch_size:(i+1)*batch_size])
        if ezs1 is None:
            ezs1, ezc1 = ezs1i, ezc1i
            ezs2, ezc2 = ezs2i, ezc2i
            ezs3, ezc3 = ezs3i, ezc3i
        else:     
            ezs1, ezc1 = torch.cat([ezs1, ezs1i]), torch.cat([ezc1, ezc1i])
            ezs2, ezc2 = torch.cat([ezs2, ezs2i]), torch.cat([ezc2, ezc2i])
            ezs3, ezc3 = torch.cat([ezs3, ezs3i]), torch.cat([ezc3, ezc3i])
        
    s13sims, s23sims = sim(ezs1, ezs3), sim(ezs2, ezs3)
    c13sims, c23sims = sim(ezc1, ezc3), sim(ezc2, ezc3)
    
    zs_acc = np.mean(s13sims.cpu().detach().numpy()>s23sims.cpu().detach().numpy())
    zc_acc = np.mean(c13sims.cpu().detach().numpy()>c23sims.cpu().detach().numpy())
    print("expanded measure size to ", len(s13sims))
    print("Paraphrase detection: with zs {}, with zc {}".format(zs_acc, zc_acc))
    self.writer.add_scalar('test/hard_zs_enc_acc', 1-zs_acc, self.step)
    self.writer.add_scalar('test/hard_zc_enc_acc', zc_acc, self.step)
    return 1-zs_acc, zc_acc
print(_get_syn_disent_encoder_hard(model, split="test", batch_size=100))

expanded measure size to  800
Paraphrase detection: with zs 0.61875, with zc 0.7475
(0.38125, 0.7475)


In [3]:
print(ezs1.shape)


torch.Size([100000, 96])


In [2]:
from disentanglement_qkv.models import SE
import logging
def get_sent_eval(self):
        def prepare(params, samples):
            pass

        def batcher_zs(params, batch):
            batch = [' '.join(sent) if sent != [] else '.' for sent in batch]
            embeddings = self.embed_sents(batch)[0]
            return embeddings.detach().cpu().clone()
        

        def batcher_zc(params, batch):
            batch = [' '.join(sent) if sent != [] else '.' for sent in batch]
            embeddings = self.embed_sents(batch)[1]
            return embeddings.detach().cpu().clone()

        logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)
        # Set params for SentEval
        print("Performing evaluation with zs")
        task_path = os.path.join("disentanglement_qkv", "senteval", "data")
        params = {'task_path': task_path, 'usepytorch': True, 'kfold': 10}
        params['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
                                 'tenacity': 3, 'epoch_size': 2}
            #{'nhid': 50, 'optim': 'adam', 'batch_size': 64, 'tenacity': 5, 'epoch_size': 4}
        se = SE(params, batcher_zs, prepare)

        transfer_tasks = [#'STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark',
                          'BigramShift', 'Depth', 'TopConstituents']

        results_zs = se.eval(transfer_tasks)

        print("Performing evaluation with zc")
        task_path = os.path.join("disentanglement_qkv", "senteval", "data")
        params = {'task_path': task_path, 'usepytorch': True, 'kfold': 10}
        params['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
                                 'tenacity': 3, 'epoch_size': 2}
        #{'nhid': 50, 'optim': 'adam', 'batch_size': 64, 'tenacity': 5, 'epoch_size': 4}
        se = SE(params, batcher_zc, prepare)

        transfer_tasks = [#'STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark',
                          'BigramShift', 'Depth', 'TopConstituents']

        results_zc = se.eval(transfer_tasks)
        return results_zs, results_zc 

se_result_zs, se_results_zc  = get_sent_eval(model)

2021-08-20 15:04:03,746 : ***** (Probing) Transfer task : BIGRAMSHIFT classification *****


Performing evaluation with zs


2021-08-20 15:04:05,441 : Loaded 100000 train - 10000 dev - 10000 test for BigramShift


2021-08-20 15:04:05,452 : Computing embeddings for train/dev/test


2021-08-20 15:07:33,968 : Computed embeddings


2021-08-20 15:07:33,969 : Training pytorch-MLP-nhid0-rmsprop-bs128 with standard validation..


2021-08-20 15:08:03,845 : [('reg:1e-05', 51.57), ('reg:0.0001', 51.54), ('reg:0.001', 51.53), ('reg:0.01', 51.46)]


2021-08-20 15:08:03,848 : Validation : best param found is reg = 1e-05 with score             51.57


2021-08-20 15:08:03,848 : Evaluating...


2021-08-20 15:08:11,485 : 
Dev acc : 51.6 Test acc : 50.6 for BIGRAMSHIFT classification



2021-08-20 15:08:11,489 : ***** (Probing) Transfer task : DEPTH classification *****


2021-08-20 15:08:11,885 : Loaded 100000 train - 10000 dev - 10000 test for Depth


2021-08-20 15:08:11,952 : Computing embeddings for train/dev/test


2021-08-20 15:11:38,548 : Computed embeddings


2021-08-20 15:11:38,549 : Training pytorch-MLP-nhid0-rmsprop-bs128 with standard validation..


2021-08-20 15:12:05,075 : [('reg:1e-05', 18.56), ('reg:0.0001', 18.52), ('reg:0.001', 18.49), ('reg:0.01', 18.43)]


2021-08-20 15:12:05,076 : Validation : best param found is reg = 1e-05 with score             18.56


2021-08-20 15:12:05,076 : Evaluating...


2021-08-20 15:12:11,655 : 
Dev acc : 18.6 Test acc : 18.4 for DEPTH classification



2021-08-20 15:12:11,660 : ***** (Probing) Transfer task : TOPCONSTITUENTS classification *****


2021-08-20 15:12:11,984 : Loaded 100000 train - 10000 dev - 10000 test for TopConstituents


2021-08-20 15:12:12,041 : Computing embeddings for train/dev/test


2021-08-20 15:15:25,016 : Computed embeddings


2021-08-20 15:15:25,016 : Training pytorch-MLP-nhid0-rmsprop-bs128 with standard validation..


2021-08-20 15:16:05,660 : [('reg:1e-05', 12.79), ('reg:0.0001', 11.39), ('reg:0.001', 9.58), ('reg:0.01', 8.19)]


2021-08-20 15:16:05,661 : Validation : best param found is reg = 1e-05 with score             12.79


2021-08-20 15:16:05,661 : Evaluating...


2021-08-20 15:16:16,579 : 
Dev acc : 12.8 Test acc : 12.3 for TOPCONSTITUENTS classification



2021-08-20 15:16:16,659 : ***** (Probing) Transfer task : BIGRAMSHIFT classification *****


Performing evaluation with zc


2021-08-20 15:16:16,938 : Loaded 100000 train - 10000 dev - 10000 test for BigramShift


2021-08-20 15:16:16,946 : Computing embeddings for train/dev/test


2021-08-20 15:20:34,001 : Computed embeddings


2021-08-20 15:20:34,002 : Training pytorch-MLP-nhid0-rmsprop-bs128 with standard validation..


2021-08-20 15:21:06,924 : [('reg:1e-05', 50.67), ('reg:0.0001', 50.57), ('reg:0.001', 50.39), ('reg:0.01', 50.39)]


2021-08-20 15:21:06,925 : Validation : best param found is reg = 1e-05 with score             50.67


2021-08-20 15:21:06,925 : Evaluating...


2021-08-20 15:21:15,007 : 
Dev acc : 50.7 Test acc : 50.5 for BIGRAMSHIFT classification



2021-08-20 15:21:15,035 : ***** (Probing) Transfer task : DEPTH classification *****


2021-08-20 15:21:15,318 : Loaded 100000 train - 10000 dev - 10000 test for Depth


2021-08-20 15:21:15,387 : Computing embeddings for train/dev/test


MemoryError: Unable to allocate 293. MiB for an array with shape (100000, 768) and data type float32

In [None]:

from supar import Parser

const_parser = Parser.load('crf-con-en')

def truncate_tree(tree, lv):
    tok_i = 0
    curr_lv = 0
    tree_toks = tree.split()
    while tok_i != len(tree_toks):
        if tree_toks[tok_i].startswith('('):
            curr_lv += 1
        else:
            closed_lvs = int(tree_toks[tok_i].count(')'))
            if curr_lv - closed_lvs <= lv:
                tree_toks[tok_i] = ')'*(closed_lvs - (curr_lv-lv))
            curr_lv -= closed_lvs
        if lv >= curr_lv and tree_toks[tok_i]!='':
            tok_i += 1
        else:
            tree_toks.pop(tok_i)
    return ' '.join(tree_toks)

def get_lin_parse_tree(sens):
    tree_parses = const_parser.predict(sens, lang='en', verbose=False)
    lin_parses = []
    for p in tree_parses:
        lin_p = repr(p)
        if lin_p.startswith("(TOP"):
            lin_p = lin_p[5:-1]
        lin_parses.append(lin_p)
    return lin_parses


In [22]:

def template_match(l1, l2, lv, verbose=0, filter_empty=True):
    if filter_empty:
        not_empty1 = [any([c != " " for c in li1]) for li1 in l1]
        not_empty2 = [any([c != " " for c in li2]) for li2 in l2]
        l1 = [li1 for li1, ne1, ne2 in zip(l1, not_empty1, not_empty2) if ne1 and ne2]
        l2 = [li2 for li2, ne1, ne2 in zip(l2, not_empty1, not_empty2) if ne1 and ne2]
        print(not_empty1)
        print(not_empty2)
    docs1, docs2 = get_lin_parse_tree(l1), get_lin_parse_tree(l2)
    temps1 = [truncate_tree(doc, lv) for doc in docs1]
    temps2 = [truncate_tree(doc, lv) for doc in docs2]
    if verbose:
        for l, t in zip(l1+l2, temps1+temps2):
            print(l, "-->", t)
        print("+++++++++++++++++++++++++")
    return [int(t1 == t2) for t1, t2 in zip(temps1, temps2)]

In [25]:
sens1 = ['Hello dear friend', "how is the day ?", "how is the weather ?"]
sens2 = ['My feet are on the table', "     .", "where are my ladies ?"]
res = template_match(sens1, sens2, 2)
print(res)

[True, True, True]
[True, True, True]
[0, 0, 1]
