# Moments of $O(3)$ Irrep Distributions (Part 1)
> A discussion on the mathematical properties of distributions of irreps.

- toc: true 
- badges: true
- comments: true
- categories: [jupyter]
- image: images/blob.gif
- use_plotly: true

The aim of this post is introduce that distributions of irreps of $O(3)$ can be described by higher order irreps of $O(3)$. For an introduction on $O(3)$ irreps see [this page from the `e3nn` docs](https://docs.e3nn.org/en/stable/guide/irreps.html).

## [Spherical harmonics](https://en.wikipedia.org/wiki/Spherical_harmonics) can describe distributions of 3D unit vectors.

As we saw in a previous post, we can use spherical harmonics to describe signals on the sphere. We can also use them to describe distributions of unit vectors over angle. By increasing $L_{max}$ we can describe distributions with higher and higher accuracy.

Note that spherical harmonic expansions can positive and negative valued, whereas distributions typically are only positive valued. One can opt to apply a `ReLu` to the signal to ensure no negative values.

Below we will show examples of how spherical harmonic expansions look for "distributions" of one (top row) and two (bottom row) example vectors and increasing $L_{max}$. As we will see, increasing $L_{max}$ increases the angular resolution of the distribution.


In [39]:
#hide_input
import torch
import e3nn
from e3nn import o3, io
import plotly
from plotly.subplots import make_subplots
import plotly.graph_objects as go


rows = 2
cols = 4
specs = [[{'is_3d': True} for i in range(cols)]
         for j in range(rows)]
fig = make_subplots(
    rows=rows, cols=cols, specs=specs, horizontal_spacing=0, vertical_spacing=0)

Ls = [1, 3, 7]

one_vector = torch.tensor([0., 0., 1.]).reshape(1, 3)
two_vectors = torch.tensor([[0., 0., 1.], [-1., 1., 1.]])
two_vectors /= two_vectors.norm(dim=-1, keepdim=True)

for i, L in enumerate(Ls):
    sph = io.SphericalTensor(L, p_val=1, p_arg=-1)
    sig = sph.with_peaks_at(one_vector)
    trace = go.Surface(**sph.plotly_surface(sig, radius=False, relu=True, res=50)[0])
    trace.showscale=False
    fig.add_trace(trace, row=1, col=i + 1)
    
for i, L in enumerate(Ls):
    sph = io.SphericalTensor(L, p_val=1, p_arg=-1)
    sig = sph.with_peaks_at(two_vectors)
    trace = go.Surface(**sph.plotly_surface(sig, radius=False, relu=True, res=50)[0])
    trace.showscale=False
    fig.add_trace(trace, row=2, col=i + 1)

eye = 2

fig.update_scenes(
    xaxis = dict(range=[-1,1]),
    yaxis = dict(range=[-1,1]),
    zaxis = dict(range=[-1,1]),
    aspectmode="cube",
    camera = dict(eye=dict(x=eye, y=eye, z=eye))
)

fig.update_layout(autosize=False, width=800, height=600, margin=dict(l=0, r=0))

# Add labels
for i, L in enumerate(Ls):
    fig.add_annotation(
        x=i / len(Ls) * 0.8 + 0.09, y=0.95, 
        text='<b>$L_{max}='+'{}$</b>'.format(L), showarrow=False,
    )

fig.show()

## Why does this work? High order spherical harmonics are higher moments of the distribution.

Moments describe the shape of a function. [This Wikipedia article](https://en.wikipedia.org/wiki/Moment_(mathematics)) gives some great examples for mass densities and probability distributions.

The "raw" $n^{th}$ moment is given by $\left<x^n\right>$. Given this, what are the "moments" of a set of 3D vectors. Let's take a representation theory approach and just determine how these moments transform under $O(3)$ symmetry. 

Let $x$ transform as a 3D vector. In `e3nn` lingo, this is equivalent to `o3.Irrep('1o')` where `1` indicates $L=1$ (angular frequency $1$) and `o` odd parity (changes sign under inversion).

The zeroth moment for a unit vector is trivial, it's just 1 which is a scalar, i.e. `o3.Irrep('0e')`. The first moment for a unit vector is the sum of all possible vectors weighted by their probabilty which also transforms as a vector, `o3.Irrep('1o')`. What about the second moment? For this we can use the `e3nn`'s `o3.ReducedTensorProduct` to compute the irreps of this product using the fact that we are taking the $n^{th}$ product of the same vector with itself.

In [2]:
from e3nn import o3
o3.ReducedTensorProducts('ij=ji', i=o3.Irreps('1o'))

ReducedTensorProducts(
    in: 1x1o times 1x1o
    out: 1x0e+1x2e
)

So we get a scalar `o3.Irrep('0e')` and an `o3.Irrep('2e')`. What does the scalar coorespond to? This is just the norm of the vector, which is already included in the zeroth moment, so the only thing that is new is the `o3.Irrep('2e')`. Let's try this for the $3^{rd}$ moment.

In [3]:
o3.ReducedTensorProducts('ijk=jik=ikj', i=o3.Irreps('1o'))

ReducedTensorProducts(
    in: 1x1o times 1x1o times 1x1o
    out: 1x1o+1x3o
)

So we get a vector `o3.Irrep('1o')` and an `o3.Irrep('3o')`. What does the vector coorespond to? This is just the norm of the vector times the vector, which is already included in the $1^{st}$ moment, so the only thing that is new is the `o3.Irrep('3o')`.

As you can imagine, this keeps going for higher moments and what we see is that to get the $n^{th}$ "independent" moment for a distribution of unit vectors on the sphere is the $L = n$ spherical harmonic.

## But what if I have an irrep other than `o3.Irrep('1o')`? This still works!

Let's take two cases, scalars `o3.Irrep('0e')` and `o3.Irrep('2e')` (things that transform like $d$-orbitals or symmetric traceless $3\times3$ matrices. Note, in the case of irreps that have $L$s that we've seen before but opposite parity, you simply have to make sure that the irreps of your moments have the corresponding parity.

For the scalars case, we simple revert to the usual definition of raw moments. Done!

For the `o3.Irrep('2e')` we can build our intuition using the same exercise we did above. However, in this case, rather than unit vectors on $S^2$ we will instead think of a unit vectors on $S^4$, the sphere in 5 dimensions.

So let's think about the action of $O(3)$ on 5 dimensional vectors. If we pick a unit vector on $S^4$, we will not be able to rotate it to any other unit vector, $O(3)$ simply does not span all rotations in $O(5)$, or more formally $O(3)$ is a subset of $O(5), $O(3) \subset O(5)$. 

### Example: A vector of L=2, m=0 does not span $S^4$ under $O(3)$

In the plot below, we demonstrate how we can't any $L=2$ vector to any other $L=2$ vector with only $O(3)$. First, we apply 500 different random rotations to an $L=2$ vector `[0., 0., 1., 0., 0.]` or in other words, the polynomial $L=2, m=0$. Then, we plot the 500 resulting vectors on the \binom(5,3) different choices for plotting 3 of the 5 vector components. For example, the first plot will show the first three elements of the vector plotted in three dimensions, the second the first, second, and fourth, etc. The magnitude of the vector in those three out of the five dimensions is indicated by color (darker = 0, lighter = 1). If the rotations of $O(3)$ fully spanned the space, we would see solid spheres in all plots.

In [35]:
#hide_input

from itertools import combinations

l = 2
x = torch.zeros(2 * l + 1)
x[2] = 1. # L=2, m=0 # For non-accidental reasons, this looks like the projective plane
# x[0] = 1. # L=2, m=-2
# x[1] = 1. # L=2, m=-2

max_iter = 500
xs = []

for i in range(max_iter):
    angles = o3.rand_angles()
    rot = o3.wigner_D(l, *angles)
    xs.append(torch.einsum('ij,j->i', rot, x))
    
xs = torch.stack(xs, dim=0)

opacity = 0.1

projections = []

for x,y,z in combinations(range(5), 3):
    projections.append(
        go.Scatter3d(x=xs[:, x], y=xs[:, y], z=xs[:, z], 
                     mode="markers", opacity=opacity, marker=dict(
                         color=xs[:, [x, y, z]].norm(2, -1), size=3
                     ))
    )
    
    
row, col = 2, 5
specs = [[{'type': 'scene'} for i in range(col)] for j in range(row)]
fig = make_subplots(row, col, specs=specs)

for p in projections:
    fig.add_traces(projections, 
                   rows=[1]*5 + [2]*5, 
                   cols=list(range(1, 6)) + list(range(1, 6)))
    
# Add labels
for i, (x,y,z) in enumerate(combinations(range(5), 3)):
    fig.add_annotation(
        x=(i % 5) / 5 * 1.1 + 0, y=1 - 0.5 * (i // 5), 
        text='<b>{},{},{}</b>'.format(x, y, z), showarrow=False,
    )

    
fig.update_layout(showlegend=False)

fig.show()

## Deriving $S^4$ harmonics
To represent any function on the $S^4$ sphere we can use the $S^4$ harmonics, but where do we get these functions? We can actually build them from tensor products of `o3.Irrep('2e')` objects. 

While we may not know the $S^4$ spherical harmonics, we can easily compute the dimensionality of them using a bit of combinatorics and the fact that $S^n$ harmonic functions satisfy the Laplace operator in $n+1$ dimensions. We use the formula from Ref. {% cite Frye2012-ac %} below. Note, there is also always the scalar irrep that has dimensionality 1 and is invariant under group action.

Because of Schur's Lemma (Check this?), we can uniquely determine which irreps for the harmonics of $O(5)$ break down into which irreps of the harmonics of $O(3)$. Let's see this in practice. This means once we've separated out which $O(3)$ irreps make which $O(5)$ irreps we can use tensor products to construct the $S^4$ harmonics.

In [5]:
# From https://arxiv.org/pdf/1205.3548.pdf
import math

def num_harm_poly(n_var, degree):
    return math.comb(n_var + degree - 2, degree) + math.comb(n_var + degree - 3, degree - 1)

First we do a simple check if this formula reproduces the well know $2 L + 1$ spherical harmonics for degree $L$

In [6]:
assert num_harm_poly(n_var=3, degree=1) == 3
assert num_harm_poly(n_var=3, degree=2) == 5
assert num_harm_poly(n_var=3, degree=3) == 7

Next, let's compute the dimensionality of the harmonic functions of $S^4$ for several degrees $J$ (using $J$ to distinguish from $L$).

In [7]:
print("-- S^4 harmonics --")
degree_max = 6
n_var = 5
print("degree: \t", list(range(1, degree_max + 1)))
print("dimenson:\t", [num_harm_poly(n_var, degree=i) for i in range(1, degree_max + 1)])

-- S^4 harmonics --
degree: 	 [1, 2, 3, 4, 5, 6]
dimenson:	 [5, 14, 30, 55, 91, 140]


In [8]:
from e3nn import o3
rtp = o3.ReducedTensorProducts('ij=ji', i=o3.Irreps('2e'))
print(rtp)
print("Dimension of irreps_out:\t", rtp.irreps_out.dim)

ReducedTensorProducts(
    in: 1x2e times 1x2e
    out: 1x0e+1x2e+1x4e
)
Dimension of irreps_out:	 15


(fix discussion to either include or drop parity)

Looking that dimensionality list above we can break down

$$
\begin{align}
15 = 1 + 14 &= (J=0) \oplus (J=2) \\ 
&= (L=0) \oplus \left((L=2) \oplus (L=4)\right)
\end{align}
$$

Let's take a moment to understand what this means. Under $O(3)$ symmetry irreps of $O(5)$ become <b>reducible</b> and are broken up into independent vector spaces. The $O(5)$ irrep $J=2$ reduces to the $O(3)$ irreps $(L=2) \oplus (L=4)$.

(Need to check parity and maybe say something about it). Okay, so we have the $0^{th} - 2^{nd}$ moments covered, what about. What about the $3^{rd}$ moment? 

In [9]:
from e3nn import o3
rtp = o3.ReducedTensorProducts('ijk=jik=ikj', i=o3.Irreps('2e'))
print(rtp)
print("Dimension of irreps_out:\t", rtp.irreps_out.dim)

ReducedTensorProducts(
    in: 1x2e times 1x2e times 1x2e
    out: 1x0e+1x2e+1x3e+1x4e+1x6e
)
Dimension of irreps_out:	 35


$$
\begin{align}
35 = 5 + 35 &= (J=1) \oplus (J=3) \\ 
&= (L=2) \oplus \left((L=0) \oplus (L=3) \oplus (L=4) \oplus (L=6)\right)
\end{align}
$$

In [10]:
from e3nn import o3
rtp = o3.ReducedTensorProducts('ijkl=jikl=ikjl=ijlk', i=o3.Irreps('2e'))
print(rtp)
print("Dimension of irreps_out:\t", rtp.irreps_out.dim)

ReducedTensorProducts(
    in: 1x2e times 1x2e times 1x2e times 1x2e
    out: 1x0e+2x2e+2x4e+1x5e+1x6e+1x8e
)
Dimension of irreps_out:	 70


$$
\begin{align}
70 = 1 + 14 + 55 &= (J=0) \oplus (J=2) \oplus (J=4) \\ 
&= (L=0) \oplus \left((L=2) \oplus (L=4)\right) \oplus \left((L=2) \oplus (L=4) \oplus (L=6) \oplus (L=8)\right)
\end{align}
$$

In [11]:
from e3nn import o3
rtp = o3.ReducedTensorProducts('ijklm=jiklm=ikjlm=ijlkm=ijkml', i=o3.Irreps('2e'))
print(rtp)
print("Dimension of irreps_out:\t", rtp.irreps_out.dim)

ReducedTensorProducts(
    in: 1x2e times 1x2e times 1x2e times 1x2e times 1x2e
    out: 1x0e+2x2e+1x3e+2x4e+1x5e+2x6e+1x7e+1x8e+1x10e
)
Dimension of irreps_out:	 126


$$ 126 = 5 + 30 + 91 = (J=1) \oplus (J=3) \oplus (J=5) $$

In [12]:
from e3nn import o3
rtp = o3.ReducedTensorProducts('ijklmn=jiklmn=ikjlmn=ijlkmn=ijkmln=ijkmln=ijklnm',
                               i=o3.Irreps('2e'))
print(rtp)
print("Dimension of irreps_out:\t", rtp.irreps_out.dim)

ReducedTensorProducts(
    in: 1x2e times 1x2e times 1x2e times 1x2e times 1x2e times 1x2e
    out: 2x0e+2x2e+1x3e+3x4e+1x5e+3x6e+1x7e+2x8e+1x9e+1x10e+1x12e
)
Dimension of irreps_out:	 210


$$ 210 = 1 + 14 + 55 + 140 = (J=0) \oplus (J=2) \oplus (J=4) \oplus (J=6) $$

What we notice is that there is a structure multiply $S^4$ harmonics by J=1, a structure that is analogous to $S^2$ spherical harmonics shown in Figure 1 of Ref. {%cite Geiger2022-yq %}.

<!-- ![](notebook_images/e3nn_Figure_1.png){:height="36px" width="36px"}. -->

{% bibliography --cited %}