# Efficiency Comparison Experiment (2): Equivariant Convolution

In [28]:
import torch
torch.set_float32_matmul_precision("high")

from e3nn import o3
from e3nn.o3 import TensorProduct, Irreps, wigner_3j, spherical_harmonics
import e3nn

# Turning this off for torch.compile
e3nn.set_optimization_defaults(jit_script_fx=False)

import time

from sh2f import sh2f
from f2sh import f2sh
from fft import FFT

In [29]:
const_wigner2gaunt = torch.load("constants/const_wigner2gaunt.pt")

sh2f_bases_dict = torch.load("constants/coefficient_sh2f.pt")
f2sh_bases_dict = torch.load("constants/coefficient_f2sh.pt")

  const_wigner2gaunt = torch.load("constants/const_wigner2gaunt.pt")
  sh2f_bases_dict = torch.load("constants/coefficient_sh2f.pt")
  f2sh_bases_dict = torch.load("constants/coefficient_f2sh.pt")


In [30]:
device = torch.device('cuda:0')

In [31]:
def e3nn_implementation(tp, in1, in2):   

    torch.cuda.synchronize()
    torch.cuda.synchronize()
    start_time = time.time()

    res = tp(in1, in2, weight=torch.ones(tp.weight_numel).to(device))
    
    torch.cuda.synchronize()
    torch.cuda.synchronize()
    end_time = time.time()  
    
    return res, end_time - start_time

In [32]:
# @T.C.: Code from eSCN implementation

# Borrowed from e3nn @ 0.4.0:
# https://github.com/e3nn/e3nn/blob/0.4.0/e3nn/o3/_wigner.py#L10
# _Jd is a list of tensors of shape (2l+1, 2l+1)
_Jd = torch.load("constants/Jd.pt")

# Borrowed from e3nn @ 0.4.0:
# https://github.com/e3nn/e3nn/blob/0.4.0/e3nn/o3/_wigner.py#L37
#
# In 0.5.0, e3nn shifted to torch.matrix_exp which is significantly slower:
# https://github.com/e3nn/e3nn/blob/0.5.0/e3nn/o3/_wigner.py#L92

def _z_rot_mat(angle: torch.Tensor, lv: int) -> torch.Tensor:
        shape, device, dtype = angle.shape, angle.device, angle.dtype
        M = angle.new_zeros((*shape, 2 * lv + 1, 2 * lv + 1))
        inds = torch.arange(0, 2 * lv + 1, 1, device=device)
        reversed_inds = torch.arange(2 * lv, -1, -1, device=device)
        frequencies = torch.arange(lv, -lv - 1, -1, dtype=dtype, device=device)
        M[..., inds, reversed_inds] = torch.sin(frequencies * angle[..., None])
        M[..., inds, inds] = torch.cos(frequencies * angle[..., None])
        return M

def wigner_D(lval, alpha, beta, gamma):
    if not lval < len(_Jd):
        return o3.wigner_D(lval, torch.tensor([alpha]),torch.tensor([beta]), torch.tensor([gamma]))

    alpha, beta, gamma = torch.broadcast_tensors(alpha, beta, gamma)
    J = _Jd[lval].to(dtype=alpha.dtype, device=device)
    Xa = _z_rot_mat(alpha, lval).to(device)
    Xb = _z_rot_mat(beta, lval).to(device)
    Xc = _z_rot_mat(gamma, lval).to(device)
    return Xa @ J @ Xb @ J @ Xc

  _Jd = torch.load("constants/Jd.pt")


In [33]:
#@torch.compile
def escn(C, in1, L, r):
    alpha, beta = o3.xyz_to_angles(r)
    alpha, beta = alpha.to(device), beta.to(device)
    wigner = torch.zeros(2*L-1, 4*L-3, 4*L-3, device=device)
    for l in range(2*L-1):
        block = wigner_D(l, torch.tensor([0.]), -beta, -alpha)
        start = -l + 2*(L-1)
        end = l + 2*L-1
        wigner[l, start:end, start:end] = block
    start_mid, end_mid = L-1, 3*L-2
    in1_rot = torch.bmm(wigner[:L, start_mid:end_mid, start_mid:end_mid], in1.unsqueeze(-1))
    res_rot_pos = (C[:, :, :, :, 0:1] * in1_rot.unsqueeze(0).unsqueeze(0)).sum(dim=1).sum(dim=1)
    res_rot_neg = (C[:, :, :, :, 1:2] * in1_rot.flip(dims=[1]).unsqueeze(0).unsqueeze(0)).sum(dim=1).sum(dim=1)
    res_rot = torch.zeros(2*L-1, 4*L-3, 1, device=device)
    res_rot[:, start_mid:end_mid, :] = res_rot_pos + res_rot_neg
    wigner_inv = wigner.transpose(1, 2)
    res = torch.bmm(wigner_inv, res_rot).squeeze(-1)
    return res

