# Efficiency Comparison Experiment (3): Equivariant Many-Body Interaction

In [1]:
import torch
import time
from e3nn import o3
from e3nn.o3 import TensorProduct, Irreps
from e3nn.util.codegen import CodeGenMixin
from opt_einsum import contract
import collections
from typing import List, Union, Dict, Optional

from sh2f import sh2f_batch_channel
from f2sh import f2sh_batch_channel
from fft import FFT_batch_channel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
const_wigner2gaunt = torch.load("constants/const_wigner2gaunt.pt")
sh2f_bases_dict = torch.load("constants/coefficient_sh2f.pt")
f2sh_bases_dict = torch.load("constants/coefficient_f2sh.pt")

In [3]:
device = "cuda:0"

In [4]:
_TP = collections.namedtuple("tp", "op, args")
_INPUT = collections.namedtuple("input", "tensor, start, stop")

In [5]:
def _wigner_nj(
    irrepss: List[o3.Irreps],
    normalization: str = "none",
    filter_ir_mid=None,
    dtype=None,
    device=None,
):
    irrepss = [o3.Irreps(irreps) for irreps in irrepss]
    if filter_ir_mid is not None:
        filter_ir_mid = [o3.Irrep(ir) for ir in filter_ir_mid]

    if len(irrepss) == 1:
        (irreps,) = irrepss
        ret = []
        e = torch.eye(irreps.dim, dtype=dtype, device=device)
        i = 0
        for mul, ir in irreps:
            for _ in range(mul):
                sl = slice(i, i + ir.dim)
                ret += [(ir, _INPUT(0, sl.start, sl.stop), e[sl])]
                i += ir.dim
        return ret

    *irrepss_left, irreps_right = irrepss
    ret = []
    for ir_left, path_left, C_left in _wigner_nj(
        irrepss_left,
        normalization=normalization,
        filter_ir_mid=filter_ir_mid,
        dtype=dtype,
        device=device,
    ):
        i = 0
        for mul, ir in irreps_right:
            for ir_out in ir_left * ir:
                if filter_ir_mid is not None and ir_out not in filter_ir_mid:
                    continue

                C = o3.wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype, device=device) * const_wigner2gaunt[ir_out.l, ir_left.l, ir.l]
                if normalization == "component":
                    C *= ir_out.dim ** 0.5
                if normalization == "norm":
                    C *= ir_left.dim ** 0.5 * ir.dim ** 0.5

                C = torch.einsum("jk,ijl->ikl", C_left.flatten(1), C)
                C = C.reshape(
                    ir_out.dim, *(irreps.dim for irreps in irrepss_left), ir.dim
                )
                for u in range(mul):
                    E = torch.zeros(
                        ir_out.dim,
                        *(irreps.dim for irreps in irrepss_left),
                        irreps_right.dim,
                        dtype=dtype,
                        device=device,
                    )
                    sl = slice(i + u * ir.dim, i + (u + 1) * ir.dim)
                    E[..., sl] = C
                    ret += [
                        (
                            ir_out,
                            _TP(
                                op=(ir_left, ir, ir_out),
                                args=(
                                    path_left,
                                    _INPUT(len(irrepss_left), sl.start, sl.stop),
                                ),
                            ),
                            E,
                        )
                    ]
            i += mul * ir.dim
    return sorted(ret, key=lambda x: x[0])

In [6]:
def U_matrix_real(
    irreps_in: Union[str, o3.Irreps],
    irreps_out: Union[str, o3.Irreps],
    correlation: int,
    normalization: str = "none",
    filter_ir_mid=None,
    dtype=None,
    device=None,
):
    irreps_out = o3.Irreps(irreps_out)
    irrepss = [o3.Irreps(irreps_in)] * correlation
    wigners = _wigner_nj(irrepss, normalization, filter_ir_mid, dtype, device)
    current_ir = wigners[0][0]
    out = []
    stack = torch.tensor([], device=device)

    for ir, _, base_o3 in wigners:
        if ir in irreps_out and ir == current_ir:
            stack = torch.cat((stack, base_o3.squeeze().unsqueeze(-1)), dim=-1)
            last_ir = current_ir
        elif ir in irreps_out and ir != current_ir:
            if len(stack) != 0:
                out += [last_ir, stack]
            stack = base_o3.squeeze().unsqueeze(-1)
            current_ir, last_ir = ir, ir
        else:
            current_ir = ir
    out += [last_ir, stack]
    return out

