In [None]:
from e3nn import o3
from e3nn.o3 import Irreps
from e3nn.nn.batchnorm import BatchNorm
from e3nn.nn.gate import Gate
import torch.tensor as t
import matplotlib.pyplot as plt
import torch

import sys
sys.path.append('..')

In [None]:
o3.FullTensorProduct(Irreps('1x0e+1x1e+1x2e'), Irreps('3x0e+3x1e+3x2e')).visualize()

In [None]:
irreps = Irreps.spherical_harmonics(11)

In [None]:
BatchNorm(irreps)

In [None]:
Gate("16x0o", [torch.tanh], "32x0o", [torch.tanh], "16x1e+16x1o")

In [None]:
Gate("16x0e", [torch.tanh], "32x0e", [torch.tanh], "16x1e+16x1e")

In [None]:
rot = - o3.rand_matrix()

In [None]:
D = irreps.D_from_matrix(rot)

In [None]:
plt.imshow(D[121:,121:], cmap='bwr')
plt.show()

In [None]:
import torch
import torch.nn as nn

from e3nn import o3
from e3nn.o3 import FullyConnectedTensorProduct, Linear
from e3nn.nn.batchnorm import BatchNorm
from e3nn.nn.gate import Gate
from e3nn.math import soft_one_hot_linspace

import matplotlib.pyplot as plt

In [None]:
class Convolution(torch.nn.Module):
    r"""convolution on voxels
    Parameters
    ----------
    irreps_in : `Irreps`
    irreps_out : `Irreps`
    irreps_sh : `Irreps`
        set typically to ``o3.Irreps.spherical_harmonics(lmax)``
    size : int
    steps : tuple of int
    """
    def __init__(self, irreps_in, irreps_out, irreps_sh, size, steps=(1, 1, 1)):
        super().__init__()

        self.irreps_in = o3.Irreps(irreps_in)
        self.irreps_out = o3.Irreps(irreps_out)
        self.irreps_sh = o3.Irreps(irreps_sh)
        self.size = size
        self.num_rbfs = self.size

        # self-connection
        self.sc = Linear(self.irreps_in, self.irreps_out)

        # connection with neighbors
        r = torch.linspace(-1, 1, self.size)
        x = r * steps[0] / min(steps)
        x = x[x.abs() <= 1]
        y = r * steps[1] / min(steps)
        y = y[y.abs() <= 1]
        z = r * steps[2] / min(steps)
        z = z[z.abs() <= 1]
        lattice = torch.stack(torch.meshgrid(x, y, z), dim=-1)  # [x, y, z, R^3]
        self.register_buffer('d', lattice.norm(dim=-1))

        sh = o3.spherical_harmonics(self.irreps_sh, lattice, True, 'component')  # [x, y, z, irreps_sh.dim]
        self.register_buffer('sh', sh)

        self.tp = FullyConnectedTensorProduct(self.irreps_in, self.irreps_sh, self.irreps_out, shared_weights=False)

        self.weight = torch.nn.Parameter(torch.randn(self.num_rbfs, self.tp.weight_numel))

    def forward(self, x):
        r"""
        Parameters
        ----------
        x : `torch.Tensor`
            tensor of shape ``(batch, irreps_in.dim, x, y, z)``
        Returns
        -------
        `torch.Tensor`
            tensor of shape ``(batch, irreps_out.dim, x, y, z)``
        """
        sc = self.sc(x.transpose(1, 4)).transpose(1, 4)

        weight = soft_one_hot_linspace(
            x=self.d,
            start=0.0,
            end=1.0,
            number=self.num_rbfs,
            base='gaussian',
            endpoint=True,
        ) @ self.weight # [d X num_rbfs] @ [num_rbfs X tp_weight_numel] => [d X tp_weight_numel]
        
        weight = weight / (self.size ** (3/2))
        kernel = self.tp.right(self.sh, weight)  # [x, y, z, irreps_in.dim, irreps_out.dim]
        kernel = torch.einsum('xyzio->oixyz', kernel) # permute axes in format expected by conv3d
        return sc + 0.1 * torch.nn.functional.conv3d(x, kernel, padding=self.size // 2)

In [None]:
%load_ext autoreload
%autoreload 2
import pytorch_lightning as plt
from models.e3nn_models import e3nnCNN, Convolution

In [None]:
model = e3nnCNN.load_from_checkpoint('../logs/e3nn_cnn-1618260121/version_0/checkpoints/epoch=2-step=2399.ckpt')

In [None]:
layers = [ layer for layer in model.modules() 
     if isinstance(layer, Convolution)]

In [None]:
kernels = [ layer.tp.right(layer.sh, layer.emb @ layer.weight).detach() for layer in layers ]

In [None]:
def get_expanded_kernel_size(layer):
    tp_weights = np.prod(layer.weight.shape[1:])
    return layer.size**3 * tp_weights

In [None]:
def get_total_params(layer):
    return sum([
        get_expanded_kernel_size(layer),
        layer.weight.numel(),
        layer.emb.numel(),
        layer.sh.numel(),
        list(layer.sc.parameters())[0].numel()
    ])

In [None]:
for layer in layers:
    print(layer, get_total_params(layer))

In [None]:
for kernel in kernels:
    kernel[kernel==0] = -10

In [None]:
plot = k3d.plot(camera_auto_fit=True)

for i in range(4):
    plot += k3d.volume(
        kernels[3][...,-1, i].numpy().astype(np.float32),
        alpha_coef=1000,
        samples=600,
        color_range=[-.2,1],
        color_map=k3d.colormaps.paraview_color_maps.Coolwarm,
)

plot.display()

## Dot and cross product using tensor products

In [None]:
from e3nn.o3 import TensorProduct
import numpy as np

In [None]:
x = torch.tensor([0.,0.,1.])
y = torch.tensor([0.,1.,0.])

In [None]:
xs = torch.vstack([x,x,x])
ys = torch.vstack([y,y,y])

In [None]:
cross = TensorProduct(
    '1e', '1e', '1e',
    [
        (0,0,0, "uuu", False)
    ],
)

In [None]:
cross(xs, ys) * np.sqrt(2)

In [None]:
dot = TensorProduct(
    '1e', '1e', '0e',
    [
        (0,0,0, 'uuw', False)
    ]
)

In [None]:
dot(xs, xs) * np.sqrt(3)