#### Includes

In [3]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
import numpy as np
import torch
import torch.nn.functional as F
import altair as alt
from altair import datum
import pandas as pd


In [5]:
from sources import load_dataset, parse_config, setup_model
from steered_cnn.utils.rotequivariance_toolbox import *
from steered_cnn.steered_conv.steerable_filters import radial_steerable_filter, plot_filter
from steered_cnn.steered_conv import SteerableKernelBase
from steered_cnn.models import SteeredHemelingNet, HemelingNet
from src.trainer import BinaryClassifierNet

## Spatial Homogeneity

In [9]:
base = SteerableKernelBase.create_from_rk(4, max_k=5)

x = torch.linspace(-1, 1, 3)
x,y = torch.meshgrid(x,x)
d = torch.sqrt(x**2+y**2)
std = 1
G = torch.exp(-d**2/(2*std**2))
G /= G.sum()

k = 7
x = torch.linspace(-k/2, k/2, k)
x,y = torch.meshgrid(x,x)
d = torch.maximum(torch.sqrt(x**2+y**2)-k//2, torch.Tensor([0]))
window =  torch.exp(-torch.square(d)*2)

K = torch.randn((2000,1,k,k)) 
K = torch.conv2d(K, G[None, None], padding=1)
K = K * window

info = {}
W = base.approximate_weights(K, info, ridge_alpha=1e-2)
print(f'mse:{info["mse"]:.3f}, r2:{info["r2"]:.3f}')

mse:20.133, r2:-361.259


In [7]:
w_df = pd.DataFrame(data=base.weights_dist(W, Q=5))
w_df

Unnamed: 0,r,k,type,name,median,q0,-q0,q1,-q1,q2,-q2,q3,-q3,q4,-q4
0,0,0,R,"r=0, k=0, Real",0.091622,5.250425,-5.274933,3.41965,-3.626982,2.207469,-2.19027,1.123819,-1.118738,0.091622,0.091622
1,1,0,R,"r=1, k=0, Real",6.9e-05,0.075616,-0.076233,0.050943,-0.049263,0.032029,-0.02978,0.015394,-0.015701,6.9e-05,6.9e-05
2,2,0,R,"r=2, k=0, Real",-0.000217,0.016161,-0.015963,0.010605,-0.011036,0.006429,-0.007077,0.00314,-0.003392,-0.000217,-0.000217
3,3,0,R,"r=3, k=0, Real",2.4e-05,0.002182,-0.002248,0.001539,-0.001478,0.000991,-0.000936,0.000512,-0.000417,2.4e-05,2.4e-05
4,1,1,R,"r=1, k=1, Real",0.002278,0.14556,-0.148237,0.103275,-0.101147,0.065784,-0.064675,0.034966,-0.0324,0.002278,0.002278
5,1,1,I,"r=1, k=1, Imag",-0.000726,0.207075,-0.203766,0.146792,-0.148141,0.090117,-0.095937,0.043531,-0.044528,-0.000726,-0.000726
6,2,1,R,"r=2, k=1, Real",-0.000554,0.129475,-0.1336,0.094026,-0.09357,0.061169,-0.056627,0.028055,-0.027674,-0.000554,-0.000554
7,2,1,I,"r=2, k=1, Imag",-8.7e-05,0.002633,-0.002742,0.001851,-0.001984,0.001122,-0.001334,0.000498,-0.000666,-8.7e-05,-8.7e-05
8,3,1,R,"r=3, k=1, Real",8.4e-05,0.003205,-0.003057,0.002289,-0.002155,0.001533,-0.001315,0.000783,-0.000584,8.4e-05,8.4e-05
9,3,1,I,"r=3, k=1, Imag",-1.7e-05,0.000793,-0.000827,0.000558,-0.000593,0.000338,-0.000396,0.000154,-0.000203,-1.7e-05,-1.7e-05


In [None]:
chart = alt.Chart(w_df)
layered = alt.LayerChart()
for real in [True, False]:
    layered += chart.mark_tick(
                        thickness=2,
                        width=10,
                        xOffset=-5 if real else 5
                    ).encode(
                        alt.X('r:N'), alt.Y('median:Q', scale=alt.Scale(type='symlog')), 
                        alt.Color('type:N')
                    ).transform_filter(datum.type== ('R' if real else 'I') )
    layered += chart.mark_bar(
                        opacity=.2, 
                        width=10, 
                        xOffset=-5 if real else 5
                 ).encode(
                    alt.X('r:N'), alt.Y('-q1:Q'), alt.Y2('q1:Q'), 
                    alt.Color('type:N')
                  ).properties(width=80).transform_filter(datum.type== ('R' if real else 'I') )
    layered += chart.mark_bar(
                        opacity=.8, 
                        width=10, 
                        xOffset=-5 if real else 5
                 ).encode(
                    alt.X('r:N'), alt.Y('-q2:Q'), alt.Y2('q2:Q'), 
                    alt.Color('type:N')
                  ).properties(width=80).transform_filter(datum.type== ('R' if real else 'I') )
layered.facet('k:O').resolve_scale(x='independent').interactive()

In [None]:
h, w = 15, 15
I = torch.zeros((1,1,h,w))
I[0,0,h//2,w//2] = 1
W1 = torch.ones(W.shape)
plot_filter(base.conv2d(I, W1)[0,0])

In [None]:
O = base.conv2d(I, W1)[0,0]
O.max()

## Variance conservation

In [None]:
def plot_variance(net):
    N_CONV = 9
    # Prepare hooks
    forward_tensors = {}
    backward_tensors = {}

    def store_forward(name):
        def hook(self, input, output):
            forward_tensors[name+'-in'] = input
            forward_tensors[name+'-out'] = output
        return hook

    def store_backward(name):
        def hook(self, grad_input, grad_output):
            backward_tensors[name+'-in'] = grad_input
            backward_tensors[name+'-out'] = grad_output
        return hook
    
    # Register hooks
    hooks = []
    for i in range(1,N_CONV):
        name = f'conv{i}'
        hooks += [getattr(net, name)[0].register_forward_hook(store_forward(name)),
                  getattr(net, name)[0].register_backward_hook(store_backward(name))]
        
    # Do forward and backward pass with noise
    I = torch.randn((5,1,500,500)).cuda()
    angle = (torch.rand((5,500,500)).cuda()-.5)*2*np.pi
    out = net(I, alpha=angle)
    loss = F.binary_cross_entropy_with_logits(out, torch.randint(1, (5,1,496,496)).float().cuda())
    loss.backward()
    
    # Compute means and variances
    data = []
    for i in range(1,N_CONV):
        name = f'conv{i}'
        for input in [True, False]:
            d = {}
            d['name'] = name
            d['type'] = 'in' if input else 'out'

            n = name + ('-in' if input else '-out')
            forward = forward_tensors[n][0].detach()
            backward = backward_tensors[n][0].detach()

            mean = forward.mean().cpu().numpy()
            std = forward.std().cpu().numpy()
            d['forward_mean'] = mean
            d['forward_y'] = mean-std
            d['forward_y2'] = mean+std

            mean = backward.mean().cpu().numpy()
            std = backward.std().cpu().numpy()
            d['backward_mean'] = mean
            d['backward_y'] = mean-std
            d['backward_y2'] = mean+std

            data += [d]
    dist_df = pd.DataFrame(data=data)
    
    # Plot
    chart = alt.Chart(dist_df)
    def plot_dist(back=False):
        n = 'forward_' if not back else 'backward_'
        layered = alt.LayerChart()
        layered += chart.mark_tick(
                            thickness=2,
                            width=15,
                        ).encode(
                            alt.X('type:N'), alt.Y(n+'mean:Q', scale=alt.Scale(type='symlog')), 
                            alt.Color('type:N')
                        )
        layered += chart.mark_bar(
                            opacity=.2, 
                            width=15,
                     ).encode(
                        alt.X('type:N'), alt.Y(n+'y:Q'), alt.Y2(n+'y2:Q'), 
                        alt.Color('type:N')
                      ).properties(width=30)
        layered = layered.facet('name:O').resolve_scale(x='independent').interactive()
        layered.properties(title='Backward Variance' if back else 'Forward Variance')
        return layered

    # Remove hooks
    for hook in hooks:
        hook.remove()

    return (plot_dist(False)|plot_dist(True))

In [None]:
plot_variance(HemelingNet(1, 1, padding='same', batchnorm=False).cuda())

In [None]:
plot_variance(SteeredHemelingNet(1, 1, padding='same', batchnorm=False).cuda())