In [7]:
class Contraction(torch.nn.Module):
    def __init__(
        self,
        irreps_in: o3.Irreps,
        irrep_out: o3.Irreps,
        correlation: int,
        device: Optional[str] = "cpu",
    ) -> None:
        super().__init__()

        self.dtype = torch.get_default_dtype()
        self.num_features = irreps_in.count((0, 1))
        self.coupling_irreps = o3.Irreps([irrep.ir for irrep in irreps_in])
        self.correlation = correlation
        dtype = torch.get_default_dtype()
        self.U_tensors = {
            nu: U_matrix_real(
                irreps_in=self.coupling_irreps,
                irreps_out=irrep_out,
                correlation=nu,
                dtype=dtype,
                device=device,
            )[-1]
            for nu in range(1, correlation + 1)
        }

        # Tensor contraction equations
        self.equation_main = "...ik,kc,bci -> bc..."
        self.equation_weighting = "...k,kc->c..."
        self.equation_contract = "bc...i,bci->bc..."
        self.weights = torch.nn.ParameterDict({})
        for i in range(1, correlation + 1):
            num_params = self.U_tensors[i].size()[-1]
            w = torch.nn.Parameter(
                torch.ones(num_params, self.num_features, device=device), requires_grad=False
            )
            self.weights[str(i)] = w

    def forward(self, x: torch.Tensor):
        out = contract(
            self.equation_main,
            self.U_tensors[self.correlation],
            self.weights[str(self.correlation)].type(self.dtype),
            x,
        )
        for corr in range(self.correlation - 1, 0, -1):
            c_tensor = contract(
                self.equation_weighting,
                self.U_tensors[corr],
                self.weights[str(corr)].type(self.dtype),
            )
            c_tensor = c_tensor + out
            out = contract(self.equation_contract, c_tensor, x)
        resize_shape = torch.prod(torch.tensor(out.shape[1:]))
        return out.view(out.shape[0], resize_shape)


In [8]:
class SymmetricContraction(CodeGenMixin, torch.nn.Module):
    def __init__(
        self,
        irreps_in: o3.Irreps,
        irreps_out: o3.Irreps,
        correlation: int or Dict[str, int],
        irrep_normalization: str = "none",
        path_normalization: str = "none",
        device: str = "cpu",
    ) -> None:
        super().__init__()

        self.type = torch.float64

        if irrep_normalization is None:
            irrep_normalization = "component"

        if path_normalization is None:
            path_normalization = "element"

        assert irrep_normalization in ["component", "norm", "none"]
        assert path_normalization in ["element", "path", "none"]

        self.irreps_in = o3.Irreps(irreps_in)
        self.irreps_out = o3.Irreps(irreps_out)

        del irreps_in, irreps_out

        if type(correlation) is not tuple:
            corr = correlation
            correlation = {}
            for irrep_out in self.irreps_out:
                correlation[irrep_out] = corr

        self.contractions = torch.nn.ModuleDict()
        for irrep_out in self.irreps_out:
            self.contractions[str(irrep_out)] = Contraction(
                irreps_in=self.irreps_in,
                irrep_out=o3.Irreps(str(irrep_out.ir)),
                correlation=correlation[irrep_out],
                device=device,
            )

    def forward(self, x: torch.tensor):
        outs = []
        for irrep in self.irreps_out:
            outs.append(self.contractions[str(irrep)](x))
        return torch.cat(outs, dim=-1)


