In [None]:
%cd "C:\Users\mateu\Desktop\uni\phd\bayesian\vbll"

In [None]:
import vbll

import numpy as np
from tqdm import tqdm
from matplotlib.pyplot import cm
import matplotlib.pyplot as plt
from sklearn import metrics
from PIL import Image
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from torchvision import transforms

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"


In [None]:
data_root = r"C:\Users\mateu\Desktop\uni\phd\data"

In [None]:
mnist_train_dataset = datasets.MNIST(root='data',
                               train=True,
                               transform=transforms.ToTensor(),
                               download=True)

mnist_test_dataset = datasets.MNIST(root='data',
                              train=False,
                              transform=transforms.ToTensor())

fashion_train_dataset = datasets.FashionMNIST(root='data',
                                    train=True,
                                    transform=transforms.ToTensor(),
                                    download=True)

fashion_test_dataset = datasets.FashionMNIST(root='data',
                                    train=False,
                                    transform=transforms.ToTensor(),
                                    download=True)

def dict_to_data(dict):
    if dict.TRAIN == "mnist":
        train =  mnist_train_dataset
    elif dict.TRAIN == "fashion":
        train = fashion_train_dataset
    if dict.TEST == "mnist":
        test =  mnist_test_dataset
    elif dict.TEST == "fashion":
        test = fashion_test_dataset
    if dict.OOD == "mnist":
        ood =  mnist_test_dataset
    elif dict.OOD == "fashion":
        ood = fashion_test_dataset
    return train, test, ood

In [None]:
train_dataset = mnist_train_dataset
test_dataset = mnist_test_dataset
ood_dataset = fashion_test_dataset

class data_cfg:
  TRAIN = "mnist"
  TEST = "mnist"
  OOD = "fashion"

In [None]:
def viz_performance(logs):
  """
  A visualization function that plots losses, accuracies, and out of
  distribution AUROC.

  logs: a dictionary, with keys corresponding to different model evals and values
  corresponding to dicts of results.
  """

  # get list of colors
  color = cm.rainbow(np.linspace(0, 1, len(logs)))

  for i, (k,v) in enumerate(logs.items()):
    # train and val loss
    plt.plot(v['train_loss'], label=k + ' (train)', color=color[i])
    plt.plot(v['val_loss'], label=k + ' (val)', linestyle = '--', color=color[i])
    if 'train_loss_empirical' in v.keys():
      plt.plot(v['train_loss_empirical'], label=k + ' (train empirical)', linestyle = 'dotted', color=color[i])
  plt.legend()
  plt.ylabel('Loss')
  plt.xlabel('Epoch')
  plt.show()

  for i, (k,v) in enumerate(logs.items()):
    plt.plot([1 - x for x in v['train_acc']], label=k + ' (train)', color=color[i])
    plt.plot([1 - x for x in v['val_acc']], label=k + ' (val)', linestyle='--', color=color[i])

  plt.ylabel('Error rate')
  plt.xlabel('Epoch')
  plt.legend()
  plt.semilogy()
  plt.show()

  for i, (k,v) in enumerate(logs.items()):
    plt.plot(v['ood_auroc'], label=k, color=color[i])
  plt.legend()
  plt.ylabel('OOD AUROC')
  plt.xlabel('Epoch')
  plt.show()


In [None]:
def aggregate_seeds(d, k):
  vals = np.vstack([v[k] for s,v in d.items()])
  mean_vals, std_vals = np.mean(vals, axis=0), np.std(vals, axis=0)
  return mean_vals, std_vals


