In [1]:
import torch
import numpy as np
import torch.utils.data
from torch import nn, optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os
from task import AutoEncoder, train, test_work
from naturalityscore import naturality_score

In [2]:
get_loader = lambda train: torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=train, download=True, transform=transforms.ToTensor()),
    batch_size=50, shuffle=True)
train_loader, test_loader = get_loader(True), get_loader(False)

In [11]:
model = AutoEncoder(inp_size=784, hid_size=20)

In [12]:
model = model.cuda(0)

In [13]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [14]:
for epoch in range(1):
    model.train()
    train_loss, test_loss = 0, 0
    for data, _ in train_loader:
        data = Variable(data.cuda(0)).view(-1, 784)
        x_rec = model(data)

        loss = model.loss_function(x_rec, data)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.data[0]
    print('=> Epoch: %s Average loss: %.3f' % (epoch, train_loss / len(train_loader.dataset)))

    model.eval()
    for data, _ in test_loader:
        data = Variable(data.cuda(0), volatile=True).view(-1, 784)
        x_rec = model(data)
        test_loss += model.loss_function(x_rec, data).data[0]

    test_loss /= len(test_loader.dataset)
    print('=> Test set loss: %.3f' % test_loss)

    n = min(data.size(0), 8)
    comparison = torch.cat([data.view(-1, 1, 28, 28)[:n], x_rec.view(-1, 1, 28, 28)[:n]])
    if not os.path.exists('./pics'): os.makedirs('./pics')
    save_image(comparison.data.cpu(), 'pics/reconstruction_' + str(epoch) + '.png', nrow=n)

=> Epoch: 0 Average loss: 0.003
=> Test set loss: 0.003


In [17]:
weights_path = '/home/jevjev/Dropbox/Projects/DeepBayesApplication/autoencoder_task/CapsuleNet/result/epoch17.pkl'
mean, std = naturality_score(test_loader, batch_size=50, gpu_id=0, weights_path=weights_path, model=model)

In [8]:
print(mean, std)

0.799666488291 0.305373423257


In [10]:
print(mean, std)

1.21560689718 0.205535922691


In [18]:
print(mean, std)

0.872227445882 0.29434422142
