In [1]:
import gpytorch
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torchvision import transforms

gpytorch.functions.use_toeplitz = False

In [65]:
bottleneck_size = 5

class FeatureExtractor(nn.Sequential):
    
    def __init__(self):
        super(FeatureExtractor, self).__init__(nn.Conv2d(1, 32, kernel_size=5, padding=2),
                                 nn.BatchNorm2d(32),
                                 nn.ReLU(),
                                 nn.MaxPool2d(2, 2),
                                 nn.Conv2d(32, 64, kernel_size=5, padding=2),
                                 nn.BatchNorm2d(64),
                                 nn.ReLU(),
                                 nn.MaxPool2d(2, 2))
        
class Bottleneck(nn.Sequential):
    
    def __init__(self):
        super(Bottleneck, self).__init__(nn.Linear(64*7*7, 128),
                                         nn.BatchNorm1d(128),
                                 nn.ReLU(),
                                 nn.Linear(128, 128),
                                 nn.BatchNorm1d(128),
                                 nn.ReLU(),
                                 nn.Linear(128, bottleneck_size),
                                 nn.BatchNorm1d(bottleneck_size))

class LeNet(nn.Module):
    
    def __init__(self):
        super(LeNet, self).__init__()
        self.feature_extractor = FeatureExtractor()
        self.bottleneck = Bottleneck()
        self.final_layer = nn.Sequential(
                                 nn.ReLU(),
                                 nn.Linear(bottleneck_size,10))
    
    def forward(self, x):
        features = self.feature_extractor(x)
        bottlenecked_features = self.bottleneck(features.view(-1, 64 * 7 * 7))
        classification = self.final_layer(bottlenecked_features)
        return classification
        

In [50]:
train_mnist = torchvision.datasets.MNIST('/tmp', train=True,
                                         download=True, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))                                              
test_mnist = torchvision.datasets.MNIST('/tmp', train=False,
                                        download=True, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))

In [51]:
train_data_loader = torch.utils.data.DataLoader(train_mnist, shuffle=True, pin_memory=True, batch_size=256)

In [52]:
criterion = nn.CrossEntropyLoss().cuda()

In [53]:
model = LeNet().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

In [54]:
num_epochs = 10
for i in range(num_epochs):
    for x, y in train_data_loader:
        x = Variable(x.cuda())
        y = Variable(y.cuda())
        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
    print("Loss: ", loss.data[0])
    

('Loss: ', 0.5748693943023682)
('Loss: ', 0.41194358468055725)
('Loss: ', 0.4367029368877411)
('Loss: ', 0.22382181882858276)
('Loss: ', 1.4835081100463867)
('Loss: ', 0.3648984432220459)
('Loss: ', 0.8446351885795593)
('Loss: ', 0.5363819003105164)
('Loss: ', 0.013164718635380268)
('Loss: ', 0.10387003421783447)


In [55]:
from gpytorch.kernels import RBFKernel, GridInterpolationKernel
from torch import nn, optim
from gpytorch.kernels import RBFKernel, GridInterpolationKernel
from gpytorch.means import ConstantMean
from gpytorch.likelihoods import GaussianLikelihood, BernoulliLikelihood
from gpytorch.random_variables import GaussianRandomVariable


class DeepKernel(gpytorch.Module):
    def __init__(self, model):
        super(DeepKernel, self).__init__()
        self.feature_extractor = model.feature_extractor
        self.bottleneck = model.bottleneck
        self.gp_layer = GPLayer()
        
    def forward(self, x):
        features = self.feature_extractor(x)
        bottlenecked_features = self.bottleneck(features.view(-1, 64 * 7 * 7))
        bottlenecked_features *= 0.01
        gp_output = self.gp_layer(bottlenecked_features)
        return gp_output

