<img src="images/hypernet.svg" width=50% align="right">
# HyperNeworks
Author: Jin Yeom (jinyeom@utexas.edu)  
Original authors: David Ha, Andrew Dai, Quoc V. Le

## Contents
- [Configuration](#Configuration)
- [Utility functions](#Utility-functions)
- [Static HyperNet](#Static-HyperNet)
- [Dynamic HyperNet](#Dynamic-HyperNet)
- [References](#References)

In [1]:
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torchvision.utils import make_grid
from torchvision import transforms
from torchvision import datasets
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm_notebook as tqdm

In [2]:
%matplotlib notebook

## Configuration

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

device: cpu


## Utility functions

In [4]:
def imshow(img, title):
    npimg = img.detach().numpy()
    plt.figure()
    plt.title(title)
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

In [5]:
def num_params(model, requires_grad=False):
    if requires_grad:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    return sum(p.numel() for p in model.parameters())

## Static HyperNet

In [6]:
class HyperNet(nn.Module):
    def __init__(self, z_dim, in_channels, out_channels, kernel_size):
        super(HyperNet, self).__init__()
        self.fc1 = nn.Linear(z_dim, in_channels * z_dim)        
        self.fc2 = nn.Linear(z_dim, out_channels * kernel_size * kernel_size)
        
    def forward(self, z):
        z = self.fc1(z)
        z = z.view(in_channels, z_dim)
        z = self.fc2(z)
        z = z.view(out_channels, in_channels, kernel_size, kernel_size)
        return z

In [7]:
z_dim = 4
in_channels = 16
out_channels = 16
kernel_size = 7

z = torch.randn(z_dim, requires_grad=True)
hn = HyperNet(z_dim, in_channels, out_channels, kernel_size)
W = hn(z)

n_params = num_params(hn, requires_grad=True)
n_weights = in_channels * out_channels * kernel_size * kernel_size
print(f"number of trainable parameters: {n_params}")
print(f"number of generated weights: {n_weights}")

W_vis = W.view(in_channels * out_channels, 1, kernel_size, kernel_size)
imshow(make_grid(W_vis, nrow=in_channels, normalize=True), 
       "HyperNet output weights before training (normalized)")

number of trainable parameters: 4240
number of generated weights: 12544


<IPython.core.display.Javascript object>

In [18]:
class ConvNet(nn.Module):
    def __init__(self, in_channels=1, out_features=10, use_hn=False):
        super(ConvNet, self).__init__()
        self._use_hn = use_hn
        self.conv1 = nn.Conv2d(in_channels, 16, 7, stride=1, padding=2)
        if self._use_hn:
            self.hn = HyperNet(4, 16, 16, 7)     
            self.z = torch.randn(4, requires_grad=True)
            self.conv2_W = self.hn(self.z)
            self.conv2_b = torch.zeros(16, requires_grad=True)
        else:
            self.conv2 = nn.Conv2d(16, 16, 7, padding=2)
        self.fc3 = nn.Linear(784, 10)
    
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, (2, 2), stride=(2, 2), padding=1)
        if self._use_hn:
            x = F.conv2d(x, self.conv2_W, self.conv2_b, padding=2)
        else:
            x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, (2, 2), stride=(2, 2), padding=1)
        x = x.view(x.size(0), -1)
        return F.log_softmax(self.fc3(x), dim=1)

In [10]:
model1 = ConvNet()
model2 = ConvNet(use_hn=True)

print("Trainable paramters without HN: {}".format(num_params(model1, requires_grad=True)))
print("Trainable paramters with HN: {}".format(num_params(model2, requires_grad=True)))

Trainable paramters without HN: 21210
Trainable paramters with HN: 12890


In [11]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

mnist_train = datasets.MNIST("datasets/mnist",
                                    train=True,
                                    transform=transform,
                                    download=True)
train_loader = torch.utils.data.DataLoader(mnist_train,
                                           batch_size=64,
                                           shuffle=True,
                                           num_workers=2)

mnist_test = datasets.MNIST("datasets/mnist",
                                   train=False,
                                   transform=transform,
                                   download=True)
test_loader = torch.utils.data.DataLoader(mnist_test,
                                          batch_size=64,
                                          shuffle=False,
                                          num_workers=2)

In [12]:
print("datasets/mnist/raw:")
!ls datasets/mnist/raw
print("datasets/mnist/processed:")
!ls datasets/mnist/processed

datasets/mnist/raw:
t10k-images-idx3-ubyte  train-images-idx3-ubyte
t10k-labels-idx1-ubyte  train-labels-idx1-ubyte
datasets/mnist/processed:
test.pt     training.pt


In [13]:
train_iter = iter(train_loader)
images, labels = train_iter.next()
imshow(make_grid(images, nrow=8, normalize=True), "MNIST training sample")

<IPython.core.display.Javascript object>

In [14]:
def validate(model, test_loader):
    model.eval()
    accuracy = 0.0
    with torch.no_grad():
        for data, label in test_loader:
            data = data.to(device)
            label = label.to(device)
            pred = model(data).max(1, keepdim=True)[1] # index of max log-prob
            accuracy += pred.eq(label.view_as(pred)).sum().item()
    return 100.0 * accuracy / len(test_loader.dataset)

In [15]:
def update_figure(fig, ax, data, title, xlabel, ylabel):
    plt.tight_layout()
    ax.clear()
    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.plot(data)
    plt.show()
    fig.canvas.draw()

In [19]:
def train_model(model, train_loader, test_loader, filename):
    fig, (ax_train, ax_validate) = plt.subplots(2, 1)

    train_losses = [] # training losses every 10 iterations
    validate_acc = [] # validation accuracy every 100 iterations

    best_acc = 0.0 # best accuracy during validation
    
    optimizer = optim.Adam(model.parameters())
    for epoch in tqdm(range(10), desc="epochs"):
        model.train()
        for i, (data, label) in enumerate(tqdm(train_loader, 
                                               desc="iters.", 
                                               leave=False)):
            optimizer.zero_grad()

            data = data.to(device)
            label = label.to(device)
            pred = model(data)

            loss = F.nll_loss(pred, label)
            loss.backward(retain_graph=True)
            optimizer.step()

            if i % 10 == 0:
                train_losses.append(loss.item())
                update_figure(fig,
                              ax_train,
                              train_losses,
                              "Training loss",
                              "iterations (10 iters.)",
                              "loss")

            if i % 100 == 0:
                # every 100 iterations, record the model's validation
                accuracy = validate(model, test_loader)
                if accuracy > best_acc:
                    best_acc = accuracy
                    torch.save(model, filename)
                validate_acc.append(accuracy)
                update_figure(fig,
                              ax_validate,
                              validate_acc,
                              "Validation accuracy",
                              "iterations (100 iters.)",
                              "accuracy (%)")
                print(f"Epoch {epoch} iter {i}: "
                      f"validation accuracy {accuracy}% "
                      f"best accuracy {best_acc}%")

In [20]:
model = ConvNet().to(device)
train_model(model, train_loader, test_loader, "MNIST_ConvNet.pt")

<IPython.core.display.Javascript object>

  "type " + obj.__name__ + ". It won't be checked "


Epoch 0 iter 0: validation accuracy 12.05% best accuracy 12.05%
Epoch 0 iter 100: validation accuracy 89.99% best accuracy 89.99%
Epoch 0 iter 200: validation accuracy 94.23% best accuracy 94.23%
Epoch 0 iter 300: validation accuracy 95.02% best accuracy 95.02%
Epoch 0 iter 400: validation accuracy 96.94% best accuracy 96.94%
Epoch 0 iter 500: validation accuracy 97.15% best accuracy 97.15%
Epoch 0 iter 600: validation accuracy 97.66% best accuracy 97.66%
Epoch 0 iter 700: validation accuracy 97.68% best accuracy 97.68%
Epoch 0 iter 800: validation accuracy 97.9% best accuracy 97.9%
Epoch 0 iter 900: validation accuracy 97.99% best accuracy 97.99%


Epoch 1 iter 0: validation accuracy 97.81% best accuracy 97.99%
Epoch 1 iter 100: validation accuracy 98.15% best accuracy 98.15%
Epoch 1 iter 200: validation accuracy 98.27% best accuracy 98.27%
Epoch 1 iter 300: validation accuracy 97.78% best accuracy 98.27%
Epoch 1 iter 400: validation accuracy 98.27% best accuracy 98.27%
Epoch 1 iter 500: validation accuracy 98.4% best accuracy 98.4%
Epoch 1 iter 600: validation accuracy 98.07% best accuracy 98.4%
Epoch 1 iter 700: validation accuracy 98.5% best accuracy 98.5%
Epoch 1 iter 800: validation accuracy 98.63% best accuracy 98.63%
Epoch 1 iter 900: validation accuracy 98.36% best accuracy 98.63%


Epoch 2 iter 0: validation accuracy 98.29% best accuracy 98.63%
Epoch 2 iter 100: validation accuracy 98.62% best accuracy 98.63%
Epoch 2 iter 200: validation accuracy 98.62% best accuracy 98.63%
Epoch 2 iter 300: validation accuracy 98.53% best accuracy 98.63%
Epoch 2 iter 400: validation accuracy 98.35% best accuracy 98.63%
Epoch 2 iter 500: validation accuracy 98.66% best accuracy 98.66%
Epoch 2 iter 600: validation accuracy 98.65% best accuracy 98.66%
Epoch 2 iter 700: validation accuracy 98.61% best accuracy 98.66%
Epoch 2 iter 800: validation accuracy 98.65% best accuracy 98.66%
Epoch 2 iter 900: validation accuracy 98.75% best accuracy 98.75%


Epoch 3 iter 0: validation accuracy 98.77% best accuracy 98.77%
Epoch 3 iter 100: validation accuracy 98.58% best accuracy 98.77%
Epoch 3 iter 200: validation accuracy 98.6% best accuracy 98.77%
Epoch 3 iter 300: validation accuracy 98.92% best accuracy 98.92%
Epoch 3 iter 400: validation accuracy 98.82% best accuracy 98.92%
Epoch 3 iter 500: validation accuracy 98.96% best accuracy 98.96%
Epoch 3 iter 600: validation accuracy 98.57% best accuracy 98.96%
Epoch 3 iter 700: validation accuracy 98.89% best accuracy 98.96%
Epoch 3 iter 800: validation accuracy 98.8% best accuracy 98.96%
Epoch 3 iter 900: validation accuracy 98.66% best accuracy 98.96%


Epoch 4 iter 0: validation accuracy 98.86% best accuracy 98.96%
Epoch 4 iter 100: validation accuracy 98.94% best accuracy 98.96%
Epoch 4 iter 200: validation accuracy 98.83% best accuracy 98.96%
Epoch 4 iter 300: validation accuracy 98.81% best accuracy 98.96%
Epoch 4 iter 400: validation accuracy 98.92% best accuracy 98.96%
Epoch 4 iter 500: validation accuracy 98.69% best accuracy 98.96%
Epoch 4 iter 600: validation accuracy 98.9% best accuracy 98.96%
Epoch 4 iter 700: validation accuracy 98.93% best accuracy 98.96%
Epoch 4 iter 800: validation accuracy 98.71% best accuracy 98.96%
Epoch 4 iter 900: validation accuracy 98.95% best accuracy 98.96%


Epoch 5 iter 0: validation accuracy 99.03% best accuracy 99.03%
Epoch 5 iter 100: validation accuracy 98.83% best accuracy 99.03%
Epoch 5 iter 200: validation accuracy 98.95% best accuracy 99.03%
Epoch 5 iter 300: validation accuracy 98.93% best accuracy 99.03%
Epoch 5 iter 400: validation accuracy 98.7% best accuracy 99.03%
Epoch 5 iter 500: validation accuracy 98.82% best accuracy 99.03%
Epoch 5 iter 600: validation accuracy 98.77% best accuracy 99.03%
Epoch 5 iter 700: validation accuracy 98.77% best accuracy 99.03%
Epoch 5 iter 800: validation accuracy 98.87% best accuracy 99.03%
Epoch 5 iter 900: validation accuracy 98.79% best accuracy 99.03%


Epoch 6 iter 0: validation accuracy 98.84% best accuracy 99.03%
Epoch 6 iter 100: validation accuracy 98.66% best accuracy 99.03%
Epoch 6 iter 200: validation accuracy 98.94% best accuracy 99.03%
Epoch 6 iter 300: validation accuracy 99.02% best accuracy 99.03%
Epoch 6 iter 400: validation accuracy 98.95% best accuracy 99.03%
Epoch 6 iter 500: validation accuracy 98.98% best accuracy 99.03%
Epoch 6 iter 600: validation accuracy 98.88% best accuracy 99.03%
Epoch 6 iter 700: validation accuracy 99.07% best accuracy 99.07%
Epoch 6 iter 800: validation accuracy 99.01% best accuracy 99.07%
Epoch 6 iter 900: validation accuracy 98.8% best accuracy 99.07%


Epoch 7 iter 0: validation accuracy 98.95% best accuracy 99.07%
Epoch 7 iter 100: validation accuracy 99.01% best accuracy 99.07%
Epoch 7 iter 200: validation accuracy 99.01% best accuracy 99.07%
Epoch 7 iter 300: validation accuracy 98.87% best accuracy 99.07%
Epoch 7 iter 400: validation accuracy 98.94% best accuracy 99.07%
Epoch 7 iter 500: validation accuracy 98.87% best accuracy 99.07%
Epoch 7 iter 600: validation accuracy 98.72% best accuracy 99.07%
Epoch 7 iter 700: validation accuracy 98.95% best accuracy 99.07%
Epoch 7 iter 800: validation accuracy 98.78% best accuracy 99.07%
Epoch 7 iter 900: validation accuracy 99.06% best accuracy 99.07%


Epoch 8 iter 0: validation accuracy 98.93% best accuracy 99.07%
Epoch 8 iter 100: validation accuracy 98.82% best accuracy 99.07%
Epoch 8 iter 200: validation accuracy 98.97% best accuracy 99.07%
Epoch 8 iter 300: validation accuracy 99.02% best accuracy 99.07%
Epoch 8 iter 400: validation accuracy 99.01% best accuracy 99.07%
Epoch 8 iter 500: validation accuracy 98.95% best accuracy 99.07%
Epoch 8 iter 600: validation accuracy 99.02% best accuracy 99.07%
Epoch 8 iter 700: validation accuracy 98.89% best accuracy 99.07%
Epoch 8 iter 800: validation accuracy 99.07% best accuracy 99.07%
Epoch 8 iter 900: validation accuracy 99.04% best accuracy 99.07%


Epoch 9 iter 0: validation accuracy 99.02% best accuracy 99.07%
Epoch 9 iter 100: validation accuracy 99.14% best accuracy 99.14%
Epoch 9 iter 200: validation accuracy 98.8% best accuracy 99.14%
Epoch 9 iter 300: validation accuracy 98.93% best accuracy 99.14%
Epoch 9 iter 400: validation accuracy 99.04% best accuracy 99.14%
Epoch 9 iter 500: validation accuracy 99.15% best accuracy 99.15%
Epoch 9 iter 600: validation accuracy 98.86% best accuracy 99.15%
Epoch 9 iter 700: validation accuracy 99.06% best accuracy 99.15%
Epoch 9 iter 800: validation accuracy 98.99% best accuracy 99.15%
Epoch 9 iter 900: validation accuracy 98.87% best accuracy 99.15%



## Dynamic HyperNet

## References

- HyperNetworks, [arXiv:1609.09106v4](https://arxiv.org/abs/1609.09106v4) \[cs.LG\]