In [1]:
import torch
import torch.nn.functional as F
import torchsde

from torchvision import datasets, transforms

import math
import numpy as np
import pandas as pd
from tqdm import tqdm

from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
import functorch

import gc

import matplotlib.pyplot as plt

from cfollmer.evaluation_utils import ECE
import cfollmer.functional as functional
from cfollmer.objectives import relative_entropy_control_cost, stl_control_cost
from cfollmer.drifts import SimpleForwardNetBN, ScoreNetwork, ResNetScoreNetwork
from cfollmer.sampler_utils import FollmerSDE

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
class LeNet5(torch.nn.Module):

    def __init__(self, n_classes):
        super(LeNet5, self).__init__()
        
        self.feature_extractor = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1),
            torch.nn.Tanh(),
            torch.nn.AvgPool2d(kernel_size=2),
            torch.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
            torch.nn.Tanh(),
            torch.nn.AvgPool2d(kernel_size=2),
        )

        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(in_features=256, out_features=120),
            torch.nn.Tanh(),
            torch.nn.Linear(in_features=120, out_features=84),
            torch.nn.Tanh(),
            torch.nn.Linear(in_features=84, out_features=n_classes),
        )

    def forward(self, x):
        x = self.feature_extractor(x)
        x = torch.flatten(x, 1)
        logits = self.classifier(x)
        return logits

In [4]:
test_transforms = transforms.Compose([transforms.ToTensor(), transforms.RandomAffine(30)])

MNIST_train = datasets.MNIST("../data/mnist/", download=True, transform=ToTensor(), train=True)
MNIST_test = datasets.MNIST("../data/mnist/", download=True, transform=test_transforms, train=False)

N_train = len(MNIST_train)
N_test = len(MNIST_test)

In [5]:
model = LeNet5(10).to(device)
func_model, params = functorch.make_functional(model)
size_list = functional.params_to_size_tuples(params)
dim = functional.get_number_of_params(size_list)
  
sigma2 = 1

def log_prior(params):
    return -torch.sum(params**2) / (2 * sigma2)

def log_likelihood(x, y, params):
    preds = func_model(functional.get_params_from_array(params, size_list), x)
    return -F.cross_entropy(preds, y, reduction="sum")

def log_likelihood_batch(x, y, params_batch):
    func = lambda params: log_likelihood(x, y, params)
    func = functorch.vmap(func)
    return func(params_batch)

def log_posterior(x, y, params):
    return log_prior(params) + (N_train / x.shape[0]) * log_likelihood(x, y, params)

def log_posterior_batch(x, y, params_batch):
    func = lambda params: log_posterior(x, y, params)
    func = functorch.vmap(func)
    return func(params_batch)

In [6]:
def train(gamma, n_epochs, data_batch_size, param_batch_size, dt=0.05, stl=False):
#     sde = FollmerSDE(gamma, SimpleForwardNetBN(input_dim=dim, width=300)).to(device)
    sde = FollmerSDE(gamma, ResNetScoreNetwork(dim)).to(device)
    optimizer = torch.optim.Adam(sde.parameters(), lr=1e-5)
    
    dataloader_train = DataLoader(MNIST_train, shuffle=True, batch_size=data_batch_size, num_workers=2)

    losses = []

    for _ in range(n_epochs):
        epoch_losses = []
        for x, y in tqdm(iter(dataloader_train)):
            x = x.to(device)
            y = y.to(device)
            
            optimizer.zero_grad()
        
            partial_log_p = lambda params_batch: log_posterior_batch(x, y, params_batch)
            
            if stl:
                loss = stl_control_cost(sde, partial_log_p, param_batch_size=param_batch_size, dt=dt, device=device)
            else:
                loss = relative_entropy_control_cost(sde, partial_log_p, param_batch_size=param_batch_size, dt=dt, device=device)
            loss.backward()

            epoch_losses.append(loss.detach().cpu().numpy())
            optimizer.step()
            
            if stl: # double check theres no references left
                sde.drift_network_detatched.load_state_dict((sde.drift_network.state_dict()))
            
        #  Memory leaks somewhere with sdeint / param_T = param_trajectory[-1]
        gc.collect()

        losses.append(epoch_losses)
    
    losses = np.array(losses)
    
    return sde, losses

