# Efficiency Comparison Experiment (1): Equivariant Feature Interaction

In [1]:
import torch
from e3nn import o3
from e3nn.o3 import TensorProduct, Irreps
import time

from sh2f import sh2f_channel
from f2sh import f2sh_channel
from fft import FFT_channel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
const_wigner2gaunt = torch.load("constants/const_wigner2gaunt.pt")

sh2f_bases_dict = torch.load("constants/coefficient_sh2f.pt")
f2sh_bases_dict = torch.load("constants/coefficient_f2sh.pt")

In [3]:
device = torch.device('cuda:0')

In [4]:
def e3nn_implementation(tp, in1, in2):
    
    torch.cuda.synchronize()
    torch.cuda.synchronize()
    start_time = time.time()
    
    res = tp(in1, in2, weight=torch.ones(tp.weight_numel).to(device))

    torch.cuda.synchronize()
    torch.cuda.synchronize()
    end_time = time.time()  
    
    return res, end_time - start_time

In [5]:
def efficient_implementation(in1, in2, sh2f_bases, f2sh_bases):

    torch.cuda.synchronize()
    torch.cuda.synchronize()
    start_time = time.time()
    
    in1_fourier, in2_fourier = sh2f_channel(in1, sh2f_bases), sh2f_channel(in2, sh2f_bases)
    out_fourier = FFT_channel(in1_fourier, in2_fourier)
    res = f2sh_channel(out_fourier, f2sh_bases)
    
    torch.cuda.synchronize()
    torch.cuda.synchronize()
    end_time = time.time()  

    return res.real, end_time - start_time

In [6]:
class GauntFullyConnectedTensorProduct(TensorProduct):

    def __init__(
        self, irreps_in1, irreps_in2, irreps_out, irrep_normalization: str = None, path_normalization: str = None, **kwargs
    ):
        irreps_in1 = o3.Irreps(irreps_in1)
        irreps_in2 = o3.Irreps(irreps_in2)
        irreps_out = o3.Irreps(irreps_out)

        instr = [
            (i_1, i_2, i_out, "uuu", True, const_wigner2gaunt[ir_out.l, ir_1.l, ir_2.l] ** 2)
            for i_1, (_, ir_1) in enumerate(irreps_in1)
            for i_2, (_, ir_2) in enumerate(irreps_in2)
            for i_out, (_, ir_out) in enumerate(irreps_out)
            if ir_out in ir_1 * ir_2
        ]
        super().__init__(
            irreps_in1,
            irreps_in2,
            irreps_out,
            instr,
            irrep_normalization=irrep_normalization,
            path_normalization=path_normalization,
            **kwargs,
        )

In [7]:
def flatten_irreps(irreps_3D):
    L = irreps_3D.shape[1]
    irreps_1D = irreps_3D[:, 0, L - 1 : L].flatten()
    for l in range(1, L):
        irreps_1D = torch.cat((irreps_1D, (irreps_3D[:, l, -l + L - 1 : l + L].flatten())))
    return irreps_1D

In [8]:
def random_input(L, num_channel):
    '''
    Generate random input irreps with degrees of [0, L) and channels of `num_channel`.
    '''
    in1_sh, in2_sh=torch.rand(num_channel, L, 2 * L - 1).to(device), torch.rand(num_channel, L, 2 * L - 1).to(device)
    for l in range(L):
        for m in range(-L+1,L):
            if m<-l or m>l:
                in1_sh[:, l, m + L - 1], in2_sh[:, l, m + L - 1] = 0, 0
    in1_e3nn, in2_e3nn = flatten_irreps(in1_sh).to(device), flatten_irreps(in2_sh).to(device)
    return in1_e3nn, in2_e3nn, in1_sh, in2_sh

