In [1]:
import gpytorch
import torch
from torch.autograd import Variable

gpytorch.functions.use_toeplitz = False

In [2]:
import os
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

train_dataset = datasets.MNIST('/tmp', train=True, download=True,
                               transform=transforms.Compose([
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.1307,), (0.3081,))
                               ]))
test_dataset = datasets.MNIST('/tmp', train=False, download=True,
                              transform=transforms.Compose([
                                  transforms.ToTensor(),
                                  transforms.Normalize((0.1307,), (0.3081,))
                              ]))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False, pin_memory=True)

## Define the feature extractor for our deep kernel

In [3]:
from collections import OrderedDict
from torch import nn
from torch.nn import functional as F

class LeNetFeatureExtractor(nn.Module):
    def __init__(self):
        super(LeNetFeatureExtractor, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=2)
        self.norm1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, padding=2)
        self.norm2 = nn.BatchNorm2d(32)
        self.fc3 = nn.Linear(32 * 7 * 7, 64)
        self.norm3 = nn.BatchNorm1d(64)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.norm1(self.conv1(x))), 2)
        x = F.max_pool2d(F.relu(self.norm2(self.conv2(x))), 2)
        x = x.view(-1, 32 * 7 * 7)
        x = F.relu(self.norm3(self.fc3(x)))
        return x
    
feature_extractor = LeNetFeatureExtractor().cuda()

### Pretrain the feature extractor a bit

In [4]:
classifier = nn.Linear(64, 10).cuda()
params = list(feature_extractor.parameters()) + list(classifier.parameters())
optimizer = torch.optim.SGD(params, lr=0.1, momentum=0.9)

def pretrain(epoch):
    feature_extractor.train()
    train_loss = 0.
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        features = feature_extractor(data)
        output = F.log_softmax(classifier(features), 1)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.data[0] * len(data)
    print('Train Epoch: %d\tLoss: %.6f' % (epoch, train_loss / len(train_dataset)))

def pretest():
    feature_extractor.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        features = feature_extractor(data)
        output = F.log_softmax(classifier(features), 1)
        test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()
    test_loss /= len(test_loader.dataset)
    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

n_epochs = 3
for epoch in range(1, n_epochs + 1):
    pretrain(epoch)
    pretest()

Train Epoch: 1	Loss: 0.149140
Test set: Average loss: 0.0613, Accuracy: 9805/10000 (98.050%)
Train Epoch: 2	Loss: 0.034949
Test set: Average loss: 0.0372, Accuracy: 9875/10000 (98.750%)
Train Epoch: 3	Loss: 0.025075
Test set: Average loss: 0.0288, Accuracy: 9895/10000 (98.950%)


## Define the deep kernel GP

In [5]:
class DKLModel(gpytorch.Module):
    def __init__(self, feature_extractor, n_features=64, grid_bounds=(-10., 10.)):
        super(DKLModel, self).__init__()
        self.feature_extractor = feature_extractor
        self.latent_functions = LatentFunctions(n_features=n_features, grid_bounds=grid_bounds)
        
        self.grid_bounds = grid_bounds
        self.n_features = n_features
    
    def forward(self, x):
        features = self.feature_extractor(x)
        # Scale to fit insid egrid bounds
        features = gpytorch.utils.scale_to_bounds(features, self.grid_bounds[0], self.grid_bounds[1])
        res = self.latent_functions(features.unsqueeze(-1))
        return res
    
    
class LatentFunctions(gpytorch.models.AdditiveGridInducingVariationalGP):
    def __init__(self, n_features=64, grid_bounds=(-10., 10.), grid_size=128):
        super(LatentFunctions, self).__init__(grid_size=grid_size, grid_bounds=[grid_bounds],
                                              n_components=n_features, mixing_params=False, sum_output=False)
        cov_module = gpytorch.kernels.RBFKernel()
        cov_module.initialize(log_lengthscale=0)
        self.cov_module = cov_module
        self.grid_bounds = grid_bounds
        
    def forward(self, x):     
        mean = Variable(x.data.new(len(x)).zero_())
        covar = self.cov_module(x)
        return gpytorch.random_variables.GaussianRandomVariable(mean, covar)
    
    