def escn_implementation(C, in1, L, r):

    torch.cuda.synchronize()
    torch.cuda.synchronize()

    start_time = time.time()
    
    res = escn(C, in1, L, r)

    torch.cuda.synchronize()
    torch.cuda.synchronize()
    end_time = time.time()  
    
    return res, end_time - start_time

In [34]:
#@torch.compile # Currently getting  RuntimeError: self.stride(-1) must be 1 to view ComplexFloat as Float (different element sizes), but got 2
def efficient(in1, in2, sh2f_bases, f2sh_bases):
    in1_fourier, in2_fourier = sh2f(in1, sh2f_bases), sh2f(in2, sh2f_bases)
    out_fourier = FFT(in1_fourier, in2_fourier)
    res = f2sh(out_fourier, f2sh_bases)
    return res

def efficient_implementation(in1, in2, sh2f_bases, f2sh_bases):
    
    torch.cuda.synchronize()
    torch.cuda.synchronize()
    start_time = time.time()

    res = efficient(in1, in2, sh2f_bases, f2sh_bases)

    torch.cuda.synchronize()
    torch.cuda.synchronize()
    end_time = time.time()  

    return res.real , end_time - start_time

In [35]:
class GauntFullyConnectedTensorProduct(TensorProduct):

    def __init__(
        self, irreps_in1, irreps_in2, irreps_out, irrep_normalization: str = None, path_normalization: str = None, **kwargs
    ):
        irreps_in1 = Irreps(irreps_in1)
        irreps_in2 = Irreps(irreps_in2)
        irreps_out = Irreps(irreps_out)

        instr = [
            (i_1, i_2, i_out, "uvw", True, const_wigner2gaunt[ir_out.l, ir_1.l, ir_2.l] ** 2)
            for i_1, (_, ir_1) in enumerate(irreps_in1)
            for i_2, (_, ir_2) in enumerate(irreps_in2)
            for i_out, (_, ir_out) in enumerate(irreps_out)
            if ir_out in ir_1 * ir_2
        ]
        super().__init__(
            irreps_in1,
            irreps_in2,
            irreps_out,
            instr,
            irrep_normalization=irrep_normalization,
            path_normalization=path_normalization,
            **kwargs,
        )


In [36]:
def flatten_irreps(irreps_2D):
    L = irreps_2D.shape[0]
    irreps_1D = irreps_2D[0, L - 1 : L]
    for l in range(1, L):
        irreps_1D = torch.cat((irreps_1D, irreps_2D[l, -l + L - 1 : l + L]))
    return irreps_1D

In [37]:
def random_input(L):
    sh_coef_in1 =torch.rand(L, 2 * L - 1).to(device)
    for l in range(L):
        for m in range(-L+1,L):
            if m<-l or m>l:
                sh_coef_in1[l, m + L - 1] = 0
    e3nn_in1 = flatten_irreps(sh_coef_in1).to(device)
    return e3nn_in1, sh_coef_in1

In [38]:
def full_spherical_harmonics_1D(L, r, normalize=True, normalization="norm"):
    res = spherical_harmonics(0, r, normalize, normalization)
    for l in range(1, L):
        res = torch.cat((res, o3.spherical_harmonics(l, r, normalize, normalization)))
    return res

def full_spherical_harmonics_2D(L, r, normalize=True, normalization="norm"):
    res = torch.zeros(L, 2*L-1)
    for l in range(L):
        res[l, -l+L-1:l+L] = spherical_harmonics(l, r, normalize, normalization)
    return res

In [46]:

def compare_equi_conv(L, n_warmup=10, n_sample=100, err_tolerance=1e-4):
    '''
    Compare the time and results for different implementation methods.
    The input irreps have degrees of [0, L).
    The final results is averaged over `n_sample` experiments over random inputs.
    The difference of the results from different method is less than `err_tolerance`.
    '''

    # e3nn needed
    irreps_in1 = Irreps([(1, (l, (-1)**l)) for l in range(L)])
    irreps_in2 = Irreps([(1, (l, (-1)**l)) for l in range(L)])
    irreps_out = Irreps([(1, (l, (-1)**l)) for l in range(2 * L - 1)])
    tp = GauntFullyConnectedTensorProduct(
        irreps_in1, irreps_in2, irreps_out,
        irrep_normalization='none', path_normalization='none', 
        internal_weights = False, shared_weights = True
    ).to(device)

    torch._dynamo.reset()
    tp = torch.compile(tp, fullgraph=True, mode='reduce-overhead')

    # eSCN needed
    C = torch.zeros(2*L-1, L, L, 2*L-1, 2, device=device)
    for lo in range(2*L-1):
        for lf in range(L):
            for li in range(L):
                if abs(lf - li) <= lo <= lf + li:
                    w3j = wigner_3j(lo, lf, li)
                    if (lo + li + lf) % 2 == 0:
                        m = min(li, lo)
                        for mo in range(-m, m+1):
                            C[lo, lf, li, mo+L-1, 0] = w3j[mo+lo, lf, mo+li]
                            C[lo, lf, li, mo+L-1, 1] = 0 if mo==0 else w3j[mo+lo, lf, -mo+li]
                        C[lo, lf, li] *= const_wigner2gaunt[lo, lf, li]  

    # gaunt needed
    sh2f_bases, f2sh_bases = sh2f_bases_dict[L], f2sh_bases_dict[2 * L - 1]
    sh2f_bases, f2sh_bases = sh2f_bases.to(device), f2sh_bases.to(device)
    
    # compare different methods
    e3nn_times, escn_times, efficient_times = torch.zeros(n_sample), torch.zeros(n_sample), torch.zeros(n_sample)
    
    for i in range(n_warmup):
        e3nn_in1, sh_coef_in1 = random_input(L)
        r = torch.rand(3).to(device)
        r /= r.norm()
        
        e3nn_in2 = full_spherical_harmonics_1D(L, r, normalize=True, normalization="norm").to(device)
        efficient_in2 = full_spherical_harmonics_2D(L, r, normalize=True, normalization="norm").to(device)
        
        e3nn_res, e3nn_time = e3nn_implementation(tp, e3nn_in1, e3nn_in2)
        escn_res, escn_time = escn_implementation(C, sh_coef_in1, L, r)
        efficient_res, efficient_time = efficient_implementation(sh_coef_in1, efficient_in2, sh2f_bases, f2sh_bases)

        # TC: compare the results
        escn_flatten = flatten_irreps(escn_res)
        efficient_res_flatten = flatten_irreps(efficient_res)
        assert (abs(e3nn_res - escn_flatten) < err_tolerance).all(), f"Max Error is {abs(e3nn_res - escn_flatten).max()}"
        assert (abs(e3nn_res - efficient_res_flatten) < err_tolerance).all(), f"Max Error is {abs(e3nn_res - efficient_res_flatten).max()}"
    
    for i in range(n_sample):
        e3nn_in1, sh_coef_in1 = random_input(L)
        r = torch.rand(3).to(device)
        r /= r.norm()
        
        e3nn_in2 = full_spherical_harmonics_1D(L, r, normalize=True, normalization="norm").to(device)
        efficient_in2 = full_spherical_harmonics_2D(L, r, normalize=True, normalization="norm").to(device)
        
        e3nn_res, e3nn_time = e3nn_implementation(tp, e3nn_in1, e3nn_in2)
        escn_res, escn_time = escn_implementation(C, sh_coef_in1, L, r)
        efficient_res, efficient_time = efficient_implementation(sh_coef_in1, efficient_in2, sh2f_bases, f2sh_bases)

        # TC: compare the results
        escn_flatten = flatten_irreps(escn_res)
        efficient_res_flatten = flatten_irreps(efficient_res)
        assert (abs(e3nn_res - escn_flatten) < err_tolerance).all(), f"Max Error is {abs(e3nn_res - escn_flatten).max()}"
        assert (abs(e3nn_res - efficient_res_flatten) < err_tolerance).all(), f"Max Error is {abs(e3nn_res - efficient_res_flatten).max()}"
        e3nn_times[i], escn_times[i], efficient_times[i] = e3nn_time, escn_time, efficient_time
    print("Sanity Check Passed!")
    
    # @T.C: compare the time
    e3nn_mean, e3nn_std = e3nn_times.mean(), e3nn_times.std()
    escn_mean, escn_std = escn_times.mean(), escn_times.std()
    efficient_mean, efficient_std = efficient_times.mean(), efficient_times.std()
    print(f"e3nn takes {e3nn_mean*1000:.2f} ± {e3nn_std*1000:.2f} ms")
    print(f"escn takes {escn_mean*1000:.2f} ± {escn_std*1000:.2f} ms")
    print(f"Efficient takes {efficient_mean*1000:.2f} ± {efficient_std*1000:.2f} ms")
    print(f"L = {L} efficient is {e3nn_mean / efficient_mean:.2f} x faster than e3nn")
    print(f"L = {L} efficient is {escn_mean / efficient_mean:.2f} x faster than escn")