In [9]:
def compare_equi_feat(L, channel, n_sample=10, err_tolerance=1e-4):
    '''
    Compare the time and results for different implementation methods.
    The input irreps have degrees of [0, L) and channels of `num_channel`.
    The final results is averaged over `n_sample` experiments over random inputs.
    The difference of the results from different method is less than `err_tolerance`.
    '''

    # e3nn needed
    irreps_in1 = Irreps([(channel, (l, (-1)**l)) for l in range(L)])
    irreps_in2 = Irreps([(channel, (l, (-1)**l)) for l in range(L)])
    irreps_out = Irreps([(channel, (l, (-1)**l)) for l in range(2 * L - 1)])
    e3nn_tp = GauntFullyConnectedTensorProduct(
        irreps_in1, irreps_in2, irreps_out,
        irrep_normalization='none', path_normalization='none', 
        internal_weights = False, shared_weights = True
    ).to(device)

    # gaunt needed
    sh2f_bases, f2sh_bases = sh2f_bases_dict[L], f2sh_bases_dict[2 * L - 1]
    sh2f_bases, f2sh_bases = sh2f_bases.to(device), f2sh_bases.to(device)
    
    # compare different methods
    e3nn_times, efficient_times = torch.zeros(n_sample),torch.zeros(n_sample)
    for i in range(n_sample):
        in1_e3nn, in2_e3nn, in1_sh, in2_sh = random_input(L, channel)

        e3nn_res, e3nn_time = e3nn_implementation(e3nn_tp, in1_e3nn, in2_e3nn)
        efficient_res, efficient_time = efficient_implementation(in1_sh, in2_sh, sh2f_bases, f2sh_bases)

        efficient_res_flatten = flatten_irreps(efficient_res)
        # TC: compare the results
        assert (abs(e3nn_res - efficient_res_flatten) < err_tolerance).all(), f"Max Error is {abs(e3nn_res - efficient_res_flatten).max()}!"
        e3nn_times[i], efficient_times[i] = e3nn_time, efficient_time
    print("Sanity Check Passed!")
    
    # @T.C: compare the time
    e3nn_mean, e3nn_std = e3nn_times.mean(), e3nn_times.std()
    efficient_mean, efficient_std = efficient_times.mean(), efficient_times.std()
    print(f"e3nn takes {e3nn_mean*1000:.2f} ± {e3nn_std*1000:.2f} ms")
    print(f"Efficient takes {efficient_mean*1000:.2f} ± {efficient_std*1000:.2f} ms")
    print(f"L = {L}, C = {channel} efficient is {e3nn_mean / efficient_mean:.2f} x faster")

## Experiments across Different Degrees (L)

In [11]:
compare_equi_feat(2, channel=128, err_tolerance=1e-6)

Sanity Check Passed!
e3nn takes 6.67 ± 12.79 ms
Efficient takes 0.76 ± 0.58 ms
L = 2, C = 128 efficient is 8.78 x faster


In [12]:
compare_equi_feat(3, channel=128, err_tolerance=5e-6)

Sanity Check Passed!
e3nn takes 11.39 ± 20.39 ms
Efficient takes 0.82 ± 0.55 ms
L = 3, C = 128 efficient is 13.95 x faster


In [13]:
compare_equi_feat(4, channel=128, err_tolerance=5e-6)

Sanity Check Passed!
e3nn takes 138.60 ± 412.91 ms
Efficient takes 2.31 ± 4.69 ms
L = 4, C = 128 efficient is 59.88 x faster


In [14]:
compare_equi_feat(5, channel=128, err_tolerance=1e-5)

Sanity Check Passed!
e3nn takes 211.10 ± 629.78 ms
Efficient takes 3.06 ± 2.80 ms
L = 5, C = 128 efficient is 69.08 x faster


In [15]:
compare_equi_feat(6, channel=128, err_tolerance=1e-5)

Sanity Check Passed!
e3nn takes 326.77 ± 986.02 ms
Efficient takes 2.63 ± 4.05 ms
L = 6, C = 128 efficient is 124.16 x faster


In [16]:
compare_equi_feat(7, channel=128, err_tolerance=1e-5)

Sanity Check Passed!
e3nn takes 403.43 ± 1186.38 ms
Efficient takes 2.28 ± 0.40 ms
L = 7, C = 128 efficient is 176.76 x faster


In [17]:
compare_equi_feat(8, channel=128, err_tolerance=1e-5)

Sanity Check Passed!
e3nn takes 678.58 ± 1972.26 ms
Efficient takes 3.48 ± 5.15 ms
L = 8, C = 128 efficient is 194.92 x faster


In [18]:
compare_equi_feat(9, channel=128, err_tolerance=5e-5)

Sanity Check Passed!
e3nn takes 852.36 ± 2195.77 ms
Efficient takes 7.16 ± 16.22 ms
L = 9, C = 128 efficient is 118.98 x faster


In [19]:
compare_equi_feat(10, channel=128, err_tolerance=5e-5)

Sanity Check Passed!
e3nn takes 914.15 ± 2529.23 ms
Efficient takes 6.14 ± 11.91 ms
L = 10, C = 128 efficient is 148.97 x faster


In [20]:
compare_equi_feat(11, channel=128, err_tolerance=5e-4)

Sanity Check Passed!
e3nn takes 1194.65 ± 3215.20 ms
Efficient takes 2.34 ± 0.76 ms
L = 11, C = 128 efficient is 510.46 x faster


In [21]:
compare_equi_feat(12, channel=128, err_tolerance=5e-4)

Sanity Check Passed!
e3nn takes 1349.64 ± 3699.24 ms
Efficient takes 2.81 ± 0.98 ms
L = 12, C = 128 efficient is 480.06 x faster


In [22]:
compare_equi_feat(13, channel=128, err_tolerance=5e-4)

Sanity Check Passed!
e3nn takes 1538.02 ± 4043.94 ms
Efficient takes 2.92 ± 2.05 ms
L = 13, C = 128 efficient is 526.64 x faster


In [23]:
compare_equi_feat(14, channel=128, err_tolerance=1e-3)

Sanity Check Passed!
e3nn takes 2270.26 ± 5585.58 ms
Efficient takes 3.07 ± 0.99 ms
L = 14, C = 128 efficient is 739.37 x faster
