# VRNN
* Original paper: A Recurrent Latent Variable Model for Sequential Data (https://arxiv.org/pdf/1506.02216.pdf)
* Original code: https://github.com/jych/nips2015_vrnn

## VRNN summary
VRNN extends the VAE into a recurrent framework for modelling high-dimensional sequences.  
VRNN integrates random variables into the RNN hidden state, and integrates the dependencies between the latent random variables at neighboring timesteps.

In [1]:
from tqdm import tqdm

import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from tensorboardX import SummaryWriter

batch_size = 256
epochs = 5
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
# Experiment Setting
# generate MNIST by stacking row images(consider row as time step)
def init_dataset(f_batch_size):
    kwargs = {'num_workers': 1, 'pin_memory': True}
    data_dir = '../data'
    mnist_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda data: data[0])
    ])
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(data_dir, train=True, download=True,
                       transform=mnist_transform),
        batch_size=f_batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(data_dir, train=False, transform=mnist_transform),
        batch_size=f_batch_size, shuffle=True, **kwargs)

    fixed_t_size = 28
    return train_loader, test_loader, fixed_t_size

train_loader, test_loader, t_max = init_dataset(batch_size)

In [3]:
import pixyz
print(pixyz.__version__)

0.1.3


In [4]:
from pixyz.utils import print_latex

## Define probability distribution
### In the original paper
Prior(equation (5) in the paper): $p\left(\mathbf{z}_{t} | \mathbf{x}_{<t}, \mathbf{z}_{<t}\right) = \mathcal{N}\left(\boldsymbol{\mu}_{0, t}, \operatorname{diag}\left(\boldsymbol{\sigma}_{0, t}^{2}\right)\right), \text { where }\left[\boldsymbol{\mu}_{0, t}, \boldsymbol{\sigma}_{0, t}\right]=\varphi_{\tau}^{\text {prior }}\left(\mathbf{h}_{t-1}\right)$

Generator(equation (6) in the paper): $p\left(\mathbf{x}_{t} | \mathbf{z} \leq t, \mathbf{x}<t\right) = \mathcal{N}\left(\boldsymbol{\mu}_{x, t}, \operatorname{diag}\left(\boldsymbol{\sigma}_{x, t}^{2}\right)\right), \text { where }\left[\boldsymbol{\mu}_{x, t}, \boldsymbol{\sigma}_{x, t}\right]=\varphi_{\tau}^{\mathrm{dec}}\left(\varphi_{\tau}^{\mathbf{z}}\left(\mathbf{z}_{t}\right), \mathbf{h}_{t-1}\right)$

Recurrence(equation (7) in the paper): $p\left(\mathbf{h}_{t} | \mathbf{z}_{t}, \mathbf{x}_{t}, \mathbf{h}_{t-1}\right) = f_{\theta}\left(\varphi_{\tau}^{\mathbf{x}}\left(\mathbf{x}_{t}\right), \varphi_{\tau}^{\mathbf{z}}\left(\mathbf{z}_{t}\right), \mathbf{h}_{t-1}\right)$

Inference(equation (9) in the paper): $q\left(\mathbf{z}_{t} | \mathbf{x}_{\leq t}, \mathbf{z}_{<t}\right) = \mathcal{N}\left(\boldsymbol{\mu}_{z, t}, \operatorname{diag}\left(\boldsymbol{\sigma}_{z, t}^{2}\right)\right), \text { where }\left[\boldsymbol{\mu}_{z, t}, \boldsymbol{\sigma}_{z, t}\right]=\varphi_{\tau}^{\mathrm{enc}}\left(\varphi_{\tau}^{\mathbf{x}}\left(\mathbf{x}_{t}\right), \mathbf{h}_{t-1}\right)$

