#### Includes

In [None]:
%load_ext autoreload
%autoreload 2

In [3]:
import orion.storage.base

ImportError: cannot import name 'get_storage' from 'orion.storage.base' (/home/gaby/.conda/envs/nnet/lib/python3.7/site-packages/orion/storage/base.py)

In [None]:
import numpy as np
import torch
import torch.nn.functional as F
import altair as alt
from altair import datum
import pandas as pd

import sys

In [None]:
sys.path.insert(0, '../../')
from lib.utils.rotequivariance_toolbox import *
from lib.steered_conv.steerable_filters import radial_steerable_filter, plot_filter
from lib.steered_conv import SteerableKernelBase
from lib.models import SteeredHemelingNet, HemelingNet
from experiments.trainer import BinaryClassifierNet
from experiments import load_dataset, parse_config, setup_model

## Spatial Homogeneity

In [None]:
base = SteerableKernelBase.create_from_rk(4, max_k=5)

x = torch.linspace(-1, 1, 3)
x,y = torch.meshgrid(x,x)
d = torch.sqrt(x**2+y**2)
std = 1
G = torch.exp(-d**2/(2*std**2))
G /= G.sum()

k = 7
x = torch.linspace(-k/2, k/2, k)
x,y = torch.meshgrid(x,x)
d = torch.maximum(torch.sqrt(x**2+y**2)-k//2, torch.Tensor([0]))
window =  torch.exp(-torch.square(d)*2)

K = torch.randn((2000,1,k,k)) 
K = torch.conv2d(K, G[None, None], padding=1)
K = K * window

info = {}
W = base.approximate_weights(K, info, ridge_alpha=1e-3)
print(f'mse:{info["mse"]:.3f}, r2:{info["r2"]:.3f}')

In [None]:
w_df = pd.DataFrame(data=base.weights_dist(W, Q=5))
w_df

In [None]:
chart = alt.Chart(w_df)
layered = alt.LayerChart()
for real in [True, False]:
    layered += chart.mark_tick(
                        thickness=2,
                        width=10,
                        xOffset=-5 if real else 5
                    ).encode(
                        alt.X('r:N'), alt.Y('median:Q', scale=alt.Scale(type='symlog')), 
                        alt.Color('type:N')
                    ).transform_filter(datum.type== ('R' if real else 'I') )
    layered += chart.mark_bar(
                        opacity=.2, 
                        width=10, 
                        xOffset=-5 if real else 5
                 ).encode(
                    alt.X('r:N'), alt.Y('-q1:Q'), alt.Y2('q1:Q'), 
                    alt.Color('type:N')
                  ).properties(width=80).transform_filter(datum.type== ('R' if real else 'I') )
    layered += chart.mark_bar(
                        opacity=.8, 
                        width=10, 
                        xOffset=-5 if real else 5
                 ).encode(
                    alt.X('r:N'), alt.Y('-q2:Q'), alt.Y2('q2:Q'), 
                    alt.Color('type:N')
                  ).properties(width=80).transform_filter(datum.type== ('R' if real else 'I') )
layered.facet('k:O').resolve_scale(x='independent').interactive()

In [None]:
h, w = 15, 15
I = torch.zeros((1,1,h,w))
I[0,0,h//2,w//2] = 1
W1 = torch.ones(W.shape)
plot_filter(base.conv2d(I, W1)[0,0])

In [None]:
O = base.conv2d(I, W1)[0,0]
O.max()

## Variance conservation

In [None]:
def plot_variance(net):
    N_CONV = 9
    # Prepare hooks
    forward_tensors = {}
    backward_tensors = {}

    def store_forward(name):
        def hook(self, input, output):
            forward_tensors[name+'-in'] = input
            forward_tensors[name+'-out'] = output
        return hook

    def store_backward(name):
        def hook(self, grad_input, grad_output):
            backward_tensors[name+'-in'] = grad_input
            backward_tensors[name+'-out'] = grad_output
        return hook
    
    # Register hooks
    hooks = []
    for i in range(1,N_CONV):
        name = f'conv{i}'
        hooks += [getattr(net, name)[0].register_forward_hook(store_forward(name)),
                  getattr(net, name)[0].register_backward_hook(store_backward(name))]
        
    # Do forward and backward pass with noise
    I = torch.randn((5,1,500,500)).cuda()
    angle = (torch.rand((5,500,500)).cuda()-.5)*2*np.pi
    out = net(I, alpha=angle)
    loss = F.binary_cross_entropy_with_logits(out, torch.randint(1, (5,1,496,496)).float().cuda())
    loss.backward()
    
    # Compute means and variances
    data = []
    for i in range(1,N_CONV):
        name = f'conv{i}'
        for input in [True, False]:
            d = {}
            d['name'] = name
            d['type'] = 'in' if input else 'out'

            n = name + ('-in' if input else '-out')
            forward = forward_tensors[n][0].detach()
            backward = backward_tensors[n][0].detach()

            mean = forward.mean().cpu().numpy()
            std = forward.std().cpu().numpy()
            d['forward_mean'] = mean
            d['forward_y'] = mean-std
            d['forward_y2'] = mean+std

            mean = backward.mean().cpu().numpy()
            std = backward.std().cpu().numpy()
            d['backward_mean'] = mean
            d['backward_y'] = mean-std
            d['backward_y2'] = mean+std

            data += [d]
    dist_df = pd.DataFrame(data=data)
    
    # Plot
    chart = alt.Chart(dist_df)
    def plot_dist(back=False):
        n = 'forward_' if not back else 'backward_'
        layered = alt.LayerChart()
        layered += chart.mark_tick(
                            thickness=2,
                            width=15,
                        ).encode(
                            alt.X('type:N'), alt.Y(n+'mean:Q', scale=alt.Scale(type='symlog')), 
                            alt.Color('type:N')
                        )
        layered += chart.mark_bar(
                            opacity=.2, 
                            width=15,
                     ).encode(
                        alt.X('type:N'), alt.Y(n+'y:Q'), alt.Y2(n+'y2:Q'), 
                        alt.Color('type:N')
                      ).properties(width=30)
        layered = layered.facet('name:O').resolve_scale(x='independent').interactive()
        layered.properties(title='Backward Variance' if back else 'Forward Variance')
        return layered

    # Remove hooks
    for hook in hooks:
        hook.remove()

    return (plot_dist(False)|plot_dist(True))

In [None]:
plot_variance(HemelingNet(1, 1, padding='same', batchnorm=False).cuda())

In [None]:
plot_variance(SteeredHemelingNet(1, 1, padding='same', batchnorm=False).cuda())