# Parity and SphericalTensor

In [None]:
import torch
from e3nn.io import SphericalTensor
import plotly.graph_objects as go

axis = dict(
    showbackground=False,
    showticklabels=False,
    showgrid=False,
    zeroline=False,
    title='',
    nticks=3,
)

layout = dict(
    width=800,
    height=300,
    scene=dict(
        xaxis=dict(
            **axis,
            range=[-4, 4]
        ),
        yaxis=dict(
            **axis,
            range=[-1, 1]
        ),
        zaxis=dict(
            **axis,
            range=[-1, 1]
        ),
        aspectmode='manual',
        aspectratio=dict(x=4, y=1, z=1),
        camera=dict(
            up=dict(x=0, y=0, z=1),
            center=dict(x=0, y=0, z=0),
            eye=dict(x=0, y=-5, z=0),
            projection=dict(type='orthographic'),
        ),
    ),
    paper_bgcolor="rgba(0,0,0,0)",
    plot_bgcolor="rgba(0,0,0,0)",
    margin=dict(l=0, r=0, t=0, b=0)
)

cmap_bwr = [[0, 'rgb(0,50,255)'], [0.5, 'rgb(200,200,200)'], [1, 'rgb(255,50,0)']]
    
def plot(traces):
    traces = [go.Surface(**d, colorscale=cmap_bwr, cmin=-4, cmax=4) for d in traces]
    fig = go.Figure(data=traces, layout=layout)
    fig.show()

In [None]:
lmax = 6
x = torch.randn((lmax + 1)**2)

x = torch.stack([
    SphericalTensor(lmax, p_val, p_arg).D_from_matrix(-torch.eye(3)) @ x
    for p_val in [+1, -1]
    for p_arg in [+1, -1]
])
centers = torch.tensor([
    [-3.0, 0.0, 0.0],
    [-1.0, 0.0, 0.0],
    [1.0, 0.0, 0.0],
    [3.0, 0.0, 0.0],
])

$[Pf](x) = ...$
- `p_val=1, p_arg=1 => ` $f(x)$
- `p_val=1, p_arg=-1 => ` $f(-x)$
- `p_val=-1, p_arg=1 => ` $-f(x)$
- `p_val=-1, p_arg=-1 => ` $-f(-x)$

In [None]:
st = SphericalTensor(lmax, 1, 1)
plot(st.plotly_surface(x, centers=centers, radius=False))