def viz_performance_seeds(logs, plot_title=["trained on ... using ...", "ood is ..."]):
  """
  A visualization function that plots losses, accuracies, and out of
  distribution AUROC.

  logs: a dictionary, with keys corresponding to different model evals and values
  corresponding to dicts of results by seed.
  """

  # get list of colors
  color = cm.rainbow(np.linspace(0, 1, len(logs)))

  plt.figure(figsize=(10,6))
  for i, (k,v) in enumerate(logs.items()):
    # train and val loss
    mean_train_loss, std_train_loss = aggregate_seeds(v, 'train_loss')
    mean_val_loss, std_val_loss = aggregate_seeds(v, 'val_loss')
    if 'train_loss_empirical' in list(v.values())[0].keys():
      mean_train_empirical_loss, std_train_empirical_loss = aggregate_seeds(v, 'train_loss_empirical')
      plt.plot(mean_train_empirical_loss, label=k + ' (train empirical)', linestyle = 'dotted', color=color[i])
      plt.fill_between(range(len(mean_train_empirical_loss)), mean_train_empirical_loss-std_train_empirical_loss, mean_train_empirical_loss+std_train_empirical_loss, alpha=0.2, color=color[i])
      #plt.errorbar(range(len(mean_train_empirical_loss)), mean_train_empirical_loss, yerr=std_train_empirical_loss, fmt='none', ecolor=color[i], capsize=7, label='Train empirical (std)')
    plt.plot(mean_train_loss, label=k + ' (train)', color=color[i])
    plt.fill_between(range(len(mean_train_loss)), mean_train_loss-std_train_loss, mean_train_loss+std_train_loss, alpha=0.2, color=color[i])
    #plt.errorbar(range(len(mean_train_loss)), mean_train_loss, yerr=std_train_loss, fmt='none', ecolor=color[i], capsize=7, label='Train (std)')
    plt.plot(mean_val_loss, label=k + ' (val)', linestyle = '--', color=color[i])
    plt.fill_between(range(len(mean_val_loss)), mean_val_loss-std_val_loss, mean_val_loss+std_val_loss, alpha=0.2, color=color[i])
    #plt.errorbar(range(len(mean_val_loss)), mean_val_loss, yerr=std_val_loss, fmt='none', ecolor=color[i], capsize=7, label='Val (std)')
  plt.legend()
  plt.title(plot_title[0] + " - losses")
  plt.ylabel('Loss')
  plt.xlabel('Epoch')
  plt.show()

  plt.figure(figsize=(10,6))
  for i, (k,v) in enumerate(logs.items()):
    # train and val acc
    mean_train_acc, std_train_acc = aggregate_seeds(v, 'train_acc')
    mean_val_acc, std_val_acc = aggregate_seeds(v, 'val_acc')
    plt.plot([1 - x for x in mean_train_acc], label=k + ' (train)', color=color[i])
    plt.fill_between(range(len(mean_train_acc)), [1 - x - y for x,y in zip(mean_train_acc,std_train_acc)], [1 - x + y for x,y in zip(mean_train_acc,std_train_acc)], alpha=0.2, color=color[i])
    plt.plot([1 - x for x in mean_val_acc], label=k + ' (val)', linestyle='--', color=color[i])
    plt.fill_between(range(len(mean_val_acc)), [1 - x - y for x,y in zip(mean_val_acc,std_val_acc)], [1 - x + y for x,y in zip(mean_val_acc,std_val_acc)], alpha=0.2, color=color[i])

  plt.title(plot_title[0] + " - errors")
  plt.ylabel('Error rate')
  plt.xlabel('Epoch')
  plt.legend()
  plt.semilogy()
  plt.show()

  plt.figure(figsize=(10,6))
  for i, (k,v) in enumerate(logs.items()):
    # ood auroc
    mean_ood_auroc, std_ood_auroc = aggregate_seeds(v, 'ood_auroc')
    plt.plot(mean_ood_auroc, label=k, color=color[i])
    plt.fill_between(range(len(mean_ood_auroc)), mean_ood_auroc-std_ood_auroc, mean_ood_auroc+std_ood_auroc, alpha=0.2, color=color[i])
  plt.legend()
  plt.title(plot_title[1])
  plt.ylabel('OOD AUROC')
  plt.xlabel('Epoch')
  plt.show()

In [None]:
class MLP(nn.Module):
  """
  A standard MLP classification model.

  cfg: a config containing model parameters.
  """
  def __init__(self, cfg):
    super(MLP, self).__init__()

    self.params = nn.ModuleDict({
      'in_layer': nn.Linear(cfg.IN_FEATURES, cfg.HIDDEN_FEATURES),
      'core': nn.ModuleList([nn.Linear(cfg.HIDDEN_FEATURES, cfg.HIDDEN_FEATURES) for i in range(cfg.NUM_LAYERS)]),
      'out_layer': nn.Linear(cfg.HIDDEN_FEATURES, cfg.OUT_FEATURES),
    })
    self.activations = nn.ModuleList([nn.ELU() for i in range(cfg.NUM_LAYERS)])
    self.cfg = cfg

  def forward(self, x):
    x = x.view(x.shape[0], -1)
    x = self.params['in_layer'](x)

    for layer, ac in zip(self.params['core'], self.activations):
      x = ac(layer(x))

    return F.log_softmax(self.params['out_layer'](x), dim=-1)

