# Interacting equivariant objects: Linear, TensorProducts, S2Grid, Equivariant Nonlinearities

Equivariant objects interact in more complex ways than invariant objects. Let's first consider multiplication. An invariant times an invariant produces another invariant. But, if interact two vectors via an outer product we get a $3\times3$ matrix. Inside this matrix there is the trace (which is invariant), the antisymmetric part (which transforms like a vector), and the symmetric traceless part. This outer product operation is more generally known as a tensor product.

In order to preserve equivariance, all operations we use must be equivariant.

In [None]:
import e3nn
from e3nn import o3

## Tensor Products and their many forms
Tensor products are the generalization of multiplication for tensor objects. Mathematically, the general tensor product is the product of two vector spaces $\otimes: X, Y \rightarrow X \times Y$ . For geometric tensors, we can express the vector spaces $X$ and $Y$ in terms of irreducible representations, thus we can also express the product space in terms of a new space of irreducible representations $X \times Y \rightarrow Z$.

In the general case, this product converves the number of degrees of freedom. For equivariant neural networks, we will include learnable parameters in the tensor product. We also will often omit parts of this product to reduce memory usage and improve computation speed. This is why our `o3.TensorProduct` class is more involved than a traditional tensor product. 

The `o3.TensorProduct` class is the workhorse of `e3nn`. It's very powerful and very flexible but can be challenging to grok. You will most often deal with subclasses of this class.

Let's start with the more traditional tensor product which is articulated with `o3.FullTensorProduct`.

In [None]:
outer_product_two_vectors = o3.FullTensorProduct('1o', '1o')
outer_product_two_vectors

In [None]:
outer_product_two_vectors.instructions

In [None]:
tp = o3.FullTensorProduct('5x1o + 4x3e', '10x2e')
tp.visualize()

We have several types of prebuilt tensor products which use different "instruction sets". The FullTensorProduct is what is most commonly thought of as the general tensor product, as used in quantum mechanics (in the addition of angular momentum). However, general tensor products are combinatorial as you interact more and more objects, in order to keep the cost down, we will use ElementwiseTensorProducts and other instruction subsets.

## ElementwiseTensorProduct
With `o3.ElementwiseTensorProduct` we can pair up irreps across two inputs and only compute tensor products within those pairs.

In [None]:
o3.ElementwiseTensorProduct('1x0e + 2x1o', '2x0e + 1x1o').visualize()

## FullyConnectedTensorProduct

In [None]:
o3.FullyConnectedTensorProduct('1x0e + 2x1o', '2x0e + 1x1o', irreps_out='5x0e + 5x1e').visualize()

## [Advanced] Using TensorProduct directly
In most cases, you will use the above subclasses of `TensorProduct`. However, you can also have complete control over which "paths" you are including by using `TensorProduct`. We do NOT recommend this for the new user.

For example, here is how to explicity re-write the above specific `o3.FullyConnectedTensorProduct` example with `o3.TensorProduct`.

In [None]:
o3.TensorProduct('1x0e + 2x1o', '2x0e + 1x1o', out='5x0e + 5x1e', 
                 instructions=[
                     # input 1 1x0e \otimes input 2 2x0e and apply weights to contribute to 5x0e output
                     (0, 0, 0, 'uvw', True), 
                     # input 1 2x1o \otimes input 2 1x1o and apply weights to contribute to 5x0e output
                     (1, 1, 0, 'uvw', True),
                     # input 1 2x1o \otimes input 2 1x1o and apply weights to contribute to 5x1e output
                     (1, 1, 1, 'uvw', True),
                 ]).visualize()

## Handling permutations of indices

Using `o3.ReducedTensorProducts`, we can additionally restrict tensor products to obey specific permutation symmetries of the input tensors. For example, this is helpful for computing the power spectrum and bispectrum.

In [None]:
from e3nn import io
import torch
import matplotlib.pyplot as plt

lmax = 4
p_val, p_arg = 1, -1  # parity of signal on sphere, parity of vector to sphere from origin

sph = io.SphericalTensor(lmax, p_val, p_arg)
peaks = torch.tensor([[1., 0., 0.], [-1., 0., 0.]])
signal = sph.with_peaks_at(peaks)

power_spectrum = o3.ReducedTensorProducts(
    'ij=ji', i=o3.Irreps.spherical_harmonics(lmax), 
    set_ir_out=list(o3.Irrep.iterator(lmax=0)),  # Only want outputs that are scalar or pseudoscalars
    set_ir_mid=list(o3.Irrep.iterator(lmax))  # Don't compute contributions with greater than lmax
)
power = lambda x: power_spectrum(x, x)

bi_spectrum = o3.ReducedTensorProducts(
    'ijk=jik=kji', i=o3.Irreps.spherical_harmonics(lmax), 
    set_ir_out=list(o3.Irrep.iterator(lmax=0)),  # Only want outputs that are scalar or pseudoscalars
    set_ir_mid=list(o3.Irrep.iterator(lmax))  # Don't compute contributions with greater than lmax
)
bi = lambda x: bi_spectrum(x, x, x)

In [None]:
fig, ax = plt.subplots(3, 1)

val = 0.1

ax[0].set_title('Signal')
ax[0].imshow(signal[None, :], cmap='RdBu', vmin=-val, vmax=val)
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[0].set_xlabel(o3.Irreps.spherical_harmonics(lmax))

ax[1].set_title('Power Spectrum')
ax[1].imshow(power(signal)[None, :], cmap='RdBu', vmin=-val, vmax=val)
ax[1].set_xticks([])
ax[1].set_yticks([])
ax[1].set_xlabel(power_spectrum.irreps_out);

ax[2].set_title('Bispectrum')
ax[2].imshow(bi(signal)[None, :], cmap='RdBu', vmin=-val, vmax=val)
ax[2].set_xticks([])
ax[2].set_yticks([])
ax[2].set_xlabel(bi_spectrum.irreps_out);

## Linear Layers
Equivariant Linear layers are able to mix channels between the same irreps, but not across irreps. Linear layers are actually built with tensor products -- a scalar of `1.0` stands in as the second input. Weight matrices are contrained in the `o3.Linear.tp` object.

In [None]:
irreps_in = '5x0e + 6x1o'
irreps_out = '7x0e + 5x1o'
linear = o3.Linear(irreps_in, irreps_out)

In [None]:
linear.tp.visualize()

In [None]:
linear.tp.instructions

## All learnable parameters must be scalars

In order to preserve equivariance, all learnable parameters in the model must be scalars. For $f: X, W \rightarrow Y$,

$f(D_X(g)x, w) = D_Y(g)f(x, w), \text{ if } \forall g \in G, D_W(g) = I$

For the irreps of $O(3)$, this condition is only met if $w$ is purely scalar.

Otherwise, we would need to rotate our learnable features with our input data, which largely defeats the purpose of creating an equivariant neural network.