# Chunking strategies for a Wide-ResNet

This tutorial shows how to utilize a hypernet container [HContainer](../hnets/hnet_container.py) and class [StructuredHMLP](../hnets/structured_mlp_hnet.py) (a certain kind of hypernetwork that allows *smart* chunking) in combination with a Wide-ResNet [WRN](../mnets/wide_resnet.py).

In [1]:
# Ensure code of repository is visible to this tutorial.
import sys
sys.path.insert(0, '..')

import numpy as np
import torch

from hypnettorch.hnets.structured_hmlp_examples import wrn_chunking
from hypnettorch.hnets import HContainer, StructuredHMLP
from hypnettorch.mnets import WRN

## Instantiate a WRN-28-10-B(3,3)

First, we instantiate a WRN-28-10 (i.e., a WRN containing $28$ convolutional layers (and an additional fully-connected output layer) and a widening factor $k=10$) with no internal weights (`no_weights=True`). Thus, it's weights are expected to originate externally (in our case from a hypernetwork) and to be passed to its `forward` method.

In particular, we are interested in instantiating a network that matches the one used in the study [Sacramento et al., "Economical ensembles with hypernetworks", 2020](https://arxiv.org/abs/2007.12927) (accessed August 18th, 2020). Therefore, the convolutional layers won't have bias terms (but the final fully-connected layer will).

In [2]:
net = WRN(in_shape=(32, 32, 3), num_classes=10, n=4, k=10,
          num_feature_maps=(16, 16, 32, 64), use_bias=False,
          use_fc_bias=True, no_weights=False, use_batch_norm=True,
          dropout_rate=-1)

Creating a WideResnet "WRN-28-10-B(3,3)" with 36479194 weights. The network uses batchnorm.


## Reproduce the chunking strategy from Sacramento et al.

We first design a hypernetwork that matches the chunking strategy described in [Sacramento et al.](https://arxiv.org/abs/2007.12927). Thus, not all parameters are produced by a hypernetwork. Batchnorm weights will be shared among conditions (in their case, each condition represents one ensemble member), while the output layer weights will be condition-specific (ensemble-member-specific). The remaining weight are produced via linear hypernetworks (no bias terms in the hypernets) using a specific chunking strategy, which is described in the paper and in the docstring of function [wrn_chunking](../hnets/structured_hmlp_examples.py). To realize the mixture between shared weights (batchnorm), condition-specific weights (output weights) and hypernetwork-produced weights, we employ the special hypernetwork class [HContainer](../hnets/hnet_container.py).

We first create an instance of class [StructuredHMLP](../hnets/structured_mlp_hnet.py) for all hypernetwork-produced weights.

In [3]:
# Number of conditions (ensemble members). Arbitrarily chosen!
num_conds = 10

# Split the network's parameter shapes into shapes corresponding to batchnorm-weights,
# hypernet-produced weights and output weights.
# Here, we make use of implementation specific knowledge, which could also be retrieved
# via the network's "param_shapes_meta" attribute, which contains meta information
# about all parameters.
bn_shapes = net.param_shapes[:2*len(net.batchnorm_layers)] # Batchnorm weight shapes
hnet_shapes = net.param_shapes[2*len(net.batchnorm_layers):-2] # Conv layer weight shapes
out_shapes = net.param_shapes[-2:] # Output layer weight shapes

# This function already defines the network chunking in the same way the paper
# specifies it.
chunk_shapes, num_per_chunk, assembly_fct = wrn_chunking(net, ignore_bn_weights=True,
                                                         ignore_out_weights=True,
                                                         gcd_chunking=False)
# Taken from table S1 in the paper.
chunk_emb_sizes = [10, 7, 14, 14, 14, 7, 7, 7]

# Important, the underlying hypernetworks should be linear, i.e., no hidden layers:
# ``layers': []``
# They also should not use bias vectors -> hence, weights are simply generated via a
# matrix vector product (chunk embedding input times hypernet, which is a weight matrix).
# Note, we make the chunk embeddings conditional and tell the hypernetwork, that
# it doesn't have to expect any other input except those learned condition-specific
# embeddings.
shnet = StructuredHMLP(hnet_shapes, chunk_shapes, num_per_chunk, chunk_emb_sizes,
                       {'layers': [], 'use_bias': False}, assembly_fct,
                       cond_chunk_embs=True, uncond_in_size=0,
                       cond_in_size=0, num_cond_embs=num_conds)

Created Structured Chunked MLP Hypernet.
It manages 8 full hypernetworks internally that produce 42 chunks in total.
The internal hypernetworks have a combined output size of 2816432 compared to 36454832 weights produced by this network.
Hypernetwork with 37462680 weights and 36454832 outputs (compression ratio: 1.03).
The network consists of 37457120 unconditional weights (37457120 internally maintained) and 5560 conditional weights (5560 internally maintained).


Now, we combine the above produce `shnet` with shared batchnorm weights and condition-specific output weights in an instance of class [HContainer](../hnets/hnet_container.py), which will represent the final hypernetwork.

In [4]:
# We first have to create a simple function handle that tells the `HContainer` how to
# recombine the batchnorm-weights, hypernet-produced weights and output weights.
def simple_assembly_func(list_of_hnet_tensors, uncond_tensors, cond_tensors):
    # `list_of_hnet_tensors`: Contains outputs of all linear hypernets (conv 
    #                         layer weights).
    # `uncond_tensors`: Contains the single set of shared batchnorm weights.
    # `cond_tensors`: Contains the condition-specific output weights.
    return uncond_tensors + list_of_hnet_tensors[0] + cond_tensors

hnet = HContainer(net.param_shapes, simple_assembly_func, hnets=[shnet],
                  uncond_param_shapes=bn_shapes, cond_param_shapes=out_shapes,
                  num_cond_embs=num_conds)

Created Hypernet Container for 1 hypernet(s). Container maintains 50 plain unconditional parameter tensors. Container maintains 2 plain conditional parameter tensors for each of 10 condiditions.
Hypernetwork with 37544732 weights and 36479194 outputs (compression ratio: 1.03).
The network consists of 37475072 unconditional weights (37475072 internally maintained) and 69660 conditional weights (69660 internally maintained).


Create sample predictions for 3 different ensemble members.

In [5]:
# Batch of inputs.
batch_size = 1
x = torch.rand((batch_size, 32*32*3))

# Which ensemble members to consider?
cond_ids = [2,3,7]

# Generate weights for ensemble members defined above.
weights = hnet.forward(cond_id=cond_ids)

# Compute prediction for each ensemble member.
for i in range(len(cond_ids)):
    pred = net.forward(x, weights=weights[i])
    # Apply softmax.
    pred = torch.nn.functional.softmax(pred, dim=1).cpu().detach().numpy()
    print('Prediction of ensemble member %d: %s' \
          % (cond_ids[i], np.array2string(pred, precision=3, separator=', ')))

Prediction of ensemble member 2: [[0.099, 0.102, 0.102, 0.102, 0.099, 0.104, 0.097, 0.097, 0.097, 0.1  ]]
Prediction of ensemble member 3: [[0.1  , 0.095, 0.102, 0.1  , 0.101, 0.101, 0.102, 0.102, 0.101, 0.097]]
Prediction of ensemble member 7: [[0.101, 0.098, 0.099, 0.1  , 0.106, 0.098, 0.098, 0.1  , 0.101, 0.099]]


## Create a batch-ensemble network

Now, we consider the special case where all parameters are shared except for batchnorm weights and output weights. Thus, no "hypernetwork" are required. Though, we use the class [HContainer](../hnets/hnet_container.py) for convinience.

In [6]:
def simple_assembly_func2(list_of_hnet_tensors, uncond_tensors, cond_tensors):
    # `list_of_hnet_tensors`: None
    # `uncond_tensors`: Contains all conv layer weights.
    # `cond_tensors`: Contains the condition-specific batchnorm and output weights.
    return cond_tensors[:-2] + uncond_tensors + cond_tensors[-2:]

hnet2 = HContainer(net.param_shapes, simple_assembly_func2, hnets=None,
                   uncond_param_shapes=hnet_shapes,
                   cond_param_shapes=bn_shapes+out_shapes,
                   num_cond_embs=num_conds)

Created Hypernet Container for 0 hypernet(s). Container maintains 28 plain unconditional parameter tensors. Container maintains 52 plain conditional parameter tensors for each of 10 condiditions.
Hypernetwork with 36698452 weights and 36479194 outputs (compression ratio: 1.01).
The network consists of 36454832 unconditional weights (36454832 internally maintained) and 243620 conditional weights (243620 internally maintained).


In [7]:
# Batch of inputs.
batch_size = 1
x = torch.rand((batch_size, 32*32*3))

# Which ensemble members to consider?
cond_ids = [2,3,7]

# Generate weights for ensemble members defined above.
weights = hnet2.forward(cond_id=cond_ids)

# Compute prediction for each ensemble member.
for i in range(len(cond_ids)):
    pred = net.forward(x, weights=weights[i])
    # Apply softmax.
    pred = torch.nn.functional.softmax(pred, dim=1).cpu().detach().numpy()
    print('Prediction of ensemble member %d: %s' \
          % (cond_ids[i], np.array2string(pred, precision=3, separator=', ')))

Prediction of ensemble member 2: [[0.1  , 0.099, 0.1  , 0.096, 0.101, 0.103, 0.098, 0.097, 0.102, 0.103]]
Prediction of ensemble member 3: [[0.102, 0.098, 0.098, 0.102, 0.101, 0.1  , 0.098, 0.102, 0.102, 0.097]]
Prediction of ensemble member 7: [[0.1  , 0.099, 0.096, 0.096, 0.102, 0.101, 0.102, 0.101, 0.101, 0.103]]
