In [3]:
import numpy as np
import jax.numpy as jnp
import jax
import jax.random as jr
import pandas as pd
import plotly.express as px
import matplotlib.pyplot as plt
import seaborn as sns
import rho_plus as rp
import wat

is_dark = False
theme, cs = rp.mpl_setup(is_dark)
rp.plotly_setup(is_dark)

In [121]:
import e3nn_jax as e3nn
from e3nn_jax import Irreps, IrrepsArray

rng = jr.key(29205)

sh_irreps = '0e + 1o + 2e + 3o + 4e + 5o + 6e'

ir_out = sum(e3nn.Irrep.iterator(4), start=e3nn.Irrep('0e'))


x = e3nn.normal('128x0e + 64x1o + 32x2e', rng, (32,1), normalize=False)
y = e3nn.normal('1o', rng, (32,4), normalize=True, normalization='norm')
y_sh = e3nn.spherical_harmonics(sh_irreps, y, normalize=True)
y_sh.array.shape

(32, 4, 49)

In [122]:
from cdv.utils import debug_structure, debug_stat
@jax.jit
def tp_standard(x, y):
    y_sh = e3nn.spherical_harmonics(sh_irreps, y, normalize=False)
    return e3nn.tensor_product(x, y_sh, filter_ir_out=ir_out)

out = {}
out['standard'] = tp_standard(x, y)
e3nn.mean(e3nn.norm(out['standard']))

1x0e+1x0e+1x0e+1x0e+1x0e+1x0e+1x0e+1x0e+1x0e
[0.79078025 1.4910177  1.5319812  1.9508361  2.0144486  2.307917
 2.3939414  2.616518   2.7189445 ]

In [123]:
%%timeit

jax.block_until_ready(tp_standard(x, y))

206 μs ± 4.16 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [124]:
from e3nn_jax.experimental.linear_shtp import shtp
from functools import partial

@jax.jit
def tp_so2(x, y):    
    @partial(e3nn.vmap, in_axes=(None, 1), out_axes=1)    
    @partial(e3nn.vmap, in_axes=(0, 0))    
    def inner_shtp(xi, yi):
        return shtp(xi, yi, ir_out)
    
    return inner_shtp(x[:, 0], y)

out['so2'] = tp_so2(x, y)

debug_structure(out);

In [125]:
%%timeit

jax.block_until_ready(tp_so2(x, y))

438 μs ± 2.32 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [133]:
@jax.jit
def tp_sh(x, y): 
    return e3nn.tensor_product_with_spherical_harmonics(x, y, degree=Irreps(sh_irreps).lmax).filter(lmax=4)

out['tpsh'] = tp_sh(x, y)

In [134]:
%%timeit

jax.block_until_ready(tp_sh(x, y))

649 μs ± 2.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [135]:
from functools import reduce
{
    k: e3nn.mean(e3nn.norm(v))
    for k, v in out.items()
}

{'standard': 1x0e+1x0e+1x0e+1x0e+1x0e+1x0e+1x0e+1x0e+1x0e
 [0.79078025 1.4910177  1.5319812  1.9508361  2.0144486  2.307917
  2.3939414  2.616518   2.7189445 ],
 'so2': 1x0e+1x0e+1x0e+1x0e+1x0e+1x0e+1x0e+1x0e+1x0e
 [0.7908689  0.92872524 1.2502549  0.95677906 1.2463126  0.95678645
  1.2462797  0.9567873  1.2462709 ],
 'tpsh': 1x0e+1x0e+1x0e+1x0e+1x0e+1x0e+1x0e+1x0e+1x0e
 [0.7908688 1.4911186 1.5320606 1.9495199 2.0145316 2.306254  2.3937762
  2.6144505 2.7187533]}

In [113]:
z1 = out['standard']
z2 = out['so2']

z1.filter('1o')[0].array / z2.filter('1o')[0].array

Array([[ 1.7305142 ,  1.7315321 ,  1.7308031 , ..., -0.7276232 ,
        -0.6896798 , -0.97105074],
       [ 1.7305272 ,  1.731149  ,  1.7306291 , ..., -0.6879926 ,
        -0.9553743 , -0.69765127],
       [ 1.7302281 ,  1.7318865 ,  1.7311018 , ..., -3.4879332 ,
        -1.022996  , -0.14467084],
       [ 1.7309673 ,  1.7316449 ,  1.7303954 , ..., -3.8500829 ,
        -0.75589716, -0.7644943 ]], dtype=float32)

# Weighted Tensor Product

In [159]:
sh_irreps = '0e + 1o + 2e + 3o + 4e'

ir_out = sum(e3nn.Irrep.iterator(4), start=e3nn.Irrep('0e'))


x = e3nn.normal('128x0e + 64x1o + 32x2e', rng, (32,1), normalize=False)
y = e3nn.normal('1o', rng, (32,4), normalize=True, normalization='norm')
y_sh = e3nn.spherical_harmonics(sh_irreps, y, normalize=True)
z_ir = e3nn.tensor_product(x.irreps, y_sh.irreps, filter_ir_out=ir_out)
print(x.shape, y_sh.shape, z_ir.dim)

(32, 1, 480) (32, 4, 25) 10176


In [161]:
z = e3nn.tensor_product(x, y_sh, filter_ir_out=ir_out)
z.shape

