In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms

from tqdm import tqdm

In [2]:
batch_size = 128
epochs = 5
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available(): device = "cuda"
else: device = "cpu"

In [3]:
from Tars.distributions import Normal, Bernoulli
from Tars.distributions.divergences import KullbackLeibler

x_dim = 28
h_dim = 100
z_dim = 64

In [4]:
class Generator(Bernoulli):
    
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z", "h"], var=["x"])
        
        self.fc1 = nn.Linear(h_dim + h_dim, h_dim)
        self.fc2 = nn.Linear(h_dim, h_dim)
        self.fc3 = nn.Linear(h_dim, x_dim)
        
    def forward(self, z, h):
        h = torch.cat((z, h), dim=-1)
        h = F.relu(self.fc1(h))
        h = F.relu(self.fc2(h))
        return {"probs": torch.sigmoid(self.fc3(h))}

In [5]:
class Prior(Normal):
    
    def __init__(self):
        super(Prior, self).__init__(cond_var=["h"], var=["z"])
        
        self.fc1 = nn.Linear(h_dim, h_dim)
        self.fc21 = nn.Linear(h_dim, z_dim)
        self.fc22 = nn.Linear(h_dim, z_dim)
        
    def forward(self, h):
        h = F.relu(self.fc1(h))
        return {"loc": self.fc21(h), "scale": F.softplus(self.fc22(h))}

In [6]:
class Inference(Normal):
    
    def __init__(self):
        super(Inference, self).__init__(cond_var=["x", "h"], var=["z"])
        
        self.fc1 = nn.Linear(h_dim + h_dim, h_dim)
        self.fc21 = nn.Linear(h_dim, z_dim)
        self.fc22 = nn.Linear(h_dim, z_dim)
        
    def forward(self, x, h):
        h = torch.cat((x, h), dim=-1)
        h = F.relu(self.fc1(h))
        return {"loc": self.fc21(h), "scale": F.softplus(self.fc22(h))}

In [7]:
class phi_x(nn.Module):
    def __init__(self):
        super(phi_x, self).__init__()
        
        self.fc0 = nn.Linear(x_dim, h_dim)
        
    def forward(self, x):
        return F.relu(self.fc0(x))

class phi_z(nn.Module):
    def __init__(self):
        super(phi_z, self).__init__()
        
        self.fc0 = nn.Linear(z_dim, h_dim)
        
    def forward(self, z):
        return F.relu(self.fc0(z))

In [8]:
# from Tars.distributions import Distribution

# class phi_z(Distribution):
#     def __init__(self):
#         self.params_keys = ["h","z"]
#         super(phi_z, self).__init__(cond_var=["h"], var=["z"])
        
#         self.fc0 = nn.Linear(z_dim, h_dim)
        
#     def forward(self, z):
#         return {"z": F.relu(self.fc0(**z))}

In [9]:
prior = Prior()
p = Generator()
q = Inference()
phi_X = phi_x()
phi_Z = phi_z()
rnn_cell = nn.GRUCell(h_dim*2, h_dim)

prior.to(device)
p.to(device)
q.to(device)
phi_X.to(device)
phi_Z.to(device)
rnn_cell.to(device)

GRUCell(200, 100)

In [10]:
kwargs = {'num_workers': 1, 'pin_memory': True}
data_dir = '../../data/mnist'
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(data_dir, train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(data_dir, train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)

In [11]:
from vrnn import VRNN

In [12]:
kl = KullbackLeibler(q, prior)
model = VRNN(q, p, prior, phi_X, phi_Z, rnn_cell, regularizer=[kl],
             optimizer=optim.Adam, optimizer_params={"lr":1e-3})

In [13]:
def train(epoch):
    train_loss = 0
    for batch_idx, (data, _) in enumerate(tqdm(train_loader)):
        data = data.to(device)
        lower_bound, loss = model.train({"x": data.view(28, -1, 28)})
        train_loss += loss
 
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss

In [14]:
def test(epoch):
    test_loss = 0
    for i, (data, _) in enumerate(test_loader):
        data = data.to(device)
        lower_bound, loss = model.test({"x": data.view(28, -1, 28)})
        test_loss += loss

    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    print('Test loss: {:.4f}'.format(test_loss))
    return test_loss

In [15]:
for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)

100%|██████████| 469/469 [00:50<00:00,  9.23it/s]

Epoch: 1 Train loss: 238.3954



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 217.4828


100%|██████████| 469/469 [00:54<00:00,  8.58it/s]


Epoch: 2 Train loss: 214.3644


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 216.5038


100%|██████████| 469/469 [00:55<00:00,  8.51it/s]


Epoch: 3 Train loss: 213.7738


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 215.8487


100%|██████████| 469/469 [00:53<00:00,  8.83it/s]


Epoch: 4 Train loss: 213.5715


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 215.8300


 47%|████▋     | 219/469 [00:23<00:27,  9.14it/s]Process Process-9:
Traceback (most recent call last):
  File "/home/naokinonaka/.pyenv/versions/2.7.15/lib/python2.7/multiprocessing/process.py", line 267, in _bootstrap
    self.run()
  File "/home/naokinonaka/.pyenv/versions/2.7.15/lib/python2.7/multiprocessing/process.py", line 114, in run
    self._target(*self._args, **self._kwargs)
  File "/home/naokinonaka/.pyenv/versions/2.7.15/envs/env_tars/lib/python2.7/site-packages/torch/utils/data/dataloader.py", line 96, in _worker_loop
    r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
  File "/home/naokinonaka/.pyenv/versions/2.7.15/lib/python2.7/multiprocessing/queues.py", line 131, in get
    if not self._poll(timeout):
KeyboardInterrupt


KeyboardInterrupt: 