#### Includes

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
import altair as alt
from altair import datum
import pandas as pd
import time


In [3]:
import sources
from steered_cnn.utils.rotequivariance_toolbox import *
from steered_cnn.steered_conv.steerable_filters import radial_steerable_filter, plot_filter
from steered_cnn.steered_conv import SteerableKernelBase, SteeredConv2d, SteeredConvBN
from steered_cnn.utils.convbn import ConvBN
from steered_cnn.models import SteeredUNet

## Weight Variability

### Kernel normalization

In [24]:
base = SteerableKernelBase.create_radial(5)
print('Base equi square sum to 1:', all((K**2).sum()-1<1e-8 for K in base.base_equi))
print('Base complex square sum to 1:', all((torch.abs(K)**2).sum()-1<1e-8 for K in base.base_complex))

Base equi square sum to 1: True
Base complex square sum to 1: True


### Forward Propagation Numerical Stability

In [5]:
ratio = []
def hook(i):
    def print_std(module, inputs, output):
        i_mean = torch.mean(inputs[0])
        o_mean = torch.mean(output)
        i_std = torch.std(inputs[0])
        o_std = torch.std(output)
        print(f"Conv {i}:  in~{i_mean:.4f}±{i_std:.4f}; out~{o_mean:.4f}±{o_std:.4f} | ratio={o_std/i_std:.4f}")
        ratio.append(o_std/i_std)
    return print_std

In [13]:
ratio = []
opts = dict(bn=True, relu=True, steerable_base=base, attention_mode='shared', rho_nonlinearity='normalize')
model = [SteeredConvBN(1,16, **opts)]+[SteeredConvBN(16,16, **opts) for _ in range(10)] + [SteeredConvBN(16, 1, **opts)]
for i, m in enumerate(model):
    m.register_forward_hook(hook(i))
net = nn.Sequential(*model).to('cuda').eval()
x = torch.randn((32,1,128,128), device='cuda')
y = net(x)

print('__'*10)
print(f'avg ratio: {torch.mean(torch.Tensor(ratio)):.4f}±{torch.std(torch.Tensor(ratio)):.4f}')
hook('')(None, x, y)

del y
del x
del net
del model

Conv 0:  in~0.0004±1.0023; out~0.5423±0.9014 | ratio=0.8993
Conv 1:  in~0.5423±0.9014; out~0.5746±0.9452 | ratio=1.0486
Conv 2:  in~0.5746±0.9452; out~1.1176±1.4130 | ratio=1.4949
Conv 3:  in~1.1176±1.4130; out~1.7984±2.4624 | ratio=1.7427
Conv 4:  in~1.7984±2.4624; out~2.7591±3.6946 | ratio=1.5004
Conv 5:  in~2.7591±3.6946; out~3.2651±5.9446 | ratio=1.6090
Conv 6:  in~3.2651±5.9446; out~5.7545±8.0617 | ratio=1.3561
Conv 7:  in~5.7545±8.0617; out~3.5780±7.9063 | ratio=0.9807
Conv 8:  in~3.5780±7.9063; out~10.2685±16.9426 | ratio=2.1429
Conv 9:  in~10.2685±16.9426; out~13.4305±15.8873 | ratio=0.9377
Conv 10:  in~13.4305±15.8873; out~24.2476±34.9288 | ratio=2.1985
Conv 11:  in~24.2476±34.9288; out~1.2898±5.9295 | ratio=0.1698
____________________
avg ratio: 1.3401±0.5725
Conv :  in~-0.0059±1.0149; out~1.2898±5.9295 | ratio=5.8424


In [23]:
ratio = []
opts = dict(bn=True, relu=False)
model = [ConvBN(5,1,16, **opts)]+[ConvBN(5,16,16, **opts) for _ in range(10)] + [ConvBN(5,16, 1, **opts)]
for i, m in enumerate(model):
    m.register_forward_hook(hook(i))
net = nn.Sequential(*model).to('cuda').eval()
x = torch.randn((64,1,256,256), device='cuda')
y = net(x)

print('__'*10)
print(f'avg ratio: {torch.mean(torch.Tensor(ratio)):.4f}±{torch.std(torch.Tensor(ratio)):.4f}')
hook('')(None, x, y)

del y
del x
del net
del model

