In [1]:
import cuequivariance as cue
import cuequivariance_torch as cuet
import json
import torch
import e3nn

In [None]:
import cuequivariance_ops_torch
import itertools

In [3]:
def mul_Irreps(mul, i_in):
    dd = []
    for ori_mul, ir in i_in:
        dd.append((ori_mul*mul, (ir.l, ir.p)))
    return e3nn.o3.Irreps(dd)
def compare(a, b):
    isclose = torch.isclose(a, b)
    diff_pos = torch.argwhere(isclose == False)
    for pos in diff_pos:
        pos_t = [x for x in pos]
        if(abs(a[pos_t] - b[pos_t]) > 1e-7):
            print(pos)
            print(a[pos_t] - b[pos_t] )
            
IR_IN1_IDX = 0
IR_IN2_IDX = 1
IR_OUT_IDX = 2
INST_IDX = 3

def load_nequip_config_e3nn(h, l_max, layer_idx):
    filename = f"/home2/lsy/fused_e3nn/benchmark_config/4_{h}_{l_max}_p_sc.txt"
    with open(filename, "r") as f:
        f_in = f.read().split("\n")

    per_layer_dict = dict()
    for l_idx, d in enumerate(f_in):
        if(d == "") : continue
        dd = json.loads(d)
        per_layer_dict[l_idx] = dd
    tp_list = per_layer_dict[layer_idx]["tp"]
    i_in1 = e3nn.o3.Irreps(tp_list[IR_IN1_IDX])
    i_in2 = e3nn.o3.Irreps(tp_list[IR_IN2_IDX])
    i_out = e3nn.o3.Irreps(tp_list[IR_OUT_IDX])
    inst_tuple = [tuple(x) for x in tp_list[INST_IDX]]

    return i_in1, i_in2, i_out, inst_tuple

def load_nequip_config_cueq(h, l_max, layer_idx):
    filename = f"/home2/lsy/fused_e3nn/benchmark_config/4_{h}_{l_max}_p_sc.txt"
    with open(filename, "r") as f:
        f_in = f.read().split("\n")

    per_layer_dict = dict()
    for l_idx, d in enumerate(f_in):
        if(d == "") : continue
        dd = json.loads(d)
        per_layer_dict[l_idx] = dd
    tp_list = per_layer_dict[layer_idx]["tp"]
    i_in1 = cue.Irreps("O3", tp_list[IR_IN1_IDX])
    i_in2 = cue.Irreps("O3", tp_list[IR_IN2_IDX])
    i_out = cue.Irreps("O3", tp_list[IR_OUT_IDX])
    inst_tuple = [tuple(x) for x in tp_list[INST_IDX]]

    return i_in1, i_in2, i_out, inst_tuple


In [4]:
torch.manual_seed(0)

h = 32
l_max = 1
layer_idx = 3
batch_size = 32
cueq_config = load_nequip_config_cueq(h,l_max,layer_idx)
e3nn_config = load_nequip_config_e3nn(h,l_max,layer_idx)


In [5]:
tp = e3nn.o3.TensorProduct(*e3nn_config,shared_weights=False, internal_weights=False, irrep_normalization="component", path_normalization="element") # 



In [6]:
cuet_tp = cuet.ChannelWiseTensorProduct(*cueq_config[:-1], shared_weights=False,internal_weights=False)

32x0e+32x1o+32x0o+32x1e+32x1e+32x0o+32x1o+32x1o+32x0e+32x1e
(0, 8, 2, 5, 1, 6, 7, 3, 4, 9)




In [9]:
irreps3 = []
for (i1, (mul1, ir1)), (i2, (mul2, ir2)) in itertools.product(
    enumerate(cueq_config[0]), enumerate(cueq_config[1])
):
    for ir3 in ir1 * ir2:
        if ir3 not in cueq_config[2]:
            continue

        # for loop over the different solutions of the Clebsch-Gordan decomposition
        for cg in cue.clebsch_gordan(ir1, ir2, ir3):
            # d.add_path(None, i1, i2, None, c=cg, dims={"u": mul1, "v": mul2})

            irreps3.append((mul1 * mul2, ir3))


In [None]:
irreps3 = cue.Irreps("O3", irreps3)
_,_, inv  = irreps3.sort()
ceq_weight_slice = []
current = 0
for mul,ir in irreps3:
    ceq_weight_slice.append(slice(current,current+mul,None))
    current+=mul

In [None]:
in1 = torch.rand(batch_size, e3nn_config[0].dim).to(torch.float32)
in2 = torch.rand(batch_size, e3nn_config[1].dim).to(torch.float32)
weight = torch.rand(batch_size,tp.weight_numel).to(torch.float32)

In [28]:
weight_sptp = []
for inst_idx in inv:
    slice = ceq_weight_slice[inst_idx]
    weight_sptp.append(weight[:,slice])
weight_sptp = torch.cat(weight_sptp,dim=1)

In [29]:
e3nn_out = tp(in1,in2,weight)

In [30]:
cuet_out = cuet_tp(in1,in2,weight_sptp)



In [None]:
_,_,out_inv = cueq_config[2].sort()
cuet_out_changed = []
for inst_idx in out_inv:
    slice = cueq_config[2].slices()[inst_idx]
    cuet_out_changed.append(cuet_out[:,slice])
cuet_out_changed = torch.cat(cuet_out_changed,dim=1)

In [38]:
for a,b in zip(e3nn_out[0],cuet_out_changed[0]):
    print(torch.isclose(a,b),a,b)

tensor(True) tensor(0.0516) tensor(0.0516)
tensor(True) tensor(0.0047) tensor(0.0047)
tensor(True) tensor(0.0632) tensor(0.0632)
tensor(True) tensor(0.1245) tensor(0.1245)
tensor(True) tensor(0.0306) tensor(0.0306)
tensor(True) tensor(0.0186) tensor(0.0186)
tensor(True) tensor(0.0373) tensor(0.0373)
tensor(True) tensor(0.0313) tensor(0.0313)
tensor(True) tensor(0.0644) tensor(0.0644)
tensor(True) tensor(0.0205) tensor(0.0205)
tensor(True) tensor(0.0007) tensor(0.0007)
tensor(True) tensor(0.0489) tensor(0.0489)
tensor(True) tensor(0.0057) tensor(0.0057)
tensor(True) tensor(0.0599) tensor(0.0599)
tensor(True) tensor(0.0812) tensor(0.0812)
tensor(True) tensor(0.0650) tensor(0.0650)
tensor(True) tensor(0.0902) tensor(0.0902)
tensor(True) tensor(0.0935) tensor(0.0935)
tensor(True) tensor(0.0055) tensor(0.0055)
tensor(True) tensor(0.0493) tensor(0.0493)
tensor(True) tensor(0.0763) tensor(0.0763)
tensor(True) tensor(0.0440) tensor(0.0440)
tensor(True) tensor(0.0247) tensor(0.0247)
tensor(True