<a href="https://colab.research.google.com/github/guanjiew/csc412_vae/blob/main/ControlVAE_Text_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Code is adapted from [the original ControlVAE repository](https://github.com/shj1987/ControlVAE-ICML2020/tree/master/Language_modeling/Text_gen_PTB)

# Preparation

In [4]:
%pip install texar-pytorch

Collecting texar-pytorch
[?25l  Downloading https://files.pythonhosted.org/packages/2d/7e/20aa39ee9d19dcd1a1c0db4dad878088cfba216f45aeaa8fa89615ef46c0/texar_pytorch-0.1.2.post1-py3-none-any.whl (434kB)
[K     |▊                               | 10kB 22.3MB/s eta 0:00:01[K     |█▌                              | 20kB 29.7MB/s eta 0:00:01[K     |██▎                             | 30kB 20.9MB/s eta 0:00:01[K     |███                             | 40kB 24.1MB/s eta 0:00:01[K     |███▊                            | 51kB 23.5MB/s eta 0:00:01[K     |████▌                           | 61kB 25.8MB/s eta 0:00:01[K     |█████▎                          | 71kB 18.3MB/s eta 0:00:01[K     |██████                          | 81kB 19.3MB/s eta 0:00:01[K     |██████▉                         | 92kB 18.2MB/s eta 0:00:01[K     |███████▌                        | 102kB 18.6MB/s eta 0:00:01[K     |████████▎                       | 112kB 18.6MB/s eta 0:00:01[K     |█████████                 

In [5]:
import math, os
from typing import Any, Dict, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch import Tensor
from torch.optim.lr_scheduler import ExponentialLR

import texar.torch as tx
from texar.torch.custom import MultivariateNormalDiag
from tqdm import tqdm

import pandas as pd
import numpy as np

In [None]:
# Loading the dataset (TODO: Load and preprocess Yelp dataset instead)

# data_path = "./simple-examples/data"
# train_path = os.path.join(data_path, "ptb.train.txt")
# if not os.path.exists(train_path):
#     url = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz'
#     tx.data.maybe_download(url, './', extract=True)

# train_path = os.path.join(data_path, "ptb.train.txt")
# vocab_path = os.path.join(data_path, "vocab.txt")
# word_to_id = tx.data.make_vocab(
#     train_path, return_type="dict")

# with open(vocab_path, 'w') as fvocab:
#     for word in word_to_id:
#         fvocab.write("%s\n" % word)

In [8]:
# Loading and preprocessing yelp review data

%mkdir -p /content/simple-examples/data/
%cd /content

data_path = "./simple-examples/data"
train_path = os.path.join(data_path, "train.txt")
validate_path = os.path.join(data_path, "validate.txt")
test_path = os.path.join(data_path, "test.txt")
if not (os.path.exists(train_path) or os.path.exists(validate_path) or os.path.exists(test_path)):
  url = 'https://raw.githubusercontent.com/rekiksab/Yelp-Data-Challenge-2013/master/yelp_challenge/yelp_phoenix_academic_dataset/yelp_academic_dataset_review.json'
  df = pd.read_json(url, lines=True)
  text = df['text']
  train, validate, test = np.split(text.sample(frac=1, random_state=42), [int(.6*len(text)), int(.8*len(text))])
  np.savetxt(train_path, train.values, fmt='%s')
  np.savetxt(validate_path, validate.values, fmt='%s')
  np.savetxt(test_path, test.values, fmt='%s')
  train_path = os.path.join(data_path, "train.txt")
  validate_path = os.path.join(data_path, "validate.txt")
  test_path = os.path.join(data_path, "test.txt")

vocab_path = os.path.join(data_path, "vocab.txt")
word_to_id = tx.data.make_vocab(
    train_path, return_type="dict")

with open(vocab_path, 'w') as fvocab:
    for word in word_to_id:
        fvocab.write("%s\n" % word)

/content


# VAE Model

In [None]:
def kl_divergence(means: Tensor, logvars: Tensor) -> Tensor:
    """Compute the KL divergence between Gaussian distribution
    """
    kl_cost = -0.5 * (logvars - means ** 2 -
                      torch.exp(logvars) + 1.0)
    kl_cost = torch.mean(kl_cost, 0)
    return torch.sum(kl_cost)

In [None]:
class VAE(nn.Module):
    _latent_z: Tensor

    def __init__(self, vocab_size: int, config_model):
        super().__init__()
        # Model architecture
        self._config = config_model
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.encoder_w_embedder = tx.modules.WordEmbedder(
            vocab_size=vocab_size, hparams=config_model.enc_emb_hparams)

        self.encoder = tx.modules.UnidirectionalRNNEncoder[tx.core.LSTMState](
            input_size=self.encoder_w_embedder.dim,
            hparams={
                "rnn_cell": config_model.enc_cell_hparams,
            })

        self.decoder_w_embedder = tx.modules.WordEmbedder(
            vocab_size=vocab_size, hparams=config_model.dec_emb_hparams)

        if config_model.decoder_type == "lstm":
            self.lstm_decoder = tx.modules.BasicRNNDecoder(
                input_size=(self.decoder_w_embedder.dim +
                            config_model.latent_dims),
                vocab_size=vocab_size,
                token_embedder=self._embed_fn_rnn,
                hparams={"rnn_cell": config_model.dec_cell_hparams})
            sum_state_size = self.lstm_decoder.cell.hidden_size * 2

        elif config_model.decoder_type == 'transformer':
            # position embedding
            self.decoder_p_embedder = tx.modules.SinusoidsPositionEmbedder(
                position_size=config_model.max_pos,
                hparams=config_model.dec_pos_emb_hparams)
            # decoder
            self.transformer_decoder = tx.modules.TransformerDecoder(
                # tie word embedding with output layer
                output_layer=self.decoder_w_embedder.embedding,
                token_pos_embedder=self._embed_fn_transformer,
                hparams=config_model.trans_hparams)
            sum_state_size = self._config.dec_emb_hparams["dim"]

        else:
            raise ValueError("Decoder type must be 'lstm' or 'transformer'")

        self.connector_mlp = tx.modules.MLPTransformConnector(
            config_model.latent_dims * 2,
            linear_layer_dim=self.encoder.cell.hidden_size * 2)

        self.mlp_linear_layer = nn.Linear(
            config_model.latent_dims, sum_state_size)

    def forward(self,  # type: ignore
                data_batch: tx.data.Batch,
                kl_weight: float, start_tokens: torch.LongTensor,
                end_token: int) -> Dict[str, Tensor]:
        # encoder -> connector -> decoder
        text_ids = data_batch["text_ids"].to(self.device)
        input_embed = self.encoder_w_embedder(text_ids)
        _, encoder_states = self.encoder(
            input_embed,
            sequence_length=data_batch["length"].to(self.device))

        mean_logvar = self.connector_mlp(encoder_states)
        mean, logvar = torch.chunk(mean_logvar, 2, 1)
        kl_loss = kl_divergence(mean, logvar)
        dst = MultivariateNormalDiag(
            loc=mean, scale_diag=torch.exp(0.5 * logvar))

        latent_z = dst.rsample()
        helper = None
        if self._config.decoder_type == "lstm":
            helper = self.lstm_decoder.create_helper(
                decoding_strategy="train_greedy",
                start_tokens=start_tokens,
                end_token=end_token)

        # decode
        seq_lengths = data_batch["length"].to(self.device) - 1
        outputs = self.decode(
            helper=helper, latent_z=latent_z,
            text_ids=text_ids[:, :-1], seq_lengths=seq_lengths)

        logits = outputs.logits

        # Losses & train ops
        rc_loss = tx.losses.sequence_sparse_softmax_cross_entropy(
            labels=text_ids[:, 1:], logits=logits,
            sequence_length=seq_lengths)

        nll = rc_loss + kl_weight * kl_loss

        ret = {
            "nll": nll,
            "kl_loss": kl_loss,
            "rc_loss": rc_loss,
            "lengths": seq_lengths,
        }

        return ret

    def _embed_fn_rnn(self, tokens: torch.LongTensor) -> Tensor:
        r"""Generates word embeddings
        """
        embedding = self.decoder_w_embedder(tokens)
        latent_z = self._latent_z
        if len(embedding.size()) > 2:
            latent_z = latent_z.unsqueeze(0).repeat(tokens.size(0), 1, 1)
        return torch.cat([embedding, latent_z], dim=-1)

    def _embed_fn_transformer(self,
                              tokens: torch.LongTensor,
                              positions: torch.LongTensor) -> Tensor:
        r"""Generates word embeddings combined with positional embeddings
        """
        output_p_embed = self.decoder_p_embedder(positions)
        output_w_embed = self.decoder_w_embedder(tokens)
        output_w_embed = output_w_embed * self._config.hidden_size ** 0.5
        output_embed = output_w_embed + output_p_embed
        return output_embed

    @property
    def decoder(self) -> tx.modules.DecoderBase:
        if self._config.decoder_type == "lstm":
            return self.lstm_decoder
        else:
            return self.transformer_decoder

    def decode(self,
               helper: Optional[tx.modules.Helper],
               latent_z: Tensor,
               text_ids: Optional[torch.LongTensor] = None,
               seq_lengths: Optional[Tensor] = None,
               max_decoding_length: Optional[int] = None) \
            -> Union[tx.modules.BasicRNNDecoderOutput,
                     tx.modules.TransformerDecoderOutput]:
        self._latent_z = latent_z
        fc_output = self.mlp_linear_layer(latent_z)

        if self._config.decoder_type == "lstm":
            lstm_states = torch.chunk(fc_output, 2, dim=1)
            outputs, _, _ = self.lstm_decoder(
                initial_state=lstm_states,
                inputs=text_ids,
                helper=helper,
                sequence_length=seq_lengths,
                max_decoding_length=max_decoding_length)
        else:
            transformer_states = fc_output.unsqueeze(1)
            outputs = self.transformer_decoder(
                inputs=text_ids,
                memory=transformer_states,
                memory_sequence_length=torch.ones(transformer_states.size(0)),
                helper=helper,
                max_decoding_length=max_decoding_length)
        return outputs
    

# PID Control

In [None]:
class PIDControl():
    """docstring for ClassName"""
    def __init__(self):
        """define them out of loop"""
        # self.exp_KL = exp_KL
        self.I_k1 = 0.0
        self.W_k1 = 0.0
        self.e_k1 = 0.0
        
    def _Kp_fun(self, Err, scale=1):
        return 1.0/(1.0 + float(scale)*math.exp(Err))
        

    def pid(self, exp_KL, kl_loss, Kp=0.001, Ki=-0.001, Kd=0.01):
        """
        position PID algorithm
        Input: KL_loss
        return: weight for KL loss, beta
        """
        error_k = exp_KL - kl_loss
        ## comput U as the control factor
        Pk = Kp * self._Kp_fun(error_k)
        Ik = self.I_k1 + Ki * error_k

        ## window up for integrator
        if self.W_k1 < 0 and self.W_k1 > 1:
            Ik = self.I_k1
            
        Wk = Pk + Ik
        self.W_k1 = Wk
        self.I_k1 = Ik
        self.e_k1 = error_k
        
        ## min and max value
        if Wk > 1:
            Wk = 1.0
        if Wk < 0:
            Wk = 0.0
        
        return Wk
    

# Hyperparameters

In [None]:
class Config():
  dataset = "ptb"
  num_epochs = 100
  hidden_size = 256
  dec_dropout_in = 0.5
  dec_dropout_out = 0.5
  enc_dropout_in = 0.
  enc_dropout_out = 0.
  word_keep_prob = 0.5
  batch_size = 32
  embed_dim = 256

  latent_dims = 32

  lr_decay_hparams = {
      "init_lr": 0.001,
      "threshold": 2,
      "decay_factor": 0.5,
      "max_decay": 5
  }


  decoder_type = 'lstm'

  enc_cell_hparams = {
      "type": "LSTMCell",
      "kwargs": {
          "num_units": hidden_size,
          "bias": 0.
      },
      "dropout": {"output_keep_prob": 1. - enc_dropout_out},
      "num_layers": 1
  }

  dec_cell_hparams = {
      "type": "LSTMCell",
      "kwargs": {
          "num_units": hidden_size,
          "bias": 0.,
      },
      "dropout": {"output_keep_prob": 1. - dec_dropout_out},
      "num_layers": 1,
  }

  enc_emb_hparams = {
      'name': 'lookup_table',
      "dim": embed_dim,
      "dropout_rate": enc_dropout_in,
      'initializer': {
          'type': 'normal_',
          'kwargs': {
              'mean': 0.0,
              'std': embed_dim**-0.5,
          },
      }
  }

  dec_emb_hparams = {
      'name': 'lookup_table',
      "dim": embed_dim,
      "dropout_rate": dec_dropout_in,
      'initializer': {
          'type': 'normal_',
          'kwargs': {
              'mean': 0.0,
              'std': embed_dim**-0.5,
          },
      }
  }

  # KL annealing
  kl_anneal_hparams = {
      "warm_up": 10,
      "start": 0.1
  }

  train_data_hparams = {
      "num_epochs": 1,
      "batch_size": batch_size,
      "seed": 123,
      "dataset": {
          "files": './simple-examples/data/ptb.train.txt',
          "vocab_file": './simple-examples/data/vocab.txt'
      }
  }

  val_data_hparams = {
      "num_epochs": 1,
      "batch_size": batch_size,
      "seed": 123,
      "dataset": {
          "files": './simple-examples/data/ptb.valid.txt',
          "vocab_file": './simple-examples/data/vocab.txt'
      }
  }

  test_data_hparams = {
      "num_epochs": 1,
      "batch_size": batch_size,
      "dataset": {
          "files": './simple-examples/data/ptb.test.txt',
          "vocab_file": './simple-examples/data/vocab.txt'
      }
  }

  opt_hparams = {
      'optimizer': {
          'type': 'Adam',
          'kwargs': {
              'lr': 0.001
          }
      },
      'gradient_clip': {
          "type": "clip_grad_norm_",
          "kwargs": {
              "max_norm": 5,
              "norm_type": 2
          }
      }
  }

# Training the model

In [None]:
config = Config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_data = tx.data.MonoTextData(config.train_data_hparams, device=torch.device("cpu"))
val_data = tx.data.MonoTextData(config.val_data_hparams, device=torch.device("cpu"))
test_data = tx.data.MonoTextData(config.test_data_hparams, device=torch.device("cpu"))

iterator = tx.data.DataIterator({"train": train_data, "valid": val_data, "test": test_data})

opt_vars = {
    'learning_rate': config.lr_decay_hparams["init_lr"],
    'best_valid_nll': 1e100,
    'steps_not_improved': 0,
    'kl_weight': config.kl_anneal_hparams["start"]
}

decay_cnt = 0
max_decay = config.lr_decay_hparams["max_decay"]
decay_factor = config.lr_decay_hparams["decay_factor"]
decay_ts = config.lr_decay_hparams["threshold"]

save_path = './checkpoint.ckpt'

anneal_r = 1.0 / (config.kl_anneal_hparams["warm_up"] * (len(train_data) / config.batch_size))

vocab = train_data.vocab
model = VAE(train_data.vocab.size, config)
model.to(device)

start_tokens = torch.full(
    (config.batch_size,),
    vocab.bos_token_id,
    dtype=torch.long).to(device)
end_token = vocab.eos_token_id
optimizer = tx.core.get_optimizer(
    params=model.parameters(),
    hparams=config.opt_hparams)
scheduler = ExponentialLR(optimizer, decay_factor)

max_iter = min(config.num_epochs*len(train_data)/config.batch_size, 80000)
print('max steps:', max_iter)

global_steps = {}
global_steps['step'] = 0
pid = PIDControl()
opt_vars["kl_weight"] = 0.0
Kp = 0.01
Ki = -0.0001
exp_kl = 0

In [None]:
def _run_epoch(epoch: int, mode: str, display: int = 10) -> Tuple[Tensor, float]:
    iterator.switch_to_dataset(mode)

    if mode == 'train':
        model.train()
        kl_weight = opt_vars["kl_weight"]
    else:
        model.eval()
        kl_weight = 1.0
    
    num_words = 0
    nll_total = 0.

    avg_rec = tx.utils.AverageRecorder()
    for batch in iterator:
        ## run model to get loss function
        if global_steps['step']>= max_iter:
            break
        ret = model(batch, kl_weight, start_tokens, end_token)
        if mode == "train":
            pbar.update(1)
            global_steps['step'] += 1
            kl_loss = ret['kl_loss'].item()
            rec_loss = ret['rc_loss'].item()
            total_loss = ret["nll"].item()
            kl_weight = pid.pid(exp_kl, kl_loss, Kp, Ki)

            opt_vars["kl_weight"] = kl_weight
            
            ## total loss
            ret["nll"].backward()
            optimizer.step()
            optimizer.zero_grad()

        batch_size = len(ret["lengths"])
        num_words += torch.sum(ret["lengths"]).item()
        nll_total += ret["nll"].item() * batch_size
        avg_rec.add(
            [ret["nll"].item(),
              ret["kl_loss"].item(),
              ret["rc_loss"].item()],
            batch_size)
            
        if global_steps['step'] % display == 1 and mode == 'train':
            nll = avg_rec.avg(0)
            klw = opt_vars["kl_weight"]
            KL = avg_rec.avg(1)
            rc = avg_rec.avg(2)
            
    nll = avg_rec.avg(0)
    KL = avg_rec.avg(1)
    rc = avg_rec.avg(2)
    if num_words > 0:
        log_ppl = nll_total / num_words
        ppl = math.exp(log_ppl)
    else:
        log_ppl = 100
        ppl = math.exp(log_ppl)
        nll = 1000
        KL = args.exp_kl
    
    print(f"\n{mode}: epoch {epoch}, nll {nll:.4f}, KL {KL:.4f}, "
          f"rc {rc:.4f}, log_ppl {log_ppl:.4f}, ppl {ppl:.4f}")
    return nll, ppl  # type: ignore

In [None]:
# Counts trainable parameters
total_parameters = sum(param.numel() for param in model.parameters())
print(f"{total_parameters} total parameters")

best_nll = best_ppl = 0.

## start running model
pbar = tqdm(total = int(max_iter))
for epoch in range(config.num_epochs):
    _, _ = _run_epoch(epoch, 'train', display=200)
    val_nll, _ = _run_epoch(epoch, 'valid')
    test_nll, test_ppl = _run_epoch(epoch, 'test')

    if val_nll < opt_vars['best_valid_nll']:
        opt_vars['best_valid_nll'] = val_nll
        opt_vars['steps_not_improved'] = 0
        best_nll = test_nll
        best_ppl = test_ppl

        states = {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict()
        }
        torch.save(states, save_path)
    else:
        opt_vars['steps_not_improved'] += 1
        if opt_vars['steps_not_improved'] == decay_ts:
            old_lr = opt_vars['learning_rate']
            opt_vars['learning_rate'] *= decay_factor
            opt_vars['steps_not_improved'] = 0
            new_lr = opt_vars['learning_rate']
            ckpt = torch.load(save_path)
            model.load_state_dict(ckpt['model'])
            optimizer.load_state_dict(ckpt['optimizer'])
            scheduler.load_state_dict(ckpt['scheduler'])
            scheduler.step()
            print(f"-----\nchange lr, old lr: {old_lr}, "
                  f"new lr: {new_lr}\n-----")

            decay_cnt += 1
            if decay_cnt == max_decay:
                break
    if global_steps['step'] >= max_iter:
        break

print(f"\nbest testing nll: {best_nll:.4f},"
      f"best testing ppl {best_ppl:.4f}\n")

# Generate text

In [None]:
model.eval()

batch_size = train_data.batch_size

dst = MultivariateNormalDiag(
    loc=torch.zeros(batch_size, config.latent_dims),
    scale_diag=torch.ones(batch_size, config.latent_dims))

# latent_z = dst.rsample().to(device)
latent_z = torch.FloatTensor(batch_size, config.latent_dims).uniform_(-1, 1).to(device)
# latent_z = torch.randn(batch_size, config.latent_dims).to(device)

helper = model.decoder.create_helper(
    decoding_strategy='infer_sample',
    start_tokens=start_tokens,
    end_token=end_token)
outputs = model.decode(
    helper=helper,
    latent_z=latent_z,
    max_decoding_length=100)

if config.decoder_type == "transformer":
    outputs = outputs[0]

sample_tokens = vocab.map_ids_to_tokens_py(outputs.sample_id.cpu())

for sent in sample_tokens:
    sent = tx.utils.compat_as_text(list(sent))
    end_id = len(sent)
    if vocab.eos_token in sent:
        end_id = sent.index(vocab.eos_token)
    print(' '.join(sent[:end_id + 1]) + '\n')