In [None]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

batch_size = 128
epochs = 10
seed = 1
log_interval = 10
torch.manual_seed(seed)
device = "cpu"

In [None]:
from Tars.distributions import CategoricalModel
from Tars.models import ML

In [None]:
kwargs = {'num_workers': 1, 'pin_memory': True}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)

In [None]:
x_dim = 784
y_dim = 10

# discriminative model p(y|x)
class Discriminative(CategoricalModel):
    def __init__(self):
        super(Discriminative, self).__init__(cond_var=["x"], var=["y"])
        self.fc1 = nn.Linear(x_dim, 512)
        self.fc2 = nn.Linear(512, 512)        
        self.fc3 = nn.Linear(512, y_dim)

    def forward(self, x):
        h1 = F.relu(self.fc1(x))
        h2 = F.relu(self.fc2(h1))
        h3 = F.softmax(self.fc3(h2))
        
        return h3

In [None]:
p = Discriminative()
p.to(device)

In [None]:
model = ML(p, optim.Adam, {"lr":1e-3})

In [None]:
def train(epoch):
    train_loss = 0
    for batch_idx, (data_x, data_y) in enumerate(train_loader):
        data_x = data_x.view(-1, 784).to(device)
        data_y = torch.eye(10)[data_y].to(device)
        log_like, loss = model.train({"x": data_x,"y": data_y})
        train_loss += loss
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data_x), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item()))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))

In [None]:
def test(epoch):
    test_loss = 0
    for i, (data_x, data_y) in enumerate(test_loader):
        data_x = data_x.view(-1, 784).to(device)
        data_y = torch.eye(10)[data_y].to(device)
        log_like, loss = model.test({"x": data_x,"y": data_y})
        test_loss += loss

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [10]:
for epoch in range(1, epochs + 1):
    train(epoch)
    test(epoch)

====> Epoch: 3 Average loss: 0.0005
====> Test set loss: 0.0006
====> Epoch: 4 Average loss: 0.0003
====> Test set loss: 0.0006
====> Epoch: 5 Average loss: 0.0002
====> Test set loss: 0.0006
====> Epoch: 6 Average loss: 0.0002
====> Test set loss: 0.0005
====> Epoch: 7 Average loss: 0.0002
====> Test set loss: 0.0005
====> Epoch: 8 Average loss: 0.0001
====> Test set loss: 0.0005
====> Epoch: 9 Average loss: 0.0001
====> Test set loss: 0.0006
====> Epoch: 10 Average loss: 0.0001
====> Test set loss: 0.0007
