In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch as t
import torch.nn as nn
import torchvision as tv
import os

In [245]:
from torch.nn.functional import elu, dropout, log_softmax, nll_loss

In [16]:
def dynamic_binarization(img):
    return t.distributions.bernoulli.Bernoulli(img).sample()

In [218]:
def svgd_kernel(x):
    n = t.Tensor([x.size(0)])
    assert n > 1
    
    norm = (x ** 2).sum(1).view(-1, 1)
    dist_mat = (norm + norm.view(1, -1)) - 2.0 * t.mm(x , x.t())
    
    h = t.median(dist_mat) / t.log(n)
    
    kxy = t.exp(- dist_mat / h)
    dxkxy = (-t.mm(kxy,x) + t.sum(kxy,1).view(-1,1)*x) / (2*h) 
    
    return kxy, dxkxy

In [63]:
path = '/T480/AnacondaProjects/svgd/'
batch_size_train = 128
batch_size_test = 1000
n_epochs = 5
log_interval = 100

In [64]:
train_data = tv.datasets.MNIST(path,train = True,download = True, 
                               transform = tv.transforms.Compose(
                                   [tv.transforms.ToTensor(),tv.transforms.Lambda(dynamic_binarization)]))
test_data = tv.datasets.MNIST(path,train = False, download = True, transform = tv.transforms.Compose(
                                   [tv.transforms.ToTensor(),tv.transforms.Lambda(dynamic_binarization)]))
train_loader = t.utils.data.DataLoader(train_data,batch_size = batch_size_train, shuffle = True)
test_loader = t.utils.data.DataLoader(test_data,batch_size = batch_size_test, shuffle = True)

In [65]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.drop_rate = 0.3
        
        self.conv1 = nn.Conv2d(1,16,kernel_size = 5,stride = 2)
        self.conv2 = nn.Conv2d(16,32,kernel_size = 5,stride = 2)
        self.dense = nn.Linear(512,512)
        self.out = nn.Linear(512,10)
        
    def forward(self,x):
        x = elu(self.conv1(x))
        x = dropout(x,self.drop_rate)
        x = elu(self.conv2(x))
        x = dropout(x,self.drop_rate)
        x = x.view(-1,512)
        x = elu(self.dense(x))
        x = dropout(x,self.drop_rate)
        x = log_softmax(self.out(x))
        return x

In [341]:
class cnn_encoder(nn.Module):
    def __init__(self,n_hidden,input_dim, output_dim, drop_rate):
        super(cnn_encoder,self).__init__()
        self.drop_rate = drop_rate
        
        self.conv1 = nn.Conv2d(1,16,kernel_size = 5,stride = 2)
        self.conv2 = nn.Conv2d(16,32,kernel_size = 5,stride = 2)
        self.out = nn.Linear(n_hidden,output_dim)
        
    def forward(self,x):
        x = elu(self.conv1(x))
        x = dropout(x,self.drop_rate)
        x = elu(self.conv2(x))
        x = dropout(x,self.drop_rate)
        x = x.view(-1,512)
        x = self.out(x)
        return x
    
class cnn_decoder(nn.Module):
    def __init__(self,n_hidden,input_dim, output_dim, drop_rate):
        super(cnn_decoder,self).__init__()
        self.drop_rate = drop_rate

        self.input = nn.Linear(input_dim, n_hidden)
        self.deconv1 = nn.ConvTranspose2d(32,16,kernel_size = 5, stride = 2)
        self.deconv2 = nn.ConvTranspose2d(16,1,kernel_size = 5, stride = 2)
        
    def forward(self,x):
        x = self.input(x)
        x = dropout(x, self.drop_rate)
        x = x.view(-1,32,4,4)
        x = elu(self.deconv1(x))
        x = dropout(x,self.drop_rate)
        x = elu(self.deconv2(x))
        x = t.sigmoid(x)
        return x

In [342]:
encoder = cnn_encoder(512,784,32,0)
decoder = cnn_decoder(512,32,784,0)

In [343]:
x = t.rand([5,1,28,28])
z = encoder(x)
x_r = decoder(z)


In [247]:
class mlp_encoder(nn.Module):
    def __init__(self,n_hidden,input_dim, output_dim, drop_rate):
        super(mlp_encoder,self).__init__()
        
        self.drop_rate = drop_rate
        
        self.hidden = nn.Linear(input_dim,n_hidden)
        self.out = nn.Linear(n_hidden,output_dim)
    
    def forward(self,x):
        x = elu(self.hidden(x))
        x = dropout(x,self.drop_rate)
        x = self.out(x)
        return x
    
class mlp_decoder(nn.Module):
    def __init__(self,n_hidden,input_dim, output_dim, drop_rate):
        super(mlp_decoder,self).__init__()
        
        self.drop_rate = drop_rate
        
        self.hidden = nn.Linear(input_dim,n_hidden)
        self.out = nn.Linear(n_hidden, output_dim)
        
    def forward(self,x):
        x = elu(self.hidden(x))
        x = dropout(x,self.drop_rate)
        x = t.sigmoid(self.out(x))
        return x

In [248]:
encoder = mlp_encoder(400,784,32,0)
decoder = mlp_decoder(400,32,784,0)
enc_opt = t.optim.Adam(encoder.parameters())
dec_opt = t.optim.Adam(decoder.parameters())

In [68]:
def train(epoch):
    network.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = network(data)
        loss = nll_loss(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(data), len(train_loader.dataset),
            100. * batch_idx / len(train_loader), loss.item()))
            train_losses.append(loss.item())
            train_counter.append(
            (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))
            
def test():
    network.eval()
    test_loss = 0
    correct = 0
    with t.no_grad():
        for data, target in test_loader:
            output = network(data)
            test_loss += nll_loss(output, target, size_average=False).item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

In [None]:
network = CNN()
optimizer = t.optim.Adam(network.parameters())

In [70]:
train_losses = []
train_counter = []
test_losses = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]

test()
for epoch in range(1,n_epochs + 1):
    train(epoch)
    test()
    
fig = plt.figure()
plt.plot(train_counter, train_losses, color='blue')
plt.scatter(test_counter, test_losses, color='red')
plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
plt.xlabel('number of training examples seen')
plt.ylabel('negative log likelihood loss')





Test set: Avg. loss: 2.3105, Accuracy: 911/10000 (9%)


Test set: Avg. loss: 0.1880, Accuracy: 9399/10000 (93%)


Test set: Avg. loss: 0.1383, Accuracy: 9572/10000 (95%)


Test set: Avg. loss: 0.1340, Accuracy: 9597/10000 (95%)


Test set: Avg. loss: 0.1099, Accuracy: 9647/10000 (96%)


Test set: Avg. loss: 0.0972, Accuracy: 9700/10000 (97%)

