In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 128
epochs = 50
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
from Tars.distributions import Categorical
from Tars.models import ML

In [3]:
kwargs = {'num_workers': 1, 'pin_memory': True}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)

In [4]:
x_dim = 784
y_dim = 10

# discriminative model p(y|x)
class Discriminative(Categorical):
    def __init__(self):
        super(Discriminative, self).__init__(cond_var=["x"], var=["y"])
        self.fc1 = nn.Linear(x_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, y_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        h = F.softmax(self.fc3(h), dim=1)
        
        return {"probs": h}

In [5]:
p = Discriminative()
p.to(device)

Discriminative(
  (fc1): Linear(in_features=784, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=10, bias=True)
)

In [6]:
model = ML(p, optimizer=optim.Adam, optimizer_params={"lr":1e-3})

In [7]:
def train(epoch):
    train_loss = 0
    for batch_idx, (data_x, data_y) in enumerate(tqdm(train_loader)):
        data_x = data_x.view(-1, 784).to(device)
        data_y = torch.eye(10)[data_y].to(device)
        loss = model.train({"x": data_x, "y": data_y})
        train_loss += loss
        
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    
    return train_loss

In [8]:
def test(epoch):
    test_loss = 0
    for i, (data_x, data_y) in enumerate(test_loader):
        data_x = data_x.view(-1, 784).to(device)
        data_y = torch.eye(10)[data_y].to(device)
        loss = model.test({"x": data_x, "y": data_y})
        test_loss += loss
        
    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    print('Test loss: {:.4f}'.format(test_loss))
    
    return test_loss

In [9]:
writer = SummaryWriter()

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)
    
    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)     
    
writer.close()

100%|██████████| 469/469 [00:05<00:00, 93.78it/s]

Epoch: 1 Train loss: 0.2704



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1224


100%|██████████| 469/469 [00:05<00:00, 92.72it/s]

Epoch: 2 Train loss: 0.0952



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.0885


100%|██████████| 469/469 [00:04<00:00, 95.71it/s]


Epoch: 3 Train loss: 0.0607


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.0741


100%|██████████| 469/469 [00:04<00:00, 96.30it/s]

Epoch: 4 Train loss: 0.0431



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.0782


100%|██████████| 469/469 [00:05<00:00, 91.91it/s]

Epoch: 5 Train loss: 0.0316



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.0678


100%|██████████| 469/469 [00:04<00:00, 95.93it/s]

Epoch: 6 Train loss: 0.0277



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.0775


100%|██████████| 469/469 [00:04<00:00, 94.93it/s]

Epoch: 7 Train loss: 0.0206



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.0816


100%|██████████| 469/469 [00:04<00:00, 97.09it/s]


Epoch: 8 Train loss: 0.0183


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.0733


100%|██████████| 469/469 [00:05<00:00, 93.68it/s]

Epoch: 9 Train loss: 0.0166



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.0707


100%|██████████| 469/469 [00:04<00:00, 94.64it/s]

Epoch: 10 Train loss: 0.0146



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.0829


100%|██████████| 469/469 [00:04<00:00, 96.00it/s]

Epoch: 11 Train loss: 0.0136



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.0848


100%|██████████| 469/469 [00:05<00:00, 92.17it/s]


Epoch: 12 Train loss: 0.0120


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.0813


100%|██████████| 469/469 [00:04<00:00, 98.64it/s]

Epoch: 13 Train loss: 0.0091



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.0997


100%|██████████| 469/469 [00:04<00:00, 93.89it/s]

Epoch: 14 Train loss: 0.0119



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1047


100%|██████████| 469/469 [00:04<00:00, 94.88it/s]

Epoch: 15 Train loss: 0.0101



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.0998


100%|██████████| 469/469 [00:04<00:00, 96.54it/s]


Epoch: 16 Train loss: 0.0093


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1097


100%|██████████| 469/469 [00:04<00:00, 98.07it/s]

Epoch: 17 Train loss: 0.0085



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1020


100%|██████████| 469/469 [00:05<00:00, 91.15it/s]

Epoch: 18 Train loss: 0.0101



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1086


100%|██████████| 469/469 [00:04<00:00, 101.90it/s]

