<a href="https://colab.research.google.com/github/guanjiew/csc412_vae/blob/main/ControlVAE_Text_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Code is adapted from [the original ControlVAE repository](https://github.com/shj1987/ControlVAE-ICML2020/tree/master/Language_modeling/Text_gen_PTB)

# Preparation

In [1]:
%pip install texar-pytorch

Collecting texar-pytorch
[?25l  Downloading https://files.pythonhosted.org/packages/2d/7e/20aa39ee9d19dcd1a1c0db4dad878088cfba216f45aeaa8fa89615ef46c0/texar_pytorch-0.1.2.post1-py3-none-any.whl (434kB)
[K     |▊                               | 10kB 22.0MB/s eta 0:00:01[K     |█▌                              | 20kB 15.0MB/s eta 0:00:01[K     |██▎                             | 30kB 13.3MB/s eta 0:00:01[K     |███                             | 40kB 12.0MB/s eta 0:00:01[K     |███▊                            | 51kB 7.9MB/s eta 0:00:01[K     |████▌                           | 61kB 8.6MB/s eta 0:00:01[K     |█████▎                          | 71kB 8.5MB/s eta 0:00:01[K     |██████                          | 81kB 9.4MB/s eta 0:00:01[K     |██████▉                         | 92kB 9.0MB/s eta 0:00:01[K     |███████▌                        | 102kB 7.7MB/s eta 0:00:01[K     |████████▎                       | 112kB 7.7MB/s eta 0:00:01[K     |█████████                       |

In [2]:
import math, os
from typing import Any, Dict, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch import Tensor
from torch.optim.lr_scheduler import ExponentialLR

import texar.torch as tx
from texar.torch.custom import MultivariateNormalDiag
from tqdm import tqdm

import pandas as pd
import numpy as np

In [3]:
# Loading the dataset (TODO: Load and preprocess Yelp dataset instead)

# data_path = "./simple-examples/data"
# train_path = os.path.join(data_path, "ptb.train.txt")
# if not os.path.exists(train_path):
#     url = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz'
#     tx.data.maybe_download(url, './', extract=True)

# train_path = os.path.join(data_path, "ptb.train.txt")
# vocab_path = os.path.join(data_path, "vocab.txt")
# word_to_id = tx.data.make_vocab(
#     train_path, return_type="dict")

# with open(vocab_path, 'w') as fvocab:
#     for word in word_to_id:
#         fvocab.write("%s\n" % word)

In [4]:
# Loading and preprocessing yelp review data

%mkdir -p /content/simple-examples/data/
%cd /content

data_path = "./simple-examples/data"
train_path = os.path.join(data_path, "train.txt")
validate_path = os.path.join(data_path, "validate.txt")
test_path = os.path.join(data_path, "test.txt")
if not (os.path.exists(train_path) or os.path.exists(validate_path) or os.path.exists(test_path)):
  url = 'https://raw.githubusercontent.com/rekiksab/Yelp-Data-Challenge-2013/master/yelp_challenge/yelp_phoenix_academic_dataset/yelp_academic_dataset_review.json'
  df = pd.read_json(url, lines=True)
  text = df['text']
  percent = .01
  train, validate, test = np.split(text.sample(frac=percent, random_state=42), [int(.6*percent*len(text)), int(.8*percent*len(text))])
  np.savetxt(train_path, train.values, fmt='%s')
  np.savetxt(validate_path, validate.values, fmt='%s')
  np.savetxt(test_path, test.values, fmt='%s')
  train_path = os.path.join(data_path, "train.txt")
  validate_path = os.path.join(data_path, "validate.txt")
  test_path = os.path.join(data_path, "test.txt")

vocab_path = os.path.join(data_path, "vocab.txt")
word_to_id = tx.data.make_vocab(
    train_path, return_type="dict")

with open(vocab_path, 'w') as fvocab:
    for word in word_to_id:
        fvocab.write("%s\n" % word)

/content


# VAE Model

In [5]:
def kl_divergence(means: Tensor, logvars: Tensor) -> Tensor:
    """Compute the KL divergence between Gaussian distribution
    """
    kl_cost = -0.5 * (logvars - means ** 2 -
                      torch.exp(logvars) + 1.0)
    kl_cost = torch.mean(kl_cost, 0)
    return torch.sum(kl_cost)

In [6]:
class VAE(nn.Module):
    _latent_z: Tensor

    def __init__(self, vocab_size: int, config_model):
        super().__init__()
        # Model architecture
        self._config = config_model
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.encoder_w_embedder = tx.modules.WordEmbedder(
            vocab_size=vocab_size, hparams=config_model.enc_emb_hparams)

        self.encoder = tx.modules.UnidirectionalRNNEncoder[tx.core.LSTMState](
            input_size=self.encoder_w_embedder.dim,
            hparams={
                "rnn_cell": config_model.enc_cell_hparams,
            })

        self.decoder_w_embedder = tx.modules.WordEmbedder(
            vocab_size=vocab_size, hparams=config_model.dec_emb_hparams)

        if config_model.decoder_type == "lstm":
            self.lstm_decoder = tx.modules.BasicRNNDecoder(
                input_size=(self.decoder_w_embedder.dim +
                            config_model.latent_dims),
                vocab_size=vocab_size,
                token_embedder=self._embed_fn_rnn,
                hparams={"rnn_cell": config_model.dec_cell_hparams})
            sum_state_size = self.lstm_decoder.cell.hidden_size * 2

        elif config_model.decoder_type == 'transformer':
            # position embedding
            self.decoder_p_embedder = tx.modules.SinusoidsPositionEmbedder(
                position_size=config_model.max_pos,
                hparams=config_model.dec_pos_emb_hparams)
            # decoder
            self.transformer_decoder = tx.modules.TransformerDecoder(
                # tie word embedding with output layer
                output_layer=self.decoder_w_embedder.embedding,
                token_pos_embedder=self._embed_fn_transformer,
                hparams=config_model.trans_hparams)
            sum_state_size = self._config.dec_emb_hparams["dim"]

        else:
            raise ValueError("Decoder type must be 'lstm' or 'transformer'")

        self.connector_mlp = tx.modules.MLPTransformConnector(
            config_model.latent_dims * 2,
            linear_layer_dim=self.encoder.cell.hidden_size * 2)

        self.mlp_linear_layer = nn.Linear(
            config_model.latent_dims, sum_state_size)

    def forward(self,  # type: ignore
                data_batch: tx.data.Batch,
                kl_weight: float, start_tokens: torch.LongTensor,
                end_token: int) -> Dict[str, Tensor]:
        # encoder -> connector -> decoder
        text_ids = data_batch["text_ids"].to(self.device)
        mean, logvar = self.encode(text_ids, data_batch["length"].to(self.device))
        kl_loss = kl_divergence(mean, logvar)
        dst = MultivariateNormalDiag(
            loc=mean, scale_diag=torch.exp(0.5 * logvar))

        latent_z = dst.rsample()
        helper = None
        if self._config.decoder_type == "lstm":
            helper = self.lstm_decoder.create_helper(
                decoding_strategy="train_greedy",
                start_tokens=start_tokens,
                end_token=end_token)

        # decode
        seq_lengths = data_batch["length"].to(self.device) - 1
        outputs = self.decode(
            helper=helper, latent_z=latent_z,
            text_ids=text_ids[:, :-1], seq_lengths=seq_lengths)

        logits = outputs.logits

        # Losses & train ops
        rc_loss = tx.losses.sequence_sparse_softmax_cross_entropy(
            labels=text_ids[:, 1:], logits=logits,
            sequence_length=seq_lengths)

        nll = rc_loss + kl_weight * kl_loss

        ret = {
            "nll": nll,
            "kl_loss": kl_loss,
            "rc_loss": rc_loss,
            "lengths": seq_lengths,
        }

        return ret

    def _embed_fn_rnn(self, tokens: torch.LongTensor) -> Tensor:
        r"""Generates word embeddings
        """
        embedding = self.decoder_w_embedder(tokens)
        latent_z = self._latent_z
        if len(embedding.size()) > 2:
            latent_z = latent_z.unsqueeze(0).repeat(tokens.size(0), 1, 1)
        return torch.cat([embedding, latent_z], dim=-1)

    def _embed_fn_transformer(self,
                              tokens: torch.LongTensor,
                              positions: torch.LongTensor) -> Tensor:
        r"""Generates word embeddings combined with positional embeddings
        """
        output_p_embed = self.decoder_p_embedder(positions)
        output_w_embed = self.decoder_w_embedder(tokens)
        output_w_embed = output_w_embed * self._config.hidden_size ** 0.5
        output_embed = output_w_embed + output_p_embed
        return output_embed
    
    def encode(self, text_ids, seq_lengths):
        input_embed = self.encoder_w_embedder(text_ids)
        _, encoder_states = self.encoder(
            input_embed,
            sequence_length=seq_lengths)
        mean_logvar = self.connector_mlp(encoder_states)
        mean, logvar = torch.chunk(mean_logvar, 2, 1)
        return mean, logvar

    @property
    def decoder(self) -> tx.modules.DecoderBase:
        if self._config.decoder_type == "lstm":
            return self.lstm_decoder
        else:
            return self.transformer_decoder

    def decode(self,
               helper: Optional[tx.modules.Helper],
               latent_z: Tensor,
               text_ids: Optional[torch.LongTensor] = None,
               seq_lengths: Optional[Tensor] = None,
               max_decoding_length: Optional[int] = None) \
            -> Union[tx.modules.BasicRNNDecoderOutput,
                     tx.modules.TransformerDecoderOutput]:
        self._latent_z = latent_z
        fc_output = self.mlp_linear_layer(latent_z)

        if self._config.decoder_type == "lstm":
            lstm_states = torch.chunk(fc_output, 2, dim=1)
            outputs, _, _ = self.lstm_decoder(
                initial_state=lstm_states,
                inputs=text_ids,
                helper=helper,
                sequence_length=seq_lengths,
                max_decoding_length=max_decoding_length)
        else:
            transformer_states = fc_output.unsqueeze(1)
            outputs = self.transformer_decoder(
                inputs=text_ids,
                memory=transformer_states,
                memory_sequence_length=torch.ones(transformer_states.size(0)),
                helper=helper,
                max_decoding_length=max_decoding_length)
        return outputs
    

# PID Control

In [7]:
class PIDControl():
    """docstring for ClassName"""
    def __init__(self):
        """define them out of loop"""
        # self.exp_KL = exp_KL
        self.I_k1 = 0.0
        self.W_k1 = 0.0
        self.e_k1 = 0.0
        
    def _Kp_fun(self, Err, scale=1):
        return 1.0/(1.0 + float(scale)*math.exp(Err))
        

    def pid(self, exp_KL, kl_loss, Kp=0.001, Ki=-0.001, Kd=0.01):
        """
        position PID algorithm
        Input: KL_loss
        return: weight for KL loss, beta
        """
        error_k = exp_KL - kl_loss
        ## comput U as the control factor
        Pk = Kp * self._Kp_fun(error_k)
        Ik = self.I_k1 + Ki * error_k

        ## window up for integrator
        if self.W_k1 < 0 and self.W_k1 > 1:
            Ik = self.I_k1
            
        Wk = Pk + Ik
        self.W_k1 = Wk
        self.I_k1 = Ik
        self.e_k1 = error_k
        
        ## min and max value
        if Wk > 1:
            Wk = 1.0
        if Wk < 0:
            Wk = 0.0
        
        return Wk
    

# Hyperparameters

In [8]:
class Config():
  dataset = "ptb"
  num_epochs = 100
  hidden_size = 256
  dec_dropout_in = 0.5
  dec_dropout_out = 0.5
  enc_dropout_in = 0.
  enc_dropout_out = 0.
  word_keep_prob = 0.5
  batch_size = 32
  embed_dim = 256

  latent_dims = 32

  lr_decay_hparams = {
      "init_lr": 0.001,
      "threshold": 2,
      "decay_factor": 0.5,
      "max_decay": 5
  }


  decoder_type = 'lstm'

  enc_cell_hparams = {
      "type": "LSTMCell",
      "kwargs": {
          "num_units": hidden_size,
          "bias": 0.
      },
      "dropout": {"output_keep_prob": 1. - enc_dropout_out},
      "num_layers": 1
  }

  dec_cell_hparams = {
      "type": "LSTMCell",
      "kwargs": {
          "num_units": hidden_size,
          "bias": 0.,
      },
      "dropout": {"output_keep_prob": 1. - dec_dropout_out},
      "num_layers": 1,
  }

  enc_emb_hparams = {
      'name': 'lookup_table',
      "dim": embed_dim,
      "dropout_rate": enc_dropout_in,
      'initializer': {
          'type': 'normal_',
          'kwargs': {
              'mean': 0.0,
              'std': embed_dim**-0.5,
          },
      }
  }

  dec_emb_hparams = {
      'name': 'lookup_table',
      "dim": embed_dim,
      "dropout_rate": dec_dropout_in,
      'initializer': {
          'type': 'normal_',
          'kwargs': {
              'mean': 0.0,
              'std': embed_dim**-0.5,
          },
      }
  }

  # KL annealing
  kl_anneal_hparams = {
      "warm_up": 10,
      "start": 0.1
  }

  train_data_hparams = {
      "num_epochs": 1,
      "batch_size": batch_size,
      "seed": 123,
      "dataset": {
          "files": './simple-examples/data/train.txt',
          "vocab_file": './simple-examples/data/vocab.txt'
      }
  }

  val_data_hparams = {
      "num_epochs": 1,
      "batch_size": batch_size,
      "seed": 123,
      "dataset": {
          "files": './simple-examples/data/validate.txt',
          "vocab_file": './simple-examples/data/vocab.txt'
      }
  }

  test_data_hparams = {
      "num_epochs": 1,
      "batch_size": batch_size,
      "dataset": {
          "files": './simple-examples/data/test.txt',
          "vocab_file": './simple-examples/data/vocab.txt'
      }
  }

  opt_hparams = {
      'optimizer': {
          'type': 'Adam',
          'kwargs': {
              'lr': 0.001
          }
      },
      'gradient_clip': {
          "type": "clip_grad_norm_",
          "kwargs": {
              "max_norm": 5,
              "norm_type": 2
          }
      }
  }

# Training the model

In [9]:
config = Config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_data = tx.data.MonoTextData(config.train_data_hparams, device=torch.device("cpu"))
val_data = tx.data.MonoTextData(config.val_data_hparams, device=torch.device("cpu"))
test_data = tx.data.MonoTextData(config.test_data_hparams, device=torch.device("cpu"))

iterator = tx.data.DataIterator({"train": train_data, "valid": val_data, "test": test_data})

opt_vars = {
    'learning_rate': config.lr_decay_hparams["init_lr"],
    'best_valid_nll': 1e100,
    'steps_not_improved': 0,
    'kl_weight': config.kl_anneal_hparams["start"]
}

decay_cnt = 0
max_decay = config.lr_decay_hparams["max_decay"]
decay_factor = config.lr_decay_hparams["decay_factor"]
decay_ts = config.lr_decay_hparams["threshold"]

save_path = './checkpoint.ckpt'

anneal_r = 1.0 / (config.kl_anneal_hparams["warm_up"] * (len(train_data) / config.batch_size))

vocab = train_data.vocab
model = VAE(train_data.vocab.size, config)
model.to(device)

start_tokens = torch.full(
    (config.batch_size,),
    vocab.bos_token_id,
    dtype=torch.long).to(device)
end_token = vocab.eos_token_id
optimizer = tx.core.get_optimizer(
    params=model.parameters(),
    hparams=config.opt_hparams)
scheduler = ExponentialLR(optimizer, decay_factor)

max_iter = min(config.num_epochs*len(train_data)/config.batch_size, 80000)
print('max steps:', max_iter)

global_steps = {}
global_steps['step'] = 0
pid = PIDControl()
opt_vars["kl_weight"] = 0.0
Kp = 0.01
Ki = -0.0001
exp_kl = 0

max steps: 18690.625


In [10]:
def _run_epoch(epoch: int, mode: str, display: int = 10) -> Tuple[Tensor, float]:
    iterator.switch_to_dataset(mode)

    if mode == 'train':
        model.train()
        kl_weight = opt_vars["kl_weight"]
    else:
        model.eval()
        kl_weight = 1.0
    
    num_words = 0
    nll_total = 0.

    avg_rec = tx.utils.AverageRecorder()
    for batch in iterator:
        ## run model to get loss function
        if global_steps['step']>= max_iter:
            break
        ret = model(batch, kl_weight, start_tokens, end_token)
        if mode == "train":
            pbar.update(1)
            global_steps['step'] += 1
            kl_loss = ret['kl_loss'].item()
            rec_loss = ret['rc_loss'].item()
            total_loss = ret["nll"].item()
            kl_weight = pid.pid(exp_kl, kl_loss, Kp, Ki)

            opt_vars["kl_weight"] = kl_weight
            
            ## total loss
            ret["nll"].backward()
            optimizer.step()
            optimizer.zero_grad()

        batch_size = len(ret["lengths"])
        num_words += torch.sum(ret["lengths"]).item()
        nll_total += ret["nll"].item() * batch_size
        avg_rec.add(
            [ret["nll"].item(),
              ret["kl_loss"].item(),
              ret["rc_loss"].item()],
            batch_size)
            
        if global_steps['step'] % display == 1 and mode == 'train':
            nll = avg_rec.avg(0)
            klw = opt_vars["kl_weight"]
            KL = avg_rec.avg(1)
            rc = avg_rec.avg(2)
            
    nll = avg_rec.avg(0)
    KL = avg_rec.avg(1)
    rc = avg_rec.avg(2)
    if num_words > 0:
        log_ppl = nll_total / num_words
        ppl = math.exp(log_ppl)
    else:
        log_ppl = 100
        ppl = math.exp(log_ppl)
        nll = 1000
        KL = args.exp_kl
    
    print(f"\n{mode}: epoch {epoch}, nll {nll:.4f}, KL {KL:.4f}, "
          f"rc {rc:.4f}, log_ppl {log_ppl:.4f}, ppl {ppl:.4f}")
    return nll, ppl  # type: ignore

In [11]:
# Counts trainable parameters
total_parameters = sum(param.numel() for param in model.parameters())
print(f"{total_parameters} total parameters")

best_nll = best_ppl = 0.

## start running model
pbar = tqdm(total = int(max_iter))
for epoch in range(config.num_epochs):
    _, _ = _run_epoch(epoch, 'train', display=200)
    val_nll, _ = _run_epoch(epoch, 'valid')
    test_nll, test_ppl = _run_epoch(epoch, 'test')

    if val_nll < opt_vars['best_valid_nll']:
        opt_vars['best_valid_nll'] = val_nll
        opt_vars['steps_not_improved'] = 0
        best_nll = test_nll
        best_ppl = test_ppl

        states = {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict()
        }
        torch.save(states, save_path)
    else:
        opt_vars['steps_not_improved'] += 1
        if opt_vars['steps_not_improved'] == decay_ts:
            old_lr = opt_vars['learning_rate']
            opt_vars['learning_rate'] *= decay_factor
            opt_vars['steps_not_improved'] = 0
            new_lr = opt_vars['learning_rate']
            ckpt = torch.load(save_path)
            model.load_state_dict(ckpt['model'])
            optimizer.load_state_dict(ckpt['optimizer'])
            scheduler.load_state_dict(ckpt['scheduler'])
            scheduler.step()
            print(f"-----\nchange lr, old lr: {old_lr}, "
                  f"new lr: {new_lr}\n-----")

            decay_cnt += 1
            if decay_cnt == max_decay:
                break
    if global_steps['step'] >= max_iter:
        break

print(f"\nbest testing nll: {best_nll:.4f},"
      f"best testing ppl {best_ppl:.4f}\n")

  0%|          | 0/18690 [00:00<?, ?it/s]

18768856 total parameters


  1%|          | 187/18690 [01:03<1:37:53,  3.15it/s]


train: epoch 0, nll 242.1186, KL 8.4507, rc 241.4173, log_ppl 7.5939, ppl 1986.1229

valid: epoch 0, nll 217.5254, KL 1.2186, rc 216.3068, log_ppl 7.1137, ppl 1228.6390

test: epoch 0, nll 212.4831, KL 1.2257, rc 211.2574, log_ppl 7.1226, ppl 1239.6523


  2%|▏         | 374/18690 [02:23<1:38:35,  3.10it/s]


train: epoch 1, nll 219.0069, KL 0.9301, rc 218.8452, log_ppl 6.8691, ppl 962.0360

valid: epoch 1, nll 204.9556, KL 0.9211, rc 204.0344, log_ppl 6.7026, ppl 814.5178

test: epoch 1, nll 200.2742, KL 0.9212, rc 199.3530, log_ppl 6.7133, ppl 823.3134


  3%|▎         | 561/18690 [03:43<1:49:40,  2.75it/s]


train: epoch 2, nll 205.5673, KL 1.1363, rc 205.3474, log_ppl 6.4475, ppl 631.1379

valid: epoch 2, nll 198.8487, KL 1.1501, rc 197.6987, log_ppl 6.5029, ppl 667.0646

test: epoch 2, nll 194.2608, KL 1.1413, rc 193.1195, log_ppl 6.5118, ppl 673.0120


  4%|▍         | 748/18690 [05:03<1:40:31,  2.97it/s]


train: epoch 3, nll 197.4882, KL 1.2958, rc 197.2076, log_ppl 6.1941, ppl 489.8639

valid: epoch 3, nll 196.3364, KL 1.1455, rc 195.1909, log_ppl 6.4207, ppl 614.4490

test: epoch 3, nll 191.8794, KL 1.1473, rc 190.7321, log_ppl 6.4319, ppl 621.3756


  5%|▌         | 935/18690 [06:23<2:01:44,  2.43it/s]


train: epoch 4, nll 191.5834, KL 1.3136, rc 191.2669, log_ppl 6.0089, ppl 407.0467

valid: epoch 4, nll 195.0562, KL 1.1105, rc 193.9457, log_ppl 6.3789, ppl 589.2567

test: epoch 4, nll 190.5888, KL 1.1193, rc 189.4695, log_ppl 6.3887, ppl 595.0675


  6%|▌         | 1122/18690 [07:43<1:30:17,  3.24it/s]


train: epoch 5, nll 186.6979, KL 1.3128, rc 186.3494, log_ppl 5.8557, ppl 349.2173

valid: epoch 5, nll 194.7297, KL 1.0741, rc 193.6556, log_ppl 6.3682, ppl 582.9982

test: epoch 5, nll 190.2367, KL 1.0827, rc 189.1540, log_ppl 6.3769, ppl 588.0848


  7%|▋         | 1309/18690 [09:04<1:30:58,  3.18it/s]


train: epoch 6, nll 182.4448, KL 1.2814, rc 182.0736, log_ppl 5.7223, ppl 305.6067

valid: epoch 6, nll 194.9913, KL 1.1892, rc 193.8022, log_ppl 6.3767, ppl 588.0067


  7%|▋         | 1310/18690 [09:20<24:24:30,  5.06s/it]


test: epoch 6, nll 190.5193, KL 1.1967, rc 189.3226, log_ppl 6.3863, ppl 593.6828


  8%|▊         | 1496/18690 [10:23<1:33:50,  3.05it/s]


train: epoch 7, nll 178.5352, KL 1.2816, rc 178.1332, log_ppl 5.5997, ppl 270.3389

valid: epoch 7, nll 194.8460, KL 1.1167, rc 193.7293, log_ppl 6.3720, ppl 585.2186

test: epoch 7, nll 190.4121, KL 1.1241, rc 189.2880, log_ppl 6.3827, ppl 591.5520
-----
change lr, old lr: 0.001, new lr: 0.0005
-----


  9%|▉         | 1683/18690 [11:41<1:36:30,  2.94it/s]


train: epoch 8, nll 181.8517, KL 1.1839, rc 181.4533, log_ppl 5.7037, ppl 299.9743

valid: epoch 8, nll 194.6275, KL 0.9629, rc 193.6646, log_ppl 6.3648, ppl 581.0527

test: epoch 8, nll 190.1776, KL 0.9678, rc 189.2099, log_ppl 6.3749, ppl 586.9218


 10%|█         | 1870/18690 [13:01<1:45:32,  2.66it/s]


train: epoch 9, nll 179.3956, KL 1.1839, rc 178.9709, log_ppl 5.6267, ppl 277.7334

valid: epoch 9, nll 194.7425, KL 1.0081, rc 193.7344, log_ppl 6.3686, ppl 583.2413


 10%|█         | 1871/18690 [13:18<25:14:58,  5.40s/it]


test: epoch 9, nll 190.3735, KL 1.0144, rc 189.3591, log_ppl 6.3815, ppl 590.7874


 11%|█         | 2057/18690 [14:20<2:18:18,  2.00it/s]


train: epoch 10, nll 177.2370, KL 1.1741, rc 176.7900, log_ppl 5.5590, ppl 259.5525

valid: epoch 10, nll 195.0363, KL 1.0577, rc 193.9786, log_ppl 6.3782, ppl 588.8726


 11%|█         | 2058/18690 [14:37<25:10:44,  5.45s/it]


test: epoch 10, nll 190.5787, KL 1.0633, rc 189.5154, log_ppl 6.3883, ppl 594.8666
-----
change lr, old lr: 0.0005, new lr: 0.00025
-----


 12%|█▏        | 2244/18690 [15:40<1:39:06,  2.77it/s]


train: epoch 11, nll 179.0867, KL 1.0991, rc 178.6450, log_ppl 5.6170, ppl 275.0553

valid: epoch 11, nll 194.8005, KL 0.9810, rc 193.8195, log_ppl 6.3705, ppl 584.3491


 12%|█▏        | 2245/18690 [15:56<23:29:13,  5.14s/it]


test: epoch 11, nll 190.3333, KL 0.9894, rc 189.3439, log_ppl 6.3801, ppl 589.9931


 13%|█▎        | 2431/18690 [17:00<1:33:28,  2.90it/s]


train: epoch 12, nll 177.8710, KL 1.1247, rc 177.3957, log_ppl 5.5788, ppl 264.7656

valid: epoch 12, nll 194.8330, KL 0.9301, rc 193.9029, log_ppl 6.3716, ppl 584.9704

test: epoch 12, nll 190.4129, KL 0.9352, rc 189.4776, log_ppl 6.3828, ppl 591.5680
-----
change lr, old lr: 0.00025, new lr: 0.000125
-----


 14%|█▍        | 2618/18690 [18:21<1:18:39,  3.41it/s]


train: epoch 13, nll 179.1431, KL 1.0663, rc 178.6707, log_ppl 5.6187, ppl 275.5425

valid: epoch 13, nll 194.7159, KL 0.9482, rc 193.7677, log_ppl 6.3677, ppl 582.7336


 14%|█▍        | 2619/18690 [18:38<22:56:54,  5.14s/it]


test: epoch 13, nll 190.2822, KL 0.9533, rc 189.3288, log_ppl 6.3784, ppl 588.9825


 15%|█▌        | 2805/18690 [19:41<1:29:48,  2.95it/s]


train: epoch 14, nll 177.9657, KL 1.0513, rc 177.4791, log_ppl 5.5818, ppl 265.5528

valid: epoch 14, nll 194.8098, KL 0.8750, rc 193.9348, log_ppl 6.3708, ppl 584.5272


 15%|█▌        | 2806/18690 [19:58<22:46:24,  5.16s/it]


test: epoch 14, nll 190.3667, KL 0.8809, rc 189.4858, log_ppl 6.3812, ppl 590.6535
-----
change lr, old lr: 0.000125, new lr: 6.25e-05
-----


 16%|█▌        | 2992/18690 [21:01<1:05:00,  4.02it/s]


train: epoch 15, nll 179.2369, KL 1.0301, rc 178.7402, log_ppl 5.6217, ppl 276.3546

valid: epoch 15, nll 194.6984, KL 0.9015, rc 193.7969, log_ppl 6.3672, ppl 582.4015


 16%|█▌        | 2993/18690 [21:18<21:49:06,  5.00s/it]


test: epoch 15, nll 190.3200, KL 0.9072, rc 189.4128, log_ppl 6.3797, ppl 589.7292


 17%|█▋        | 3179/18690 [22:21<1:12:29,  3.57it/s]


train: epoch 16, nll 177.9588, KL 0.9827, rc 177.4665, log_ppl 5.5816, ppl 265.4952

valid: epoch 16, nll 194.8302, KL 0.8232, rc 194.0070, log_ppl 6.3715, ppl 584.9162

test: epoch 16, nll 190.3695, KL 0.8280, rc 189.5415, log_ppl 6.3813, ppl 590.7082
-----
change lr, old lr: 6.25e-05, new lr: 3.125e-05
-----

best testing nll: 190.1776,best testing ppl 586.9218



# Generate text

In [14]:
model.eval()

batch_size = train_data.batch_size

test_sentence = 'The service was terrible.'
text_ids = torch.tensor(vocab.map_tokens_to_ids_py([(test_sentence+' <EOS>').split()])).cuda()
mean, logvar = model.encode(text_ids, torch.tensor([5]).cuda())
dst = MultivariateNormalDiag(loc=mean[0], scale_diag=torch.exp(logvar[0]))
latent_z = dst.sample((batch_size,))

# latent_z = torch.FloatTensor(batch_size, config.latent_dims).uniform_(-1, 1).to(device)
# latent_z = torch.randn(batch_size, config.latent_dims).to(device)

helper = model.decoder.create_helper(
    decoding_strategy='infer_sample',
    start_tokens=start_tokens,
    end_token=end_token)
outputs = model.decode(
    helper=helper,
    latent_z=latent_z,
    max_decoding_length=100)

if config.decoder_type == "transformer":
    outputs = outputs[0]


sample_tokens = vocab.map_ids_to_tokens_py(outputs.sample_id.cpu())
for sent in sample_tokens:
    sent = tx.utils.compat_as_text(list(sent))
    end_id = len(sent)
    if vocab.eos_token in sent:
        end_id = sent.index(vocab.eos_token)
    print(' '.join(sent[:end_id + 1]) + '\n')

My food is fresh and no was two selections of just disappointed. The portions did have another convenient bike adobo, food polishes and I did appears some again..maybe assume I will not will have time it, on a convenient owned afternoon dark med overpriced), <EOS>

We have found bad days with an experts for her toward both and she recommended across a person? with huge amount with us. The workers came in (every flavors and needed though. <EOS>

The <EOS>

It's adapt 10:30 fast food, fairly wasnt asking out to get to this to be expensive people, isn't a dessert can find the street around a friend's crowd toast. It's dry and those sugar next in myself atmosphere. This place is friendly. We were extremely sometime as much toppings did for no good mojito AND the mistake is quickly. The biscuits and - fast, bad. <EOS>

near ahead of pho, that were very really toppings: and chose: shaker I'm stage growing to them to try the diet of Scottsdale before the dishes were cos for friendly and Short