In [None]:
class DiscVBLLMLP(nn.Module):
  def __init__(self, cfg):
    super(DiscVBLLMLP, self).__init__()

    self.params = nn.ModuleDict({
      'in_layer': nn.Linear(cfg.IN_FEATURES, cfg.HIDDEN_FEATURES),
      'core': nn.ModuleList([nn.Linear(cfg.HIDDEN_FEATURES, cfg.HIDDEN_FEATURES) for i in range(cfg.NUM_LAYERS)]),
      'out_layer': vbll.DiscClassification(cfg.HIDDEN_FEATURES, cfg.OUT_FEATURES, cfg.REG_WEIGHT, softmax_bound=cfg.SOFTMAX_BOUND, return_empirical=cfg.RETURN_EMPIRICAL, softmax_bound_empirical=cfg.SOFTMAX_BOUND_EMPIRICAL, parameterization = cfg.PARAM, return_ood=cfg.RETURN_OOD, prior_scale=cfg.PRIOR_SCALE),
    })
    self.activations = nn.ModuleList([nn.ELU() for i in range(cfg.NUM_LAYERS)])
    self.cfg = cfg

  def forward(self, x):
    x = x.view(x.shape[0], -1)
    x = self.params['in_layer'](x)

    for layer, ac in zip(self.params['core'], self.activations):
      x = ac(layer(x))

    return self.params['out_layer'](x)

In [None]:
def eval_acc(preds, y):
  map_preds = torch.argmax(preds, dim=1)
  return (map_preds == y).float().mean()

def eval_ood(model, ind_dataloader, ood_dataloader, VBLL=True):
  ind_preds = []
  ood_preds = []

  def get_score(out):
    if VBLL:
      score = out.ood_scores.detach().cpu().numpy()
    else:
      score = torch.max(out, dim=-1)[0].detach().cpu().numpy()
    return score

  for i, (x, y) in enumerate(ind_dataloader):
    x = x.to(device)
    out = model(x)
    ind_preds = np.concatenate((ind_preds, get_score(out)))

  for i, (x, y) in enumerate(ood_dataloader):
    x = x.to(device)
    out = model(x)
    ood_preds = np.concatenate((ood_preds, get_score(out)))

  labels = np.concatenate((np.ones_like(ind_preds)+1, np.ones_like(ind_preds)))
  scores = np.concatenate((ind_preds, ood_preds))
  fpr, tpr, thresholds = metrics.roc_curve(labels, scores, pos_label=2)
  return metrics.auc(fpr, tpr)

In [None]:
def train(model, train_cfg):
  """Train a standard classification model with either standard or VBLL models.
  """

  train_dataset, test_dataset, ood_dataset = dict_to_data(train_cfg.DATA)

  if train_cfg.VBLL:
    # for VBLL models, set weight decay to zero on last layer
    param_list = [
        {'params': model.params.in_layer.parameters(), 'weight_decay': train_cfg.WD},
        {'params': model.params.core.parameters(), 'weight_decay': train_cfg.WD},
        {'params': model.params.out_layer.parameters(), 'weight_decay': 0.}
    ]
  else:
    param_list = model.parameters()
    loss_fn = nn.CrossEntropyLoss() # define loss function only for non-VBLL model

  optimizer = train_cfg.OPT(param_list,
                            lr=train_cfg.LR,
                            weight_decay=train_cfg.WD)

  train_dataloader = DataLoader(train_dataset, batch_size = train_cfg.BATCH_SIZE, shuffle=True)
  val_dataloader = DataLoader(test_dataset, batch_size = train_cfg.BATCH_SIZE, shuffle=True)
  ood_dataloader = DataLoader(ood_dataset, batch_size = train_cfg.BATCH_SIZE, shuffle=True)

  output_metrics = {
      'train_loss': [],
      'val_loss': [],
      'train_acc': [],
      'val_acc': [],
      'ood_auroc': []
  }

  for epoch in range(train_cfg.NUM_EPOCHS):
    model.train()
    running_loss = []
    if train_cfg.VBLL and model.params.out_layer.return_empirical:
      output_metrics['train_loss_empirical'] = []
      running_loss_empirical = []
    running_acc = []

    for train_step, data in enumerate(train_dataloader):
      optimizer.zero_grad()
      x = data[0].to(device)
      y = data[1].to(device)

      out = model(x)
      if train_cfg.VBLL:
        loss = out.train_loss_fn(y)
        if model.params.out_layer.return_empirical:
          loss_empirical = out.train_loss_fn_empirical(y, train_cfg.N_SAMPLES)
          running_loss_empirical.append(loss_empirical.item())
        probs = out.predictive.probs
        acc = eval_acc(probs, y).item()
      else:
        loss = loss_fn(out, y)
        acc = eval_acc(out, y).item()

      running_loss.append(loss.item())
      running_acc.append(acc)

      if train_cfg.VBLL and model.params.out_layer.return_empirical and train_cfg.VBLL_EMPIRICAL and train_cfg.N_SAMPLES:
        loss_empirical.backward()
      else:
        loss.backward()
      optimizer.step()

    output_metrics['train_loss'].append(np.mean(running_loss))
    if train_cfg.VBLL and model.params.out_layer.return_empirical:
      output_metrics['train_loss_empirical'].append(np.mean(running_loss_empirical))
    output_metrics['train_acc'].append(np.mean(running_acc))

    if epoch % train_cfg.VAL_FREQ == 0:
      running_val_loss = []
      running_val_acc = []

      with torch.no_grad():
        model.eval()
        for test_step, data in enumerate(val_dataloader):
          x = data[0].to(device)
          y = data[1].to(device)

          out = model(x)
          if train_cfg.VBLL:
            loss = out.val_loss_fn(y)
            probs = out.predictive.probs
            acc = eval_acc(probs, y).item()
          else:
            loss = loss_fn(out, y)
            acc = eval_acc(out, y).item()

          running_val_loss.append(loss.item())
          running_val_acc.append(acc)

        output_metrics['val_loss'].append(np.mean(running_val_loss))
        output_metrics['val_acc'].append(np.mean(running_val_acc))
      output_metrics['ood_auroc'].append(eval_ood(model, val_dataloader, ood_dataloader, VBLL=train_cfg.VBLL))
      print('Epoch: {:2d}, train loss: {:4.4f}, train acc: {:4.4f}'.format(epoch, np.mean(running_loss), np.mean(running_acc)))
      if train_cfg.VBLL and model.params.out_layer.return_empirical:
        print('Epoch: {:2d}, train loss empirical: {:4.4f}'.format(epoch, np.mean(running_loss_empirical)))
      print('Epoch: {:2d}, valid loss: {:4.4f}, valid acc: {:4.4f}'.format(epoch, np.mean(np.mean(running_val_loss)), np.mean(np.mean(running_val_acc))))
  return output_metrics