In [9]:
class GauntFullyConnectedTensorProductChannel(TensorProduct):

    def __init__(
        self, irreps_in1, irreps_in2, irreps_out, irrep_normalization: str = None, path_normalization: str = None, **kwargs
    ):
        irreps_in1 = o3.Irreps(irreps_in1)
        irreps_in2 = o3.Irreps(irreps_in2)
        irreps_out = o3.Irreps(irreps_out)

        instr = [
            (i_1, i_2, i_out, "uuu", True, const_wigner2gaunt[ir_out.l, ir_1.l, ir_2.l] ** 2)
            for i_1, (_, ir_1) in enumerate(irreps_in1)
            for i_2, (_, ir_2) in enumerate(irreps_in2)
            for i_out, (_, ir_out) in enumerate(irreps_out)
            if ir_out in ir_1 * ir_2
        ]
        super().__init__(
            irreps_in1,
            irreps_in2,
            irreps_out,
            instr,
            irrep_normalization=irrep_normalization,
            path_normalization=path_normalization,
            **kwargs,
        )

In [10]:
def e3nn_implementation(L, correlation, input, num_channel):

    irreps_in1 = Irreps([(num_channel, (l, (-1)**l)) for l in range(L)])
    irreps_in2 = irreps_in1
    res = input.clone()
    output = input.clone()
    cutoff = input.shape[-1]
    
    torch.cuda.synchronize()
    torch.cuda.synchronize()
    start_time = time.time()

    for nu in range(2, correlation + 1):
        lmax = nu * (L - 1)
        irreps_out = Irreps([(num_channel, (l, (-1)**l)) for l in range(lmax + 1)])
        tp = GauntFullyConnectedTensorProductChannel(
            irreps_in1, irreps_in2, irreps_out,
            irrep_normalization='none', path_normalization='none', 
            internal_weights = False, shared_weights = True
        ).to(device)

        output = tp(input, output, weight=torch.ones(tp.weight_numel).to(device))
        res += output[:, :cutoff]
        irreps_in2 = irreps_out

    torch.cuda.synchronize()
    torch.cuda.synchronize()
    end_time = time.time()  
    
    return res, end_time - start_time

In [11]:
def mace_implementation(L, correlation, input, num_channel):

    irreps_in = Irreps([(num_channel, (l, (-1)**l)) for l in range(L)])
    irreps_out = Irreps([(num_channel, (l, (-1)**l)) for l in range(L)])
    symmetric_contractions = SymmetricContraction(
        irreps_in,
        irreps_out,
        correlation,
        device=device,
        irrep_normalization = "none",
        path_normalization = "none",
    )
    
    torch.cuda.synchronize()
    torch.cuda.synchronize()
    start_time = time.time()

    res = symmetric_contractions(input)
    
    torch.cuda.synchronize()
    torch.cuda.synchronize()
    end_time = time.time()  
    
    return res, end_time - start_time