### MNIST Settings
Prior: $p_{\theta}(z_t | h_{t-1}) = \cal{N}(\mu=f_{prior_\mu}(h_{t-1}),\sigma^2=f_{prior_\sigma^2}(h_{t-1})$

Generator: $p_{\theta}(x | z, h_{t-1})=\mathcal{B}\left(x ; \lambda=g_{x}(z, h_{t-1})\right)$

Recurrence: $p(h_{t} | z_t, x_t, h_{t-1}) = RNN(\varphi_{\tau}^{\mathbf{x}}\left(\mathbf{x}_{t}\right), \varphi_{\tau}^{\mathbf{z}}\left(\mathbf{z}_{t}\right), \mathbf{h}_{t-1})$

Inference: $q_\phi(z_t | h_{t-1}, x_t) =  \cal{N}(\mu=f_{infer_\mu}(h_{t-1}, \varphi_{\tau}^{\mathbf{x}}(x_t)),\sigma^2=f_{infer_\sigma^2}(h_{t-1}, \varphi_{\tau}^{\mathbf{x}}(x_t))$

In [5]:
x_dim = 28
h_dim = 100
z_dim = 16
t_max = x_dim

# feature extraction for x
class Phi_x(nn.Module):
    def __init__(self):
        super(Phi_x, self).__init__()
        self.fc0 = nn.Linear(x_dim, h_dim)

    def forward(self, x):
        return F.relu(self.fc0(x))


# feature extraction for z
class Phi_z(nn.Module):
    def __init__(self):
        super(Phi_z, self).__init__()
        self.fc0 = nn.Linear(z_dim, h_dim)

    def forward(self, z):
        return F.relu(self.fc0(z))

f_phi_x = Phi_x().to(device)
f_phi_z = Phi_z().to(device)

In [6]:
from pixyz.distributions import Bernoulli, Normal, Deterministic

class Generator(Bernoulli):
    """
    Parameterizes the bernoulli(for MNIST) observation likelihood p(x_t | z_t, h_{t-1})
    Given the latent z at a particular time step t and hidden state,
    return the vector of probabilities taht parameterizes the bernoulli distribution
    """
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z", "h_prev"], var=["x"])
        self.fc1 = nn.Linear(h_dim + h_dim, h_dim)
        self.fc2 = nn.Linear(h_dim, h_dim)
        self.fc3 = nn.Linear(h_dim, x_dim)
        self.f_phi_z = f_phi_z

    def forward(self, z, h_prev):
        h = torch.cat((self.f_phi_z(z), h_prev), dim=-1)
        h = F.relu(self.fc1(h))
        h = F.relu(self.fc2(h))
        return {"probs": torch.sigmoid(self.fc3(h))}

class Prior(Normal):
    """
    VRNN's Prior for latent z is parameterized by hidden_state h_{t-1}
    z ~ N(loc(h_{t-1}), scale(h_{t-1}))
    """
    def __init__(self):
        super(Prior, self).__init__(cond_var=["h_prev"], var=["z"])
        self.fc1 = nn.Linear(h_dim, h_dim)
        self.fc21 = nn.Linear(h_dim, z_dim)
        self.fc22 = nn.Linear(h_dim, z_dim)

    def forward(self, h_prev):
        h = F.relu(self.fc1(h_prev))
        return {"loc": self.fc21(h), "scale": F.softplus(self.fc22(h))}

class Inference(Normal):
    """
    Parameterizes q(z_t | h_{t-1}, x_t)
    infered z ~ N(loc(h_{t-1}, x_t), scale(h_{t-1}, x_t))
    """
    def __init__(self):
        super(Inference, self).__init__(cond_var=["x", "h_prev"], var=["z"], name="q")
        self.fc1 = nn.Linear(h_dim + h_dim, h_dim)
        self.fc21 = nn.Linear(h_dim, z_dim)
        self.fc22 = nn.Linear(h_dim, z_dim)
        self.f_phi_x = f_phi_x

    def forward(self, x, h_prev):
        h = torch.cat((self.f_phi_x(x), h_prev), dim=-1)
        h = F.relu(self.fc1(h))
        return {"loc": self.fc21(h), "scale": F.softplus(self.fc22(h))}

class Recurrence(Deterministic):
    """
    RNN for hidden_state
    p(h_t | x_t, z_t, h_prev)
    """
    def __init__(self):
        super(Recurrence, self).__init__(cond_var=["x", "z", "h_prev"], var=["h"])
        self.rnncell = nn.GRUCell(h_dim * 2, h_dim).to(device)
        self.f_phi_x = f_phi_x
        self.f_phi_z = f_phi_z
        self.hidden_size = self.rnncell.hidden_size

    def forward(self, x, z, h_prev):
        h_next = self.rnncell(torch.cat((self.f_phi_z(z), self.f_phi_x(x)), dim=-1), h_prev)
        return {"h": h_next}

prior = Prior().to(device)
decoder = Generator().to(device)
encoder = Inference().to(device)
recurrence = Recurrence().to(device)

In [7]:
encoder_with_recurrence = encoder * recurrence
generate_from_prior = prior * decoder * recurrence

In [8]:
print_latex(encoder_with_recurrence)

<IPython.core.display.Math object>

In [9]:
print_latex(generate_from_prior)

<IPython.core.display.Math object>

## Define Loss function
### In the original paper(equation (11) in the original paper)
$\mathbb{E}_{q(\mathbf{z} \leq T | \mathbf{x} \leq T)}\left[\sum_{t=1}^{T}\left(-\mathrm{KL}\left(q\left(\mathbf{z}_{t} | \mathbf{x}_{\leq t}, \mathbf{z}_{<t}\right) \| p\left(\mathbf{z}_{t} | \mathbf{x}_{<t}, \mathbf{z}_{<t}\right)\right)+\log p\left(\mathbf{x}_{t} | \mathbf{z}_{\leq t}, \mathbf{x}_{<t}\right)\right)\right]$

In [10]:
from pixyz.losses import KullbackLeibler
from pixyz.losses import LogProb
from pixyz.losses import Expectation as E
from pixyz.losses import IterativeLoss

reconst_loss = E(encoder_with_recurrence, LogProb(decoder))
kl_loss = KullbackLeibler(encoder, prior)        
step_loss = (kl_loss - reconst_loss).mean()

loss = IterativeLoss(step_loss, max_iter=t_max,
                             series_var=['x'],
                             update_value={"h": "h_prev"})
print_latex(loss)

<IPython.core.display.Math object>

## Define VRNN model using Model class

In [11]:
from pixyz.models import Model
vrnn = Model(loss, distributions=[encoder, decoder, prior, recurrence],
             optimizer=optim.Adam, optimizer_params={'lr': 5e-3})

print(vrnn)
print_latex(vrnn)

Distributions (for training): 
  q(z|x,h_{prev}), p(x|z,h_{prev}), p(z|h_{prev}), p(h|x,z,h_{prev}) 
Loss function: 
  \sum_{t=1}^{28} mean \left(D_{KL} \left[q(z|x,h_{prev})||p(z|h_{prev}) \right] - \mathbb{E}_{p(h,z|x,h_{prev})} \left[\log p(x|z,h_{prev}) \right] \right) 
Optimizer: 
  Adam (
  Parameter Group 0
      amsgrad: False
      betas: (0.9, 0.999)
      eps: 1e-08
      lr: 0.005
      weight_decay: 0
  )


<IPython.core.display.Math object>

## Define Train and Test loop using model

In [12]:
def data_loop(epoch, loader, model, device, train_mode=False):
    mean_loss = 0
    for batch_idx, (data, _) in enumerate(tqdm(loader)):
        data = data.to(device)
        batch_size = data.size()[0]
        x = data.transpose(0, 1)
        h_prev = torch.zeros(batch_size, recurrence.hidden_size).to(device)
        if train_mode:
            mean_loss += model.train({'x': x, 'h_prev': h_prev}).item() * batch_size
        else:
            mean_loss += model.test({'x': x, 'h_prev': h_prev}).item() * batch_size

    mean_loss /= len(loader.dataset)
    if train_mode:
        print('Epoch: {} Train loss: {:.4f}'.format(epoch, mean_loss))
    else:
        print('Test loss: {:.4f}'.format(mean_loss))
    return mean_loss

## Reconstruction and generation

In [13]:
def plot_image_from_latent(batch_size):
    x = []
    h_prev = torch.zeros(batch_size, recurrence.hidden_size).to(device)
    for step in range(t_max):
        samples = generate_from_prior.sample({'h_prev': h_prev})
        x_t = decoder.sample_mean({"z": samples["z"], "h_prev": samples["h_prev"]})
        h_prev = samples["h"]
        x.append(x_t[None, :])
    x = torch.cat(x, dim=0).transpose(0, 1)
    return x


def generate_image_after_nsteps(n_step_num, original_data):
    xs = []
    x = original_data.transpose(0, 1)
    batch_size = original_data.size()[0]
    h_prev = torch.zeros(batch_size, recurrence.hidden_size).to(device)
    for t in range(28):
        if t < n_step_num - 1:
            # before n_step, reconstruct
            x_t = x[t]
            z_t = encoder.sample_mean({'x': x_t, 'h_prev': h_prev})
            h = recurrence.sample_mean({'x': x_t, 'h_prev': h_prev, 'z': z_t})
            dec_x = decoder.sample_mean({'h_prev': h_prev, 'z': z_t})
            h_prev = h
            xs.append(dec_x[None, :])
        else:
            # generate
            z_t = prior.sample_mean({'h_prev': h_prev})
            dec_x = decoder.sample_mean({'h_prev': h_prev, 'z': z_t})
            h = recurrence.sample_mean({'x': dec_x, 'h_prev': h_prev, 'z': z_t})
            h_prev = h
            xs.append(dec_x[None, :])
    generated_img = torch.cat(xs, dim=0).transpose(0, 1)
    return generated_img


def reconst_image(original_data):
    xs = []
    x = original_data.transpose(0, 1)
    batch_size = original_data.size()[0]
    h_prev = torch.zeros(batch_size, recurrence.hidden_size).to(device)
    for t in range(28):
        # before n_step, reconstruct
        x_t = x[t]
        z_t = encoder.sample_mean({'x': x_t, 'h_prev': h_prev})
        h = recurrence.sample_mean({'x': x_t, 'h_prev': h_prev, 'z': z_t})
        dec_x = decoder.sample_mean({'h_prev': h_prev, 'z': z_t})
        h_prev = h
        xs.append(dec_x[None, :])
    recon_img = torch.cat(xs, dim=0).transpose(0, 1)
    return recon_img

In [14]:
writer = SummaryWriter('VRNN_Model_Class_without_stochastic')
# fixed _x for watching reconstruction improvement
_x, _ = iter(test_loader).next()
_x = _x.to(device)

for epoch in range(1, epochs + 1):
    train_loss = data_loop(epoch, train_loader, vrnn, device, train_mode=True)
    test_loss = data_loop(epoch, test_loader, vrnn, device)

    writer.add_scalar('train_loss', train_loss, epoch)
    writer.add_scalar('test_loss', test_loss, epoch)

    sample = plot_image_from_latent(batch_size)[:, None]
    writer.add_images('Image_from_latent', sample, epoch)
    generated_img_7 = generate_image_after_nsteps(7, _x)
    writer.add_images('Generate_after_7steps', generated_img_7[:, None], epoch)
    
    generated_img_14 = generate_image_after_nsteps(14, _x)
    writer.add_images('Generate_after_14steps', generated_img_14[:, None], epoch)
    
    recon_img = reconst_image(_x)
    writer.add_images('Reconstructed',  recon_img[:, None], epoch)
    
    writer.add_images('orignal', _x[:, None], epoch)

100%|██████████| 235/235 [00:44<00:00,  5.25it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 1 Train loss: 141.2540


100%|██████████| 40/40 [00:03<00:00, 11.12it/s]


Test loss: 90.0821


100%|██████████| 235/235 [00:45<00:00,  5.15it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 2 Train loss: 87.3032


100%|██████████| 40/40 [00:03<00:00, 11.38it/s]


Test loss: 83.2476


100%|██████████| 235/235 [00:45<00:00,  5.17it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 3 Train loss: 82.2253


100%|██████████| 40/40 [00:03<00:00, 10.73it/s]


Test loss: 80.1809


100%|██████████| 235/235 [00:45<00:00,  5.16it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 4 Train loss: 79.7854


100%|██████████| 40/40 [00:03<00:00, 11.56it/s]


Test loss: 78.0523


100%|██████████| 235/235 [00:45<00:00,  5.11it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 5 Train loss: 78.5019


100%|██████████| 40/40 [00:03<00:00, 10.98it/s]


Test loss: 77.4152


## Stochastic version

In [15]:
from tqdm import tqdm

import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from tensorboardX import SummaryWriter

batch_size = 256
epochs = 5
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

# Experiment Setting
# generate MNIST by stacking row images(consider row as time step)
def init_dataset(f_batch_size):
    kwargs = {'num_workers': 1, 'pin_memory': True}
    data_dir = '../data'
    mnist_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda data: data[0])
    ])
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(data_dir, train=True, download=True,
                       transform=mnist_transform),
        batch_size=f_batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(data_dir, train=False, transform=mnist_transform),
        batch_size=f_batch_size, shuffle=True, **kwargs)

    fixed_t_size = 28
    return train_loader, test_loader, fixed_t_size

train_loader, test_loader, t_max = init_dataset(batch_size)

x_dim = 28
h_dim = 100
z_dim = 16
t_max = x_dim

# feature extraction for x
class Phi_x(nn.Module):
    def __init__(self):
        super(Phi_x, self).__init__()
        self.fc0 = nn.Linear(x_dim, h_dim)

    def forward(self, x):
        return F.relu(self.fc0(x))


# feature extraction for z
class Phi_z(nn.Module):
    def __init__(self):
        super(Phi_z, self).__init__()
        self.fc0 = nn.Linear(z_dim, h_dim)

    def forward(self, z):
        return F.relu(self.fc0(z))

f_phi_x = Phi_x().to(device)
f_phi_z = Phi_z().to(device)

from pixyz.distributions import Bernoulli, Normal, Deterministic

class Generator(Bernoulli):
    """
    Parameterizes the bernoulli(for MNIST) observation likelihood p(x_t | z_t, h_{t-1})
    Given the latent z at a particular time step t and hidden state,
    return the vector of probabilities taht parameterizes the bernoulli distribution
    """
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z", "h_prev"], var=["x"])
        self.fc1 = nn.Linear(h_dim + h_dim, h_dim)
        self.fc2 = nn.Linear(h_dim, h_dim)
        self.fc3 = nn.Linear(h_dim, x_dim)
        self.f_phi_z = f_phi_z

    def forward(self, z, h_prev):
        h = torch.cat((self.f_phi_z(z), h_prev), dim=-1)
        h = F.relu(self.fc1(h))
        h = F.relu(self.fc2(h))
        return {"probs": torch.sigmoid(self.fc3(h))}

class Prior(Normal):
    """
    VRNN's Prior for latent z is parameterized by hidden_state h_{t-1}
    z ~ N(loc(h_{t-1}), scale(h_{t-1}))
    """
    def __init__(self):
        super(Prior, self).__init__(cond_var=["h_prev"], var=["z"])
        self.fc1 = nn.Linear(h_dim, h_dim)
        self.fc21 = nn.Linear(h_dim, z_dim)
        self.fc22 = nn.Linear(h_dim, z_dim)

    def forward(self, h_prev):
        h = F.relu(self.fc1(h_prev))
        return {"loc": self.fc21(h), "scale": F.softplus(self.fc22(h))}

class Inference(Normal):
    """
    Parameterizes q(z_t | h_{t-1}, x_t)
    infered z ~ N(loc(h_{t-1}, x_t), scale(h_{t-1}, x_t))
    """
    def __init__(self):
        super(Inference, self).__init__(cond_var=["x", "h_prev"], var=["z"], name="q")
        self.fc1 = nn.Linear(h_dim + h_dim, h_dim)
        self.fc21 = nn.Linear(h_dim, z_dim)
        self.fc22 = nn.Linear(h_dim, z_dim)
        self.f_phi_x = f_phi_x

    def forward(self, x, h_prev):
        h = torch.cat((self.f_phi_x(x), h_prev), dim=-1)
        h = F.relu(self.fc1(h))
        return {"loc": self.fc21(h), "scale": F.softplus(self.fc22(h))}

class Recurrence(Deterministic):
    """
    RNN for hidden_state
    p(h_t | x_t, z_t, h_prev)
    """
    def __init__(self):
        super(Recurrence, self).__init__(cond_var=["x", "z", "h_prev"], var=["h"])
        self.rnncell = nn.GRUCell(h_dim * 2, h_dim).to(device)
        self.f_phi_x = f_phi_x
        self.f_phi_z = f_phi_z
        self.hidden_size = self.rnncell.hidden_size

    def forward(self, x, z, h_prev):
        h_next = self.rnncell(torch.cat((self.f_phi_z(z), self.f_phi_x(x)), dim=-1), h_prev)
        return {"h": h_next}

prior = Prior().to(device)
decoder = Generator().to(device)
encoder = Inference().to(device)
recurrence = Recurrence().to(device)

In [16]:
encoder_with_recurrence = encoder * recurrence
generate_from_prior = prior * decoder * recurrence

In [17]:
from pixyz.losses import KullbackLeibler, StochasticReconstructionLoss
from pixyz.losses import IterativeLoss
# With Stochastic Reconstruction
reconst = StochasticReconstructionLoss(encoder_with_recurrence, decoder)
kl = KullbackLeibler(encoder, prior)

# 1 time step loss
step_loss = (reconst + kl).mean()
# Iterative loss for total time step
loss = IterativeLoss(step_loss, max_iter=t_max,
                     series_var=['x'],
                     update_value={"h": "h_prev"})
print_latex(loss)

<IPython.core.display.Math object>

In [18]:
from pixyz.models import Model
vrnn = Model(loss, distributions=[encoder, decoder, prior, recurrence],
             optimizer=optim.Adam, optimizer_params={'lr': 5e-3})

print(vrnn)
print_latex(vrnn)

Distributions (for training): 
  q(z|x,h_{prev}), p(x|z,h_{prev}), p(z|h_{prev}), p(h|x,z,h_{prev}) 
Loss function: 
  \sum_{t=1}^{28} mean \left(D_{KL} \left[q(z|x,h_{prev})||p(z|h_{prev}) \right] - \mathbb{E}_{p(h,z|x,h_{prev})} \left[\log p(x|z,h_{prev}) \right] \right) 
Optimizer: 
  Adam (
  Parameter Group 0
      amsgrad: False
      betas: (0.9, 0.999)
      eps: 1e-08
      lr: 0.005
      weight_decay: 0
  )


<IPython.core.display.Math object>

In [19]:
writer = SummaryWriter('VRNN_Model_Class_with_stochastic')
# fixed _x for watching reconstruction improvement
_x, _ = iter(test_loader).next()
_x = _x.to(device)

for epoch in range(1, epochs + 1):
    train_loss = data_loop(epoch, train_loader, vrnn, device, train_mode=True)
    test_loss = data_loop(epoch, test_loader, vrnn, device)

    writer.add_scalar('train_loss', train_loss, epoch)
    writer.add_scalar('test_loss', test_loss, epoch)

    sample = plot_image_from_latent(batch_size)[:, None]
    writer.add_images('Image_from_latent', sample, epoch)
    generated_img_7 = generate_image_after_nsteps(7, _x)
    writer.add_images('Generate_after_7steps', generated_img_7[:, None], epoch)
    
    generated_img_14 = generate_image_after_nsteps(14, _x)
    writer.add_images('Generate_after_14steps', generated_img_14[:, None], epoch)
    
    recon_img = reconst_image(_x)
    writer.add_images('Reconstructed',  recon_img[:, None], epoch)
    
    writer.add_images('orignal', _x[:, None], epoch)

100%|██████████| 235/235 [00:44<00:00,  5.33it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 1 Train loss: 225.6257


100%|██████████| 40/40 [00:03<00:00, 10.47it/s]


Test loss: 204.2892


100%|██████████| 235/235 [00:42<00:00,  5.54it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 2 Train loss: 202.1745


100%|██████████| 40/40 [00:03<00:00, 11.64it/s]


Test loss: 198.2287


100%|██████████| 235/235 [00:43<00:00,  5.45it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 3 Train loss: 196.5438


100%|██████████| 40/40 [00:03<00:00, 11.71it/s]


Test loss: 193.4067


100%|██████████| 235/235 [00:42<00:00,  5.48it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 4 Train loss: 191.5279


100%|██████████| 40/40 [00:03<00:00, 11.32it/s]


Test loss: 189.7882


100%|██████████| 235/235 [00:42<00:00,  5.51it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 5 Train loss: 187.6465


100%|██████████| 40/40 [00:03<00:00, 11.64it/s]


Test loss: 185.3731


100%|██████████| 235/235 [00:42<00:00,  5.47it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 6 Train loss: 185.5341


100%|██████████| 40/40 [00:03<00:00, 10.29it/s]


Test loss: 183.2504


100%|██████████| 235/235 [00:42<00:00,  5.48it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 7 Train loss: 184.8477


 38%|███▊      | 15/40 [00:01<00:02, 10.28it/s]


KeyboardInterrupt: 