In [1]:
import gpytorch
import torch
from torch.autograd import Variable

gpytorch.functions.use_toeplitz = False

In [2]:
import os
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

train_dataset = datasets.MNIST('/tmp', train=True, download=True,
                               transform=transforms.Compose([
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.1307,), (0.3081,))
                               ]))
test_dataset = datasets.MNIST('/tmp', train=False, download=True,
                              transform=transforms.Compose([
                                  transforms.ToTensor(),
                                  transforms.Normalize((0.1307,), (0.3081,))
                              ]))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False, pin_memory=True)

## Define the feature extractor for our deep kernel

In [3]:
from collections import OrderedDict
from torch import nn
from torch.nn import functional as F

class LeNetFeatureExtractor(nn.Module):
    def __init__(self):
        super(LeNetFeatureExtractor, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=2)
        self.norm1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, padding=2)
        self.norm2 = nn.BatchNorm2d(32)
        self.fc3 = nn.Linear(32 * 7 * 7, 64)
        self.norm3 = nn.BatchNorm1d(64)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.norm1(self.conv1(x))), 2)
        x = F.max_pool2d(F.relu(self.norm2(self.conv2(x))), 2)
        x = x.view(-1, 32 * 7 * 7)
        x = F.relu(self.norm3(self.fc3(x)))
        return x
    
feature_extractor = LeNetFeatureExtractor().cuda()

### Pretrain the feature extractor a bit

In [4]:
classifier = nn.Linear(64, 10).cuda()
params = list(feature_extractor.parameters()) + list(classifier.parameters())
optimizer = torch.optim.SGD(params, lr=0.1, momentum=0.9)

def pretrain(epoch):
    feature_extractor.train()
    train_loss = 0.
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        features = feature_extractor(data)
        output = F.log_softmax(classifier(features), 1)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.data[0] * len(data)
    print('Train Epoch: %d\tLoss: %.6f' % (epoch, train_loss / len(train_dataset)))

def pretest():
    feature_extractor.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        features = feature_extractor(data)
        output = F.log_softmax(classifier(features), 1)
        test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()
    test_loss /= len(test_loader.dataset)
    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

n_epochs = 3
for epoch in range(1, n_epochs + 1):
    pretrain(epoch)
    pretest()

Train Epoch: 1	Loss: 0.145056




Test set: Average loss: 0.0398, Accuracy: 9866/10000 (98.660%)
Train Epoch: 2	Loss: 0.035001
Test set: Average loss: 0.0380, Accuracy: 9884/10000 (98.840%)
Train Epoch: 3	Loss: 0.023860
Test set: Average loss: 0.0265, Accuracy: 9904/10000 (99.040%)


## Define the deep kernel GP

In [11]:
class DKLModel(gpytorch.GPModel):
    def __init__(self, feature_extractor, n_features=64, n_classes=10, grid_bounds=(-10., 10.)):
        likelihood = gpytorch.likelihoods.SoftmaxLikelihood(n_features=n_features, n_classes=n_classes)
        super(DKLModel, self).__init__(likelihood)
        
        self.feature_extractor = feature_extractor
        self.latent_functions = LatentFunctions(n_features=n_features, grid_bounds=grid_bounds)
        self.grid_bounds = grid_bounds
    
    def forward(self, x):
        features = self.feature_extractor(x)
        # Scale to fit insid egrid bounds
        features = gpytorch.utils.scale_to_bounds(features, self.grid_bounds[0], self.grid_bounds[1])
        res = self.latent_functions(features.unsqueeze(-1))
        return res
    
    
class LatentFunctions(gpytorch.AdditiveGridInducingPointModule):
    def __init__(self, n_features, grid_bounds, grid_size=128):
        super(LatentFunctions, self).__init__(grid_size=grid_size, grid_bounds=[grid_bounds],
                                             n_components=n_features, mixing_params=False, sum_output=False)
        cov_module = gpytorch.kernels.RBFKernel()
        cov_module.initialize(log_lengthscale=0)
        self.cov_module = cov_module
        
    def forward(self, x):
        mean = Variable(x.data.new(len(x)).zero_())
        covar = self.cov_module(x)
        return gpytorch.random_variables.GaussianRandomVariable(mean, covar)
    
    
