In [11]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

%load_ext autoreload
%autoreload 2

from se3_cnn import basis_kernels
from se3_cnn import SO3

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [28]:
size = 7
n_radial = size//2+1 # hard coded
upsampling = 1 # hard coded

reps = [SO3.repr1, SO3.repr3, SO3.repr5, SO3.repr7, SO3.repr9, SO3.repr11]
order_in = 2
order_out = 2
R_in = reps[order_in]
R_out = reps[order_out]

basis = basis_kernels.cube_basis_kernels_analytical(size, n_radial, upsampling, R_out, R_in)

print('basis of shape', basis.shape)


kernel size: 7
shell radii: [0 1 2 3]
shell bandlimit: [0 4 6 8]

check_basis_equivariance for R_in=repr5 -> R_out=repr5:
[ 0.87521099  0.87521099  0.26460547  0.14271229  0.10268137  0.11542211
  0.87521099  0.26460547  0.14271229  0.10268137  0.11542211  0.87521099
  0.26460547  0.14271229  0.10268137  0.11542211]
basis of shape (16, 5, 5, 7, 7, 7)


In [None]:
avg_equiv = np.zeros(len(basis))
N_samples = 100
for a,b,c in 2*np.pi*np.random.rand(N_samples,3):
    avg_equiv += basis_kernels.check_basis_equivariance(basis, R_out, R_in, a,b,c)
avg_equiv /= N_samples
print(avg_equiv)

In [None]:
def plot_cube(cube):
    vmin, vmax = cube.min(), cube.max()
    size = cube.shape[-1]
    plt.figure(figsize=(1.5*size, 1.5))
    for idx_z in range(size):
        plt.subplot(1, 7, idx_z+1)
        plt.imshow(cube[:,:,idx_z], vmin=vmin, vmax=vmax)
        plt.axis("off")
    plt.tight_layout()
    plt.show()

for idx_basis,b_elem in enumerate(basis):
    print('#################################################################')
    print('Basis element {}'.format(idx_basis))
    print('#################################################################')
    for m,out_ch in zip(np.arange(-order_out, order_out+1), b_elem):
        print('\nout channel m = {}'.format(m))
        for n,in_ch in zip(np.arange(-order_in, order_in+1), out_ch):
            print('in channel n = {}'.format(n))
            plot_cube(in_ch)