# Structured Hyper-Chunking of a ResNet-32

This tutorial shows how to utilize a [StructuredHMLP](../hnets/structured_mlp_hnet.py) (a certain kind of hypernetwork that allows *smart* chunking) in combination with a [ResNet](../mnets/resnet.py).

In [1]:
# Ensure code of repository is visible to this tutorial.
import sys
sys.path.insert(0, '..')

import numpy as np
import torch

from hypnettorch.hnets.structured_hmlp_examples import resnet_chunking
from hypnettorch.hnets import StructuredHMLP
from hypnettorch.mnets import ResNet

## Instantiate a Resnet-32

First, we instantiate a Resnet-32 (a resnet contains $6n+2$ layers $\rightarrow$ 32 layers for $n=5$) with no internal weights (`no_weights=True`). Thus, it's weights are expected to originate externally (in our case from a hypernetwork) and to be passed to its `forward` method.

In [2]:
net = ResNet(use_bias=False, use_batch_norm=True, no_weights=True, n=5, num_feature_maps=[8, 16, 32, 64])

A ResNet with 32 layers and 462760 weights is created. The network uses batchnorm.


## Decide how to *chunk* the resnet

Next, we need to decide how we can chunk the weights of the main network `net` in a smart way. We therefore utilize a helper function provided in module [structured_hmlp_examples](../hnets/structured_hmlp_examples.py). This helper function treats all weights in the first and last layer as one chunk each. The $6*n = 30$ hidden layers are chunked in a smart way. In our case (using `gcd_chunking=True`), the function will first compute the greatest common divisor (gcd) of the channel sizes `num_feature_maps=[8, 16, 32, 64]`. The output of each hidden layer can then be split into chunks of this size (gcd). This way, only 4 chunks are needed to build all weight tensors of the hidden layers (note, that the hidden convolutional layers can have 4 different channel input sizes).

In [3]:
chunk_shapes, num_per_chunk, assembly_fct = resnet_chunking(net, gcd_chunking=True)

# Print specified way of chunking.
print('The hypernetwork is expected to produce the following chunks which are then assembled to the resnet its internal weights:')
for i, s in enumerate(chunk_shapes):
    print('* %d chunks of shape %s' % (num_per_chunk[i], s))

The hypernetwork is expected to produce the following chunks which are then assembled to the resnet its internal weights:
* 1 chunks of shape [[8, 3, 3, 3], [8], [8]]
* 2 chunks of shape [[8, 8, 3, 3], [8], [8]]
* 22 chunks of shape [[8, 16, 3, 3], [8], [8]]
* 44 chunks of shape [[8, 32, 3, 3], [8], [8]]
* 72 chunks of shape [[8, 64, 3, 3], [8], [8]]
* 1 chunks of shape [[10, 64]]


## Instantiate hypernetwork

Now, we can instantiate the hypernetwork; providing all the information we collected above. Note, the function handle `assembly_fct` will tell the hypernetwork how to reassemble the chunking above into the target shapes expected by the resnet (`net.hyper_shapes_learned`).

We instantiate the hypernetwork, such that it doesn't expect any external input. Instead, the chunk embeddings are the only input to the internal hypernetworks which create the above specified chunks. They are conditional, such that upon receiving a conditional ID a set of chunk embeddings is selected by the hypernetwork internally.

In [4]:
chunk_emb_size = 8
hnet = StructuredHMLP(net.hyper_shapes_learned, chunk_shapes, num_per_chunk, 
                      chunk_emb_size, {'layers': [10,10]}, assembly_fct,
                      cond_chunk_embs=True, uncond_in_size=0, cond_in_size=0,
                      verbose=True, no_uncond_weights=False, no_cond_weights=False,
                      num_cond_embs=10)

Created Structured Chunked MLP Hypernet.
It manages 6 full hypernetworks internally that produce 142 chunks in total.
The internal hypernetworks have a combined output size of 9576 compared to 462760 weights produced by this network.
Hypernetwork with 117896 weights and 462760 outputs (compression ratio: 0.25).
The network consists of 106536 unconditional weights (106536 internally maintained) and 11360 conditional weights (11360 internally maintained).


## Generate weights and use them in the main network

Lastly, we show how to generate two sets of weights and how to use those weights to make predictions with the main network.

In [5]:
cond_ids = [1, 3]
weights = hnet.forward(cond_id=cond_ids)

# Generate batch of random images.
# Note, due to the data handlers currently used in this repository,
# the resnet expects a batch of flattened images as input.
x = torch.rand(9, 32 * 32 * 3)
y1 = net.forward(x, weights=weights[0])
y2 = net.forward(x, weights=weights[1])

y1 = torch.nn.functional.softmax(y1, dim=1)
y2 = torch.nn.functional.softmax(y2, dim=1)

for cid, pred in zip(cond_ids, [y1, y2]):
    print('Predictions made with weights of condition "%d": %s' \
          % (cid, np.array2string(pred[0,:].detach().numpy(), precision=1, suppress_small=True)))

Predictions made with weights of condition "1": [0.  0.  0.  0.  0.7 0.  0.  0.  0.1 0.2]
Predictions made with weights of condition "3": [0.1 0.2 0.  0.  0.1 0.  0.  0.  0.1 0.5]
