<a href="https://colab.research.google.com/github/guanjiew/csc412_vae/blob/main/ControlVAE_Text_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Code is adapted from [the original ControlVAE repository](https://github.com/shj1987/ControlVAE-ICML2020/tree/master/Language_modeling/Text_gen_PTB)

# Preparation

In [1]:
%pip install texar-pytorch

Collecting texar-pytorch
[?25l  Downloading https://files.pythonhosted.org/packages/2d/7e/20aa39ee9d19dcd1a1c0db4dad878088cfba216f45aeaa8fa89615ef46c0/texar_pytorch-0.1.2.post1-py3-none-any.whl (434kB)
[K     |▊                               | 10kB 19.9MB/s eta 0:00:01[K     |█▌                              | 20kB 26.4MB/s eta 0:00:01[K     |██▎                             | 30kB 23.7MB/s eta 0:00:01[K     |███                             | 40kB 26.9MB/s eta 0:00:01[K     |███▊                            | 51kB 24.6MB/s eta 0:00:01[K     |████▌                           | 61kB 27.2MB/s eta 0:00:01[K     |█████▎                          | 71kB 18.6MB/s eta 0:00:01[K     |██████                          | 81kB 19.9MB/s eta 0:00:01[K     |██████▉                         | 92kB 18.5MB/s eta 0:00:01[K     |███████▌                        | 102kB 18.3MB/s eta 0:00:01[K     |████████▎                       | 112kB 18.3MB/s eta 0:00:01[K     |█████████                 

In [2]:
import math, os
from typing import Any, Dict, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch import Tensor
from torch.optim.lr_scheduler import ExponentialLR

import texar.torch as tx
from texar.torch.custom import MultivariateNormalDiag
from tqdm import tqdm

import pandas as pd
import numpy as np

In [3]:
# Loading the dataset (TODO: Load and preprocess Yelp dataset instead)

# data_path = "./simple-examples/data"
# train_path = os.path.join(data_path, "ptb.train.txt")
# if not os.path.exists(train_path):
#     url = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz'
#     tx.data.maybe_download(url, './', extract=True)

# train_path = os.path.join(data_path, "ptb.train.txt")
# vocab_path = os.path.join(data_path, "vocab.txt")
# word_to_id = tx.data.make_vocab(
#     train_path, return_type="dict")

# with open(vocab_path, 'w') as fvocab:
#     for word in word_to_id:
#         fvocab.write("%s\n" % word)

In [4]:
# Loading and preprocessing yelp review data

%mkdir -p /content/simple-examples/data/
%cd /content

data_path = "./simple-examples/data"
train_path = os.path.join(data_path, "train.txt")
validate_path = os.path.join(data_path, "validate.txt")
test_path = os.path.join(data_path, "test.txt")
if not (os.path.exists(train_path) or os.path.exists(validate_path) or os.path.exists(test_path)):
  url = 'https://raw.githubusercontent.com/rekiksab/Yelp-Data-Challenge-2013/master/yelp_challenge/yelp_phoenix_academic_dataset/yelp_academic_dataset_review.json'
  df = pd.read_json(url, lines=True)
  text = df['text']
  percent = .01
  train, validate, test = np.split(text.sample(frac=percent, random_state=42), [int(.6*percent*len(text)), int(.8*percent*len(text))])
  np.savetxt(train_path, train.values, fmt='%s')
  np.savetxt(validate_path, validate.values, fmt='%s')
  np.savetxt(test_path, test.values, fmt='%s')
  train_path = os.path.join(data_path, "train.txt")
  validate_path = os.path.join(data_path, "validate.txt")
  test_path = os.path.join(data_path, "test.txt")

vocab_path = os.path.join(data_path, "vocab.txt")
word_to_id = tx.data.make_vocab(
    train_path, return_type="dict")

with open(vocab_path, 'w') as fvocab:
    for word in word_to_id:
        fvocab.write("%s\n" % word)

/content


# VAE Model

In [5]:
def kl_divergence(means: Tensor, logvars: Tensor) -> Tensor:
    """Compute the KL divergence between Gaussian distribution
    """
    kl_cost = -0.5 * (logvars - means ** 2 -
                      torch.exp(logvars) + 1.0)
    kl_cost = torch.mean(kl_cost, 0)
    return torch.sum(kl_cost)

In [6]:
class VAE(nn.Module):
    _latent_z: Tensor

    def __init__(self, vocab_size: int, config_model):
        super().__init__()
        # Model architecture
        self._config = config_model
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.encoder_w_embedder = tx.modules.WordEmbedder(
            vocab_size=vocab_size, hparams=config_model.enc_emb_hparams)

        self.encoder = tx.modules.UnidirectionalRNNEncoder[tx.core.LSTMState](
            input_size=self.encoder_w_embedder.dim,
            hparams={
                "rnn_cell": config_model.enc_cell_hparams,
            })

        self.decoder_w_embedder = tx.modules.WordEmbedder(
            vocab_size=vocab_size, hparams=config_model.dec_emb_hparams)

        if config_model.decoder_type == "lstm":
            self.lstm_decoder = tx.modules.BasicRNNDecoder(
                input_size=(self.decoder_w_embedder.dim +
                            config_model.latent_dims),
                vocab_size=vocab_size,
                token_embedder=self._embed_fn_rnn,
                hparams={"rnn_cell": config_model.dec_cell_hparams})
            sum_state_size = self.lstm_decoder.cell.hidden_size * 2

        elif config_model.decoder_type == 'transformer':
            # position embedding
            self.decoder_p_embedder = tx.modules.SinusoidsPositionEmbedder(
                position_size=config_model.max_pos,
                hparams=config_model.dec_pos_emb_hparams)
            # decoder
            self.transformer_decoder = tx.modules.TransformerDecoder(
                # tie word embedding with output layer
                output_layer=self.decoder_w_embedder.embedding,
                token_pos_embedder=self._embed_fn_transformer,
                hparams=config_model.trans_hparams)
            sum_state_size = self._config.dec_emb_hparams["dim"]

        else:
            raise ValueError("Decoder type must be 'lstm' or 'transformer'")

        self.connector_mlp = tx.modules.MLPTransformConnector(
            config_model.latent_dims * 2,
            linear_layer_dim=self.encoder.cell.hidden_size * 2)

        self.mlp_linear_layer = nn.Linear(
            config_model.latent_dims, sum_state_size)

    def forward(self,  # type: ignore
                data_batch: tx.data.Batch,
                kl_weight: float, start_tokens: torch.LongTensor,
                end_token: int) -> Dict[str, Tensor]:
        # encoder -> connector -> decoder
        text_ids = data_batch["text_ids"].to(self.device)
        mean, logvar = self.encode(text_ids, data_batch["length"].to(self.device))
        kl_loss = kl_divergence(mean, logvar)
        dst = MultivariateNormalDiag(
            loc=mean, scale_diag=torch.exp(0.5 * logvar))

        latent_z = dst.rsample()
        helper = None
        if self._config.decoder_type == "lstm":
            helper = self.lstm_decoder.create_helper(
                decoding_strategy="train_greedy",
                start_tokens=start_tokens,
                end_token=end_token)

        # decode
        seq_lengths = data_batch["length"].to(self.device) - 1
        outputs = self.decode(
            helper=helper, latent_z=latent_z,
            text_ids=text_ids[:, :-1], seq_lengths=seq_lengths)

        logits = outputs.logits

        # Losses & train ops
        rc_loss = tx.losses.sequence_sparse_softmax_cross_entropy(
            labels=text_ids[:, 1:], logits=logits,
            sequence_length=seq_lengths)

        nll = rc_loss + kl_weight * kl_loss

        ret = {
            "nll": nll,
            "kl_loss": kl_loss,
            "rc_loss": rc_loss,
            "lengths": seq_lengths,
        }

        return ret

    def _embed_fn_rnn(self, tokens: torch.LongTensor) -> Tensor:
        r"""Generates word embeddings
        """
        embedding = self.decoder_w_embedder(tokens)
        latent_z = self._latent_z
        if len(embedding.size()) > 2:
            latent_z = latent_z.unsqueeze(0).repeat(tokens.size(0), 1, 1)
        return torch.cat([embedding, latent_z], dim=-1)

    def _embed_fn_transformer(self,
                              tokens: torch.LongTensor,
                              positions: torch.LongTensor) -> Tensor:
        r"""Generates word embeddings combined with positional embeddings
        """
        output_p_embed = self.decoder_p_embedder(positions)
        output_w_embed = self.decoder_w_embedder(tokens)
        output_w_embed = output_w_embed * self._config.hidden_size ** 0.5
        output_embed = output_w_embed + output_p_embed
        return output_embed
    
    def encode(self, text_ids, seq_lengths):
        input_embed = self.encoder_w_embedder(text_ids)
        _, encoder_states = self.encoder(
            input_embed,
            sequence_length=seq_lengths)
        mean_logvar = self.connector_mlp(encoder_states)
        mean, logvar = torch.chunk(mean_logvar, 2, 1)
        return mean, logvar

    @property
    def decoder(self) -> tx.modules.DecoderBase:
        if self._config.decoder_type == "lstm":
            return self.lstm_decoder
        else:
            return self.transformer_decoder

    def decode(self,
               helper: Optional[tx.modules.Helper],
               latent_z: Tensor,
               text_ids: Optional[torch.LongTensor] = None,
               seq_lengths: Optional[Tensor] = None,
               max_decoding_length: Optional[int] = None) \
            -> Union[tx.modules.BasicRNNDecoderOutput,
                     tx.modules.TransformerDecoderOutput]:
        self._latent_z = latent_z
        fc_output = self.mlp_linear_layer(latent_z)

        if self._config.decoder_type == "lstm":
            lstm_states = torch.chunk(fc_output, 2, dim=1)
            outputs, _, _ = self.lstm_decoder(
                initial_state=lstm_states,
                inputs=text_ids,
                helper=helper,
                sequence_length=seq_lengths,
                max_decoding_length=max_decoding_length)
        else:
            transformer_states = fc_output.unsqueeze(1)
            outputs = self.transformer_decoder(
                inputs=text_ids,
                memory=transformer_states,
                memory_sequence_length=torch.ones(transformer_states.size(0)),
                helper=helper,
                max_decoding_length=max_decoding_length)
        return outputs
    

# PID Control

In [7]:
class PIDControl():
    """docstring for ClassName"""
    def __init__(self):
        """define them out of loop"""
        # self.exp_KL = exp_KL
        self.I_k1 = 0.0
        self.W_k1 = 0.0
        self.e_k1 = 0.0
        
    def _Kp_fun(self, Err, scale=1):
        return 1.0/(1.0 + float(scale)*math.exp(Err))
        

    def pid(self, exp_KL, kl_loss, Kp=0.001, Ki=-0.001, Kd=0.01):
        """
        position PID algorithm
        Input: KL_loss
        return: weight for KL loss, beta
        """
        error_k = exp_KL - kl_loss
        ## comput U as the control factor
        Pk = Kp * self._Kp_fun(error_k)
        Ik = self.I_k1 + Ki * error_k

        ## window up for integrator
        if self.W_k1 < 0 and self.W_k1 > 1:
            Ik = self.I_k1
            
        Wk = Pk + Ik
        self.W_k1 = Wk
        self.I_k1 = Ik
        self.e_k1 = error_k
        
        ## min and max value
        if Wk > 1:
            Wk = 1.0
        if Wk < 0:
            Wk = 0.0
        
        return Wk
    

# Hyperparameters

In [8]:
class Config():
  dataset = "ptb"
  num_epochs = 100
  hidden_size = 256
  dec_dropout_in = 0.5
  dec_dropout_out = 0.5
  enc_dropout_in = 0.
  enc_dropout_out = 0.
  word_keep_prob = 0.5
  batch_size = 32
  embed_dim = 256

  latent_dims = 32

  lr_decay_hparams = {
      "init_lr": 0.001,
      "threshold": 2,
      "decay_factor": 0.5,
      "max_decay": 5
  }


  decoder_type = 'lstm'

  enc_cell_hparams = {
      "type": "LSTMCell",
      "kwargs": {
          "num_units": hidden_size,
          "bias": 0.
      },
      "dropout": {"output_keep_prob": 1. - enc_dropout_out},
      "num_layers": 1
  }

  dec_cell_hparams = {
      "type": "LSTMCell",
      "kwargs": {
          "num_units": hidden_size,
          "bias": 0.,
      },
      "dropout": {"output_keep_prob": 1. - dec_dropout_out},
      "num_layers": 1,
  }

  enc_emb_hparams = {
      'name': 'lookup_table',
      "dim": embed_dim,
      "dropout_rate": enc_dropout_in,
      'initializer': {
          'type': 'normal_',
          'kwargs': {
              'mean': 0.0,
              'std': embed_dim**-0.5,
          },
      }
  }

  dec_emb_hparams = {
      'name': 'lookup_table',
      "dim": embed_dim,
      "dropout_rate": dec_dropout_in,
      'initializer': {
          'type': 'normal_',
          'kwargs': {
              'mean': 0.0,
              'std': embed_dim**-0.5,
          },
      }
  }

  # KL annealing
  kl_anneal_hparams = {
      "warm_up": 10,
      "start": 0.1
  }

  train_data_hparams = {
      "num_epochs": 1,
      "batch_size": batch_size,
      "seed": 123,
      "dataset": {
          "files": './simple-examples/data/train.txt',
          "vocab_file": './simple-examples/data/vocab.txt'
      }
  }

  val_data_hparams = {
      "num_epochs": 1,
      "batch_size": batch_size,
      "seed": 123,
      "dataset": {
          "files": './simple-examples/data/validate.txt',
          "vocab_file": './simple-examples/data/vocab.txt'
      }
  }

  test_data_hparams = {
      "num_epochs": 1,
      "batch_size": batch_size,
      "dataset": {
          "files": './simple-examples/data/test.txt',
          "vocab_file": './simple-examples/data/vocab.txt'
      }
  }

  opt_hparams = {
      'optimizer': {
          'type': 'Adam',
          'kwargs': {
              'lr': 0.001
          }
      },
      'gradient_clip': {
          "type": "clip_grad_norm_",
          "kwargs": {
              "max_norm": 5,
              "norm_type": 2
          }
      }
  }

# Training the model

In [9]:
config = Config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_data = tx.data.MonoTextData(config.train_data_hparams, device=torch.device("cpu"))
val_data = tx.data.MonoTextData(config.val_data_hparams, device=torch.device("cpu"))
test_data = tx.data.MonoTextData(config.test_data_hparams, device=torch.device("cpu"))

iterator = tx.data.DataIterator({"train": train_data, "valid": val_data, "test": test_data})

opt_vars = {
    'learning_rate': config.lr_decay_hparams["init_lr"],
    'best_valid_nll': 1e100,
    'steps_not_improved': 0,
    'kl_weight': config.kl_anneal_hparams["start"]
}

decay_cnt = 0
max_decay = config.lr_decay_hparams["max_decay"]
decay_factor = config.lr_decay_hparams["decay_factor"]
decay_ts = config.lr_decay_hparams["threshold"]

save_path = './checkpoint.ckpt'

anneal_r = 1.0 / (config.kl_anneal_hparams["warm_up"] * (len(train_data) / config.batch_size))

vocab = train_data.vocab
model = VAE(train_data.vocab.size, config)
model.to(device)

start_tokens = torch.full(
    (config.batch_size,),
    vocab.bos_token_id,
    dtype=torch.long).to(device)
end_token = vocab.eos_token_id
optimizer = tx.core.get_optimizer(
    params=model.parameters(),
    hparams=config.opt_hparams)
scheduler = ExponentialLR(optimizer, decay_factor)

max_iter = min(config.num_epochs*len(train_data)/config.batch_size, 80000)
print('max steps:', max_iter)

global_steps = {}
global_steps['step'] = 0
pid = PIDControl()
opt_vars["kl_weight"] = 0.0
Kp = 0.01
Ki = -0.0001
exp_kl = 0

max steps: 18690.625


In [10]:
def _run_epoch(epoch: int, mode: str, display: int = 10) -> Tuple[Tensor, float]:
    iterator.switch_to_dataset(mode)

    if mode == 'train':
        model.train()
        kl_weight = opt_vars["kl_weight"]
    else:
        model.eval()
        kl_weight = 1.0
    
    num_words = 0
    nll_total = 0.

    avg_rec = tx.utils.AverageRecorder()
    for batch in iterator:
        ## run model to get loss function
        if global_steps['step']>= max_iter:
            break
        ret = model(batch, kl_weight, start_tokens, end_token)
        if mode == "train":
            pbar.update(1)
            global_steps['step'] += 1
            kl_loss = ret['kl_loss'].item()
            rec_loss = ret['rc_loss'].item()
            total_loss = ret["nll"].item()
            kl_weight = pid.pid(exp_kl, kl_loss, Kp, Ki)

            opt_vars["kl_weight"] = kl_weight
            
            ## total loss
            ret["nll"].backward()
            optimizer.step()
            optimizer.zero_grad()

        batch_size = len(ret["lengths"])
        num_words += torch.sum(ret["lengths"]).item()
        nll_total += ret["nll"].item() * batch_size
        avg_rec.add(
            [ret["nll"].item(),
              ret["kl_loss"].item(),
              ret["rc_loss"].item()],
            batch_size)
            
        if global_steps['step'] % display == 1 and mode == 'train':
            nll = avg_rec.avg(0)
            klw = opt_vars["kl_weight"]
            KL = avg_rec.avg(1)
            rc = avg_rec.avg(2)
            
    nll = avg_rec.avg(0)
    KL = avg_rec.avg(1)
    rc = avg_rec.avg(2)
    if num_words > 0:
        log_ppl = nll_total / num_words
        ppl = math.exp(log_ppl)
    else:
        log_ppl = 100
        ppl = math.exp(log_ppl)
        nll = 1000
        KL = args.exp_kl
    
    print(f"\n{mode}: epoch {epoch}, nll {nll:.4f}, KL {KL:.4f}, "
          f"rc {rc:.4f}, log_ppl {log_ppl:.4f}, ppl {ppl:.4f}")
    return nll, ppl  # type: ignore

In [11]:
# Counts trainable parameters
total_parameters = sum(param.numel() for param in model.parameters())
print(f"{total_parameters} total parameters")

best_nll = best_ppl = 0.

## start running model
pbar = tqdm(total = int(max_iter))
for epoch in range(config.num_epochs):
    _, _ = _run_epoch(epoch, 'train', display=200)
    val_nll, _ = _run_epoch(epoch, 'valid')
    test_nll, test_ppl = _run_epoch(epoch, 'test')

    if val_nll < opt_vars['best_valid_nll']:
        opt_vars['best_valid_nll'] = val_nll
        opt_vars['steps_not_improved'] = 0
        best_nll = test_nll
        best_ppl = test_ppl

        states = {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict()
        }
        torch.save(states, save_path)
    else:
        opt_vars['steps_not_improved'] += 1
        if opt_vars['steps_not_improved'] == decay_ts:
            old_lr = opt_vars['learning_rate']
            opt_vars['learning_rate'] *= decay_factor
            opt_vars['steps_not_improved'] = 0
            new_lr = opt_vars['learning_rate']
            ckpt = torch.load(save_path)
            model.load_state_dict(ckpt['model'])
            optimizer.load_state_dict(ckpt['optimizer'])
            scheduler.load_state_dict(ckpt['scheduler'])
            scheduler.step()
            print(f"-----\nchange lr, old lr: {old_lr}, "
                  f"new lr: {new_lr}\n-----")

            decay_cnt += 1
            if decay_cnt == max_decay:
                break
    if global_steps['step'] >= max_iter:
        break

print(f"\nbest testing nll: {best_nll:.4f},"
      f"best testing ppl {best_ppl:.4f}\n")

  0%|          | 0/18690 [00:00<?, ?it/s]

18768856 total parameters


  1%|          | 187/18690 [01:06<1:43:23,  2.98it/s]


train: epoch 0, nll 243.3305, KL 7.9083, rc 242.7153, log_ppl 7.6320, ppl 2063.0751

valid: epoch 0, nll 218.4918, KL 1.3941, rc 217.0978, log_ppl 7.1453, ppl 1268.0921

test: epoch 0, nll 213.4397, KL 1.3956, rc 212.0441, log_ppl 7.1547, ppl 1280.0504


  2%|▏         | 374/18690 [02:26<1:55:04,  2.65it/s]


train: epoch 1, nll 220.1756, KL 1.0346, rc 220.0050, log_ppl 6.9057, ppl 997.9560

valid: epoch 1, nll 207.0242, KL 1.2159, rc 205.8083, log_ppl 6.7702, ppl 871.5277

test: epoch 1, nll 202.3507, KL 1.2338, rc 201.1169, log_ppl 6.7829, ppl 882.6612


  3%|▎         | 561/18690 [03:50<1:37:04,  3.11it/s]


train: epoch 2, nll 206.7424, KL 1.1462, rc 206.5299, log_ppl 6.4844, ppl 654.8353

valid: epoch 2, nll 199.8725, KL 1.2302, rc 198.6424, log_ppl 6.5364, ppl 689.7762

test: epoch 2, nll 195.2884, KL 1.2434, rc 194.0450, log_ppl 6.5462, ppl 696.5976


  4%|▍         | 748/18690 [05:12<1:56:45,  2.56it/s]


train: epoch 3, nll 198.5139, KL 1.3069, rc 198.2412, log_ppl 6.2263, ppl 505.8796

valid: epoch 3, nll 197.3921, KL 1.0863, rc 196.3058, log_ppl 6.4552, ppl 636.0323

test: epoch 3, nll 192.8432, KL 1.1090, rc 191.7341, log_ppl 6.4642, ppl 641.7777


  5%|▌         | 935/18690 [06:34<1:42:19,  2.89it/s]


train: epoch 4, nll 192.6214, KL 1.2698, rc 192.3260, log_ppl 6.0415, ppl 420.5162

valid: epoch 4, nll 196.0353, KL 1.0015, rc 195.0338, log_ppl 6.4109, ppl 608.4285

test: epoch 4, nll 191.4714, KL 1.0136, rc 190.4578, log_ppl 6.4183, ppl 612.9356


  6%|▌         | 1122/18690 [07:55<2:36:10,  1.87it/s]


train: epoch 5, nll 187.8666, KL 1.2807, rc 187.5381, log_ppl 5.8924, ppl 362.2557

valid: epoch 5, nll 195.9889, KL 1.1092, rc 194.8797, log_ppl 6.4094, ppl 607.5066

test: epoch 5, nll 191.4240, KL 1.1229, rc 190.3011, log_ppl 6.4167, ppl 611.9631


  7%|▋         | 1309/18690 [09:16<1:30:21,  3.21it/s]


train: epoch 6, nll 183.6492, KL 1.2884, rc 183.2877, log_ppl 5.7601, ppl 317.3718

valid: epoch 6, nll 195.5329, KL 1.2834, rc 194.2496, log_ppl 6.3945, ppl 598.5147

test: epoch 6, nll 191.0512, KL 1.2928, rc 189.7584, log_ppl 6.4042, ppl 604.3623


  8%|▊         | 1496/18690 [10:38<1:59:49,  2.39it/s]


train: epoch 7, nll 179.7846, KL 1.2677, rc 179.3987, log_ppl 5.6389, ppl 281.1432

valid: epoch 7, nll 195.8458, KL 1.1895, rc 194.6564, log_ppl 6.4047, ppl 604.6706


  8%|▊         | 1497/18690 [10:55<25:19:49,  5.30s/it]


test: epoch 7, nll 191.4298, KL 1.2035, rc 190.2263, log_ppl 6.4169, ppl 612.0812


  9%|▉         | 1683/18690 [12:00<1:45:46,  2.68it/s]


train: epoch 8, nll 176.1263, KL 1.2554, rc 175.7145, log_ppl 5.5241, ppl 250.6660

valid: epoch 8, nll 196.1982, KL 1.0577, rc 195.1405, log_ppl 6.4162, ppl 611.6779

test: epoch 8, nll 191.7211, KL 1.0700, rc 190.6511, log_ppl 6.4266, ppl 618.0871
-----
change lr, old lr: 0.001, new lr: 0.0005
-----


 10%|█         | 1870/18690 [13:22<2:00:47,  2.32it/s]


train: epoch 9, nll 179.0136, KL 1.1859, rc 178.5978, log_ppl 5.6147, ppl 274.4262

valid: epoch 9, nll 195.8132, KL 1.0390, rc 194.7743, log_ppl 6.4036, ppl 604.0262


 10%|█         | 1871/18690 [13:39<25:20:07,  5.42s/it]


test: epoch 9, nll 191.2808, KL 1.0453, rc 190.2355, log_ppl 6.4119, ppl 609.0313


 11%|█         | 2057/18690 [14:44<1:43:21,  2.68it/s]


train: epoch 10, nll 176.8416, KL 1.1971, rc 176.3951, log_ppl 5.5466, ppl 256.3538

valid: epoch 10, nll 195.9819, KL 1.0066, rc 194.9753, log_ppl 6.4091, ppl 607.3661

test: epoch 10, nll 191.4847, KL 1.0168, rc 190.4678, log_ppl 6.4187, ppl 613.2082
-----
change lr, old lr: 0.0005, new lr: 0.00025
-----


 12%|█▏        | 2244/18690 [16:06<1:34:37,  2.90it/s]


train: epoch 11, nll 179.0873, KL 1.1189, rc 178.6458, log_ppl 5.6170, ppl 275.0607

valid: epoch 11, nll 195.8138, KL 1.1855, rc 194.6283, log_ppl 6.4036, ppl 604.0375


 12%|█▏        | 2245/18690 [16:24<24:56:07,  5.46s/it]


test: epoch 11, nll 191.2732, KL 1.1988, rc 190.0744, log_ppl 6.4116, ppl 608.8770


 13%|█▎        | 2431/18690 [17:29<1:33:22,  2.90it/s]


train: epoch 12, nll 176.9096, KL 1.1283, rc 176.4407, log_ppl 5.5487, ppl 256.9005

valid: epoch 12, nll 195.9980, KL 1.1540, rc 194.8440, log_ppl 6.4097, ppl 607.6869

test: epoch 12, nll 191.4305, KL 1.1640, rc 190.2665, log_ppl 6.4169, ppl 612.0959
-----
change lr, old lr: 0.00025, new lr: 0.000125
-----


 14%|█▍        | 2618/18690 [18:49<1:06:46,  4.01it/s]


train: epoch 13, nll 179.0834, KL 1.0813, rc 178.6118, log_ppl 5.6169, ppl 275.0276

valid: epoch 13, nll 195.4964, KL 0.8417, rc 194.6547, log_ppl 6.3933, ppl 597.7996

test: epoch 13, nll 191.0971, KL 0.8484, rc 190.2487, log_ppl 6.4057, ppl 605.2925


 15%|█▌        | 2805/18690 [20:12<1:17:08,  3.43it/s]


train: epoch 14, nll 176.9938, KL 1.0646, rc 176.5082, log_ppl 5.5513, ppl 257.5805

valid: epoch 14, nll 195.9019, KL 0.8696, rc 195.0323, log_ppl 6.4065, ppl 605.7796


 15%|█▌        | 2806/18690 [20:29<23:43:48,  5.38s/it]


test: epoch 14, nll 191.4777, KL 0.8787, rc 190.5990, log_ppl 6.4185, ppl 613.0650


 16%|█▌        | 2992/18690 [21:34<1:22:31,  3.17it/s]


train: epoch 15, nll 175.0213, KL 1.0643, rc 174.5146, log_ppl 5.4895, ppl 242.1271

valid: epoch 15, nll 196.1106, KL 0.9588, rc 195.1519, log_ppl 6.4133, ppl 609.9295

test: epoch 15, nll 191.6083, KL 0.9692, rc 190.6391, log_ppl 6.4228, ppl 615.7547
-----
change lr, old lr: 0.000125, new lr: 6.25e-05
-----


 17%|█▋        | 3179/18690 [22:56<1:18:34,  3.29it/s]


train: epoch 16, nll 176.6669, KL 1.0302, rc 176.1564, log_ppl 5.5411, ppl 254.9528

valid: epoch 16, nll 195.7139, KL 0.9303, rc 194.7836, log_ppl 6.4004, ppl 602.0672


 17%|█▋        | 3180/18690 [23:13<22:42:20,  5.27s/it]


test: epoch 16, nll 191.2531, KL 0.9391, rc 190.3140, log_ppl 6.4109, ppl 608.4665


 18%|█▊        | 3366/18690 [24:18<1:02:12,  4.11it/s]


train: epoch 17, nll 175.4979, KL 1.0251, rc 174.9703, log_ppl 5.5044, ppl 245.7743

valid: epoch 17, nll 195.8720, KL 0.8937, rc 194.9783, log_ppl 6.4055, ppl 605.1879

test: epoch 17, nll 191.4011, KL 0.9033, rc 190.4978, log_ppl 6.4159, ppl 611.4926
-----
change lr, old lr: 6.25e-05, new lr: 3.125e-05
-----

best testing nll: 191.0971,best testing ppl 605.2925



# Generate text

In [12]:
model.eval()

batch_size = train_data.batch_size

test_sentence = 'The service was terrible.'
text_ids = torch.tensor(vocab.map_tokens_to_ids_py([(test_sentence+' <EOS>').split()])).cuda()
mean, logvar = model.encode(text_ids, torch.tensor([5]).cuda())
dst = MultivariateNormalDiag(loc=mean[0], scale_diag=torch.exp(logvar[0]))
latent_z = dst.sample((batch_size,))

# latent_z = torch.FloatTensor(batch_size, config.latent_dims).uniform_(-1, 1).to(device)
# latent_z = torch.randn(batch_size, config.latent_dims).to(device)

helper = model.decoder.create_helper(
    decoding_strategy='infer_sample',
    start_tokens=start_tokens,
    end_token=end_token)
outputs = model.decode(
    helper=helper,
    latent_z=latent_z,
    max_decoding_length=100)

if config.decoder_type == "transformer":
    outputs = outputs[0]


sample_tokens = vocab.map_ids_to_tokens_py(outputs.sample_id.cpu())
for sent in sample_tokens:
    sent = tx.utils.compat_as_text(list(sent))
    end_id = len(sent)
    if vocab.eos_token in sent:
        end_id = sent.index(vocab.eos_token)
    print(' '.join(sent[:end_id + 1]) + '\n')

Oh anymore. <EOS>

It's given this several beds, 100 LOVE! <EOS>

When I enjoy this experience a little hotel that so we went on people So, she would go out if we were very served. Martini Always had yet Kickass for small and It's two greatly Whenever I love our booster. burritos were a variety of choi, unreasonable. <EOS>

Oh food. Also, the times were the lucky very dirt to pick up and I've really reasonable. <EOS>

You sat very stadium here and I have hair, lost my tiramisu amount of me in 3 happening and complain I didn't order up to lump people there, the (round, corn pesto started may ensure it would suggest to have what gives me feel slightly burnt to be becoming a try. <EOS>

- also is a nice Burger with my favorite time that they have been on. It was great, the morning. Nice team two tires. By the menu here has the blueberry pita came in an practice breakfast - Saw theres very it? <EOS>

My dish was really three of a wheelchair look a-a-a-amazing!! <EOS>

To be Loews...like po

In [16]:
def mutual_information(model, data_loader, dataset='train', N1=10000, N2=10000):
    data_loader.switch_to_dataset(dataset)
    model.eval()

    kl1 = None
    kl2 = torch.tensor(0.) # TODO: Compute second KL term
    for batch in data_loader:
        mean, logvar = model.encode(batch['text_ids'].cuda(), batch['length'].cuda())
        dst = MultivariateNormalDiag(loc=mean, scale_diag=torch.exp(logvar))
        z1 = dst.sample((N1,))
        batch_kl1 = torch.mean(torch.sum(((z1 ** 2 - ((z1 - mean) / torch.exp(logvar)) ** 2) / 2) - logvar, axis=-1), axis=0)
        if kl1 is None:
            kl1 = batch_kl1
        else:
            kl1 = torch.cat((kl1, batch_kl1))
    kl1 = torch.mean(kl1)
    return kl1 - kl2

with torch.no_grad():
    print(mutual_information(model, iterator))

tensor(1.7999, device='cuda:0')
