# TyXe 

TyXe (Ancient greek: goddess of chance)

In [1]:
!jupyter nbextension enable --py widgetsnbextension

Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: [32mOK[0m


In [2]:
# first some imports:

import contextlib
import functools
import os
from typing import List, Optional

import torch
import torch.nn as nn
import torch.utils.data as data

import torchvision

import pyro
import pyro.distributions as dist

import tyxe

In [3]:
# dataset helper functions

from utils import make_loaders_resnet_cifar


In [4]:
# ----- PARAMETERS OF BAYESIAN RESNET TRAINING ------

inference: str = "ml"
architecture: str = "resnet18"
dataset: str = "cifar10" # 10 or 100
train_batch_size: int = 10
test_batch_size: int = 10
local_reparameterization: bool = False # important: variance reduction for gradients!
flipout: bool = False
num_epochs: int = 1
test_samples: int = 20
max_guide_scale: float = 0.1 # to prevent underfitting
rank: int = 10
root: str = os.environ.get("DATSETS_PATH", "./data")
seed: int = 42
output_dir: Optional[str] = None
pretrained_weights: Optional[str] = None # path to pretrained weights
scale_only: bool = False
lr: float = 0.001
milestones: Optional[List[int]] = None
gamma: float = 0.1
mock_dataset: bool = False

# ----- check args: inference, architecture, dataset ------
inference_options = [
    "ml",
    "map",
    "mean-field",
    "last-layer-mean-field",
    "last-layer-full",
    "last-layer-low-rank"
]
assert inference in inference_options, inference

resnets = [n for n in dir(torchvision.models) if (n.startswith("resnet") or n.startswith("wide_resnet")) and n[-1].isdigit()]
assert architecture in resnets, architecture

datasets = ["cifar10", "cifar100", "mnist"]
assert dataset in datasets, dataset

### Initialize our Dataset and Model
* arbitrary pytorch datasets and models work
* it is straightforward to go integrate TyXe into any existing Pytorch workflow since we just start out with an arbitrary `torch.nn.Module`!

In [5]:
# ----- set up pyro & torch -----
pyro.set_rng_seed(seed)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

# ----- set up dataset and model -----
def make_net(dataset, architecture):
    net = getattr(torchvision.models, architecture)(pretrained=True)
    if dataset.startswith("cifar"):
        net.conv1 = nn.Conv2d(3, net.conv1.out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        net.maxpool = nn.Identity()
        num_classes = 10 if dataset.endswith("10") else 100
        net.fc = nn.Linear(net.fc.in_features, num_classes)
    return net

train_loader, test_loader, ood_loader = make_loaders(dataset, root, train_batch_size, test_batch_size, use_cuda, mock_dataset)
net: torch.nn.Module = make_net(dataset, architecture).to(device)
if pretrained_weights is not None:
    sd = torch.load(pretrained_weights, map_location=device)
    net.load_state_dict(sd)

Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ./data/test_32x32.mat


### Set up ResNet to be Bayesian using TyXe

To set up a Bayesian Model in TyXe we need to select one of a few options for each of:

1. **Likelihood** - The Likelihood of our Training Data $p(D|\theta)$.
2. **Guide** - The Variational Distribution $q(\theta|D)$. 
3. **Prior** - Our Prior Belief $p(\theta)$ about our parameters. Distribution will be around the pretrained weights' values.

Our choice for each of these three components does two things:
* Determine the family of distributions the particular object may belong to.
* Initialize the distribution's parameters.

Let's go through the meaning and options for each of the three components, one by one.

#### 1. Setting up the Likelihood

The Likelihood of our Training Data $p(D|\theta)$. The support of the distribution must be equal in size to the number of training samples.

In [None]:
# Bernoulli, Categorical, HeteroskedasticGaussian, HomoskedasticGaussian
likelihood = tyxe.likelihoods.Categorical(len(train_loader.sampler))

# uncomment for documentation:
# tyxe.likelihoods.HeteroskedasticGaussian?

#### 2. Setting up our Guide

The choice of the variational distribution is where most of our flexibility lies. 
It is recommended and most convenient to use one of the `Autoguide`s, either from pyro or TyXe.
The TyXe `BNN`s expect an only partially initialized `guide` object
Let's go through some of the options:
1. Maximum Likelihood: it is straightforward to just do maximum likelihood by letting the guide be `None`.
2. Maximum a posteriori inference 

In [None]:
if inference == "ml":
    # do maximum likelihood
    test_samples = 1
    guide = None
elif inference == "map":
    # maximum a posteriori inference 
    test_samples = 1
    guide = functools.partial(
        pyro.infer.autoguide.AutoDelta,
        init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(net)
    )
elif inference == "mean-field":
    guide = functools.partial(
        tyxe.guides.AutoNormal,
        init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(net),
        init_scale=1e-4,
        max_guide_scale=max_guide_scale, # prevent underfitting
        train_loc=not scale_only # train mean parameter?
    )
elif inference.startswith("last-layer"):
    if pretrained_weights is None:
        raise ValueError("Asked to do last-layer inference, but no pre-trained weights were provided.")
    # turning parameters except for last layer in buffers to avoid training them
    # this might be avoidable via poutine.block
    for module in net.modules():
        if module is not net.fc:
            for param_name, param in list(module.named_parameters(recurse=False)):
                delattr(module, param_name)
                module.register_buffer(param_name, param.detach().data)

    if inference == "last-layer-mean-field":
        guide = functools.partial(
            tyxe.guides.AutoNormal, 
            init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(net),
            init_scale=1e-4
        )
    elif inference == "last-layer-full":
        guide = functools.partial(
            pyro.infer.autoguide.AutoMultivariateNormal,
            init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(net),
            init_scale=1e-4
        )
    elif inference == "last-layer-low-rank":
        guide = functools.partial(
            pyro.infer.autoguide.AutoLowRankMultivariateNormal,
            rank=rank,
            init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(net),
            init_scale=1e-4
        )
    else:
        raise RuntimeError(f"Invalid option for inference: '{inference}''")
else:
    raise RuntimeError(f"Invalid option for inference: '{inference}''")

print("Automatic Guide Families:")
print([g for g in dir(pyro.infer.autoguide) if g.startswith("Auto")])

# uncomment for documentation:
# pyro.infer.autoguide.AutoDelta?

#### 3. Setting up our Prior

In [None]:
# it is standard practice to not be bayesian about batchnorm modules:
prior_kwargs = {
    "expose_all": False, # do not treat all nn.Modules with pyro 
    "hide_module_types": (nn.BatchNorm2d,) # specifically, ignore batchnorms
}

# our choice of guide impacts how we need to initialize the Prior:
if inference == "ml":
    # do not be bayesian if we are doing maximum likelihood
    prior_kwargs["hide_all"] = True
elif inference.startswith("last-layer"):
    # only be bayesian about the final, fully connected layer
    del prior_kwargs['hide_module_types']
    prior_kwargs["expose_modules"] = [net.fc]
    
prior = tyxe.priors.IIDPrior(
    dist.Normal(
        torch.zeros(1, device=device),
        torch.ones(1, device=device)
    ),
    **prior_kwargs
)

# IIDPrior, DictPrior, LambdaPrior, LayerwiseNormalPrior
print("Available Prior Distributions:")
print([p for p in dir(tyxe.priors) if p[0].upper() == p[0] and not "_" in p])

# uncomment for documentation:
# tyxe.priors.IIDPrior?

In [None]:
# Finally set up our VariationalBNN!
bnn = tyxe.VariationalBNN(
    net, prior, likelihood, guide
)

# uncomment for documentation:
# bnn?

In [13]:
# gradient variance reduction techniques:
if local_reparameterization:
    if flipout:
        raise RuntimeError("Can't use both local reparameterization and flipout, pick one.")
    train_context = tyxe.poutine.local_reparameterization
elif flipout:
    train_context = tyxe.poutine.flipout
else:
    train_context = contextlib.nullcontext

In [14]:
# pyro-specific: optimizer must come from pyro.optim
if milestones is None:
    optim = pyro.optim.Adam({"lr": lr})
else:
    optimizer = torch.optim.Adam
    optim = pyro.optim.MultiStepLR({"optimizer": optimizer, "optim_args": {"lr": lr}, "milestones": milestones, "gamma": gamma})

print("All typical optimizers & schedulers are supported by pyro.optim:")
print([opt for opt in dir(pyro.optim) if "_" not in opt and opt[0] == opt[0].upper()])
    
# tyXe-specific: evaluation and logging may be done using a callback function, passed to the bnn.fit() method
# callback is called after every epoch with the following arguments:
def callback(
        b: tyxe.VariationalBNN, # bnn
        i: int, # epoch number
        avg_elbo: float # mean elbo this epoch
    ):
    avg_err, avg_ll = 0., 0.
    for x, y in iter(test_loader):
        err, ll = b.evaluate(x.to(device), y.to(device), num_predictions=test_samples)
        avg_err += err / len(test_loader.sampler)
        avg_ll += ll / len(test_loader.sampler)
    print(f"ELBO={avg_elbo}; test error={100 * avg_err:.2f}%; LL={avg_ll:.4f}")

['ASGD', 'Adadelta', 'Adagrad', 'AdagradRMSProp', 'Adam', 'AdamW', 'Adamax', 'ChainedScheduler', 'ClippedAdam', 'ConstantLR', 'CosineAnnealingLR', 'CosineAnnealingWarmRestarts', 'CyclicLR', 'DCTAdam', 'ExponentialLR', 'LambdaLR', 'LinearLR', 'MultiStepLR', 'MultiplicativeLR', 'NAdam', 'OneCycleLR', 'PyroLRScheduler', 'PyroOptim', 'RAdam', 'RMSprop', 'ReduceLROnPlateau', 'Rprop', 'SGD', 'SequentialLR', 'SparseAdam', 'StepLR']


In [None]:

# ------ TRAIN THE MODEL ------
with train_context():
    bnn.fit(train_loader, optim, num_epochs, callback=callback, device=device)


In [None]:
# optionally store results by simply using torch.save:
if output_dir is not None:
    pyro.get_param_store().save(os.path.join(output_dir, "param_store.pt"))
    torch.save(bnn.state_dict(), os.path.join(output_dir, "state_dict.pt"))

    test_predictions = torch.cat([bnn.predict(x.to(device), num_predictions=test_samples)
                                  for x, _ in iter(test_loader)])
    torch.save(test_predictions.detach().cpu(), os.path.join(output_dir, "test_predictions.pt"))

    ood_predictions = torch.cat([bnn.predict(x.to(device), num_predictions=test_samples)
                                 for x, _ in iter(ood_loader)])
    torch.save(ood_predictions.detach().cpu(), os.path.join(output_dir, "ood_predictions.pt"))