Conv 0:  in~-0.0002±1.0005; out~-0.0000±0.2516 | ratio=0.2515
Conv 1:  in~-0.0000±0.2516; out~-0.0000±0.2499 | ratio=0.9933
Conv 2:  in~-0.0000±0.2499; out~-0.0000±0.2470 | ratio=0.9882
Conv 3:  in~-0.0000±0.2470; out~-0.0000±0.2399 | ratio=0.9715
Conv 4:  in~-0.0000±0.2399; out~0.0000±0.2425 | ratio=1.0107
Conv 5:  in~0.0000±0.2425; out~0.0000±0.2471 | ratio=1.0188
Conv 6:  in~0.0000±0.2471; out~0.0000±0.2470 | ratio=0.9997
Conv 7:  in~0.0000±0.2470; out~0.0000±0.2420 | ratio=0.9797
Conv 8:  in~0.0000±0.2420; out~0.0000±0.2460 | ratio=1.0164
Conv 9:  in~0.0000±0.2460; out~0.0000±0.2501 | ratio=1.0167
Conv 10:  in~0.0000±0.2501; out~0.0000±0.2524 | ratio=1.0092
Conv 11:  in~0.0000±0.2524; out~-0.0004±0.9078 | ratio=3.5974
____________________
avg ratio: 1.1544±0.7990
Conv :  in~-0.0026±0.9984; out~-0.0004±0.9078 | ratio=0.9093


## Preconvolve vs Composite Kernel

### Implementation Consistency

In [None]:
base = SteerableKernelBase.create_radial(4, max_k=5)
W = base.init_weights(3,3)

X = torch.randn((100,3,200,200))
alpha = torch.exp(torch.rand((100,3,200,200))*2j*np.pi)
alpha = torch.stack([alpha.real, alpha.imag])

Y1 = base.composite_kernels_conv2d(X, W, alpha=alpha)
Y2 = base.preconvolved_base_conv2d(X, W, alpha=alpha)

diff = torch.abs(Y1-Y2)
print(f"Mean diff: {diff.mean()}, Quantile diff: {torch.quantile(diff, torch.Tensor([.5,.75,.9,.99]))}")
diff = torch.abs((Y1-Y2)/(Y1+1e-8))
print(f"Mean normed diff: {diff.mean()}, Quantile normed diff: {torch.quantile(diff, torch.Tensor([.25,.5,.75,.9,.99]))}")


### Speed Comparison

In [None]:

def speed_test(X=200, M=32, N=32, R=4, K=5, batchsize=50, iterations=10):
    base = SteerableKernelBase.create_radial(R, max_k=K)
    x = torch.randn((batchsize,M,X,X), requires_grad=True).cuda()
    W = base.init_weights(M,N).cuda()
    alpha = torch.exp(torch.rand((batchsize,N,X,X))*2j*np.pi).cuda()
    alpha = torch.stack([alpha.real, alpha.imag])
    
    tf_comp, tf_pre = 0, 0
    tb_comp, tb_pre = 0, 0
    tn_comp, tn_pre = 0, 0
    for i in range(iterations):
        x = torch.randn((batchsize,M,X,X), requires_grad=True).cuda()
        
        t0 = time.perf_counter()
        Y = base.composite_kernels_conv2d(x, W, alpha=alpha)
        t1 = time.perf_counter()
        tf_comp += t1-t0
        Y = Y.sum()
        
        t0 = time.perf_counter()
        Y.backward()
        t1 = time.perf_counter()
        tb_comp += t1-t0
        del Y
        
        with torch.no_grad():
            t0 = time.perf_counter()
            Y = base.composite_kernels_conv2d(x, W, alpha=alpha)
            t1 = time.perf_counter()
            tn_comp += t1-t0
            del Y
        
        t0 = time.perf_counter()
        Y = base.preconvolved_base_conv2d(x, W, alpha=alpha)
        t1 = time.perf_counter()
        tf_pre += t1-t0
        Y = Y.sum()
        
        t0 = time.perf_counter()
        Y.backward()
        t1 = time.perf_counter()
        tb_pre += t1-t0
        del Y
        
        with torch.no_grad():
            t0 = time.perf_counter()
            Y = base.preconvolved_base_conv2d(x, W, alpha=alpha)
            t1 = time.perf_counter()
            tn_pre += t1-t0
            del Y
        
    return pd.DataFrame(
            {'Composite':{'forward': tf_comp, 'backward': tb_comp, 'no grad': tn_comp},
             'Preconv':{'forward': tf_pre, 'backward': tb_pre, 'no grad': tn_pre}})

In [None]:
speed_test(X=32, M=256, N=500, batchsize=4)