In [12]:
def efficient_implementation(L, correlation, input, sh2f_bases, f2sh_bases_nu, offsets_st, offsets_ed):

    torch.cuda.synchronize()
    torch.cuda.synchronize()
    start_time = time.time()
    
    fs_out = {}
    res = input
    fs_out[1] = sh2f_batch_channel(input, sh2f_bases)
    for nu in range(2, correlation + 1):
        if nu % 2 == 0:
            fs_out[nu] = FFT_batch_channel(fs_out[nu//2], fs_out[nu//2])
        else:
            fs_out[nu] = FFT_batch_channel(fs_out[nu//2], fs_out[nu//2 + 1])
        res += f2sh_batch_channel(fs_out[nu], f2sh_bases_nu[nu]).real[:, :, :L, offsets_st[nu]:offsets_ed[nu]]
    
    torch.cuda.synchronize()
    torch.cuda.synchronize()
    end_time = time.time()  

    return res, end_time - start_time

In [13]:
def flatten_irreps_4to2(irreps_4D):
    L = irreps_4D.shape[2]
    irreps_2D = irreps_4D[:, :, 0, L - 1 : L].flatten(start_dim=1)
    for l in range(1, L):
        irreps_2D = torch.cat((irreps_2D, (irreps_4D[:, :, l, -l + L - 1 : l + L].flatten(start_dim=1))), dim=1)
    return irreps_2D

In [14]:
def flatten_irreps_4to3(irreps_4D):
    L = irreps_4D.shape[2]
    irreps_3D = irreps_4D[:, :, 0, L - 1 : L].flatten(start_dim=2)
    for l in range(1, L):
        irreps_3D = torch.cat((irreps_3D, (irreps_4D[:, :, l, -l + L - 1 : l + L].flatten(start_dim=2))), dim=2)
    return irreps_3D

In [15]:
def random_input_batch_channel(L, batch, num_channel):
    sh_coef_in=torch.rand(batch, num_channel, L, 2 * L - 1).to(device)
    for l in range(L):
        for m in range(-L+1,L):
            if m<-l or m>l:
                sh_coef_in[:, :, l, m + L - 1] = 0
    e3nn_in = flatten_irreps_4to2(sh_coef_in).to(device)
    mace_in = flatten_irreps_4to3(sh_coef_in).to(device)
    return e3nn_in, mace_in, sh_coef_in

In [16]:
def compare_equi_many_body(L, correlation=3, batch=32, num_channel=128, n_sample=10, err_tolerance=1e-4):

    '''
    Compare the time and results for different implementation methods.
    The input irreps have degrees of [0, L) and batch size of `batch` channels of `num_channel`.
    The final results is averaged over `n_sample` experiments over random inputs.
    The difference of the results from different method is less than `err_tolerance`.
    '''
    
    sh2f_bases = sh2f_bases_dict[L].to(device)
    f2sh_bases_nu = {}
    offsets_st, offsets_ed = [0,0], [0,0]
    for nu in range(2, correlation + 1):
        lmax = nu * (L - 1)
        f2sh_bases_nu[nu] = f2sh_bases_dict[lmax + 1].to(device)
        offsets_st.append(lmax - L + 1)
        offsets_ed.append(lmax + L)

    # compare different methods
    e3nn_times, mace_times, efficient_times = torch.zeros(n_sample), torch.zeros(n_sample), torch.zeros(n_sample)    
    for i in range(n_sample):
        e3nn_in, mace_in, sh_coef_in = random_input_batch_channel(L, batch, num_channel)
        e3nn_res, e3nn_time = e3nn_implementation(L, correlation, e3nn_in, num_channel)
        mace_res, mace_time = mace_implementation(L, correlation, mace_in, num_channel)
        efficient_res, efficient_time = efficient_implementation(L, correlation, sh_coef_in, sh2f_bases, f2sh_bases_nu, offsets_st, offsets_ed)
        
        efficient_res_flatten = flatten_irreps_4to2(efficient_res)
        
        # TC: compare the results
        assert (abs(e3nn_res - mace_res) < err_tolerance).all(), f"Max Error is {abs(e3nn_res - mace_res).max()}!"
        assert (abs(e3nn_res - efficient_res_flatten) < err_tolerance).all(), f"Max Error is {abs(e3nn_res - efficient_res_flatten).max()}!"
        assert (abs(mace_res - efficient_res_flatten) < err_tolerance).all(), f"Max Error is {abs(mace_res - efficient_res_flatten).max()}!"

        e3nn_times[i], mace_times[i], efficient_times[i] = e3nn_time, mace_time, efficient_time
    print("Sanity Check Passed!")

    # @T.C: compare the time
    e3nn_mean, e3nn_std = e3nn_times.mean(), e3nn_times.std()
    mace_mean, mace_std = mace_times.mean(), mace_times.std()
    efficient_mean, efficient_std = efficient_times.mean(), efficient_times.std()

    print(f"e3nn takes {e3nn_mean*1000:.2f} ± {e3nn_std*1000:.2f} ms")
    print(f"mace takes {mace_mean*1000:.2f} ± {mace_std*1000:.2f} ms")
    print(f"Efficient takes {efficient_mean*1000:.2f} ± {efficient_std*1000:.2f} ms")
    print(f"L = {L}, correlation = {correlation}, batch_size = {batch}, channels = {num_channel}")
    print(f"efficient is {mace_mean / efficient_mean:.2f} x faster than mace")

## Experiments across Different Degrees (L), holding correlation (ν = 3) fixed

In [18]:
compare_equi_many_body(L=2, correlation=3, err_tolerance=1e-5)

Sanity Check Passed!
e3nn takes 249.73 ± 37.27 ms
mace takes 3.00 ± 0.07 ms
Efficient takes 1.29 ± 0.04 ms
L = 2, correlation = 3, batch_size = 32, channels = 128
efficient is 2.32 x faster than mace


In [19]:
compare_equi_many_body(L=3, correlation=3, err_tolerance=1e-5)

Sanity Check Passed!
e3nn takes 493.65 ± 45.74 ms
mace takes 4.86 ± 0.22 ms
Efficient takes 1.71 ± 0.88 ms
L = 3, correlation = 3, batch_size = 32, channels = 128
efficient is 2.84 x faster than mace


In [20]:
compare_equi_many_body(L=4, correlation=3, err_tolerance=1e-5)

Sanity Check Passed!
e3nn takes 970.76 ± 76.84 ms
mace takes 7.06 ± 0.52 ms
Efficient takes 2.27 ± 0.40 ms
L = 4, correlation = 3, batch_size = 32, channels = 128
efficient is 3.11 x faster than mace


In [21]:
compare_equi_many_body(L=5, correlation=3, err_tolerance=1e-5)

Sanity Check Passed!
e3nn takes 1767.27 ± 149.84 ms
mace takes 13.59 ± 1.07 ms
Efficient takes 3.80 ± 0.75 ms
L = 5, correlation = 3, batch_size = 32, channels = 128
efficient is 3.57 x faster than mace


In [22]:
compare_equi_many_body(L=6, correlation=3, err_tolerance=1e-4)

Sanity Check Passed!
e3nn takes 3648.32 ± 986.26 ms
mace takes 44.02 ± 4.16 ms
Efficient takes 8.60 ± 3.22 ms
L = 6, correlation = 3, batch_size = 32, channels = 128
efficient is 5.12 x faster than mace


## Experiments across Different correlation (ν), holding degree (L=2) fixed

In [23]:
compare_equi_many_body(L=2, correlation=4, err_tolerance=1e-4)

Sanity Check Passed!
e3nn takes 568.19 ± 166.10 ms
mace takes 3.81 ± 1.50 ms
Efficient takes 1.41 ± 0.42 ms
L = 2, correlation = 4, batch_size = 32, channels = 128
efficient is 2.71 x faster than mace


In [24]:
compare_equi_many_body(L=2, correlation=5, err_tolerance=1e-4)

Sanity Check Passed!
e3nn takes 734.46 ± 241.98 ms
mace takes 4.55 ± 1.10 ms
Efficient takes 2.05 ± 0.40 ms
L = 2, correlation = 5, batch_size = 32, channels = 128
efficient is 2.21 x faster than mace


In [25]:
compare_equi_many_body(L=2, correlation=6, err_tolerance=1e-4)

Sanity Check Passed!
e3nn takes 1087.00 ± 265.75 ms
mace takes 8.25 ± 4.38 ms
Efficient takes 2.68 ± 0.37 ms
L = 2, correlation = 6, batch_size = 32, channels = 128
efficient is 3.08 x faster than mace


In [26]:
compare_equi_many_body(L=2, correlation=7, err_tolerance=1e-4)

Sanity Check Passed!
e3nn takes 1140.07 ± 231.60 ms
mace takes 11.40 ± 3.32 ms
Efficient takes 3.46 ± 0.89 ms
L = 2, correlation = 7, batch_size = 32, channels = 128
efficient is 3.29 x faster than mace