model = DKLModel(feature_extractor).cuda()

In [12]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2048, shuffle=True, pin_memory=True)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train(epoch):
    model.train()
    train_loss = 0.
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        loss = -model.marginal_log_likelihood(output, target, n_data=len(train_dataset))
        loss.backward()
        optimizer.step()
        print('Train Epoch: %d [%03d/%03d], Loss: %.6f' % (epoch, batch_idx + 1, len(train_loader), loss.data[0]))

def test():
    model.eval()
    feature_extractor.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        pred = output.argmax()
        correct += pred.eq(target.view_as(pred)).data.cpu().sum()
    test_loss /= len(test_loader.dataset)
    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

n_epochs = 10
for epoch in range(1, n_epochs + 1):
    train(epoch)
    test()

Train Epoch: 1 [001/030], Loss: 42.384556
Train Epoch: 1 [002/030], Loss: 59.947166
Train Epoch: 1 [003/030], Loss: 46.958019
Train Epoch: 1 [004/030], Loss: 37.468971
Train Epoch: 1 [005/030], Loss: 38.515999
Train Epoch: 1 [006/030], Loss: 33.736656
Train Epoch: 1 [007/030], Loss: 61.322174
Train Epoch: 1 [008/030], Loss: 27.759411
Train Epoch: 1 [009/030], Loss: 21.316530
Train Epoch: 1 [010/030], Loss: 24.071991
Train Epoch: 1 [011/030], Loss: 21.866581
Train Epoch: 1 [012/030], Loss: 16.668083
Train Epoch: 1 [013/030], Loss: 19.503393
Train Epoch: 1 [014/030], Loss: 13.266150
Train Epoch: 1 [015/030], Loss: 12.656313
Train Epoch: 1 [016/030], Loss: 14.961938
Train Epoch: 1 [017/030], Loss: 11.993261
Train Epoch: 1 [018/030], Loss: 9.440063
Train Epoch: 1 [019/030], Loss: 12.061893
Train Epoch: 1 [020/030], Loss: 9.291729
Train Epoch: 1 [021/030], Loss: 7.797052
Train Epoch: 1 [022/030], Loss: 6.539400
Train Epoch: 1 [023/030], Loss: 7.787019
Train Epoch: 1 [024/030], Loss: 6.34382



Test set: Average loss: 0.0000, Accuracy: 8902/10000 (89.020%)
Train Epoch: 2 [001/030], Loss: 5.266603
Train Epoch: 2 [002/030], Loss: 3.753875
Train Epoch: 2 [003/030], Loss: 3.828629
Train Epoch: 2 [004/030], Loss: 3.631824
Train Epoch: 2 [005/030], Loss: 4.894492
Train Epoch: 2 [006/030], Loss: 2.812186
Train Epoch: 2 [007/030], Loss: 6.411754
Train Epoch: 2 [008/030], Loss: 2.969193
Train Epoch: 2 [009/030], Loss: 3.277704
Train Epoch: 2 [010/030], Loss: 3.264744
Train Epoch: 2 [011/030], Loss: 2.062227
Train Epoch: 2 [012/030], Loss: 4.527439
Train Epoch: 2 [013/030], Loss: 3.147879
Train Epoch: 2 [014/030], Loss: 2.526008
Train Epoch: 2 [015/030], Loss: 2.639375
Train Epoch: 2 [016/030], Loss: 2.211667
Train Epoch: 2 [017/030], Loss: 2.150765
Train Epoch: 2 [018/030], Loss: 2.678360
Train Epoch: 2 [019/030], Loss: 2.035064
Train Epoch: 2 [020/030], Loss: 1.618116
Train Epoch: 2 [021/030], Loss: 1.722281
Train Epoch: 2 [022/030], Loss: 2.044909
Train Epoch: 2 [023/030], Loss: 1.6

In [13]:
test()



Test set: Average loss: 0.0000, Accuracy: 9896/10000 (98.960%)
