#### Includes

In [1]:
%load_ext autoreload
%autoreload 2

In [62]:
import numpy as np
import torch
import torch.nn.functional as F
import altair as alt
from altair import datum
import pandas as pd
import time

In [44]:
from steered_cnn.utils.rotequivariance_toolbox import *
from steered_cnn.steered_conv.steerable_filters import radial_steerable_filter, plot_filter
from steered_cnn.steered_conv import SteerableKernelBase

### Implementation Consistency

In [59]:
base = SteerableKernelBase.create_from_rk(4, max_k=5)
W = base.create_weights(3,3)

X = torch.randn((100,3,200,200))
Y1 = base.composite_kernels_conv2d(X, W)
Y2 = base.preconvolved_base_conv2d(X, W)
diff = torch.abs(Y1-Y2)
print(f"Mean diff: {diff.mean()}, Quantile diff: {torch.quantile(diff, torch.Tensor([.5,.75,.9,.99]))}")
diff = torch.abs((Y1-Y2)/Y1)
print(f"Mean normed diff: {diff.mean()}, Quantile normed diff: {torch.quantile(diff, torch.Tensor([.25,.5,.75,.9,.99]))}")

Mean diff: 2.797012825794809e-07, Quantile diff: tensor([2.3842e-07, 3.5763e-07, 5.9605e-07, 1.1921e-06])
Mean normed diff: 1.9564754438761156e-06, Quantile normed diff: tensor([1.0951e-07, 2.5203e-07, 5.1362e-07, 1.1555e-06, 1.1284e-05])


### Speed Comparison

In [89]:
X = torch.randn((1,3,5,5)).cuda()
W = W.cuda()
base.conv2d(X, W)

def speed_test(X=200, M=32, N=32, R=4, K=5, batchsize=50, iterations=10):
    base = SteerableKernelBase.create_from_rk(R, max_k=K)
    x = torch.randn((batchsize,M,X,X), requires_grad=True).cuda()
    W = base.create_weights(M,N).cuda()
    base.conv2d(x, W)
    
    tf_comp, tf_pre = 0, 0
    tb_comp, tb_pre = 0, 0
    for i in range(iterations):
        x = torch.randn((batchsize,M,X,X), requires_grad=True).cuda()
        
        t0 = time.perf_counter()
        Y = base.composite_kernels_conv2d(x, W)
        t1 = time.perf_counter()
        tf_comp += t1-t0
        Y = Y.sum()
        
        t0 = time.perf_counter()
        Y.backward()
        t1 = time.perf_counter()
        tb_comp += t1-t0
        del Y
        
        t0 = time.perf_counter()
        Y = base.preconvolved_base_conv2d(x, W)
        t1 = time.perf_counter()
        tf_pre += t1-t0
        Y = Y.sum()
        
        t0 = time.perf_counter()
        Y.backward()
        t1 = time.perf_counter()
        tb_pre += t1-t0
        del Y
    return tf_comp, tf_pre, tb_comp, tb_pre

In [93]:
speed_test(X=32, M=256, N=500, batchsize=4)

(0.003889273995810072,
 0.005263696997644729,
 0.26142607500332815,
 0.18343742900106008)