In [1]:
import torch
from torchvision.datasets import MNIST

In [233]:
mnist_train_data = MNIST(root='~/data/mnist', download=True, train=True)
train_x, train_y = mnist_train_data.data, mnist_train_data.targets
mnist_test_data = MNIST(root='~/data/mnist', download=True, train=False)
test_x, test_y = mnist_test_data.data, mnist_test_data.targets

In [234]:
train_y = train_y.float()
test_y = test_y.float()

In [237]:
train_x = train_x.float().view(train_x.size(0), -1)
test_x = test_x.float().view(test_x.size(0), -1)
train_x_std = train_x.std(dim=-2) + 1e-4
train_x_mean = train_x.mean(dim=-2)
train_x = (train_x - train_x_mean) / train_x_std
test_x = (test_x - train_x_mean) / train_x_std
avg_norm = train_x.norm(dim=-1).mean()

train_x = train_x / avg_norm
test_x = test_x / avg_norm

In [238]:
from torch.utils.data import TensorDataset, DataLoader
train_dataset = TensorDataset(train_x, train_y)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

test_dataset = TensorDataset(test_x, test_y)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [259]:
import gpytorch
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.variational import VariationalStrategy

class GPModel(ApproximateGP):
    def __init__(self, inducing_points):
        # Consider using BatchNorm2d as final layer of ConvNet too.
        
#         feature_extractor = torch.nn.Sequential(
#             torch.nn.Conv2d(1, 32, kernel_size=(3, 3), padding=1),
#             torch.nn.ReLU(),
#             torch.nn.MaxPool2d(kernel_size=(2, 2)),  # 14 x 14
#             torch.nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1),
#             torch.nn.ReLU(),
#             torch.nn.MaxPool2d(kernel_size=(2, 2)),  # 64 x 7 x 7
#             torch.nn.Flatten(),
#             torch.nn.Linear(64 * 7 * 7, 10),
#             torch.nn.BatchNorm2d(10),
#         )

#         feature_extractor = torch.nn.Sequential(
#             torch.nn.Flatten(),
#             torch.nn.Linear(28*28, 64),
#             torch.nn.ReLU(),
#             torch.nn.Linear(64, 32),
#             torch.nn.ReLU(),
#             torch.nn.BatchNorm2d(32),
#         )

#         feature_extractor = torch.nn.Sequential(
#             torch.nn.Flatten(),
#             torch.nn.BatchNorm2d(768),
#         )

        # TODO: don't hardcode image shapes maybe
        inducing_points = feature_extractor(inducing_points.view(inducing_points.size(-2), 1, 28, 28))
        print(inducing_points.shape)
        # inducing points should now be m x 10
        
        variational_distribution = CholeskyVariationalDistribution(inducing_points.size(0))
        variational_strategy = VariationalStrategy(self, inducing_points, variational_distribution, learn_inducing_locations=True)
        super(GPModel, self).__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
        self.feature_extractor = feature_extractor
        
        
    def forward(self, x):
        from IPython.core.debugger import set_trace
        set_trace()
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
    
    def __call__(self, x):
        x = self.feature_extractor(x.view(x.size(-2), 1, 28, 28))
        return super().__call__(x)

num_inducing = 1024  # Can lower this if you want it to be faster
inducing_points = train_x[:num_inducing, :]
model = GPModel(inducing_points=inducing_points)

likelihood = gpytorch.likelihoods.GaussianLikelihood()

if torch.cuda.is_available():
    model = model.cuda()
    likelihood = likelihood.cuda()

torch.Size([1024, 784])


In [260]:
train_y.size(0)

60000

In [None]:
from tqdm.notebook import tqdm

num_epochs = 5

model.train()
likelihood.train()

optimizer = torch.optim.Adam([
    {'params': model.parameters()},
    {'params': likelihood.parameters()},
], lr=0.01)

# PLL is like VariationalELBO, but often gives better calibrated results (see https://arxiv.org/pdf/1910.07123.pdf)
# Can only use PLL for regression
mll = gpytorch.mlls.PredictiveLogLikelihood(likelihood, model, num_data=train_y.size(0))

epochs_iter = tqdm(range(num_epochs), desc="Epoch")
for i in epochs_iter:
    # Within each iteration, we will go over each minibatch of data
    minibatch_iter = tqdm(train_loader, desc="Minibatch", leave=False)
    for x_batch, y_batch in minibatch_iter:
        # TODO: Use pinned memory etc etc to make the next two lines fast
        x_batch = x_batch.cuda()
        y_batch = y_batch.cuda()
        optimizer.zero_grad()
        output = model(x_batch)
        loss = -mll(output, y_batch)
        minibatch_iter.set_postfix(loss=loss.item())
        loss.backward()
        optimizer.step()

Epoch:   0%|          | 0/5 [00:00<?, ?it/s]

Minibatch:   0%|          | 0/469 [00:00<?, ?it/s]

> [0;32m<ipython-input-259-af678a9d007a>[0m(49)[0;36mforward[0;34m()[0m
[0;32m     47 [0;31m        [0;32mfrom[0m [0mIPython[0m[0;34m.[0m[0mcore[0m[0;34m.[0m[0mdebugger[0m [0;32mimport[0m [0mset_trace[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     48 [0;31m        [0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 49 [0;31m        [0mmean_x[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mmean_module[0m[0;34m([0m[0mx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     50 [0;31m        [0mcovar_x[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mcovar_module[0m[0;34m([0m[0mx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     51 [0;31m        [0;32mreturn[0m [0mgpytorch[0m[0;34m.[0m[0mdistributions[0m[0;34m.[0m[0mMultivariateNormal[0m[0;34m([0m[0mmean_x[0m[0;34m,[0m [0mcovar_x[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> x
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0

In [254]:
pred_means = []
pred_vars = []
model.eval()
likelihood.eval()
with torch.no_grad():
    for x_batch_test, y_batch_test in tqdm(test_loader):
        x_batch_test = x_batch_test.cuda()
        pred = likelihood(model(x_batch_test))
        pred_means.append(pred.mean.cpu())

  0%|          | 0/79 [00:00<?, ?it/s]

In [255]:
torch.mean((torch.round(torch.cat(pred_means)) - test_y).abs())

tensor(1.6770)