In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from functools import partial
import numpy as np
import torch
import e3nn
from spherical import plot_data_on_grid, SphericalTensor
import e3nn.o3 as o3
import e3nn.rs as rs

import plotly
import plotly.graph_objects as go

In [None]:
def FixedGaussianRadialModel(max_radius, number_of_basis, min_radius=0.):
    spacing = (max_radius - min_radius) / number_of_basis
    radii = torch.linspace(min_radius, max_radius, number_of_basis)
    gamma = 1. / spacing
    
    def radial_function(x):
        shape = x.shape
        radial_shape = [1] * len(shape) + [number_of_basis]
        thing = radii.reshape(*radial_shape)
        x = x.unsqueeze(-1)
        return torch.exp(-gamma * (x - thing) ** 2)
    
    return radial_function

In [None]:
n_radial = 5
lmax = 3
Rs = [(n_radial, l) for l in range(lmax)]
coefficients = torch.randn(rs.dim(Rs))

In [None]:
radial_function = FixedGaussianRadialModel(3.0, n_radial)

In [None]:
r = torch.randn(10,5)
radial_function(r).shape

In [None]:
sphten = SphericalTensor(coefficients, Rs)

In [None]:
x, f = sphten.plot_with_radial(5.0, radial_model=radial_function)

In [None]:
plot_max = float(f.abs().max())
trace = go.Volume(
    x=x[:,0], y=x[:,1], z=x[:,2], value=f,
    isomin=-plot_max,
    isomax=plot_max,
    opacity=0.3, # needs to be small to see through all surfaces
    surface_count=50, # needs to be a large number for good volume rendering
    colorscale='RdBu'
)
go.Figure([trace])