In [None]:
import numpy as np
from collections import defaultdict
import matplotlib

%matplotlib inline

import matplotlib.pyplot as plt

plt.rcParams['image.cmap'] = 'RdYlGn'

import torch

from e2cnn import gspaces
from e2cnn import nn

import matplotlib.image as mpimg

In [None]:
image = np.zeros((20,20,3))

In [None]:
image[10,:,:] = 1
image[:,10,:] = 1
image[:5,:5,:] = 0.5

In [None]:
plt.imshow(image)

In [None]:
gimage = nn.GeometricTensor(torch.tensor(image).permute(2,0,1).unsqueeze(0), nn.FieldType(g, 3*[g.trivial_repr]))

In [None]:
_, axes = plt.subplots(2,2, figsize=(10,10))
for i in range(4):
    axes[i//2, i%2].set_axis_off()
    axes[i//2, i%2].imshow(gimage.transform(i).tensor.squeeze(0).permute(1,2,0))

In [None]:
plt.imshow(g.regular_repr.representation(3))

In [None]:
np.set_printoptions(precision=2, floatmode='maxprec', suppress=True)

In [None]:
g.regular_repr.change_of_basis @ g.regular_repr.representation(1)

In [None]:
g.regular_repr.change_of_basis_inv

In [None]:
N = 4
g = gspaces.Rot2dOnR2(N)
in_type = nn.FieldType(g, [g.trivial_repr])
out_type = nn.FieldType(g, [g.regular_repr])
conv = nn.R2Conv(in_type, out_type, 13, maximum_offset=0)

conv.weights.data.fill_(0.)

with torch.no_grad():
    for i, p in enumerate(conv.basisexpansion.get_basis_info()):  
        if p['radius'] == 1. and p['frequency'] == 2 and p['gamma'] == 0.0:
            conv.weights[i] = 1.

# conv.weights.data.fill_(1.)
conv.eval()

I = 1
O = N
fig, axes = plt.subplots(ncols=I, nrows=O, constrained_layout=True, squeeze=False, figsize=(7, 7))

for i in range(I):
    for o in range(O):
        axes[o][i].set_axis_off()
        axes[o][i].imshow(conv.filter[o, i, ...].detach().numpy())



# Trivial -> Regular

In [None]:
N = 8
g = gspaces.Rot2dOnR2(N)
in_type = nn.FieldType(g, [g.trivial_repr])
out_type = nn.FieldType(g, [g.regular_repr])
K = 9

conv = nn.R2Conv(in_type, out_type, K, sigma=0.8, frequencies_cutoff=lambda r: .5*r, rings=[0, 2, 4])
conv.eval()

freqs = defaultdict(lambda : [])

for p in conv.basisexpansion.get_basis_info():
    freqs[p['radius']].append(p['frequency'])
    
R = len(freqs)
F = max(len(f) for f in freqs.values())

print(F, R)

fig, axes = plt.subplots(nrows=R, ncols=F, constrained_layout=True, squeeze=False, figsize=(12,6))

V = .08

radius = {r: i for i, r in enumerate(sorted(freqs.keys()))}
radii = defaultdict(int)

for i in range(R):
    for j in range(F):
        axes[i][j].set_axis_off()
        axes[i][j].set_xticklabels([])
        axes[i][j].set_yticklabels([])

with torch.no_grad():
    for i, p in enumerate(conv.basisexpansion.get_basis_info()):  
        conv.weights.data.fill_(0.)
        conv.weights[i] = 1.
        conv.train().eval()


        filter = conv.filter[0, 0, ...].detach().numpy()

        row = radius[p['radius']]
        col = radii[p['radius']]
        radii[p['radius']] += 1

        axes[row][col].set_axis_off()
        axes[row][col].imshow(filter, vmin=-V, vmax=V)

plt.subplots_adjust(hspace=0)
# plt.savefig('basis_hd.pdf', bbox_inches='tight', dpi=100)

# Trivial -> Regular (all filters, grouped by irreps)


In [None]:
N = 9
g = gspaces.Rot2dOnR2(N)
in_type = nn.FieldType(g, [g.trivial_repr])
out_type = nn.FieldType(g, [g.regular_repr])

FT = g.regular_repr.change_of_basis

K = 23

conv = nn.R2Conv(in_type, out_type, K, sigma=0.8, maximum_offset=0, frequencies_cutoff=lambda r: 1*r, rings=[6])
conv.eval()

irreps_basis = defaultdict(list)

for i, p in enumerate(conv.basisexpansion.get_basis_info()):
    irreps_basis[(p['radius'], p['out_irrep'])].append(i)

irreps_basis = sorted(list(irreps_basis.items()))


V = .08

for (radius, irrep), filters_idxs in irreps_basis:
    
    F = len(filters_idxs)
    fig, axes = plt.subplots(nrows=F, ncols=N, squeeze=False, figsize=(12,6))
    
    for i in range(F):
        for j in range(N):
            axes[i][j].set_axis_off()
            axes[i][j].set_xticklabels([])
            axes[i][j].set_yticklabels([])

    with torch.no_grad():
        for row, i in enumerate(filters_idxs):  
            conv.weights.data.fill_(0.)
            conv.weights[i] = 1.
            conv.train().eval()

            filter = conv.filter[:, 0, ...].detach().numpy()
            # do FT of the output channels
    #         filter = np.einsum('fc,cxy->fxy', FT, filter)

            axes[row][0].set_title(f'{irrep}: basis {row}')

            for col in range(N):    
                axes[row][col].set_axis_off()
                axes[row][col].imshow(filter[col, ...], vmin=-V, vmax=V)

    plt.subplots_adjust(hspace=0)


# Trivial -> Fourier Transform of Regular (all filters, grouped by irreps)


In [None]:
N = 9
g = gspaces.Rot2dOnR2(N)
in_type = nn.FieldType(g, [g.trivial_repr])
out_type = nn.FieldType(g, [g.regular_repr])

FT = g.regular_repr.change_of_basis_inv

K = 23

conv = nn.R2Conv(in_type, out_type, K, sigma=0.8, maximum_offset=0, frequencies_cutoff=lambda r: r, rings=[6])
conv.eval()

irreps_basis = defaultdict(lambda : [])

for i, p in enumerate(conv.basisexpansion.get_basis_info()):
    irreps_basis[
        (p['radius'], p['out_irrep'])
    ].append(i)

irreps_basis = sorted(list(irreps_basis.items()))


V = .08

for (radius, irrep), filters_idxs in irreps_basis:
    
    F = len(filters_idxs)
    fig, axes = plt.subplots(nrows=F, ncols=N, constrained_layout=True, squeeze=False, figsize=(12,6))
    
    for i in range(F):
        for j in range(N):
            axes[i][j].set_axis_off()
            axes[i][j].set_xticklabels([])
            axes[i][j].set_yticklabels([])

    with torch.no_grad():
        for row, i in enumerate(filters_idxs):  
            conv.weights.data.fill_(0.)
            conv.weights[i] = 1.
            conv.train().eval()

            filter = conv.filter[:, 0, ...].detach().numpy()
            # do FT of the output channels
            filter = np.einsum('fc,cxy->fxy', FT, filter)

            axes[row][0].set_title(f'{irrep}: basis {row}')

            for col in range(N):    
                axes[row][col].set_axis_off()
                axes[row][col].imshow(filter[col, ...], vmin=-V, vmax=V)

    plt.subplots_adjust(hspace=0)


# on the expressivity of the filters

In [None]:
N = 8
g = gspaces.Rot2dOnR2(N)
in_type = nn.FieldType(g, [g.trivial_repr])
out_type = nn.FieldType(g, [g.regular_repr])

FT = g.regular_repr.change_of_basis

K = 9

conv = nn.R2Conv(in_type, out_type, K, sigma=0.8, frequencies_cutoff=lambda r: 3*r, rings=[0, 1, 2, 3, 4])
conv.eval()

basis = []

with torch.no_grad():
    for i, p in enumerate(conv.basisexpansion.get_basis_info()):
        conv.weights.data.fill_(0.)
        conv.weights[i] = 1.
        conv.train().eval()

        filter = conv.filter[0, 0, ...].detach().numpy()
        basis.append(filter)
    
basis = np.stack(basis, axis=0)
print(basis.shape)

basis /= np.linalg.norm(basis.reshape(-1, 81), axis=1).reshape(-1, 1, 1)

In [None]:
x = np.random.randn(9, 9)

x[:, 4] = 1.
x[4, :] = 1.
for i in range(x.shape[0]):
    x[i,i] = 1.5
    x[i, x.shape[0]-1-i] = 1.75
    
x[0:3,0:3] = 2

plt.imshow(x.squeeze())

In [None]:
w = basis.reshape(-1, 81) @ x.reshape(81, 1)

x_rec = basis.reshape(-1, 81).T @ w
x_rec = x_rec.reshape(9, 9)

plt.imshow(x_rec)