Epoch: 19 Train loss: 0.0063



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.0959


100%|██████████| 469/469 [00:04<00:00, 95.88it/s]

Epoch: 20 Train loss: 0.0057



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1051


100%|██████████| 469/469 [00:04<00:00, 95.01it/s]

Epoch: 21 Train loss: 0.0073



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1013


100%|██████████| 469/469 [00:05<00:00, 93.77it/s]

Epoch: 22 Train loss: 0.0097



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1164


100%|██████████| 469/469 [00:04<00:00, 96.41it/s]

Epoch: 23 Train loss: 0.0067



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1193


100%|██████████| 469/469 [00:05<00:00, 91.41it/s]

Epoch: 24 Train loss: 0.0072



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1241


100%|██████████| 469/469 [00:04<00:00, 98.63it/s]

Epoch: 25 Train loss: 0.0103



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.0956


100%|██████████| 469/469 [00:04<00:00, 95.24it/s]

Epoch: 26 Train loss: 0.0046



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.0970


100%|██████████| 469/469 [00:05<00:00, 84.57it/s]

Epoch: 27 Train loss: 0.0047



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.0901


100%|██████████| 469/469 [00:04<00:00, 96.85it/s]

Epoch: 28 Train loss: 0.0020



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1010


100%|██████████| 469/469 [00:04<00:00, 98.84it/s]


Epoch: 29 Train loss: 0.0114


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1145


100%|██████████| 469/469 [00:04<00:00, 95.36it/s]

Epoch: 30 Train loss: 0.0059



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1052


100%|██████████| 469/469 [00:04<00:00, 95.96it/s]

Epoch: 31 Train loss: 0.0049



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1116


100%|██████████| 469/469 [00:05<00:00, 91.40it/s]

Epoch: 32 Train loss: 0.0066



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1296


100%|██████████| 469/469 [00:04<00:00, 94.65it/s]

Epoch: 33 Train loss: 0.0051



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.0935


100%|██████████| 469/469 [00:04<00:00, 97.26it/s]

Epoch: 34 Train loss: 0.0026



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1177


100%|██████████| 469/469 [00:04<00:00, 97.47it/s]

Epoch: 35 Train loss: 0.0058



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1180


100%|██████████| 469/469 [00:04<00:00, 96.71it/s]

Epoch: 36 Train loss: 0.0091



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1235


100%|██████████| 469/469 [00:04<00:00, 95.25it/s]

Epoch: 37 Train loss: 0.0076



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.0958


100%|██████████| 469/469 [00:04<00:00, 94.02it/s]

Epoch: 38 Train loss: 0.0036



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1017


100%|██████████| 469/469 [00:04<00:00, 97.65it/s]

Epoch: 39 Train loss: 0.0031



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1122


100%|██████████| 469/469 [00:04<00:00, 113.68it/s]

Epoch: 40 Train loss: 0.0052



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1125


100%|██████████| 469/469 [00:04<00:00, 100.20it/s]

Epoch: 42 Train loss: 0.0047



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1296


100%|██████████| 469/469 [00:05<00:00, 90.77it/s]

Epoch: 43 Train loss: 0.0091



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1236


100%|██████████| 469/469 [00:04<00:00, 95.04it/s]

Epoch: 44 Train loss: 0.0037



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1069


100%|██████████| 469/469 [00:04<00:00, 104.25it/s]

Epoch: 45 Train loss: 0.0006



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1033


100%|██████████| 469/469 [00:05<00:00, 87.46it/s]

Epoch: 46 Train loss: 0.0003



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1082


100%|██████████| 469/469 [00:05<00:00, 91.51it/s]

Epoch: 47 Train loss: 0.0003



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1142


100%|██████████| 469/469 [00:04<00:00, 95.24it/s]

Epoch: 48 Train loss: 0.0003



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1062


100%|██████████| 469/469 [00:04<00:00, 96.50it/s]

Epoch: 49 Train loss: 0.0003



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 0.1063


100%|██████████| 469/469 [00:05<00:00, 91.75it/s]

Epoch: 50 Train loss: 0.0003





Test loss: 0.1063