model = DKLModel(feature_extractor).cuda()
likelihood = gpytorch.likelihoods.SoftmaxLikelihood(n_features=model.n_features, n_classes=10).cuda()

In [6]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2048, shuffle=True, pin_memory=True)
optimizer = torch.optim.Adam(list(model.parameters()) + list(likelihood.parameters()), lr=0.01)

def train(epoch):
    model.train()
    likelihood.train()
    
    train_loss = 0.
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        loss = -model.latent_functions.marginal_log_likelihood(likelihood, output, target, n_data=len(train_dataset))
        loss.backward()
        optimizer.step()
        print('Train Epoch: %d [%03d/%03d], Loss: %.6f' % (epoch, batch_idx + 1, len(train_loader), loss.data[0]))

def test():
    model.eval()
    likelihood.eval()

    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = likelihood(model(data))
        pred = output.argmax()
        correct += pred.eq(target.view_as(pred)).data.cpu().sum()
    test_loss /= len(test_loader.dataset)
    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

n_epochs = 10
for epoch in range(1, n_epochs + 1):
    %time train(epoch)
    test()

Train Epoch: 1 [001/030], Loss: 42.115540
Train Epoch: 1 [002/030], Loss: 44.211239
Train Epoch: 1 [003/030], Loss: 56.054520
Train Epoch: 1 [004/030], Loss: 53.056030
Train Epoch: 1 [005/030], Loss: 46.779018
Train Epoch: 1 [006/030], Loss: 35.498089
Train Epoch: 1 [007/030], Loss: 27.772362
Train Epoch: 1 [008/030], Loss: 27.916121
Train Epoch: 1 [009/030], Loss: 25.482174
Train Epoch: 1 [010/030], Loss: 23.497782
Train Epoch: 1 [011/030], Loss: 17.294296
Train Epoch: 1 [012/030], Loss: 17.066601
Train Epoch: 1 [013/030], Loss: 14.218574
Train Epoch: 1 [014/030], Loss: 33.584183
Train Epoch: 1 [015/030], Loss: 14.517918
Train Epoch: 1 [016/030], Loss: 12.990355
Train Epoch: 1 [017/030], Loss: 11.168376
Train Epoch: 1 [018/030], Loss: 11.682179
Train Epoch: 1 [019/030], Loss: 9.016045
Train Epoch: 1 [020/030], Loss: 7.732582
Train Epoch: 1 [021/030], Loss: 8.022996
Train Epoch: 1 [022/030], Loss: 8.573727
Train Epoch: 1 [023/030], Loss: 7.009734
Train Epoch: 1 [024/030], Loss: 5.99574

  softmax = nn.functional.softmax(mixed_fs.t()).view(n_data, n_samples, self.n_classes)


Test set: Average loss: 0.0000, Accuracy: 9637/10000 (96.370%)
Train Epoch: 2 [001/030], Loss: 4.418161
Train Epoch: 2 [002/030], Loss: 4.201461
Train Epoch: 2 [003/030], Loss: 3.327814
Train Epoch: 2 [004/030], Loss: 4.318752
Train Epoch: 2 [005/030], Loss: 3.698612
Train Epoch: 2 [006/030], Loss: 3.005366
Train Epoch: 2 [007/030], Loss: 3.816376
Train Epoch: 2 [008/030], Loss: 3.133369
Train Epoch: 2 [009/030], Loss: 3.284032
Train Epoch: 2 [010/030], Loss: 2.920748
Train Epoch: 2 [011/030], Loss: 3.216592
Train Epoch: 2 [012/030], Loss: 3.209044
Train Epoch: 2 [013/030], Loss: 2.328241
Train Epoch: 2 [014/030], Loss: 2.509367
Train Epoch: 2 [015/030], Loss: 2.482173
Train Epoch: 2 [016/030], Loss: 2.645454
Train Epoch: 2 [017/030], Loss: 2.657871
Train Epoch: 2 [018/030], Loss: 2.333062
Train Epoch: 2 [019/030], Loss: 1.792445
Train Epoch: 2 [020/030], Loss: 1.824884
Train Epoch: 2 [021/030], Loss: 1.522922
Train Epoch: 2 [022/030], Loss: 1.978374
Train Epoch: 2 [023/030], Loss: 1.9