## Experiments across Different Degrees (L)

In [47]:
compare_equi_conv(L=2, err_tolerance=1e-6)

Sanity Check Passed!
e3nn takes 0.20 ± 0.01 ms
escn takes 1.70 ± 0.02 ms
Efficient takes 0.38 ± 0.02 ms
L = 2 efficient is 0.53 x faster than e3nn
L = 2 efficient is 4.51 x faster than escn


In [48]:
compare_equi_conv(L=3, err_tolerance=1e-6)

Sanity Check Passed!
e3nn takes 0.24 ± 0.03 ms
escn takes 2.58 ± 0.02 ms
Efficient takes 0.38 ± 0.00 ms
L = 3 efficient is 0.64 x faster than e3nn
L = 3 efficient is 6.86 x faster than escn


In [49]:
compare_equi_conv(L=4, err_tolerance=1e-6)

Sanity Check Passed!
e3nn takes 0.34 ± 0.02 ms
escn takes 3.51 ± 0.10 ms
Efficient takes 0.38 ± 0.01 ms
L = 4 efficient is 0.88 x faster than e3nn
L = 4 efficient is 9.20 x faster than escn


In [50]:
compare_equi_conv(L=5, err_tolerance=1e-6)

Sanity Check Passed!
e3nn takes 0.48 ± 0.01 ms
escn takes 4.32 ± 0.02 ms
Efficient takes 0.38 ± 0.01 ms
L = 5 efficient is 1.29 x faster than e3nn
L = 5 efficient is 11.49 x faster than escn


In [54]:
compare_equi_conv(L=6, err_tolerance=1e-5)

Sanity Check Passed!
e3nn takes 0.69 ± 0.00 ms
escn takes 5.15 ± 0.02 ms
Efficient takes 0.37 ± 0.00 ms
L = 6 efficient is 1.87 x faster than e3nn
L = 6 efficient is 13.90 x faster than escn


In [55]:
compare_equi_conv(L=7, err_tolerance=1e-5)

Sanity Check Passed!
e3nn takes 1.03 ± 0.09 ms
escn takes 7.97 ± 1.37 ms
Efficient takes 0.39 ± 0.04 ms
L = 7 efficient is 2.65 x faster than e3nn
L = 7 efficient is 20.54 x faster than escn


In [56]:
compare_equi_conv(L=8, err_tolerance=1e-4)

Sanity Check Passed!
e3nn takes 1.48 ± 0.00 ms
escn takes 12.04 ± 0.12 ms
Efficient takes 0.37 ± 0.00 ms
L = 8 efficient is 3.98 x faster than e3nn
L = 8 efficient is 32.32 x faster than escn


In [57]:
compare_equi_conv(L=9, err_tolerance=1e-4)

Sanity Check Passed!
e3nn takes 1.99 ± 0.08 ms
escn takes 19.00 ± 0.22 ms
Efficient takes 0.47 ± 0.03 ms
L = 9 efficient is 4.28 x faster than e3nn
L = 9 efficient is 40.80 x faster than escn


In [58]:
compare_equi_conv(L=10, err_tolerance=1e-4)

Sanity Check Passed!
e3nn takes 2.73 ± 0.06 ms
escn takes 21.93 ± 0.92 ms
Efficient takes 0.39 ± 0.03 ms
L = 10 efficient is 6.92 x faster than e3nn
L = 10 efficient is 55.60 x faster than escn


In [59]:
compare_equi_conv(L=11, err_tolerance=5e-4)

Sanity Check Passed!
e3nn takes 3.63 ± 0.04 ms
escn takes 26.37 ± 1.56 ms
Efficient takes 0.40 ± 0.04 ms
L = 11 efficient is 9.13 x faster than e3nn
L = 11 efficient is 66.38 x faster than escn


In [60]:
compare_equi_conv(L=12, err_tolerance=5e-4)

Sanity Check Passed!
e3nn takes 4.78 ± 0.18 ms
escn takes 32.27 ± 2.19 ms
Efficient takes 0.40 ± 0.03 ms
L = 12 efficient is 11.94 x faster than e3nn
L = 12 efficient is 80.57 x faster than escn
