# Datatypes in E(3) Neural Networks

### using the `e3nn` repository

## tutorial by: Tess E. Smidt
## code by: 

[![DOI](https://zenodo.org/badge/116704656.svg)](https://zenodo.org/badge/latestdoi/116704656)
```
@misc{mario_geiger_2019_3348277,
  author       = {Mario Geiger and
                  Tess Smidt and
                  Wouter Boomsma and
                  Maurice Weiler and
                  MichaÅ‚ Tyszkiewicz and
                  Jes Frellsen and
                  Benjamin K. Miller and
                  Josh Rackers},
  title        = {e3nn/e3nn: Point cloud support},
  month        = jul,
  year         = 2019,
  doi          = {10.5281/zenodo.3348277},
  url          = {https://doi.org/10.5281/zenodo.3348277}
}
```

# Our data types are geometry and features on that geometry expressed as geometric tensors

Most properties of physical systems are expressed in terms of geometric tensors. Scalars (mass), vectors (velocities, forces, polarizations), matrices (polarizability, moment of inertia) and higher rank tensors are all geometric tensors.

## Geometric tensors: Cartesian tensors and Spherical Tensors

Geometric tensors are commonly expressed with Cartesian indicies $(x, y, z)$ -- we will call these Cartesian tensors. However, there is an equally expressive way of representing geometric tensors as spherical tensors.

Whereas for Cartesian tensors the indices can be interpreted as information along $(x, y, z)$, spherical tensors are index by which spherical harmonic they are associated with. These representations can be used interchangable.

We use spherical tensors in our network because our convolutional filters are expressed in terms of spherical harmonics -- *more about that later*. 

[Wikipedia has a great overview of spherical harmonics](https://en.wikipedia.org/wiki/Spherical_harmonics). As a quick recap, the spherical harmonics are the Fourier basis for functions on the unit sphere. They have two indices, most commonly called the "degree" $L$ and "order" $m$ and are commonly parameterized by spherical coordinate angles $\theta$ and $\phi$. 

$Y_{l}^{m}(\theta, \phi)$ for complex spherical harmonics or
$Y_{lm}(\theta, \phi)$ for real spherical harmonics.

In `e3nn`, we use [real spherical harmonics](https://en.wikipedia.org/wiki/Table_of_spherical_harmonics#Real_spherical_harmonics). There are $2 L + 1$ functions (indexed by $m$) for each $L$. Functions of degree $L$ have the same frequency. Note, that these frequencies must be integral (or half-integral for $SU(2)$) because of the periodic boundary conditions of the sphere.

## Representation Lists in `e3nn`
To keep track of which spherical tensor entries correspond to which spherical harmonic, we use representation lists, commonly saved as a variable `Rs`.

`Rs` is a list of tuples `(mult, L)` where `mult` is the multiplicity (or number of copies) and `L` is the degree of the spherical harmonic.

For example, the `Rs` of a single vector is 
`Rs_vec = [(1, 1)]`
and two vectors
`Rs_2vec = [(2, 1)]`

You will sometimes see an `Rs` with three integers in the tuple `(mult, L, parity)`, where the first two are the same as before and `parity` indicates whether that part of the tensor has equal `0` or opposite `1` parity as the spherical harmonic. All odd $L$ spherical harmonics have odd parity (they do change under parity) and all even $L$ spherical harmonics have even parity (they do NOT change under parity).

## Spherical Harmonics
First, let's draw the spherical harmonics using the `SphericalTensor` class defined in spherical.py. This is a handy helper class that I've written for this tutorial so we can quickly manipulate and plot spherical tensors.

In [None]:
%load_ext autoreload
%autoreload 2

import torch 
import numpy as np
import e3nn.o3 as o3
from e3nn.rs import tensor_product
from spherical import SphericalTensor # a small Signal class written for ease of handling Spherical Tensors
import plotly
from plotly.subplots import make_subplots

torch.set_default_dtype(torch.float64)

L_max = 3
rows = L_max + 1
cols = 2 * L_max + 1

specs = [[{'is_3d': True} for i in range(cols)]
         for j in range(rows)]
fig = make_subplots(rows=rows, cols=cols, specs=specs)

for L in range(L_max + 1):
    for m in range(0, 2 * L + 1):
        tensor = torch.zeros(2 * L + 1)
        tensor[m] = 1.0
        sphten = SphericalTensor(tensor, Rs=[(1, L)])
        row, col = L + 1, (L_max - L) + m + 1
        trace = sphten.plot(relu=False, n=60)
        if m != 2 * L_max:
            trace.showscale = False
        fig.add_trace(trace, row=row, col=col)

fig.show()

## Spherical harmonics as linear combination of monomials

To understand how the spherical harmonics are grouped, it can be helpful to think of the spherical harmonics as being built from monomials proportional to $x^\alpha y^\beta z^\gamma$ where $L = \alpha + \beta + \gamma$. For $L=0$ there is only 1 spherical harmonic and 1 monomial ($1$), for $L=1$ there are 3 spherical harmonics and 3 monomials $(y, z, x)$, for $L=2$ there are 5 spherical harmonics but 6 monomials $(x^2, y^2, z^2, xy, yz, zx)$. 

How do we go from 6 to 5? Well, there's a hidden redundancy in these 6 monomials. $x^2$, $y^2$, and $z^2$ are mixtures of L=0 and L=2 which stems from the fact that $x^2 + y^2 + z^2 = r^2$ which is a scalar. We can calculate how these monomials project onto spherical tensors.

In [None]:
empty = torch.zeros(3, 3)
x2, y2, z2 = empty.clone(), empty.clone(), empty.clone()
x2[0, 0], y2[1, 1], z2[2, 2] = 1, 1, 1 # Create tensor representation of x^2, y^2 and z^2

perm_x2 = x2.clone()[torch.tensor([1, 2, 0]).unsqueeze(1), torch.tensor([1, 2, 0]).unsqueeze(0)]
perm_y2 = y2.clone()[torch.tensor([1, 2, 0]).unsqueeze(1), torch.tensor([1, 2, 0]).unsqueeze(0)]
perm_z2 = z2.clone()[torch.tensor([1, 2, 0]).unsqueeze(1), torch.tensor([1, 2, 0]).unsqueeze(0)]

# Representation lists that we use in `e3nn` indicate the order of coeffients in a spherical tensor
# A component of a representation list describes the multiplicity and degree (mult, L)
Rs_vec = [(1, 1)] # Representation list of a single vector
Rs_3x3, C = tensor_product(Rs_vec, Rs_vec) # Rep vec and Clebsch-Gordon
C = C.permute(1,2,0)

print("x^2, y^2, and z^2 are mixtures of L=0 and L=2")
print("SH:", "  1      y      z      x      xy     yz     *      zx     %", )
print("x^2", torch.einsum('ijk,ij->k', C, perm_x2).detach().numpy().round(3))
print("y^2", torch.einsum('ijk,ij->k', C, perm_y2).detach().numpy().round(3))
print("z^2", torch.einsum('ijk,ij->k', C, perm_z2).detach().numpy().round(3))
print("* == 2z^2 - x^2 - y^2")
print("% == x^2 - y^2")

## 3x3 Matrix as a Cartesian and Spherical tensor

Geometric tensors rotate predictably under rotation. Let's take the example of a 3 x 3 matrix, a Cartesian tensor of rank 2.

$M_{ij} = 
\begin{pmatrix}
    \alpha_{xx} & \alpha_{xy} & \alpha_{xz} \\
    \alpha_{yx} & \alpha_{yy} & \alpha_{yz}\\
    \alpha_{zx} & \alpha_{zy} & \alpha_{zz}
\end{pmatrix}$

where $i$ and $j$ are indexed as $(x, y, z)$. 

We can also express this matrix as a spherical harmonic tensor. The way to do this conversion is to recognize that $L=1$ spherical tensor has the same indices as $(x, y, z)$ EXCEPT they are permuted as $(y, z, x)$.

In [None]:
M = torch.randn(3, 3)
# Permute indices to ('y', 'z', 'x') to be compatible with spherical harmonic convention
perm_M = M.clone()[torch.tensor([1, 2, 0]).unsqueeze(1), torch.tensor([1, 2, 0]).unsqueeze(0)]

import matplotlib.pyplot as plt
%matplotlib inline
fig, axes = plt.subplots(1, 2, figsize=(8, 5));
axes[0].matshow(M)
axes[0].set_title('M');
axes[0].get_xaxis().set_visible(False)
axes[0].get_yaxis().set_visible(False)
for i, x in enumerate(["x", "y", "z"]):
    for j, y in enumerate(["x", "y", "z"]):
        axes[0].text(j - 0.2, i + 0.1, x + y, {'color':'white', 'fontsize': 20})
    
im = axes[1].matshow(perm_M)
axes[1].set_title('M permuted for both indices');
axes[1].get_xaxis().set_visible(False)
axes[1].get_yaxis().set_visible(False)
for i, x in enumerate(["y", "z", "x"]):
    for j, y in enumerate(["y", "z", "x"]):
        axes[1].text(j - 0.2, i + 0.1, x + y, {'color':'white', 'fontsize': 20})
        
fig.colorbar(im, ax=axes[:], shrink=0.75);

## Rotating Cartesian and Spherical tensors
Our Cartesian matrix can be rotated with a 3D rotation matrix R applied to each Cartesian index.

$R_{ki} R_{lj} M_{ij} = M_{kl}$

As shown above, we can permute our Cartesian indices $(x, y, z)$ into those of L=1 spherical harmonics $(y, z, x)$.

We can even simplify the matrix by combining its two indices into a single index using the [Clebsch-Gordon coefficients](https://en.wikipedia.org/wiki/Clebsch%E2%80%93Gordan_coefficients).

$I_k = C_{ijk} M_{ij}$

where $C_{ijk}$ are the Clebsch-Gordon tensor. See Griffiths -- Introduction to Quantum Mechanics, Ch. 4 for more details.

We can then rotate this index using [Wigner D-matrices](https://en.wikipedia.org/wiki/Wigner_D-matrix), rotation matrices for the irreducible basis.

$I_{i} = D_{ij} I_j$

We can then convert back to the 3x3 matrix format to see that these rotations are indeed equivalent.

In [None]:
# random rotation Euler angles alpha, beta, gamma
angles = torch.rand(3) * torch.tensor([np.pi, 2 * np.pi, np.pi])
rot = o3.rot(*angles)

rotated_M = torch.einsum('ki,ij,lj->kl', rot, M, rot)

Rs_vec = [(1, 1)] # Representation list of a single vector
Rs_3x3, C = tensor_product(Rs_vec, Rs_vec)
C = C.permute(1,2,0)
print("Single index representation of 3x3 matrix:", Rs_3x3)
print("Shape of Clebsch-Gordon tensor:", C.shape)

# Wigner D matrix -- rotation matrix for irreducible representations
wignerD = o3.direct_sum(*[o3.irr_repr(l, *angles) for mul, l, parity in Rs_3x3 for _ in range(mul)])
print("Shape of Wigner-D matrix:", wignerD.shape)

# Convert matrix to representation vector
I = torch.einsum('ijk,ij->k', C, perm_M)
# Rotate representation vector
rotated_I = torch.einsum('ij,j->i', wignerD, I)

# And we can convert this back to our original format to compare
rotated_perm_M = torch.einsum('ijk,k->ij', C, rotated_I)
rotated_M_prime = rotated_perm_M.clone()
rotated_M_prime[torch.tensor([1, 2, 0]).unsqueeze(1), torch.tensor([1, 2, 0]).unsqueeze(0)] = rotated_perm_M.clone()

In [None]:
# Visualize M and rotated_M
import matplotlib.pyplot as plt
%matplotlib inline
fig, axes = plt.subplots(1, 3, figsize=(12, 6));
axes[0].matshow(M)
axes[0].set_title('M');
axes[0].get_xaxis().set_visible(False)
axes[0].get_yaxis().set_visible(False)
    
axes[1].matshow(rotated_M)
axes[1].set_title('Rotated M from Cartesian');
axes[1].get_xaxis().set_visible(False)
axes[1].get_yaxis().set_visible(False)

im = axes[2].matshow(rotated_M_prime)
axes[2].set_title("Rotated M from Spherical");
axes[2].get_xaxis().set_visible(False)
axes[2].get_yaxis().set_visible(False)
        
fig.colorbar(im, ax=axes[:], shrink=0.75);

## We can interpret our spherical harmonic tensors as components of traditional geometric tensor or as geometry itself.

In [None]:
rows, cols = 1, 1
specs = [[{'is_3d': True} for i in range(cols)]
         for j in range(rows)]
fig = make_subplots(rows=rows, cols=cols, specs=specs)

L_max = 6
Rs = [(1, L) for L in range(L_max + 1)]
sum_Ls = sum(2 * L + 1 for mult, L in Rs) 

# Random spherical tensor up to L_Max
rand_sph_tensor = torch.randn(sum_Ls)

sphten = SphericalTensor(rand_sph_tensor, Rs)

trace = sphten.plot(relu=False, n=60)
fig.add_trace(trace, row=1, col=1)
fig.show()

In [None]:
# Projection of tetrahedron on origin

rows, cols = 1, 1
specs = [[{'is_3d': True} for i in range(cols)]
         for j in range(rows)]
fig = make_subplots(rows=rows, cols=cols, specs=specs)

L_max = 6
tetra_coords = torch.tensor( # The easiest way to construct a tetrahedron is using opposite corners of a box
    [[0., 0., 0.], [1., 1., 0.], [1., 0., 1.], [0., 1., 1.]]
)
tetra_coords -= tetra_coords.mean(-2)

fig = make_subplots(rows=rows, cols=cols, specs=specs)

sphten = SphericalTensor.from_geometry(tetra_coords, L_max)

trace = sphten.plot(relu=False, n=60)
fig.add_trace(trace, row=1, col=1)
fig.show()