In [7]:
gamma = 0.2**2
# gamma = 1**2
n_epochs = 10
data_batch_size = 32
param_batch_size = 32

In [8]:
num_exp = 3

for i in range(num_exp):
    sde, losses = train(gamma, n_epochs, data_batch_size, param_batch_size, dt=0.05, stl=False)
    torch.save(sde.state_dict(), "weights/bnn/weights-resnet-sigma1-{:d}.pt".format(i))

  return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
100%|██████████| 1875/1875 [04:02<00:00,  7.72it/s]
100%|██████████| 1875/1875 [04:16<00:00,  7.30it/s]
100%|██████████| 1875/1875 [04:05<00:00,  7.63it/s]
100%|██████████| 1875/1875 [04:08<00:00,  7.55it/s]
100%|██████████| 1875/1875 [04:03<00:00,  7.70it/s]
100%|██████████| 1875/1875 [04:02<00:00,  7.73it/s]
100%|██████████| 1875/1875 [04:03<00:00,  7.68it/s]
100%|██████████| 1875/1875 [04:02<00:00,  7.73it/s]
100%|██████████| 1875/1875 [04:03<00:00,  7.71it/s]
100%|██████████| 1875/1875 [04:01<00:00,  7.75it/s]
100%|██████████| 1875/1875 [04:02<00:00,  7.73it/s]
100%|██████████| 1875/1875 [04:03<00:00,  7.71it/s]
100%|██████████| 1875/1875 [04:03<00:00,  7.71it/s]
100%|██████████| 1875/1875 [04:03<00:00,  7.71it/s]
100%|██████████| 1875/1875 [04:02<00:00,  7.72it/s]
100%|██████████| 1875/1875 [04:04<00:00,  7.67it/s]
100%|██████████| 1875/1875 [04:04<00:00, 

In [9]:
# torch.save(sde.state_dict(), "weights/bnn/weights-resnet-{:d}.pt".format(i))
print(1)

1


In [10]:
def evaluate(param_samples):
    dataloader_test = DataLoader(MNIST_test, shuffle=False, batch_size=data_batch_size, num_workers=2)
    
    all_predictions = []
    all_confidences = []
    all_logps = []
    
    for x, y in tqdm(iter(dataloader_test)):
        with torch.no_grad():
            x = x.to(device)
            y = y.to(device)
            
            predict_func = lambda params : func_model(functional.get_params_from_array(params, size_list), x)
            predict_func = functorch.vmap(predict_func)

            out = F.softmax(predict_func(param_samples), dim=-1)
            out = torch.mean(out, dim=0)
            
            confidences, predictions = torch.max(out, dim=1)

            all_predictions.append(predictions)
            all_confidences.append(confidences)
            
            all_logps.append(torch.mean(log_likelihood_batch(x, y, param_samples)))
    
    all_predictions = torch.hstack(all_predictions).cpu().numpy()
    all_confidences = torch.hstack(all_confidences).cpu().numpy()
    true_labels = MNIST_test.targets.numpy()
    
    accuracy = np.mean(all_predictions == true_labels)
    ece = ECE(all_confidences, all_predictions, true_labels)
    
    logp = torch.sum(torch.stack(all_logps)) / N_test
    logp = logp.cpu().numpy()
    return accuracy, ece, logp

In [11]:
accuracies, eces, logps = [], [], []

for i in range(num_exp):
#     sde = FollmerSDE(gamma, SimpleForwardNetBN(input_dim=dim, width=300)).to(device)
    sde = FollmerSDE(gamma, ResNetScoreNetwork(dim)).to(device)
    sde.load_state_dict(torch.load("weights/bnn/weights-resnet-sigma1-{:d}.pt".format(i)))
    
    with torch.no_grad():
        param_samples = sde.sample(100, dt=0.005, device=device)
    
    accuracy, ece, logp = evaluate(param_samples)
    
    accuracies.append(accuracy)
    eces.append(ece)
    logps.append(logp)
    
accuracies = np.array(accuracies)
eces = np.array(eces)
logps = np.array(logps)

100%|██████████| 313/313 [00:02<00:00, 123.58it/s]
100%|██████████| 313/313 [00:02<00:00, 123.86it/s]
100%|██████████| 313/313 [00:02<00:00, 126.36it/s]


In [16]:
SBP_df = pd.DataFrame({"Accuracy": accuracies, "ECE": eces, "log predictive": logps})

In [17]:
sigma2

1

In [18]:
SBP_df

Unnamed: 0,Accuracy,ECE,log predictive
0,0.9484,0.004773,-0.386982
1,0.9428,0.005549,-0.436452
2,0.9524,0.006941,-0.35488


In [19]:
SBP_df.describe()

Unnamed: 0,Accuracy,ECE,log predictive
count,3.0,3.0,3.0
mean,0.947867,0.005754,-0.392771
std,0.004822,0.001099,0.041093
min,0.9428,0.004773,-0.436452
25%,0.9456,0.005161,-0.411717
50%,0.9484,0.005549,-0.386982
75%,0.9504,0.006245,-0.370931
max,0.9524,0.006941,-0.35488


SGLD from here onwards

In [36]:
@torch.enable_grad()
def gradient(x, y, params):
    params_ = params.clone().requires_grad_(True)
    loss = log_posterior(x, y, params_)
    grad, = torch.autograd.grad(loss, params_)
    return loss.detach().cpu().numpy(), grad

In [37]:
def step_size(n):
    return 1e-4 / (1 + n)**0.55

In [38]:
def sgld(n_epochs, data_batch_size):
    dataloader_train = DataLoader(MNIST_train, shuffle=True, batch_size=data_batch_size, num_workers=2)
    params = torch.cat([param.flatten() for param in model.parameters()]).detach()
    losses = []
    step = 0
    for _ in range(n_epochs):
        epoch_losses = []
        for x, y in tqdm(iter(dataloader_train)):
            x = x.to(device)
            y = y.to(device)

            eps = step_size(step)
            loss, grad = gradient(x, y, params)
            params = params + 0.5 * eps * grad #+ np.sqrt(eps) * torch.randn_like(params)
            step += 1
            epoch_losses.append(loss)
        
        losses.append(epoch_losses)
    
    param_samples = []
    
    iterator = iter(dataloader_train)
    for _ in range(100):
        x = x.to(device)
        y = y.to(device)
        
        eps = step_size(step)
        loss, grad = gradient(x, y, params)
        params = params + 0.5 * eps * grad + np.sqrt(eps) * torch.randn_like(params)
        param_samples.append(params)
        step += 1
        
    param_samples = torch.stack(param_samples)
    losses = np.array(losses)
    
    return param_samples, losses

In [39]:
accuracies, eces, logps = [], [], []

for i in range(5):
    param_samples, losses = sgld(n_epochs, data_batch_size)
    
    accuracy, ece, logp = evaluate(param_samples)
    
    accuracies.append(accuracy)
    eces.append(ece)
    logps.append(logp)
    
accuracies = np.array(accuracies)
eces = np.array(eces)
logps = np.array(logps)

100%|██████████| 1875/1875 [00:05<00:00, 341.45it/s]
100%|██████████| 1875/1875 [00:06<00:00, 306.48it/s]
100%|██████████| 1875/1875 [00:05<00:00, 315.63it/s]
100%|██████████| 1875/1875 [00:06<00:00, 306.48it/s]
100%|██████████| 1875/1875 [00:06<00:00, 309.51it/s]
100%|██████████| 1875/1875 [00:06<00:00, 307.62it/s]
100%|██████████| 1875/1875 [00:06<00:00, 304.97it/s]
100%|██████████| 1875/1875 [00:06<00:00, 302.22it/s]
100%|██████████| 1875/1875 [00:06<00:00, 304.57it/s]
100%|██████████| 1875/1875 [00:06<00:00, 301.56it/s]
  return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
100%|██████████| 313/313 [00:02<00:00, 127.88it/s]
100%|██████████| 1875/1875 [00:05<00:00, 313.93it/s]
100%|██████████| 1875/1875 [00:06<00:00, 311.23it/s]
100%|██████████| 1875/1875 [00:06<00:00, 294.66it/s]
100%|██████████| 1875/1875 [00:05<00:00, 314.42it/s]
100%|██████████| 1875/1875 [00:06<00:00, 303.36it/s]
100%|██████████| 1875/1875 

In [40]:
SGLD_df = pd.DataFrame({"Accuracy": accuracies, "ECE": eces, "log predictive": logps})

In [41]:
SGLD_df

Unnamed: 0,Accuracy,ECE,log predictive
0,0.0892,0.093154,-2.407456
1,0.9277,0.014658,-0.2211
2,0.9332,0.017694,-0.230874
3,0.098,0.084175,-2.425903
4,0.9365,0.013428,-0.211999


In [42]:
SGLD_df.describe()

Unnamed: 0,Accuracy,ECE,log predictive
count,5.0,5.0,5.0
mean,0.59692,0.044622,-1.099466
std,0.459487,0.04036,1.202482
min,0.0892,0.013428,-2.425903
25%,0.098,0.014658,-2.407456
50%,0.9277,0.017694,-0.230874
75%,0.9332,0.084175,-0.2211
max,0.9365,0.093154,-0.211999


In [13]:
def sgd(n_epochs, data_batch_size):
    
    dataloader_train = DataLoader(MNIST_train, shuffle=True, batch_size=data_batch_size, num_workers=2)
    model = LeNet5(10).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
    losses = []

    for i in range(n_epochs):
        for x, y in tqdm(iter(dataloader_train)):

            x = x.to(device)
            y = y.to(device)

            optimizer.zero_grad()

            out = model(x)

            l = F.cross_entropy(out, y, reduction="mean")

            l.backward()

            losses.append(l.detach().cpu().numpy())

            optimizer.step()

    losses = np.array(losses)
    
    return model, losses

In [19]:
accuracies, eces, logps = [], [], []

for i in range(5):
    model, losses = sgd(n_epochs, data_batch_size)
    params = model.parameters()
    params = torch.cat([param.flatten() for param in params]).detach()
    params = params.view(1, -1)
    accuracy, ece, logp = evaluate(params)
    
    accuracies.append(accuracy)
    eces.append(ece)
    logps.append(logp)
    
accuracies = np.array(accuracies)
eces = np.array(eces)
logps = np.array(logps)

100%|███████████████████████████████████████████████████████████| 1875/1875 [00:02<00:00, 715.47it/s]
100%|███████████████████████████████████████████████████████████| 1875/1875 [00:02<00:00, 694.22it/s]
100%|███████████████████████████████████████████████████████████| 1875/1875 [00:02<00:00, 703.74it/s]
100%|███████████████████████████████████████████████████████████| 1875/1875 [00:02<00:00, 691.18it/s]
100%|███████████████████████████████████████████████████████████| 1875/1875 [00:02<00:00, 692.61it/s]
100%|███████████████████████████████████████████████████████████| 1875/1875 [00:02<00:00, 703.72it/s]
100%|███████████████████████████████████████████████████████████| 1875/1875 [00:02<00:00, 689.86it/s]
100%|███████████████████████████████████████████████████████████| 1875/1875 [00:02<00:00, 701.62it/s]
 84%|█████████████████████████████████████████████████▋         | 1580/1875 [00:02<00:00, 754.53it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f49ede

In [20]:
SGD_df = pd.DataFrame({"Accuracy": accuracies, "ECE": eces, "log predictive": logps})

In [21]:
SGD_df

Unnamed: 0,Accuracy,ECE,log predictive
0,0.9115,0.009907,-0.271943
1,0.9166,0.00878,-0.269324
2,0.9174,0.009571,-0.262576
3,0.9079,0.007665,-0.290798
4,0.9102,0.010044,-0.273837


In [22]:
SGD_df.describe()

Unnamed: 0,Accuracy,ECE,log predictive
count,5.0,5.0,5.0
mean,0.91272,0.009193,-0.273696
std,0.004124,0.000985,0.010468
min,0.9079,0.007665,-0.290798
25%,0.9102,0.00878,-0.273837
50%,0.9115,0.009571,-0.271943
75%,0.9166,0.009907,-0.269324
max,0.9174,0.010044,-0.262576