In [None]:
outputs = {}

In [None]:
GLOBAL_LR = 1e-3
GLOBAL_LR_VBLL = 3e-3
GLOBAL_NUM_EPOCHS = 30
GLOBAL_OPT = torch.optim.AdamW
GLOBAL_SEEDS = [0,1,2]
GLOBAL_PS = 1.0

In [None]:
class train_cfg:
  DATA = data_cfg()
  NUM_EPOCHS = GLOBAL_NUM_EPOCHS
  BATCH_SIZE = 512
  LR = GLOBAL_LR
  WD = 1e-4
  OPT = GLOBAL_OPT
  CLIP_VAL = 1
  VAL_FREQ = 1
  VBLL = False

class cfg:
  IN_FEATURES = 784
  HIDDEN_FEATURES = 128
  OUT_FEATURES = 10
  NUM_LAYERS = 2


In [None]:
mod_c = cfg()
tr_c = train_cfg()
exp_name = f'MLP-POINT opt:{tr_c.OPT.__name__} lr:{tr_c.LR}'
exp_setup = f'model:{mod_c.NUM_LAYERS}x{mod_c.HIDDEN_FEATURES} data:{tr_c.DATA.TRAIN} ood:{tr_c.DATA.OOD}'
outputs[exp_name] = dict()

for s in GLOBAL_SEEDS:
    tr_c.SEED = s
    torch.manual_seed(s)
    np.random.seed(s)
    model = MLP(mod_c).to(device)
    outputs[exp_name][s] = train(model, tr_c)

In [None]:
class train_cfg:
  DATA = data_cfg()
  NUM_EPOCHS = 30
  BATCH_SIZE = 512
  LR = GLOBAL_LR_VBLL
  WD = 1e-4
  OPT = torch.optim.AdamW
  CLIP_VAL = 1
  VAL_FREQ = 1
  VBLL = True
  VBLL_EMPIRICAL = False
  N_SAMPLES = 10

class cfg:
    IN_FEATURES = 784
    HIDDEN_FEATURES = 128
    OUT_FEATURES = 10
    NUM_LAYERS = 2
    REG_WEIGHT = 1./train_dataset.__len__()
    PARAM = 'diagonal'
    SOFTMAX_BOUND = "jensen"
    RETURN_EMPIRICAL = True
    SOFTMAX_BOUND_EMPIRICAL = "montecarlo"
    RETURN_OOD = True
    PRIOR_SCALE = GLOBAL_PS

