## Variational Recurrent Network (VRNN)

Implementation based on Chung's *A Recurrent Latent Variable Model for Sequential Data* [arXiv:1506.02216v6].

##  Network design



There are three types of layers: input (x), hidden(h) and latent(z). We can compare VRNN sided by side with RNN to see how it works in generation phase.


- RNN: $h_o + x_o -> h_1 + x_1 -> h_2 + x_2 -> ...$
- VRNN: with $ h_o \left\{
\begin{array}{ll}
      h_o -> z_1 \\
      z_1 + h_o -> x_1\\
      z_1 + x_1 + h_o -> h_1 \\
\end{array} 
\right .$ 
with $ h_1 \left\{
\begin{array}{ll}
      h_1 -> z_2 \\
      z_2 + h_1 -> x_2\\
      z_2 + x_2 + h_1 -> h_2 \\
\end{array} 
\right .$


It is clearer to see how it works in the code blocks below. This loop is used to generate new text when the network is properly trained. x is wanted output, h is deterministic hidden state, and z is latent state (stochastic hidden state). Both h and z are changing with repect to time.

## Training



The VRNN above contains three components, a latent layer genreator $h_o -> z_1$, a decoder net to get $x_1$, and a recurrent net to get $h_1$ for the next cycle.


The training objective is to make sure $x_0$ is realistic. To do that, an encoder layer is added to transform $x_1 + h_0 -> z_1$. Then the decoder should transform $z_1 + h_o -> x_1$ correctly. This implies a cross-entropy loss in the "tiny shakespear" or MSE in image reconstruction.


Another loose end is  $h_o -> z_1$. Statistically, $x_1 + h_0 -> z_1$ should be the same as $h_o -> z_1$, if $x_1$ is sampled randomly. This constraint is formularize as a KL divergence between the two.



>### KL Divergence of Multivariate Normal Distribution
>![](https://wikimedia.org/api/rest_v1/media/math/render/svg/8dad333d8c5fc46358036ced5ab8e5d22bae708c)


Now putting everything together for one training cycle.

$\left\{
\begin{array}{ll}
      h_o -> z_{1,prior} \\
      x_1 + h_o -> z_{1,infer}\\
      z_1 <- sampling N(z_{1,infer})\\
      z_1 + h_o -> x_{1,reconstruct}\\
      z_1 + x_1 + h_o -> h_1 \\
\end{array} 
\right . $
=>
$
\left\{
\begin{array}{ll}
      loss\_latent = DL(z_{1,infer} | z_{1,prior}) \\
      loss\_reconstruct = x_1 - x_{1,reconstruct} \\
\end{array} 
\right .
$


In [2]:


import torch
from torch import nn, optim
from torch.autograd import Variable

class VRNNCell(nn.Module):
    def __init__(self):
        super(VRNNCell,self).__init__()
        self.phi_x = nn.Sequential(nn.Embedding(128,64), nn.Linear(64,64), nn.ELU())
        self.encoder = nn.Linear(128,64*2) # output hyperparameters
        self.phi_z = nn.Sequential(nn.Linear(64,64), nn.ELU())
        self.decoder = nn.Linear(128,128) # logits
        self.prior = nn.Linear(64,64*2) # output hyperparameters
        self.rnn = nn.GRUCell(128,64)
    def forward(self, x, hidden):
        x = self.phi_x(x)
        # 1. h => z
        z_prior = self.prior(hidden)
        # 2. x + h => z
        z_infer = self.encoder(torch.cat([x,hidden], dim=1))
        # sampling
        z = Variable(torch.randn(x.size(0),64))*z_infer[:,64:].exp()+z_infer[:,:64]
        z = self.phi_z(z)
        # 3. h + z => x
        x_out = self.decoder(torch.cat([hidden, z], dim=1))
        # 4. x + z => h
        hidden_next = self.rnn(torch.cat([x,z], dim=1),hidden)
        return x_out, hidden_next, z_prior, z_infer
    def calculate_loss(self, x, hidden):
        x_out, hidden_next, z_prior, z_infer = self.forward(x, hidden)
        # 1. logistic regression loss
        loss1 = nn.functional.cross_entropy(x_out, x) 
        # 2. KL Divergence between Multivariate Gaussian
        mu_infer, log_sigma_infer = z_infer[:,:64], z_infer[:,64:]
        mu_prior, log_sigma_prior = z_prior[:,:64], z_prior[:,64:]
        loss2 = (2*(log_sigma_infer-log_sigma_prior)).exp() \
                + ((mu_infer-mu_prior)/log_sigma_prior.exp())**2 \
                - 2*(log_sigma_infer-log_sigma_prior) - 1
        loss2 = 0.5*loss2.sum(dim=1).mean()
        return loss1, loss2, hidden_next
    def generate(self, hidden=None, temperature=None):
        if hidden is None:
            hidden=Variable(torch.zeros(1,64))
        if temperature is None:
            temperature = 0.8
        # 1. h => z
        z_prior = self.prior(hidden)
        # sampling
        z = Variable(torch.randn(z_prior.size(0),64))*z_prior[:,64:].exp()+z_prior[:,:64]
        z = self.phi_z(z)
        # 2. h + z => x
        x_out = self.decoder(torch.cat([hidden, z], dim=1))
        # sampling
        x_sample = x = x_out.div(temperature).exp().multinomial(1).squeeze()
        x = self.phi_x(x)
        # 3. x + z => h
        xkl = x.view(1,-1)
        hidden_next = self.rnn(torch.cat([xkl,z], dim=1),hidden)
        return x_sample, hidden_next
    def generate_text(self, hidden=None,temperature=None, n=100):
        res = []
        hidden = None
        for _ in range(n):
            x_sample, hidden = self.generate(hidden,temperature)
            res.append(chr(x_sample.data[0]))
        return "".join(res)
        
# Test
net = VRNNCell()
x = Variable(torch.LongTensor([12,13,14]))
hidden = Variable(torch.rand(3,64))
output, hidden_next, z_infer, z_prior = net(x, hidden)
loss1, loss2, _ = net.calculate_loss(x, hidden)
loss1, loss2
hidden = Variable(torch.zeros(1,64))
net.generate_text()



'w,\x0c+4\x08\x11^IPc\x1b\x048\x13,qIRG\x0bdo\x03~;BR681F}L\x00\x18{#xc3<E\x02ud8<\x07~\x08*\n}\x1b\x1eSN\x191$s\x009Lp{41QvE\x04*cl~67\x13\x19N{y{HO{d]}\rIz"Y\x0c\x1c")'

##  Download tiny shakspear text

In [2]:
from six.moves.urllib import request
url = "https://raw.githubusercontent.com/jcjohnson/torch-rnn/master/data/tiny-shakespeare.txt"
text = request.urlopen(url).read().decode()

print('-----SAMPLE----\n')
print(text[:100])

-----SAMPLE----

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


### A convinient function to sample text

In [3]:
import numpy as np

def batch_generator(seq_size=300, batch_size=64):
    cap = len(text) - seq_size*batch_size
    while True:
        idx = np.random.randint(0, cap, batch_size)
        res = []
        for _ in range(seq_size):
            batch = torch.LongTensor([ord(text[i]) for i in idx])
            res.append(batch)
            idx += 1
        yield res

g = batch_generator()
batch = next(g)

NameError: name 'text' is not defined

## Model Training


In [4]:
net = VRNNCell()
# the_model = torch.load(model_path)
max_epoch = 2000
optimizer = optim.Adam(net.parameters(), lr=0.01)
g = batch_generator()
model_path = "save/model.pt"


hidden = Variable(torch.zeros(64,64)) #batch_size x hidden_size
for epoch in range(max_epoch):
    batch = next(g)
    loss_seq = 0
    loss1_seq, loss2_seq = 0, 0
    optimizer.zero_grad()
    for x in batch:
        loss1, loss2, hidden = net.calculate_loss(Variable(x),hidden)
        loss1_seq += loss1.data[0]
        loss2_seq += loss2.data[0]
        loss_seq = loss_seq + loss1+loss2
    loss_seq.backward()
    optimizer.step()
    hidden.detach_()
    if epoch%200==0:
        print('>> epoch {}, loss {:12.4f}, decoder loss {:12.4f}, latent loss {:12.4f}'.format(epoch, loss_seq.data[0], loss1_seq, loss2_seq))
        print(net.generate_text())
        print()
torch.save(net, model_path)        



>> epoch 0, loss    4114.2339, decoder loss    1460.9972, latent loss    2653.2354
@{k>VP^*V T~=|!pQQQrp[F{$~rN^xI.=c/fU/1m Vi"2nD8oYg\4(0

>> epoch 200, loss     634.5780, decoder loss     627.3054, latent loss       7.2730
thald thy a tre.

BOMIPIO:
Thame diour as bear? I howaw the theath!

LONVENIO:
I' as of in mor elipl

>> epoch 400, loss     543.5396, decoder loss     540.6561, latent loss       2.8834
ing Some!
'To dear stoul I goes to past
Seepese the sore of foul with my doust.
Serveren that preato

>> epoch 600, loss     515.4258, decoder loss     512.3687, latent loss       3.0574
BOLESSER:
Which you are and thy hast at a grovesed and
To his grave for my franted have ance time
Th

>> epoch 800, loss     500.3734, decoder loss     497.9365, latent loss       2.4368
tent and then report,
When you sale none be of accuce
The dukes the day of Withter us with look.

FR

>> epoch 1000, loss     487.2025, decoder loss     483.1525, latent loss       4.0500
re dis m

  "type " + obj.__name__ + ". It won't be checked "


## Evaluation

In [4]:
model_path = "save/model.pt"
cdss = torch.load(model_path)
sample = cdss.generate_text(n=1000, temperature=1)
print(sample)



a each as:
Of my lord untemb, last do well up,
Well, and stalm not.

BUCKINGHAM:
Mark to more too, women te love is a dearte:
Hake you, Mancia; that Bonenday, one fortun
To spirly -troeable, there injul'd the darchicly capsent
vice, best viciunties, soon soule
Duke these then is not break itoar on our treat
And that we't him into her sea-mend, thy queen I in
This hearth beasugnes fellow. Wh' lords to thyself?
My sword mascholeny, to wear your late,
Provershing goest luins in thy genize hearth.

Second Roman,
He's behonan may grace of her and reason on their siltry,
And so, and breaks a hundey, give your hopety fall
And cruidy love, mayself bevil in that,
With Tripst Auforable too peatue,
Shall thou dranch in eye--nown, this?

First JEWIS:
Will go I do you entruct seem;
From the comastice accondily imprast:
Be stand yet, if a hout by thy pities
Tell? beings, being va generalt couse,
I,
she duke weep to our lame is in York, there up the gentless?

RICHARD:
A knee he but ro'd friend, sosh

## Comments

- Denifinitely train longer to get better results. 
- Keep in mind the rnn kernel only has 1 layer, with 64 neurons.
- Seems no need to tune temperature here. temperature = 0.8 generates a lot of obscure spelling. temperature = 1 works fine.

In [12]:
model_path = "save/model.pt"
net = torch.load(model_path)
max_epoch = 100
optimizer = optim.Adam(net.parameters(), lr=0.0001)
g = batch_generator()


hidden = Variable(torch.zeros(64,64)) #batch_size x hidden_size
for epoch in range(max_epoch):
    batch = next(g)
    loss_seq = 0
    loss1_seq, loss2_seq = 0, 0
    optimizer.zero_grad()
    for x in batch:
        loss1, loss2, hidden = net.calculate_loss(Variable(x),hidden)
        loss1_seq += loss1.data[0]
        loss2_seq += loss2.data[0]
        loss_seq = loss_seq + loss1+loss2
    loss_seq.backward()
    optimizer.step()
    hidden.detach_()
    if epoch%20==0:
        print('>> epoch {}, loss {:12.4f}, decoder loss {:12.4f}, latent loss {:12.4f}'.format(epoch, loss_seq.data[0], loss1_seq, loss2_seq))
        print(net.generate_text())
        print()

  app.launch_new_instance()


>> epoch 0, loss     464.4120, decoder loss     464.2681, latent loss       0.1439
and,
I shall be meer it of you in the speak.

GLOUCESTER:
Thine take me my lord, but it is scarned t

>> epoch 20, loss     463.3658, decoder loss     463.2422, latent loss       0.1236
 not not that his hage:
My peiconds, so wite thee come to the breest;
Remour else, him say the purtu

>> epoch 40, loss     462.0584, decoder loss     461.9486, latent loss       0.1099
ites:
From Henry, this mounth gods and Petual:
O, I sige suppose; wate all thine Rome,
And the fair 

>> epoch 60, loss     452.8532, decoder loss     452.7492, latent loss       0.1040
ar,
That we have duke, but another, God's stand
how, and give on on our jeam and first;
Pert in the 

>> epoch 80, loss     447.1253, decoder loss     447.0273, latent loss       0.0979
seth!

AUFIDIUS:
It brauds is thy flien, I pady more for the deapetts,
And the call to be instartine



In [13]:
torch.save(net, model_path)    

  "type " + obj.__name__ + ". It won't be checked "