(32, 4, 10176)

In [198]:
z_ir

224x0e+320x1o+96x1e+352x2e+128x2o+320x3o+128x3e+256x4e+96x4o

In [200]:
flin = e3nn.FunctionalLinear(z_ir, x.irreps)
w = jr.normal(rng, (32, 4, flin.num_weights))


@jax.jit
def tpw(x, y_sh, w):
    flin = e3nn.FunctionalLinear(z_ir, x.irreps)
    z = e3nn.tensor_product(x, y_sh, filter_ir_out=ir_out)
    o = e3nn.vmap(e3nn.vmap(flin))(w, z)
    return o


o = tpw(x, y_sh, w)
o

128x0e+64x1o+32x2e
[[[-1.4129685   1.1402715   0.25504923 ...  1.8924981  -0.00531835
   -1.8662033 ]
  [-1.0756253   1.0809479  -0.05144887 ... -0.60608894  1.2200073
   -0.36909977]
  [-1.2273952   0.76915544  0.49694243 ... -1.0615445  -0.20397419
   -0.42872605]
  [ 0.1639668   1.3861811  -0.6238708  ...  1.0226527   0.8628838
    1.4443743 ]]

 [[ 1.7903879   0.3035054  -0.33021867 ...  1.1557155   0.44907847
   -0.3035723 ]
  [-0.9015899  -0.64314073  1.881139   ...  0.68339175 -0.36422247
    0.6404423 ]
  [ 0.18454598  1.3068503  -0.27890283 ...  0.1611563   1.7762641
    0.01823836]
  [-1.0667871  -0.4293332   1.1588671  ... -0.01767952  0.8883663
   -0.06292846]]

 [[-1.0441962  -0.82064533  2.5622804  ...  0.8163843   0.57320595
    0.47717857]
  [-0.81432915 -0.646554   -0.01390615 ... -0.5682297  -1.1935803
   -1.1351529 ]
  [ 0.03235404  0.6888337   0.55429244 ... -0.5925441  -1.1648116
   -0.15122524]
  [-1.4448063   0.7323848  -1.0343683  ...  0.95806503  0.33416122
   

In [201]:
%%timeit

jax.block_until_ready(tpw(x, y_sh, w))

273 μs ± 1.11 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [202]:
from e3nn_jax.legacy import FunctionalTensorProduct

e3nn.tens

[0;31mInit signature:[0m
[0mFunctionalTensorProduct[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mirreps_in1[0m[0;34m:[0m [0me3nn_jax[0m[0;34m.[0m[0m_src[0m[0;34m.[0m[0mirreps[0m[0;34m.[0m[0mIrreps[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mirreps_in2[0m[0;34m:[0m [0me3nn_jax[0m[0;34m.[0m[0m_src[0m[0;34m.[0m[0mirreps[0m[0;34m.[0m[0mIrreps[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mirreps_out[0m[0;34m:[0m [0me3nn_jax[0m[0;34m.[0m[0m_src[0m[0;34m.[0m[0mirreps[0m[0;34m.[0m[0mIrreps[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0minstructions[0m[0;34m:[0m [0mList[0m[0;34m[[0m[0mTuple[0m[0;34m[[0m[0mint[0m[0;34m,[0m [0mint[0m[0;34m,[0m [0mint[0m[0;34m,[0m [0mstr[0m[0;34m,[0m [0mbool[0m[0;34m,[0m [0mOptional[0m[0;34m[[0m[0mfloat[0m[0;34m][0m[0;34m][0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0min1_var[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mList[0m[0;34m[[0m[0mfloat[0m[0;3

In [204]:
jax.make_jaxpr(tpw)(x, y_sh, w)

{ lambda ; a:f32[32,1,480] b:f32[32,4,25] c:f32[32,4,60416]. let
    d:f32[32,4,480] = pjit[
      name=tpw
      jaxpr={ lambda e:f32[1,1,1] f:f32[1,3,3] g:f32[1,5,5] h:f32[1,7,7] i:f32[1,9,9]
          j:f32[3,1,3] k:f32[3,3,1] l:f32[3,3,3] m:f32[3,3,5] n:f32[3,5,3] o:f32[3,5,5]
          p:f32[3,5,7] q:f32[3,7,5] r:f32[3,7,7] s:f32[3,7,9] t:f32[3,9,7] u:f32[3,9,9]
          v:f32[5,1,5] w:f32[5,3,3] x:f32[5,3,5] y:f32[5,3,7] z:f32[5,5,1] ba:f32[5,5,3]
          bb:f32[5,5,5] bc:f32[5,5,7] bd:f32[5,5,9] be:f32[5,7,3] bf:f32[5,7,5] bg:f32[5,7,7]
          bh:f32[5,7,9] bi:f32[5,9,5] bj:f32[5,9,7] bk:f32[5,9,9]; bl:f32[32,1,480]
          bm:f32[32,4,25] bn:f32[32,4,60416]. let
          bo:f32[32,480] = squeeze[dimensions=(1,)] bl
          _:f32[32,4,480] = broadcast_in_dim[
            broadcast_dimensions=(0, 2)
            shape=(32, 4, 480)
          ] bo
          bp:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
          bq:f32[32,1,128] = gather[
            