In [None]:
mod_c = cfg()
tr_c = train_cfg()
exp_name = f'MLP-VBLL-PS{mod_c.PRIOR_SCALE} opt:{tr_c.OPT.__name__} lr:{tr_c.LR}'
exp_setup = f'model:{mod_c.NUM_LAYERS}x{mod_c.HIDDEN_FEATURES} data:{tr_c.DATA.TRAIN} ood:{tr_c.DATA.OOD}'
outputs[exp_name] = dict()

for s in GLOBAL_SEEDS:
    tr_c.SEED = s
    torch.manual_seed(s)
    np.random.seed(s)
    model = MLP(mod_c).to(device)
    outputs[exp_name][s] = train(model, tr_c)

In [None]:
class data_cfg:
  TRAIN = "mnist"
  TEST = "mnist"
  OOD = "fashion"

class train_cfg:
  DATA = data_cfg()
  NUM_EPOCHS = 30
  BATCH_SIZE = 512
  LR = GLOBAL_LR_VBLL
  WD = 1e-4
  OPT = torch.optim.AdamW
  CLIP_VAL = 1
  VAL_FREQ = 1
  VBLL = True
  VBLL_EMPIRICAL = True
  N_SAMPLES = 10

class cfg:
    IN_FEATURES = 784
    HIDDEN_FEATURES = 128
    OUT_FEATURES = 10
    NUM_LAYERS = 2
    REG_WEIGHT = 1./mnist_train_dataset.__len__()
    PARAM = 'diagonal'
    SOFTMAX_BOUND = "jensen"
    RETURN_EMPIRICAL = True
    SOFTMAX_BOUND_EMPIRICAL = "montecarlo"
    RETURN_OOD = True
    PRIOR_SCALE = GLOBAL_PS

disc_vbll_model = DiscVBLLMLP(cfg()).to(device)
outputs[f'DiscVBLL empirical {train_cfg.N_SAMPLES}'] = train(disc_vbll_model, train_cfg())

In [None]:
mod_c = cfg()
tr_c = train_cfg()
exp_name = f'MLP-VBLL-PS{mod_c.PRIOR_SCALE}-MC{tr_c.N_SAMPLES} opt:{tr_c.OPT.__name__} lr:{tr_c.LR}'
exp_setup = f'model:{mod_c.NUM_LAYERS}x{mod_c.HIDDEN_FEATURES} data:{tr_c.DATA.TRAIN} ood:{tr_c.DATA.OOD}'
outputs[exp_name] = dict()

for s in GLOBAL_SEEDS:
    tr_c.SEED = s
    torch.manual_seed(s)
    np.random.seed(s)
    model = DiscVBLLMLP(mod_c).to(device)
    outputs[exp_name][s] = train(model, tr_c)

In [None]:
mod_c = cfg()
tr_c = train_cfg()
tr_c.N_SAMPLES = 1000
exp_name = f'MLP-VBLL-PS{mod_c.PRIOR_SCALE}-MC{tr_c.N_SAMPLES} opt:{tr_c.OPT.__name__} lr:{tr_c.LR}'
exp_setup = f'model:{mod_c.NUM_LAYERS}x{mod_c.HIDDEN_FEATURES} data:{tr_c.DATA.TRAIN} ood:{tr_c.DATA.OOD}'
outputs[exp_name] = dict()

for s in GLOBAL_SEEDS:
    tr_c.SEED = s
    torch.manual_seed(s)
    np.random.seed(s)
    model = DiscVBLLMLP(mod_c).to(device)
    outputs[exp_name][s] = train(model, tr_c)

In [None]:
mod_c = cfg()
tr_c = train_cfg()
mod_c.PRIOR_SCALE = 10000
exp_name = f'MLP-VBLL-PS{mod_c.PRIOR_SCALE}-MC{tr_c.N_SAMPLES} opt:{tr_c.OPT.__name__} lr:{tr_c.LR}'
exp_setup = f'model:{mod_c.NUM_LAYERS}x{mod_c.HIDDEN_FEATURES} data:{tr_c.DATA.TRAIN} ood:{tr_c.DATA.OOD}'
outputs[exp_name] = dict()

for s in GLOBAL_SEEDS:
    tr_c.SEED = s
    torch.manual_seed(s)
    np.random.seed(s)
    model = DiscVBLLMLP(mod_c).to(device)
    outputs[exp_name][s] = train(model, tr_c)

In [None]:
viz_performance_seeds(outputs, [f"[sgd] trained on {data_cfg().TRAIN}", f"[sgd] OOD is {data_cfg.OOD}"])

In [None]:
viz_performance_seeds(outputs)