In [1]:
import torch
import torch.nn.functional as F
import torchsde

from torchvision import datasets, transforms

import math
import numpy as np
import pandas as pd
from tqdm import tqdm

from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
import functorch

import gc

import matplotlib.pyplot as plt

from cfollmer.evaluation_utils import ECE
import cfollmer.functional as functional
from cfollmer.objectives import relative_entropy_control_cost, stl_control_cost
from cfollmer.drifts import SimpleForwardNetBN, ScoreNetwork, ResNetScoreNetwork
from cfollmer.sampler_utils import FollmerSDE

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
class LeNet5(torch.nn.Module):

    def __init__(self, n_classes):
        super(LeNet5, self).__init__()
        
        self.feature_extractor = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1),
            torch.nn.Tanh(),
            torch.nn.AvgPool2d(kernel_size=2),
            torch.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
            torch.nn.Tanh(),
            torch.nn.AvgPool2d(kernel_size=2),
        )

        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(in_features=256, out_features=120),
            torch.nn.Tanh(),
            torch.nn.Linear(in_features=120, out_features=84),
            torch.nn.Tanh(),
            torch.nn.Linear(in_features=84, out_features=n_classes),
        )

    def forward(self, x):
        x = self.feature_extractor(x)
        x = torch.flatten(x, 1)
        logits = self.classifier(x)
        return logits

In [4]:
test_transforms = transforms.Compose([transforms.ToTensor(), transforms.RandomAffine(30)])

MNIST_train = datasets.MNIST("../data/mnist/", download=True, transform=ToTensor(), train=True)
MNIST_test = datasets.MNIST("../data/mnist/", download=True, transform=test_transforms, train=False)

N_train = len(MNIST_train)
N_test = len(MNIST_test)

In [5]:
model = LeNet5(10).to(device)
func_model, params = functorch.make_functional(model)
size_list = functional.params_to_size_tuples(params)
dim = functional.get_number_of_params(size_list)
  
sigma2 = 0.1

def log_prior(params):
    return -torch.sum(params**2) / (2 * sigma2)

def log_likelihood(x, y, params):
    preds = func_model(functional.get_params_from_array(params, size_list), x)
    return -F.cross_entropy(preds, y, reduction="sum")

def log_likelihood_batch(x, y, params_batch):
    func = lambda params: log_likelihood(x, y, params)
    func = functorch.vmap(func)
    return func(params_batch)

def log_posterior(x, y, params):
    return log_prior(params) + (N_train / x.shape[0]) * log_likelihood(x, y, params)

def log_posterior_batch(x, y, params_batch):
    func = lambda params: log_posterior(x, y, params)
    func = functorch.vmap(func)
    return func(params_batch)

In [6]:
def train(gamma, n_epochs, data_batch_size, param_batch_size, dt=0.05, stl=False):
    sde = FollmerSDE(gamma, SimpleForwardNetBN(input_dim=dim, width=300)).to(device)
#     sde = FollmerSDE(gamma, ResNetScoreNetwork(dim)).to(device)
    optimizer = torch.optim.Adam(sde.parameters(), lr=1e-5)
    
    dataloader_train = DataLoader(MNIST_train, shuffle=True, batch_size=data_batch_size, num_workers=2)

    losses = []

    for _ in range(n_epochs):
        epoch_losses = []
        for x, y in tqdm(iter(dataloader_train)):
            x = x.to(device)
            y = y.to(device)
            
            optimizer.zero_grad()
        
            partial_log_p = lambda params_batch: log_posterior_batch(x, y, params_batch)
            
            if stl:
                loss = stl_control_cost(sde, partial_log_p, param_batch_size=param_batch_size, dt=dt, device=device)
            else:
                loss = relative_entropy_control_cost(sde, partial_log_p, param_batch_size=param_batch_size, dt=dt, device=device)
            loss.backward()

            epoch_losses.append(loss.detach().cpu().numpy())
            optimizer.step()
            
            if stl: # double check theres no references left
                sde.drift_network_detatched.load_state_dict((sde.drift_network.state_dict()))
            
        #  Memory leaks somewhere with sdeint / param_T = param_trajectory[-1]
        gc.collect()

        losses.append(epoch_losses)
    
    losses = np.array(losses)
    
    return sde, losses