class LatentFunction(gpytorch.AdditiveGridInducingPointModule):
    def __init__(self):
        super(LatentFunction, self).__init__(grid_size=100, grid_bounds=[(-3, 3)])
        self.mean_module = ConstantMean(constant_bounds=(-1e-5, 1e-5))
        self.covar_module = RBFKernel(log_lengthscale_bounds=(-5, 5))
        self.register_parameter('log_outputscale', nn.Parameter(torch.Tensor([0])), bounds=(-5,6))
        
    def forward(self,x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        covar_x = covar_x.mul(self.log_outputscale.exp())
        latent_pred = GaussianRandomVariable(mean_x, covar_x)
        return latent_pred

class GPLayer(gpytorch.GPModel):
    def __init__(self):
        super(GPLayer,self).__init__(BernoulliLikelihood())
        self.latent_function = LatentFunction()
        
    def forward(self, x):
        return self.latent_function(x)
    

In [56]:
deep_kernel = DeepKernel(model).cuda()

In [59]:
gp_data_loader = torch.utils.data.DataLoader(train_mnist, batch_size=128, pin_memory=True)

In [62]:
# Find optimal model hyperparameters
deep_kernel.train()
optimizer = torch.optim.Adam(deep_kernel.gp_layer.parameters(), lr=0.005)
optimizer.n_iter = 0
num_epochs = 3
for i in range(num_epochs):
    for j, (train_x_batch, train_y_batch) in enumerate(gp_data_loader):
        train_y_batch[train_y_batch!=1]=-1
        train_x_batch = Variable(train_x_batch).cuda()
        train_y_batch = Variable(train_y_batch.cuda())
        optimizer.zero_grad()
        output = deep_kernel(train_x_batch)
        loss = -deep_kernel.gp_layer.marginal_log_likelihood(output, train_y_batch.float())
        loss.backward()
        optimizer.n_iter += 1
        print('Iter %d/%d - Loss: %.3f' % (
            i + 1, num_epochs, loss.data[0],
        ))
        optimizer.step()
    
# Set back to eval mode
deep_kernel.eval()

Iter 1/3 - Loss: 291.962
Iter 1/3 - Loss: 222.469
Iter 1/3 - Loss: 191.826
Iter 1/3 - Loss: 264.767
Iter 1/3 - Loss: 94.878
Iter 1/3 - Loss: 435.070
Iter 1/3 - Loss: 492.511
Iter 1/3 - Loss: 443.101
Iter 1/3 - Loss: 424.483
Iter 1/3 - Loss: 314.415
Iter 1/3 - Loss: 242.884
Iter 1/3 - Loss: 366.626
Iter 1/3 - Loss: 316.214
Iter 1/3 - Loss: 452.091
Iter 1/3 - Loss: 317.611
Iter 1/3 - Loss: 312.950
Iter 1/3 - Loss: 329.193
Iter 1/3 - Loss: 240.427
Iter 1/3 - Loss: 539.236
Iter 1/3 - Loss: 559.298
Iter 1/3 - Loss: 426.317
Iter 1/3 - Loss: 318.377
Iter 1/3 - Loss: 510.419
Iter 1/3 - Loss: 395.583
Iter 1/3 - Loss: 347.130
Iter 1/3 - Loss: 660.529
Iter 1/3 - Loss: 333.490
Iter 1/3 - Loss: 230.737
Iter 1/3 - Loss: 316.067
Iter 1/3 - Loss: 295.117
Iter 1/3 - Loss: 156.262
Iter 1/3 - Loss: 432.384
Iter 1/3 - Loss: 638.158
Iter 1/3 - Loss: 359.473
Iter 1/3 - Loss: 396.712
Iter 1/3 - Loss: 495.437
Iter 1/3 - Loss: 194.968
Iter 1/3 - Loss: 500.123
Iter 1/3 - Loss: 368.459
Iter 1/3 - Loss: 366.816
I

Iter 1/3 - Loss: 126.276
Iter 1/3 - Loss: 31.703
Iter 1/3 - Loss: 112.536
Iter 1/3 - Loss: 118.965
Iter 1/3 - Loss: 80.115
Iter 1/3 - Loss: 182.692
Iter 1/3 - Loss: 72.409
Iter 1/3 - Loss: 53.449
Iter 1/3 - Loss: 58.988
Iter 1/3 - Loss: 131.254
Iter 1/3 - Loss: 3.573
Iter 1/3 - Loss: 93.605
Iter 1/3 - Loss: 106.199
Iter 1/3 - Loss: 22.971
Iter 1/3 - Loss: 60.309
Iter 1/3 - Loss: 65.180
Iter 1/3 - Loss: 64.233
Iter 1/3 - Loss: 86.824
Iter 1/3 - Loss: 14.741
Iter 1/3 - Loss: 89.412
Iter 1/3 - Loss: 60.615
Iter 1/3 - Loss: 113.363
Iter 1/3 - Loss: 70.971
Iter 1/3 - Loss: 109.663
Iter 1/3 - Loss: 101.380
Iter 1/3 - Loss: 87.349
Iter 1/3 - Loss: 62.582
Iter 1/3 - Loss: 51.176
Iter 1/3 - Loss: 84.619
Iter 1/3 - Loss: 39.632
Iter 1/3 - Loss: 109.013
Iter 1/3 - Loss: 121.465
Iter 1/3 - Loss: 99.178
Iter 1/3 - Loss: 102.585
Iter 1/3 - Loss: 104.920
Iter 1/3 - Loss: 84.282
Iter 1/3 - Loss: 19.743
Iter 1/3 - Loss: 30.943
Iter 1/3 - Loss: 138.196
Iter 1/3 - Loss: 81.957
Iter 1/3 - Loss: 115.135
It

KeyboardInterrupt: 

In [63]:
deep_kernel.gp_layer.condition(Variable(torch.randn(2, 2)).cuda(), Variable(torch.zeros(2)).cuda())
deep_kernel.eval()
test_data_loader = torch.utils.data.DataLoader(test_mnist, shuffle=False, pin_memory=True, batch_size=256)

for test_batch_x, test_batch_y in test_data_loader:
    predictions = deep_kernel(Variable(test_batch_x).cuda()).probability.round()
    #print(predictions)
    test_batch_y = Variable(test_batch_y.fmod(2)).cuda().float()
    print(torch.eq(predictions, test_batch_y).float().mean().data[0])


0.55859375
0.59375
0.58984375
0.58984375
0.58203125
0.55859375
0.5703125
0.53125
0.5859375
0.5625
0.55078125
0.55078125
0.609375
0.578125
0.578125
0.5390625
0.53125
0.5625
0.5390625
0.6015625
0.5703125
0.609375
0.59765625
0.59765625
0.57421875
0.6015625
0.58203125
0.59375
0.60546875
0.5390625
0.59375
0.5546875
0.58984375
0.58203125
0.5859375
0.59765625
0.5703125
0.5234375
0.59765625
0.5


In [64]:
from torch.nn import functional as F

model.eval()
test_data_loader = torch.utils.data.DataLoader(test_mnist, shuffle=False, pin_memory=True, batch_size=256)

for test_batch_x, test_batch_y in test_data_loader:
    _, predictions = model(Variable(test_batch_x).cuda()).max(dim=1)
    predictions = predictions.fmod(2)
    #print(predictions)
    test_batch_y = Variable(test_batch_y).cuda().fmod(2)
    #print(test_batch_y)
    print(torch.eq(predictions, test_batch_y).float().mean().data[0])


0.9765625
0.98046875
0.98828125
0.99609375
0.9765625
0.98828125
0.98828125
0.98828125
0.98046875
0.99609375
0.96875
0.984375
0.9921875
0.98828125
0.97265625
0.98046875
0.984375
0.96875
0.984375
0.99609375
0.9921875
1.0
0.98828125
0.984375
1.0
0.98046875
0.99609375
1.0
1.0
0.99609375
0.9921875
1.0
0.9921875
1.0
0.99609375
0.9921875
1.0
0.984375
0.98828125
1.0