In [7]:
gamma = 0.1**2
n_epochs = 10
data_batch_size = 32
param_batch_size = 32

In [None]:
num_exp = 1

for i in range(num_exp):
    sde, losses = train(gamma, n_epochs, data_batch_size, param_batch_size, dt=0.05, stl=True)
    torch.save(sde.state_dict(), "weights/bnn/weights-stl-{:d}.pt".format(i))

  return torch.batch_norm(
  return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
100%|██████████| 1875/1875 [10:40<00:00,  2.93it/s]
100%|██████████| 1875/1875 [10:51<00:00,  2.88it/s]
 20%|██        | 376/1875 [02:12<08:49,  2.83it/s]

In [None]:
# torch.save(sde.state_dict(), "weights/bnn/weights-resnet-{:d}.pt".format(i))
print(1)

In [None]:
def evaluate(param_samples):
    dataloader_test = DataLoader(MNIST_test, shuffle=False, batch_size=data_batch_size, num_workers=2)
    
    all_predictions = []
    all_confidences = []
    all_logps = []
    
    for x, y in tqdm(iter(dataloader_test)):
        with torch.no_grad():
            x = x.to(device)
            y = y.to(device)
            
            predict_func = lambda params : func_model(functional.get_params_from_array(params, size_list), x)
            predict_func = functorch.vmap(predict_func)

            out = F.softmax(predict_func(param_samples), dim=-1)
            out = torch.mean(out, dim=0)
            
            confidences, predictions = torch.max(out, dim=1)

            all_predictions.append(predictions)
            all_confidences.append(confidences)
            
            all_logps.append(torch.mean(log_likelihood_batch(x, y, param_samples)))
    
    all_predictions = torch.hstack(all_predictions).cpu().numpy()
    all_confidences = torch.hstack(all_confidences).cpu().numpy()
    true_labels = MNIST_test.targets.numpy()
    
    accuracy = np.mean(all_predictions == true_labels)
    ece = ECE(all_confidences, all_predictions, true_labels)
    
    logp = torch.sum(torch.stack(all_logps)) / N_test
    logp = logp.cpu().numpy()
    return accuracy, ece, logp

In [17]:
accuracies, eces, logps = [], [], []

for i in range(num_exp):
    sde = FollmerSDE(gamma, SimpleForwardNetBN(input_dim=dim, width=300)).to(device)
#     sde = FollmerSDE(gamma, ResNetScoreNetwork(dim)).to(device)
    sde.load_state_dict(torch.load("weights/bnn/weights-stl-{:d}.pt".format(i)))
    
    with torch.no_grad():
        param_samples = sde.sample(100, dt=0.005, device=device)
    
    accuracy, ece, logp = evaluate(param_samples)
    
    accuracies.append(accuracy)
    eces.append(ece)
    logps.append(logp)
    
accuracies = np.array(accuracies)
eces = np.array(eces)
logps = np.array(logps)

  return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
100%|██████████| 313/313 [00:02<00:00, 126.40it/s]


In [18]:
SBP_df = pd.DataFrame({"Accuracy": accuracies, "ECE": eces, "log predictive": logps})

In [19]:
SBP_df

Unnamed: 0,Accuracy,ECE,log predictive
0,0.9404,0.009862,-0.299734


In [25]:
SBP_df.describe()

Unnamed: 0,Accuracy,ECE,log predictive
count,5.0,5.0,5.0
mean,0.94002,0.00563,-0.290168
std,0.003786,0.002728,0.021163
min,0.935,0.00346,-0.319881
25%,0.9382,0.004022,-0.303128
50%,0.9394,0.004327,-0.285371
75%,0.9432,0.006186,-0.271571
max,0.9443,0.010154,-0.270888


SGLD from here onwards

In [29]:
@torch.enable_grad()
def gradient(x, y, params):
    params_ = params.clone().requires_grad_(True)
    loss = log_posterior(x, y, params_)
    grad, = torch.autograd.grad(loss, params_)
    return loss.detach().cpu().numpy(), grad

In [30]:
def step_size(n):
    return 1e-4 / (1 + n)**0.55

In [31]:
def sgld(n_epochs, data_batch_size):
    dataloader_train = DataLoader(MNIST_train, shuffle=True, batch_size=data_batch_size, num_workers=2)
    params = torch.cat([param.flatten() for param in model.parameters()]).detach()
    losses = []
    step = 0
    for _ in range(n_epochs):
        epoch_losses = []
        for x, y in tqdm(iter(dataloader_train)):
            x = x.to(device)
            y = y.to(device)

            eps = step_size(step)
            loss, grad = gradient(x, y, params)
            params = params + 0.5 * eps * grad #+ np.sqrt(eps) * torch.randn_like(params)
            step += 1
            epoch_losses.append(loss)
        
        losses.append(epoch_losses)
    
    param_samples = []
    
    iterator = iter(dataloader_train)
    for _ in range(100):
        x = x.to(device)
        y = y.to(device)
        
        eps = step_size(step)
        loss, grad = gradient(x, y, params)
        params = params + 0.5 * eps * grad + np.sqrt(eps) * torch.randn_like(params)
        param_samples.append(params)
        step += 1
        
    param_samples = torch.stack(param_samples)
    losses = np.array(losses)
    
    return param_samples, losses

In [34]:
accuracies, eces, logps = [], [], []

for i in range(5):
    param_samples, losses = sgld(n_epochs, data_batch_size)
    
    accuracy, ece, logp = evaluate(param_samples)
    
    accuracies.append(accuracy)
    eces.append(ece)
    logps.append(logp)
    
accuracies = np.array(accuracies)
eces = np.array(eces)
logps = np.array(logps)

100%|███████████████████████████████████████████████████████████| 1875/1875 [00:03<00:00, 583.49it/s]
100%|███████████████████████████████████████████████████████████| 1875/1875 [00:03<00:00, 586.28it/s]
100%|███████████████████████████████████████████████████████████| 1875/1875 [00:03<00:00, 588.93it/s]
100%|███████████████████████████████████████████████████████████| 1875/1875 [00:03<00:00, 575.57it/s]
100%|███████████████████████████████████████████████████████████| 1875/1875 [00:03<00:00, 588.42it/s]
100%|███████████████████████████████████████████████████████████| 1875/1875 [00:03<00:00, 551.23it/s]
100%|███████████████████████████████████████████████████████████| 1875/1875 [00:03<00:00, 581.09it/s]
100%|███████████████████████████████████████████████████████████| 1875/1875 [00:03<00:00, 566.18it/s]
100%|███████████████████████████████████████████████████████████| 1875/1875 [00:03<00:00, 572.30it/s]
100%|███████████████████████████████████████████████████████████| 1875/1875 [00:03

In [35]:
SGLD_df = pd.DataFrame({"Accuracy": accuracies, "ECE": eces, "log predictive": logps})

In [36]:
SGLD_df

Unnamed: 0,Accuracy,ECE,log predictive
0,0.937,0.009406,-0.201103
1,0.937,0.011076,-0.198261
2,0.8647,0.019926,-0.436007
3,0.9313,0.012448,-0.22101
4,0.9261,0.019318,-0.243154


In [37]:
SGLD_df.describe()

Unnamed: 0,Accuracy,ECE,log predictive
count,5.0,5.0,5.0
mean,0.91922,0.014435,-0.259907
std,0.030814,0.004861,0.100079
min,0.8647,0.009406,-0.436007
25%,0.9261,0.011076,-0.243154
50%,0.9313,0.012448,-0.22101
75%,0.937,0.019318,-0.201103
max,0.937,0.019926,-0.198261


In [13]:
def sgd(n_epochs, data_batch_size):
    
    dataloader_train = DataLoader(MNIST_train, shuffle=True, batch_size=data_batch_size, num_workers=2)
    model = LeNet5(10).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
    losses = []

    for i in range(n_epochs):
        for x, y in tqdm(iter(dataloader_train)):

            x = x.to(device)
            y = y.to(device)

            optimizer.zero_grad()

            out = model(x)

            l = F.cross_entropy(out, y, reduction="mean")

            l.backward()

            losses.append(l.detach().cpu().numpy())

            optimizer.step()

    losses = np.array(losses)
    
    return model, losses

In [19]:
accuracies, eces, logps = [], [], []

for i in range(5):
    model, losses = sgd(n_epochs, data_batch_size)
    params = model.parameters()
    params = torch.cat([param.flatten() for param in params]).detach()
    params = params.view(1, -1)
    accuracy, ece, logp = evaluate(params)
    
    accuracies.append(accuracy)
    eces.append(ece)
    logps.append(logp)
    
accuracies = np.array(accuracies)
eces = np.array(eces)
logps = np.array(logps)

100%|███████████████████████████████████████████████████████████| 1875/1875 [00:02<00:00, 715.47it/s]
100%|███████████████████████████████████████████████████████████| 1875/1875 [00:02<00:00, 694.22it/s]
100%|███████████████████████████████████████████████████████████| 1875/1875 [00:02<00:00, 703.74it/s]
100%|███████████████████████████████████████████████████████████| 1875/1875 [00:02<00:00, 691.18it/s]
100%|███████████████████████████████████████████████████████████| 1875/1875 [00:02<00:00, 692.61it/s]
100%|███████████████████████████████████████████████████████████| 1875/1875 [00:02<00:00, 703.72it/s]
100%|███████████████████████████████████████████████████████████| 1875/1875 [00:02<00:00, 689.86it/s]
100%|███████████████████████████████████████████████████████████| 1875/1875 [00:02<00:00, 701.62it/s]
 84%|█████████████████████████████████████████████████▋         | 1580/1875 [00:02<00:00, 754.53it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f49ede

<function _MultiProcessingDataLoaderIter.__del__ at 0x7f49ededb700>
    Traceback (most recent call last):
self._shutdown_workers()  File "/home/ao464@ad.eng.cam.ac.uk/repos/ControlledFollmerDrift/env/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__

  File "/home/ao464@ad.eng.cam.ac.uk/repos/ControlledFollmerDrift/env/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    self._shutdown_workers()    
if w.is_alive():  File "/home/ao464@ad.eng.cam.ac.uk/repos/ControlledFollmerDrift/env/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers

  File "/usr/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
        if w.is_alive():
assert self._parent_pid == os.getpid(), 'can only test a child process'  File "/usr/lib/python3.8/multiprocessing/process.py", line 160, in is_alive

    AssertionErrorassert self._parent_pid == os.getpid(), 'can only test a child proce

In [20]:
SGD_df = pd.DataFrame({"Accuracy": accuracies, "ECE": eces, "log predictive": logps})

In [21]:
SGD_df

Unnamed: 0,Accuracy,ECE,log predictive
0,0.9115,0.009907,-0.271943
1,0.9166,0.00878,-0.269324
2,0.9174,0.009571,-0.262576
3,0.9079,0.007665,-0.290798
4,0.9102,0.010044,-0.273837


In [22]:
SGD_df.describe()

Unnamed: 0,Accuracy,ECE,log predictive
count,5.0,5.0,5.0
mean,0.91272,0.009193,-0.273696
std,0.004124,0.000985,0.010468
min,0.9079,0.007665,-0.290798
25%,0.9102,0.00878,-0.273837
50%,0.9115,0.009571,-0.271943
75%,0.9166,0.009907,-0.269324
max,0.9174,0.010044,-0.262576
