In [36]:
import torch
import e3nn
from torch.utils.cpp_extension import load
import json
import matplotlib.pyplot as plt
import itertools
import os
import cuequivariance as cue
import cuequivariance_torch as cuet
import random

In [37]:
os.environ['TORCH_CUDA_ARCH_LIST'] = "8.0"
sptp_bwd = load(name='sptp_linear_bwd', sources=['/home2/lsy/mdsim/fused_e3nn/fused_e3nn_kernel/sptp_linear_bwd.cpp', 
                                  '/home2/lsy/mdsim/fused_e3nn/fused_e3nn_kernel/bwd_sptp_linear.cu',
                                  ], 
                                  extra_cuda_cflags=["-lineinfo"], verbose=True)


Using /home2/lsy/.cache/torch_extensions/py311_cu124 as PyTorch extensions root...
The input conditions for extension module sptp_linear_bwd have changed. Bumping to version 1 and re-building as sptp_linear_bwd_v1...
Detected CUDA files, patching ldflags
Emitting ninja build file /home2/lsy/.cache/torch_extensions/py311_cu124/sptp_linear_bwd/build.ninja...
Building extension module sptp_linear_bwd_v1...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


[1/3] c++ -MMD -MF sptp_linear_bwd.o.d -DTORCH_EXTENSION_NAME=sptp_linear_bwd_v1 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home2/lsy/miniconda3/envs/cueq/lib/python3.11/site-packages/torch/include -isystem /home2/lsy/miniconda3/envs/cueq/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /home2/lsy/miniconda3/envs/cueq/lib/python3.11/site-packages/torch/include/TH -isystem /home2/lsy/miniconda3/envs/cueq/lib/python3.11/site-packages/torch/include/THC -isystem /usr/local/cuda-12.4/include -isystem /home2/lsy/miniconda3/envs/cueq/include/python3.11 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -c /home2/lsy/mdsim/fused_e3nn/fused_e3nn_kernel/sptp_linear_bwd.cpp -o sptp_linear_bwd.o 
[2/3] /usr/local/cuda-12.4/bin/nvcc --generate-dependencies-with-compile --dependency-output bwd_sptp_linear.cuda.o.d -DTORCH_EXTENSION_NAME=sptp_linear_bwd_v1 -DTORCH_API_INCLUDE_EXTE

Loading extension module sptp_linear_bwd_v1...


In [38]:
def mul_Irreps(mul, i_in):
    dd = []
    for ori_mul, ir in i_in:
        dd.append((ori_mul*mul, (ir.l, ir.p)))
    return e3nn.o3.Irreps(dd)
def compare(a, b):
    isclose = torch.isclose(a, b)
    diff_pos = torch.argwhere(isclose == False)
    for pos in diff_pos:
        pos_t = [x for x in pos]
        if(abs(a[pos_t] - b[pos_t]) > 1e-7):
            print(pos)
            print(a[pos_t] - b[pos_t] )
            
IR_IN1_IDX = 0
IR_IN2_IDX = 1
IR_OUT_IDX = 2
INST_IDX = 3

def load_nequip_config(h, l_max, layer_idx):
    filename = f"/home2/lsy/mdsim/nequip/benchmark_config/4_{h}_{l_max}_p_sc.txt"
    with open(filename, "r") as f:
        f_in = f.read().split("\n")

    per_layer_dict = dict()
    for l_idx, d in enumerate(f_in):
        if(d == "") : continue
        dd = json.loads(d)
        per_layer_dict[l_idx] = dd
    tp_list = per_layer_dict[layer_idx]["tp"]
    i_in1 = e3nn.o3.Irreps(tp_list[IR_IN1_IDX])
    i_in2 = e3nn.o3.Irreps(tp_list[IR_IN2_IDX])
    i_out = e3nn.o3.Irreps(tp_list[IR_OUT_IDX])
    inst_tuple = [tuple(x) for x in tp_list[INST_IDX]]

    return i_in1, i_in2, i_out, inst_tuple


In [39]:
def to_cuda_list(*arg, input_dtype = torch.float32):
    return_list = []
    for item in arg:
        if(type(item) == torch.Tensor):
            return_list.append(item.to(device="cuda"))
        else:
            return_list.append(torch.tensor(item,device="cuda", dtype=input_dtype))
    return return_list

def to_cuda_dict(strname_list, *arg):
    return_dict = {}
    for item,name in zip(arg,strname_list):
        if(type(item) == torch.Tensor):
            return_dict[name] = item.to("cuda")
        else:
            return_dict[name] = torch.tensor(item,device="cuda")
    return return_dict

def cumsum_list(s, np1 = True):
    new_s = []
    current = 0
    for e in s:
        new_s.append(current)
        current += e
    if(np1):
        new_s.append(current)
    return new_s

In [40]:
torch.manual_seed(0)

h = 32
l_max = 1
layer_idx = 3
batch_size = 4096
i_in1, i_in2, i_out, inst_tuple = load_nequip_config(h,l_max,layer_idx)
# i_in2 = mul_Irreps(3, i_in2)

uvuv_tp = e3nn.o3.FullTensorProduct(i_in1,i_in2, filter_ir_out=i_out, path_normalization="none", normalization="none")
uvuv_i_out = uvuv_tp.irreps_out

split_size = []
reshape_size = []
for inst in uvuv_tp.instructions:
    split_size.append(uvuv_i_out[inst.i_out].dim)
    reshape_size.append([inst.path_shape[0],inst.path_shape[1],uvuv_i_out[inst.i_out][1].dim])
weight_mul = e3nn.o3.experimental.FullTensorProduct_uvu_weight_only(i_in1, i_in2, split_size, reshape_size, filter_ir_out=i_out, irrep_normalization=None, regroup_output=False)
# uvw
# i_out = e3nn.o3.Irreps(tp_list[IR_OUT_IDX])
# tp = e3nn.o3.FullyConnectedTensorProduct(i_in1,i_in2,i_out,shared_weights=False, internal_weights=False)

# # uvu
tp = e3nn.o3.TensorProduct(i_in1,i_in2,i_out,inst_tuple,shared_weights=False, internal_weights=False, path_normalization="none", normalization="none") # 



In [41]:
grad_uvu = e3nn.o3.experimental.FullTensorProduct_grad_uvu(i_in1,i_in2, filter_ir_out=i_out, irrep_normalization=None, regroup_output=False)

In [42]:
# full tp -> linear
# i_out = full_tp.irreps_out
unique_cg = []
unique_cg_mat = {}
nnz_cg_cnt = 0
all_cg_cnt = 0
cg_dummy = torch.zeros(i_in1.dim, i_in2.dim, uvuv_i_out.dim)
cg_dummy_coverage = torch.zeros(i_in1.dim, i_in2.dim, uvuv_i_out.dim)
for inst in uvuv_tp.instructions:
    i = inst.i_in1
    j = inst.i_in2
    k = inst.i_out

    mul_in1, ir_in1 = i_in1[i]
    mul_in2, ir_in2 = i_in2[j]
    mul_out, ir_out = uvuv_i_out[k]

    cg = e3nn.o3.wigner_3j(ir_in1.l, ir_in2.l, ir_out.l)
    unique_cg += list(cg.unique())
    all_cg_cnt+= cg.numel()
    nnz_cg_cnt += cg.count_nonzero()

    partial_mat_cg = torch.zeros(i_in1[i].dim, i_in2[j].dim, uvuv_i_out[k].dim)
    # print(cg)
    unique_cg_mat[f"{ir_in1.l}_{ir_in2.l}_{ir_out.l}"] = cg
    
    ## uvuv
    for u,v in itertools.product(range(mul_in1), range(mul_in2)):
        partial_mat_cg [u*ir_in1.dim:(u+1)*ir_in1.dim,
        v*ir_in2.dim:(v+1)*ir_in2.dim,
        (u*mul_in2+v)*ir_out.dim:(u*mul_in2+v+1)*ir_out.dim] = cg 

    cg_dummy[i_in1.slices()[i], i_in2.slices()[j], uvuv_i_out.slices()[k]] = partial_mat_cg
    cg_dummy_coverage[i_in1.slices()[i], i_in2.slices()[j], uvuv_i_out.slices()[k]] = 1
print("compute density", nnz_cg_cnt / all_cg_cnt)

compute density tensor(0.2909)


In [43]:
unique_cg_val = list(set([x.item() for x in unique_cg]))

In [44]:
per_path_fiber_start_orignal = [0]
fiber_array_orignal = []

for inst in tp.instructions:
    path_cg = unique_cg_mat[f"{i_in1[inst.i_in1][1].l}_{i_in2[inst.i_in2][1].l}_{i_out[inst.i_out][1].l}"]
    for i,j,k in path_cg.nonzero():
        cg_idx = unique_cg_val.index(path_cg[i,j,k])
        fiber_array_orignal.append([i.item(),j.item(),k.item(),cg_idx])
    per_path_fiber_start_orignal.append(len(fiber_array_orignal))

In [45]:
per_in1_path = {}
for inst in tp.instructions:
    if(inst.i_in1 not in per_in1_path):
        per_in1_path[inst.i_in1] = []
    per_in1_path[inst.i_in1].append([inst.i_out, inst.i_in2, inst.path_weight])

In [108]:
tp_inst_outorder = sorted(tp.instructions, key=lambda x : x.i_out)

(2, 5, 0, 8, 1, 6, 7, 3, 4, 9)

In [77]:
weight_mul.weight_uv_pair_sorted_slice

[Chunk(mul=32, dim=1, slice=slice(0, 32, None)),
 Chunk(mul=32, dim=1, slice=slice(32, 64, None)),
 Chunk(mul=32, dim=1, slice=slice(64, 96, None)),
 Chunk(mul=32, dim=1, slice=slice(96, 128, None)),
 Chunk(mul=32, dim=1, slice=slice(128, 160, None)),
 Chunk(mul=32, dim=1, slice=slice(160, 192, None)),
 Chunk(mul=32, dim=1, slice=slice(192, 224, None)),
 Chunk(mul=32, dim=1, slice=slice(224, 256, None)),
 Chunk(mul=32, dim=1, slice=slice(256, 288, None)),
 Chunk(mul=32, dim=1, slice=slice(288, 320, None))]

In [None]:
WARPSIZE = 32
in1_idxing = [0]
in1_ival = []
in1_related_path_idx = [0]

path_array1 = []
path_array2 = []
path_weight = []
per_path_weight_pos = []

per_path_fiber_start = []
fiber_array = []

in1_slices = i_in1.slices()
in2_slices = i_in2.slices()
out_slices = i_out.slices()

real_path_cnt = 0

for idx, (mul,ir) in enumerate(i_in1):
    assert (mul%WARPSIZE ==0)
    in1_idx_start = in1_slices[idx].start
    in1_idx_end = in1_slices[idx].stop
    i_val = ir.dim
    
    if mul >= WARPSIZE:
        for i in range(mul//WARPSIZE):
            in1_idxing.append(in1_idx_start + WARPSIZE*i_val*(i+1))
            in1_ival.append(i_val)
            in1_related_path_idx.append(in1_related_path_idx[-1] + len(per_in1_path[idx]))
            
            dummy_list = []
            dummy_list2 = []
            # Bug? TODO:
            for k_idx, j_idx, pw in per_in1_path[idx]:
                print(idx, j_idx, k_idx)
                # should be in order
                fiber_start = per_path_fiber_start_orignal[real_path_cnt]
                fiber_end = per_path_fiber_start_orignal[real_path_cnt+1]
                
                new_fiber_start = len(fiber_array)
                new_fiber_end = new_fiber_start + fiber_end - fiber_start

                per_path_fiber_start.append([new_fiber_start, new_fiber_end])
                # print(fiber_array_orignal[1:4])
                # print(fiber_start,fiber_end)
                # print(fiber_array_orignal[fiber_start:fiber_end])
                
                fiber_array+= fiber_array_orignal[fiber_start:fiber_end]

                dummy_list.append([out_slices[k_idx].start + WARPSIZE*i_out[k_idx].ir.dim * i,
                                   out_slices[k_idx].start + WARPSIZE*i_out[k_idx].ir.dim * (i+1)
                                   ])
                dummy_list2.append([i_out[k_idx].ir.dim,
                                    in2_slices[j_idx].start,
                                    i_in2[j_idx].ir.dim,
                                    in2_slices[j_idx].stop])
                path_weight.append(pw)
                
                # TODO:??
                per_path_weight_pos.append(weight_mul.weight_uv_pair_sorted_slice[real_path_cnt].slice.start + WARPSIZE*i)
                real_path_cnt+=1


            path_array1.append(dummy_list)
            path_array2.append(dummy_list2)

0 0 2
0 1 4
1 0 0
1 1 7
2 0 8
2 1 1
2 1 5
3 0 6
3 1 3
3 1 9


In [None]:
# copy for case i screw up

# WARPSIZE = 32
# in1_idxing = [0]
# in1_ival = []
# in1_related_path_idx = [0]

# path_array1 = []
# path_array2 = []
# path_weight = []
# per_path_weight_pos = []

# per_path_fiber_start = []
# fiber_array = []

# in1_slices = i_in1.slices()
# in2_slices = i_in2.slices()
# out_slices = i_out.slices()

# real_path_cnt = 0

# for idx, (mul,ir) in enumerate(i_in1):
#     assert (mul%WARPSIZE ==0)
#     in1_idx_start = in1_slices[idx].start
#     in1_idx_end = in1_slices[idx].stop
#     i_val = ir.dim
    
#     if mul >= WARPSIZE:
#         for i in range(mul//WARPSIZE):
#             in1_idxing.append(in1_idx_start + WARPSIZE*i_val*(i+1))
#             in1_ival.append(i_val)
#             in1_related_path_idx.append(in1_related_path_idx[-1] + len(per_in1_path[idx]))
            
#             dummy_list = []
#             dummy_list2 = []
#             # Bug? TODO:
#             for k_idx, j_idx, pw in per_in1_path[idx]:
#                 print(idx, j_idx, k_idx)
#                 # should be in order
#                 fiber_start = per_path_fiber_start_orignal[real_path_cnt]
#                 fiber_end = per_path_fiber_start_orignal[real_path_cnt+1]
                
#                 new_fiber_start = len(fiber_array)
#                 new_fiber_end = new_fiber_start + fiber_end - fiber_start

#                 per_path_fiber_start.append([new_fiber_start, new_fiber_end])
#                 # print(fiber_array_orignal[1:4])
#                 # print(fiber_start,fiber_end)
#                 # print(fiber_array_orignal[fiber_start:fiber_end])
                
#                 fiber_array+= fiber_array_orignal[fiber_start:fiber_end]

#                 dummy_list.append([out_slices[k_idx].start + WARPSIZE*i_out[k_idx].ir.dim * i,
#                                    out_slices[k_idx].start + WARPSIZE*i_out[k_idx].ir.dim * (i+1)
#                                    ])
#                 dummy_list2.append([i_out[k_idx].ir.dim,
#                                     in2_slices[j_idx].start,
#                                     i_in2[j_idx].ir.dim,
#                                     in2_slices[j_idx].stop])
#                 path_weight.append(pw)
                
#                 # TODO:??
#                 per_path_weight_pos.append(weight_mul.weight_uv_pair_sorted_slice[real_path_cnt].slice.start + WARPSIZE*i)
#                 real_path_cnt+=1


#             path_array1.append(dummy_list)
#             path_array2.append(dummy_list2)

In [95]:
per_path_weight_pos

[0, 32, 64, 96, 128, 160, 192, 224, 256, 288]

In [96]:
t_in1_idxing = torch.tensor(in1_idxing, dtype=torch.int32, device="cuda")
t_in1_ival = torch.tensor(in1_ival, dtype=torch.int32, device="cuda")
t_in1_related_path_idx = torch.tensor(in1_related_path_idx, dtype=torch.int32, device="cuda")

t_path_array1 = torch.tensor(list(itertools.chain.from_iterable(path_array1)), dtype=torch.uint16, device="cuda")
t_path_array2 = torch.tensor(list(itertools.chain.from_iterable(path_array2)), dtype=torch.uint8, device="cuda")
t_path_weight = torch.tensor(path_weight, dtype=torch.float32, device="cuda")
t_per_path_weight_pos = torch.tensor(per_path_weight_pos, dtype=torch.int32, device="cuda")

t_per_path_fiber_start = torch.tensor(per_path_fiber_start, dtype=torch.uint16, device="cuda")
t_fiber_array = torch.tensor(fiber_array, dtype=torch.uint8, device="cuda")

t_unique_cg_val = torch.tensor(unique_cg_val, dtype=torch.float32, device="cuda")


In [97]:
in1 = torch.rand(batch_size, i_in1.dim, device="cuda", requires_grad=True)
in2 = torch.rand(batch_size, i_in2.dim, device="cuda", requires_grad=True)
# weight = torch.rand(batch_size,tp.weight_numel, device="cuda", requires_grad=True)
weight = torch.ones(batch_size,tp.weight_numel, device="cuda", requires_grad=True)
grad_uvu = grad_uvu.cuda()
tp = tp.cuda()
out = tp(in1,in2,weight)
# out = grad_uvu(in1,in2,weight)
out.retain_grad()
y = out.sum()
y.retain_grad()
y.backward()

In [98]:
out2 = grad_uvu(in1,in2,weight)

In [99]:
# sum = 0
# for x in grad_uvu.w_result_list:
#     sum += x.reshape(batch_size,-1).shape[1]
# print(sum)

In [100]:
path_cnt = len(in1_idxing)-1
mem_dl_din1 = torch.zeros_like(in1)
mem_dl_din2 = torch.zeros((batch_size, i_in2.dim * path_cnt), device="cuda")
mem_dl_dw = torch.zeros_like(weight)
# mem_dl_do = torch.cat([x.grad.reshape(batch_size,-1) for x in grad_uvu.w_result_list],dim=1)
mem_dl_do = torch.ones((batch_size, i_out.dim), device="cuda")
mem_debug = torch.ones((batch_size, i_out.dim), device="cuda") * -1

In [101]:
sptp_bwd.sptp_linear_bwd_v1(in1,in2,weight, mem_dl_do, mem_dl_din1, mem_dl_din2, mem_dl_dw, mem_debug,
                            t_in1_idxing, t_in1_ival , t_in1_related_path_idx, t_path_array1,t_path_array2,t_per_path_fiber_start, t_path_weight, t_per_path_weight_pos, t_fiber_array,t_unique_cg_val, len(t_path_array1), 1
                            )

In [107]:
in1.grad[0]

tensor([1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903,
        1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903,
        1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903,
        1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903,
        1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903,
        1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903,
        1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903,
        1.4903, 0.6396, 0.5487, 0.4844, 0.6396, 0.5487, 0.4844, 0.6396, 0.5487,
        0.4844, 0.6396, 0.5487, 0.4844, 0.6396, 0.5487, 0.4844, 0.6396, 0.5487,
        0.4844, 0.6396, 0.5487, 0.4844, 0.6396, 0.5487, 0.4844, 0.6396, 0.5487,
        0.4844, 0.6396, 0.5487, 0.4844, 0.6396, 0.5487, 0.4844, 0.6396, 0.5487,
        0.4844, 0.6396, 0.5487, 0.4844, 0.6396, 0.5487, 0.4844, 0.6396, 0.5487,
        0.4844, 0.6396, 0.5487, 0.4844, 

In [104]:
mem_dl_din1[0]

tensor([1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903,
        1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903,
        1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903,
        1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903,
        1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903,
        1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903,
        1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903, 1.4903,
        1.4903, 0.6396, 0.5487, 0.4844, 0.6396, 0.5487, 0.4844, 0.6396, 0.5487,
        0.4844, 0.6396, 0.5487, 0.4844, 0.6396, 0.5487, 0.4844, 0.6396, 0.5487,
        0.4844, 0.6396, 0.5487, 0.4844, 0.6396, 0.5487, 0.4844, 0.6396, 0.5487,
        0.4844, 0.6396, 0.5487, 0.4844, 0.6396, 0.5487, 0.4844, 0.6396, 0.5487,
        0.4844, 0.6396, 0.5487, 0.4844, 0.6396, 0.5487, 0.4844, 0.6396, 0.5487,
        0.4844, 0.6396, 0.5487, 0.4844, 

In [None]:
for idx, (i, j) in enumerate(zip(weight.grad[0], mem_dl_dw[0])):
    print(idx, torch.isclose(i, j).item(), i.item(), j.item())

0 True 0.10493971407413483 0.10493971407413483
1 True 0.47238099575042725 0.47238099575042725
2 True 0.4347875416278839 0.4347875416278839
3 True 0.12449125200510025 0.12449125200510025
4 True 0.42273008823394775 0.42273008823394775
5 True 0.39295703172683716 0.39295703172683716
6 True 0.17435835301876068 0.17435835301876068
7 True 0.4256122410297394 0.4256122410297394
8 True 0.04595226049423218 0.04595226049423218
9 True 0.43674543499946594 0.43674543499946594
10 True 0.4691219627857208 0.4691219627857208
11 True 0.2431338131427765 0.2431338131427765
12 True 0.3432729244232178 0.3432729244232178
13 True 0.3002142012119293 0.3002142012119293
14 True 0.4162345230579376 0.4162345230579376
15 True 0.09332073479890823 0.09332073479890823
16 True 0.42511647939682007 0.42511647939682007
17 True 0.32560786604881287 0.32560786604881287
18 True 0.039402563124895096 0.039402563124895096
19 True 0.12959089875221252 0.12959089875221252
20 True 0.08066482096910477 0.08066482096910477
21 True 0.0166

In [61]:
t_in1_related_path_idx

tensor([ 0,  2,  4,  7, 10], device='cuda:0', dtype=torch.int32)

In [62]:
path_idx = 6
t_path_array1[path_idx]

tensor([224, 320], device='cuda:0', dtype=torch.uint16)

In [63]:
t_per_path_fiber_start[path_idx]

tensor([14, 20], device='cuda:0', dtype=torch.uint16)

In [64]:
t_fiber_array[14:20]

tensor([[0, 1, 2, 2],
        [0, 2, 1, 4],
        [1, 0, 2, 4],
        [1, 2, 0, 2],
        [2, 0, 1, 2],
        [2, 1, 0, 4]], device='cuda:0', dtype=torch.uint8)

In [65]:
t_unique_cg_val

tensor([ 0.0000,  1.0000,  0.4082,  0.5774, -0.4082], device='cuda:0')

In [66]:
inst = tp.instructions[5]
print(inst)

Instruction(i_in1=2, i_in2=1, i_out=1, connection_mode='uvu', has_weight=True, path_weight=1.0, path_shape=(32, 1))


In [67]:
unique_cg_mat[f"{i_in1[inst.i_in1][1].l}_{i_in2[inst.i_in2][1].l}_{i_out[inst.i_out][1].l}"]

tensor([[[0.5774],
         [0.0000],
         [0.0000]],

        [[0.0000],
         [0.5774],
         [0.0000]],

        [[0.0000],
         [0.0000],
         [0.5774]]])

In [71]:
for idx, (i, j) in enumerate(zip(out[3], mem_debug[3])):
    print(idx, torch.isclose(i, j).item(), i.item(), j.item())

0 True 0.0675240084528923 0.0675240084528923
1 True 0.03700258955359459 0.03700258955359459
2 True 0.07221638411283493 0.07221638411283493
3 True 0.12463438510894775 0.12463438510894775
4 True 0.017856916412711143 0.017856916412711143
5 True 0.018384918570518494 0.018384918570518494
6 True 0.0911894366145134 0.0911894366145134
7 True 0.11837322264909744 0.11837322264909744
8 True 0.043513428419828415 0.043513428419828415
9 True 0.06094146519899368 0.06094146519899368
10 True 0.11206122487783432 0.11206122487783432
11 True 0.05237733572721481 0.05237733572721481
12 True 0.08933717757463455 0.08933717757463455
13 True 0.01772143319249153 0.01772143319249153
14 True 0.07016905397176743 0.07016905397176743
15 True 0.008081362582743168 0.008081362582743168
16 True 0.036737993359565735 0.036737993359565735
17 True 0.1359482854604721 0.1359482854604721
18 True 0.024527616798877716 0.024527616798877716
19 True 0.12778133153915405 0.12778133153915405
20 True 0.12556631863117218 0.12556631863117

In [None]:
correct_dummy = torch.zeros(i_out.dim)

In [None]:
for idx, i in enumerate(t_path_array1):
    dd = i.to(int).tolist()
    j_start = t_path_array2[idx][1]
    j_val = t_path_array2[idx][2]
    print(j_val.item())
    # correct_dummy[dd[0]:dd[1]] = t_fiber_array[t_per_path_fiber_start[idx][0]][1]
    fiber_y = t_fiber_array[t_per_path_fiber_start[idx][0]][1].item()
    correct_dummy[dd[0]:dd[1]] = in2[0][j_start.item()+fiber_y]
print(correct_dummy)

1
3
1
3
1
3
3
1
3
3
tensor([0.0194, 0.0194, 0.0194, 0.0194, 0.0194, 0.0194, 0.0194, 0.0194, 0.0194,
        0.0194, 0.0194, 0.0194, 0.0194, 0.0194, 0.0194, 0.0194, 0.0194, 0.0194,
        0.0194, 0.0194, 0.0194, 0.0194, 0.0194, 0.0194, 0.0194, 0.0194, 0.0194,
        0.0194, 0.0194, 0.0194, 0.0194, 0.0194, 0.7733, 0.7733, 0.7733, 0.7733,
        0.7733, 0.7733, 0.7733, 0.7733, 0.7733, 0.7733, 0.7733, 0.7733, 0.7733,
        0.7733, 0.7733, 0.7733, 0.7733, 0.7733, 0.7733, 0.7733, 0.7733, 0.7733,
        0.7733, 0.7733, 0.7733, 0.7733, 0.7733, 0.7733, 0.7733, 0.7733, 0.7733,
        0.7733, 0.0194, 0.0194, 0.0194, 0.0194, 0.0194, 0.0194, 0.0194, 0.0194,
        0.0194, 0.0194, 0.0194, 0.0194, 0.0194, 0.0194, 0.0194, 0.0194, 0.0194,
        0.0194, 0.0194, 0.0194, 0.0194, 0.0194, 0.0194, 0.0194, 0.0194, 0.0194,
        0.0194, 0.0194, 0.0194, 0.0194, 0.0194, 0.0194, 0.7733, 0.7733, 0.7733,
        0.7733, 0.7733, 0.7733, 0.7733, 0.7733, 0.7733, 0.7733, 0.7733, 0.7733,
        0.7733, 0.77

In [None]:
for idx, x in enumerate(mem_debug[0]):
    print(idx, x.item())

0 18.0
1 18.0
2 18.0
3 18.0
4 18.0
5 18.0
6 18.0
7 18.0
8 18.0
9 18.0
10 18.0
11 18.0
12 18.0
13 18.0
14 18.0
15 18.0
16 18.0
17 18.0
18 18.0
19 18.0
20 18.0
21 18.0
22 18.0
23 18.0
24 18.0
25 18.0
26 18.0
27 18.0
28 18.0
29 18.0
30 18.0
31 18.0
32 27.0
33 27.0
34 27.0
35 27.0
36 27.0
37 27.0
38 27.0
39 27.0
40 27.0
41 27.0
42 27.0
43 27.0
44 27.0
45 27.0
46 27.0
47 27.0
48 27.0
49 27.0
50 27.0
51 27.0
52 27.0
53 27.0
54 27.0
55 27.0
56 27.0
57 27.0
58 27.0
59 27.0
60 27.0
61 27.0
62 27.0
63 27.0
64 47.0
65 47.0
66 47.0
67 47.0
68 47.0
69 47.0
70 47.0
71 47.0
72 47.0
73 47.0
74 47.0
75 47.0
76 47.0
77 47.0
78 47.0
79 47.0
80 47.0
81 47.0
82 47.0
83 47.0
84 47.0
85 47.0
86 47.0
87 47.0
88 47.0
89 47.0
90 47.0
91 47.0
92 47.0
93 47.0
94 47.0
95 47.0
96 94.0
97 94.0
98 94.0
99 94.0
100 94.0
101 94.0
102 94.0
103 94.0
104 94.0
105 94.0
106 94.0
107 94.0
108 94.0
109 94.0
110 94.0
111 94.0
112 94.0
113 94.0
114 94.0
115 94.0
116 94.0
117 94.0
118 94.0
119 94.0
120 94.0
121 94.0
122 94.0
123

In [None]:
in1_idx = 3
path_start = in1_related_path_idx[in1_idx]
path_end = in1_related_path_idx[in1_idx+1]
for path_idx in range(path_start, path_end):
    for fiber_idx in range(per_path_fiber_start[path_idx][0], per_path_fiber_start[path_idx][1]):
        print(fiber_idx)

27
28
29
30
31
32
33
34
35


In [None]:
path_array1

[[[192, 224], [384, 480], [4608, 4768]],
 [[224, 256], [480, 576], [4768, 4928]],
 [[0, 32], [1536, 1632], [2688, 2848]],
 [[32, 64], [1632, 1728], [2848, 3008]],
 [[1728, 1824],
  [64, 96],
  [576, 672],
  [3008, 3168],
  [1920, 2016],
  [4928, 5088]],
 [[1824, 1920],
  [96, 128],
  [672, 768],
  [3168, 3328],
  [2016, 2112],
  [5088, 5248]],
 [[5248, 5408],
  [768, 864],
  [3328, 3488],
  [256, 288],
  [2112, 2208],
  [5568, 5728]],
 [[5408, 5568],
  [864, 960],
  [3488, 3648],
  [288, 320],
  [2208, 2304],
  [5728, 5888]],
 [[960, 1056],
  [320, 352],
  [2304, 2400],
  [5888, 6048],
  [1152, 1248],
  [3648, 3808]],
 [[1056, 1152],
  [352, 384],
  [2400, 2496],
  [6048, 6208],
  [1248, 1344],
  [3808, 3968]],
 [[3968, 4128],
  [2496, 2592],
  [6208, 6368],
  [128, 160],
  [1344, 1440],
  [4288, 4448]],
 [[4128, 4288],
  [2592, 2688],
  [6368, 6528],
  [160, 192],
  [1440, 1536],
  [4448, 4608]]]

In [None]:
weight.grad

tensor([[ 0.2848,  0.3687,  0.0178,  ...,  0.1173,  0.1546, -0.2277],
        [ 0.1044,  0.1084,  0.1919,  ...,  0.5142,  0.1884,  0.3163],
        [ 0.1645,  0.0727,  0.0392,  ...,  0.4114,  0.0661,  0.2192],
        ...,
        [ 0.2368,  0.1596,  0.1389,  ...,  0.2766,  0.3911,  0.3093],
        [ 0.1196,  0.0221,  0.4875,  ..., -0.1189,  0.3322, -0.1699],
        [ 0.7679,  0.5451,  0.1425,  ...,  0.1081,  0.1311,  0.5657]],
       device='cuda:0')

In [None]:
mem_dl_dw

tensor([[0.5778, 0.4255, 0.3935,  ..., 0.0000, 0.0000, 0.0000],
        [0.1009, 0.0397, 0.1317,  ..., 0.0000, 0.0000, 0.0000],
        [0.1894, 0.1852, 0.0642,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0094, 0.0797, 0.5361,  ..., 0.0000, 0.0000, 0.0000],
        [0.4011, 0.5824, 0.6059,  ..., 0.0000, 0.0000, 0.0000],
        [0.3460, 0.5593, 0.1543,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')

In [None]:
print(mem_debug)

tensor([[0.5778, 0.4255, 0.3935,  ..., 0.0000, 0.0000, 0.0000],
        [0.1009, 0.0397, 0.1317,  ..., 0.0000, 0.0000, 0.0000],
        [0.1894, 0.1852, 0.0642,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0094, 0.0797, 0.5361,  ..., 0.0000, 0.0000, 0.0000],
        [0.4011, 0.5824, 0.6059,  ..., 0.0000, 0.0000, 0.0000],
        [0.3460, 0.5593, 0.1543,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


In [None]:
path_idx = 4
dL_dO = torch.cat([x.reshape(batch_size,-1) for x in grad_uvu.w_result_list],dim=1)
W = grad_uvu.used_w_list[path_idx]


# dL_dO = grad_uvu.w_result_list[path_idx].grad
W = grad_uvu.used_w_list[path_idx]
O = grad_uvu.w_result_list[path_idx]
uvuv = grad_uvu.cg_result_list[path_idx]
uvuv_gen = (O / W).reshape(uvuv.shape)
cg = grad_uvu.used_cg_list[path_idx]

dL_duvuv = torch.einsum("zuk,zuv -> zuvk", dL_dO, W)
dL_dW = torch.einsum("zuk,zuvk -> zuv", dL_dO, uvuv_gen)
# only place with sparsity
dL_dOuter = torch.einsum("zuvk, ijk -> zuvij", dL_duvuv, cg)
# dL_dA = torch.einsum("zuvij,zvj -> zui", dL_dOuter, in2)
# dL_dB = torch.einsum("zuvij,zui -> zvj", dL_dOuter, in1)

In [None]:
weight.grad

tensor([[ 2.7819e-01,  2.9000e-01,  4.4694e-01,  ..., -1.1819e-01,
          2.1405e-01,  3.0741e-01],
        [ 1.1526e-02,  3.8194e-04,  2.1035e-03,  ...,  3.2508e-02,
          2.3617e-01,  2.9659e-01],
        [ 2.1938e-01,  4.6870e-01,  5.5882e-01,  ..., -1.1362e-01,
         -3.5503e-03,  2.2558e-01],
        ...,
        [ 1.9734e-01,  1.2222e-01,  1.0069e-02,  ...,  1.2758e-01,
          3.7898e-01,  9.5604e-02],
        [ 7.9995e-01,  3.2470e-01,  6.9050e-01,  ...,  1.2743e-01,
          6.8164e-02,  2.8999e-01],
        [ 2.4270e-02,  4.9620e-01,  4.9715e-01,  ..., -4.0044e-02,
          1.7810e-01, -1.1853e-01]], device='cuda:0')

In [None]:
for idx, i in enumerate(mem_dl_dw[0]):
    print(idx, i.item())

0 0.5777977108955383
1 0.42547371983528137
2 0.3934994339942932
3 0.29751643538475037
4 0.21127715706825256
5 0.4872279763221741
6 0.33741071820259094
7 0.30186495184898376
8 0.5012400150299072
9 0.5035048127174377
10 0.42413073778152466
11 0.12352830916643143
12 0.44583016633987427
13 0.44152140617370605
14 0.5126194357872009
15 0.6264858245849609
16 0.6752227544784546
17 0.3462996184825897
18 0.00413484638556838
19 0.6628735065460205
20 0.5218164324760437
21 0.21840308606624603
22 0.6934933066368103
23 0.5333660244941711
24 0.18424765765666962
25 0.2628057301044464
26 0.4452141225337982
27 0.06925690919160843
28 0.5831878781318665
29 0.4710877537727356
30 0.25288960337638855
31 0.4058476388454437
32 0.638543426990509
33 0.1728799194097519
34 0.4561007022857666
35 0.5625309348106384
36 0.47250548005104065
37 0.6292127370834351
38 0.5340367555618286
39 0.2722375988960266
40 0.02227296493947506
41 0.39838480949401855
42 0.5334016680717468
43 0.4973806142807007
44 0.5723694562911987
45 0

In [None]:
path_cnt

12

In [None]:
mem_dl_din1

tensor([[0.6131, 0.7756, 0.4873,  ..., 0.0000, 0.0000, 0.0000],
        [0.3711, 0.7441, 0.7201,  ..., 0.0000, 0.0000, 0.0000],
        [0.0597, 0.1072, 0.1581,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.4090, 0.4015, 0.7214,  ..., 0.0000, 0.0000, 0.0000],
        [0.9178, 0.3077, 0.6242,  ..., 0.0000, 0.0000, 0.0000],
        [0.6195, 0.9485, 1.1998,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')

In [None]:
mem_dl_din2

tensor([[ 7.9286,  4.1022,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 6.6009,  3.9261,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 6.1457,  4.3972,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 6.8278, 11.8569,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 8.1584, 14.4260,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 7.1088, 13.0698,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')

In [None]:
in1.grad

tensor([[ 1.7356,  0.8162,  1.1408,  ...,  0.1279,  0.2492,  0.2180],
        [ 0.6783,  0.8351,  0.5105,  ..., -0.0989,  0.2028,  0.2027],
        [ 1.0968,  0.3952,  0.3919,  ...,  0.2731,  0.6295,  0.0823],
        ...,
        [ 0.2124,  1.5009,  1.2702,  ...,  0.1347,  0.5793,  0.0400],
        [ 0.8752,  0.7649,  1.6928,  ...,  0.1046,  0.5329,  0.1785],
        [ 2.1436,  1.8458,  1.6488,  ...,  0.4176,  0.8223,  0.0997]],
       device='cuda:0')

In [None]:
for idx, x in enumerate(mem_dl_din1[0]):
    print(idx, x.item())

0 0.00761428801342845
1 0.008617337793111801
2 0.04493430629372597
3 0.0024862992577254772
4 0.008869743905961514
5 0.046693701297044754
6 0.01662699319422245
7 0.012019012123346329
8 0.011984731070697308
9 -0.00966546218842268
10 0.016029082238674164
11 0.0029243382159620523
12 0.0036543344613164663
13 0.0028815194964408875
14 0.008998558856546879
15 0.003233320079743862
16 0.019768288359045982
17 0.016636699438095093
18 0.017954720184206963
19 0.007235220167785883
20 0.01593690924346447
21 0.0008762934594415128
22 0.005930520128458738
23 0.0035504354164004326
24 -0.0022801035083830357
25 0.0056976983323693275
26 0.059797707945108414
27 0.018554290756583214
28 0.001596739748492837
29 0.01827997900545597
30 0.012856731191277504
31 0.012358070351183414
32 0.04330095276236534
33 0.044015418738126755
34 0.022041011601686478
35 0.01623249426484108
36 0.002287591341882944
37 0.017636748030781746
38 0.0042628697119653225
39 0.003500852268189192
40 0.0031508684623986483
41 0.00348128285259008

In [None]:
mem_dl_dw

tensor([[ 6.0158e-02,  9.9694e-03,  9.1432e-04,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 5.1375e-04,  1.0229e-04,  5.1135e-05,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 7.6702e-02,  5.0314e-03,  3.5407e-03,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        ...,
        [ 2.6665e-02,  7.4945e-01,  6.2319e-01,  ...,  2.1814e-01,
          3.1595e-01,  3.3164e-01],
        [ 2.2057e-02,  6.9063e-03,  4.8448e-02,  ..., -2.6143e-01,
          1.6202e-02, -9.9778e-01],
        [ 6.9430e-01,  2.3100e-01,  3.1813e-01,  ...,  1.0128e-01,
         -3.6585e-01,  5.1858e-01]], device='cuda:0')

In [None]:
weight.grad

tensor([[ 2.7819e-01,  2.9000e-01,  4.4694e-01,  ..., -1.1819e-01,
          2.1405e-01,  3.0741e-01],
        [ 1.1526e-02,  3.8194e-04,  2.1035e-03,  ...,  3.2508e-02,
          2.3617e-01,  2.9659e-01],
        [ 2.1938e-01,  4.6870e-01,  5.5882e-01,  ..., -1.1362e-01,
         -3.5503e-03,  2.2558e-01],
        ...,
        [ 1.9734e-01,  1.2222e-01,  1.0069e-02,  ...,  1.2758e-01,
          3.7898e-01,  9.5604e-02],
        [ 7.9995e-01,  3.2470e-01,  6.9050e-01,  ...,  1.2743e-01,
          6.8164e-02,  2.8999e-01],
        [ 2.4270e-02,  4.9620e-01,  4.9715e-01,  ..., -4.0044e-02,
          1.7810e-01, -1.1853e-01]], device='cuda:0')

In [None]:
def generate_edgepair (num_node, max_neighbour):
    edge_src = []
    edge_dst = []
    for i in range(num_node):
        num_neighbour = random.randint(0,max_neighbour)
        for j in range(num_neighbour):
            edge_dst.append(i)
            src_idx = i
            while(src_idx == i):
                src_idx = random.randint(0,num_node-1)
            edge_src.append(src_idx)
    return edge_src, edge_dst 

In [None]:
total_node = 128
max_neighbour = 64
edge_src, edge_dst = generate_edgepair(total_node, max_neighbour)

In [None]:
MAX_IR = 11
MAX_IN2 = 36
MAX_IN1_IR_CNT = 32
MAX_NUM_PATH = 512
MAX_U_FIBER_CNT = 5265
MAX_U_CG_VAL_CNT = 344

print(MAX_IN1_IR_CNT*12 + MAX_NUM_PATH*20 + MAX_U_FIBER_CNT*4 + MAX_U_CG_VAL_CNT*4)

33060


In [None]:
print(MAX_IR*5+MAX_IN2)

91


In [None]:
back_w = grad_uvu.used_w_list[0]
dl_dO = grad_uvu.w_result_list[0].grad

In [None]:
# v = 1
# uvuv can easily retrievd by dividing it with path_weight * weight as v is 1

In [None]:
path_idx = 4
dL_dO = grad_uvu.w_result_list[path_idx].grad
W = grad_uvu.used_w_list[path_idx]
O = grad_uvu.w_result_list[path_idx]
uvuv = grad_uvu.cg_result_list[path_idx]
uvuv_gen = (O / W).reshape(uvuv.shape)
cg = grad_uvu.used_cg_list[path_idx]

dL_duvuv = torch.einsum("zuk,zuv -> zuvk", dL_dO, W)
dL_dW = torch.einsum("zuk,zuvk -> zuv", dL_dO, uvuv_gen)
# only place with sparsity
dL_dOuter = torch.einsum("zuvk, ijk -> zuvij", dL_duvuv, cg)
# dL_dA = torch.einsum("zuvij,zvj -> zui", dL_dOuter, in2)
# dL_dB = torch.einsum("zuvij,zui -> zvj", dL_dOuter, in1)

In [None]:
W.shape

torch.Size([4096, 32, 1])

In [None]:
dL_duvuv = torch.einsum("zuk,zuv -> zuvk", dl_dO, back_w)

In [None]:
dL_dw = torch.einsum("zuk,zuvk -> zuv", dl_dO, grad_uvu.cg_result_list[0])

In [None]:
out[0]

tensor([ 0.0056,  0.0788,  0.0625,  ...,  0.0286,  0.0494, -0.0057],
       grad_fn=<SelectBackward0>)

In [None]:
y = torch.rand(correct_out.shape)
loss = (correct_out-y)[0].sum()
loss.backward(retain_graph=True)

In [None]:
total_fiber = 0
for key,val in unique_cg_mat.items():
    # print(key, val.count_nonzero().item())
    total_fiber += val.count_nonzero().item()
print(total_fiber)

5265


In [None]:
tp.weight_numel

7104

In [None]:
len(tp.instructions)

222

In [None]:
i_in1

32x0e+32x0o+32x1e+32x2e+32x3e+32x4e+32x5e+32x1o+32x2o+32x3o+32x4o+32x5o

In [None]:
unique_cg_lookup = list(set([x.item() for x in unique_cg]))


In [None]:
len(unique_cg_lookup)

340

In [None]:
5265 * 4 / 1024 

20.56640625

In [None]:
cg_dummy.count_nonzero()

tensor(39104)

In [None]:
39104 * (1+1+1) / 1024

114.5625

In [None]:
cg_dummy_kij = cg_dummy.permute(2,0,1)
cg_dummy_kij_flat = cg_dummy_kij.reshape(uvuv_i_out.dim, -1)
cg_kij_pos_tuple = cg_dummy_kij.nonzero(as_tuple=True)
cg_kij_pos_feed = cg_dummy_kij.nonzero().to(dtype=torch.int32).contiguous()
cg_kij_val_feed = cg_dummy_kij[cg_kij_pos_tuple]
per_out_fiber = cg_dummy_kij_flat.count_nonzero(dim=1)
unique_cg_lookup = list(set([x.item() for x in unique_cg]))

uni_cg_cnt = len(unique_cg_lookup)
out_size = uvuv_i_out.dim

In [None]:
in1_occurance = torch.bincount(cg_kij_pos_tuple[1])

In [None]:
in2_occurance = torch.bincount(cg_kij_pos_tuple[2])

In [None]:
per_threadblock_data = {}
i_out_slice = i_out.slices()
for inst in tp.instructions:
    if inst.i_in1 not in per_threadblock_data.keys():
        per_threadblock_data[inst.i_in1] = {}
        per_threadblock_data[inst.i_in1]["out_slice"] = []
        per_threadblock_data[inst.i_in1][""] = 0
    per_threadblock_data[inst.i_in1]["out_slice"].append(i_out_slice[inst.i_out])

In [None]:
cg_dummy_kij

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [None]:
mac2nnz_ij = torch.bincount(cg_dummy_kij_flat.nonzero()[:,1])

In [None]:
mac2nnz_ij.to(float).mean()

tensor(2.3867, dtype=torch.float64)

In [None]:
nnz_cnt_for_each_path = []
for inst in tp.instructions:
    nnz_cnt_for_each_path.append((inst.i_in1, i_in1[inst.i_in1].dim//32, i_in2[inst.i_in2].dim, cg_dummy_kij_flat[i_out_slice[inst.i_out]].count_nonzero()//32))

In [None]:
nnz_cnt_for_each_path

[(0, 1, 1, tensor(1)),
 (0, 1, 3, tensor(3)),
 (0, 1, 5, tensor(5)),
 (0, 1, 7, tensor(7)),
 (1, 1, 1, tensor(1)),
 (1, 1, 3, tensor(3)),
 (1, 1, 5, tensor(5)),
 (1, 1, 7, tensor(7)),
 (2, 3, 1, tensor(3)),
 (2, 3, 3, tensor(3)),
 (2, 3, 3, tensor(6)),
 (2, 3, 3, tensor(11)),
 (2, 3, 5, tensor(11)),
 (2, 3, 5, tensor(16)),
 (2, 3, 5, tensor(21)),
 (2, 3, 7, tensor(21)),
 (2, 3, 7, tensor(26)),
 (3, 5, 1, tensor(5)),
 (3, 5, 3, tensor(11)),
 (3, 5, 3, tensor(16)),
 (3, 5, 3, tensor(21)),
 (3, 5, 5, tensor(5)),
 (3, 5, 5, tensor(16)),
 (3, 5, 5, tensor(25)),
 (3, 5, 5, tensor(28)),
 (3, 5, 7, tensor(21)),
 (3, 5, 7, tensor(28)),
 (3, 5, 7, tensor(41)),
 (4, 7, 1, tensor(7)),
 (4, 7, 3, tensor(21)),
 (4, 7, 3, tensor(26)),
 (4, 7, 5, tensor(21)),
 (4, 7, 5, tensor(28)),
 (4, 7, 5, tensor(41)),
 (4, 7, 7, tensor(7)),
 (4, 7, 7, tensor(26)),
 (4, 7, 7, tensor(41)),
 (4, 7, 7, tensor(42)),
 (5, 3, 1, tensor(3)),
 (5, 3, 3, tensor(3)),
 (5, 3, 3, tensor(6)),
 (5, 3, 3, tensor(11)),
 (5, 3, 5,

In [None]:
in1_0_ij_nnz = torch.cat([cg_dummy_kij_flat[s] for s in per_threadblock_data[5]["out_slice"]], dim=0).nonzero()

In [None]:
len(in1_0_ij_nnz[:,1].unique())

1536

In [None]:
for x in in1_0_ij_nnz:
    print(x[:,1].unique())

IndexError: too many indices for tensor of dimension 1

In [None]:
in2_occurance

tensor([ 576, 1152,  896, 1152,  896, 1152,  896, 1152,  896])

In [None]:
in1_occurance

tensor([ 9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,
         9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,
         9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,
         9,  9,  9,  9,  9,  9,  9,  9,  9,  9, 18, 14, 18, 18, 14, 18, 18, 14,
        18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14,
        18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14,
        18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14,
        18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14,
        18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 14, 18,
        14, 18, 14, 14, 18, 14, 18, 14, 14, 18, 14, 18, 14, 14, 18, 14, 18, 14,
        14, 18, 14, 18, 14, 14, 18, 14, 18, 14, 14, 18, 14, 18, 14, 14, 18, 14,
        18, 14, 14, 18, 14, 18, 14, 14, 18, 14, 18, 14, 14, 18, 14, 18, 14, 14,
        18, 14, 18, 14, 14, 18, 14, 18, 

In [None]:
per_in1_fiber = cg_dummy_kij_flat.count_nonzero(dim=1)


In [None]:
in1 = torch.rand(batch_size, i_in1.dim)
in2 = torch.rand(batch_size, i_in2.dim)
out = torch.zeros((batch_size,uvuv_i_out.dim))

uni_w3j = torch.tensor(unique_cg_lookup)

if(tp.shared_weights):
    weight = torch.rand(tp.weight_numel)
else:
    # weight = torch.rand(batch_size,tp.weight_numel)
    weight = torch.ones(batch_size,tp.weight_numel)
    # weight[:,928:960] = 2

path_weight = torch.zeros(i_out.dim) 
for inst in tp.instructions:
    k = inst.i_out
    path_weight[i_out.slices()[k]] = inst.path_weight

In [None]:
blk_u_in1_idx = []
blk_u_in2_idx = []

per_blk_u_in1_idx_range = []
per_blk_u_in2_idx_range = []
per_blk_fiber_start_range = []

per_fiber_local_idx = []
per_fiber_global_idx = []

max_u_in1_dim = 0
max_u_in2_dim = 0 
max_fiber_cnt = 0

pattern_len_array = []
stride_mul_array = []
rem_cum_idx_array = []
rem_cumval_array = []
fiber_cnt_array = []

step = 32
i = 0
while(i< uvuv_i_out.dim):
    local_k_idx = cg_dummy_kij[i:i+step].nonzero()[:,0]
    blk_in1_idx = cg_dummy_kij[i:i+step].nonzero()[:,1]
    blk_in2_idx = cg_dummy_kij[i:i+step].nonzero()[:,2]

    u_in1_list = blk_in1_idx.unique().tolist()
    u_in2_list = blk_in2_idx.unique().tolist()

    max_u_in1_dim = max(max_u_in1_dim, len(u_in1_list))
    max_u_in2_dim = max(max_u_in2_dim, len(u_in2_list))
    max_fiber_cnt = max(max_fiber_cnt, len(blk_in1_idx))

    blk_u_in1_idx.append(u_in1_list)
    blk_u_in2_idx.append(u_in2_list)

    per_blk_u_in1_idx_range.append(len(u_in1_list))
    per_blk_u_in2_idx_range.append(len(u_in2_list))

    per_blk_fiber_start_range.append(len(blk_in1_idx))

    local_in1_idx = [u_in1_list.index(x.item()) for x in blk_in1_idx]
    local_in2_idx = [u_in2_list.index(x.item()) for x in blk_in2_idx]

    for l_a,l_b,l_k,g_a,g_b in zip(local_in1_idx,local_in2_idx, local_k_idx,blk_in1_idx,blk_in2_idx):
        real_k = i + l_k
        cg_idx = unique_cg_lookup.index(cg_dummy_kij[real_k,g_a,g_b])

        per_fiber_local_idx.append(l_a)
        per_fiber_local_idx.append(l_b)
        per_fiber_local_idx.append(cg_idx)
        per_fiber_local_idx.append(l_k)

        per_fiber_global_idx.append(g_a)
        per_fiber_global_idx.append(g_b)
        per_fiber_global_idx.append(cg_idx)
        per_fiber_global_idx.append(real_k)


    pattern = per_out_fiber[i:i+step].tolist()
    pattern_length = 0
    if ([pattern[0]]*len(pattern) == pattern):
        pattern_length=1
    elif (list(itertools.chain.from_iterable(([pattern[0:3]]*len(pattern))))[0:len(pattern)] == pattern):
        pattern_length=3
    elif (list(itertools.chain.from_iterable(([pattern[0:5]]*len(pattern))))[0:len(pattern)] == pattern):
        pattern_length=5
    elif (list(itertools.chain.from_iterable(([pattern[0:7]]*len(pattern))))[0:len(pattern)] == pattern):
        pattern_length=7
    elif (list(itertools.chain.from_iterable(([pattern[0:9]]*len(pattern))))[0:len(pattern)] == pattern):
        pattern_length=9
    elif (list(itertools.chain.from_iterable(([pattern[0:11]]*len(pattern))))[0:len(pattern)] == pattern):
        pattern_length=11
    else:
        print("bad")

    unit_pattern = pattern[0:pattern_length]

    pattern_len_array.append(pattern_length)
    stride_mul_array.append(sum(unit_pattern))
    rem_cum_idx_array.append(pattern_length)

    rem_cumval_array.append(cumsum_list(unit_pattern, False))
    fiber_cnt_array.append(unit_pattern)

    i += step

rem_cum_idx_array = cumsum_list(rem_cum_idx_array,False)
per_blk_u_in1_idx_range  = cumsum_list(per_blk_u_in1_idx_range)
per_blk_u_in2_idx_range  = cumsum_list(per_blk_u_in2_idx_range)
per_blk_fiber_start_range = cumsum_list(per_blk_fiber_start_range)

rem_cumval_array = list(itertools.chain.from_iterable(rem_cumval_array))
fiber_cnt_array = list(itertools.chain.from_iterable(fiber_cnt_array))
blk_u_in1_idx = list(itertools.chain.from_iterable(blk_u_in1_idx))
blk_u_in2_idx = list(itertools.chain.from_iterable(blk_u_in2_idx))

per_fiber_local_idx = torch.tensor(per_fiber_local_idx).to(torch.uint8)
per_fiber_global_idx = torch.tensor(per_fiber_global_idx)

In [None]:
max_u_in1_dim

288

In [None]:
per_blk_w_idx = []
for w_idx, i in enumerate(uvuv_i_out):
    for j in range(i.ir.dim):
        per_blk_w_idx.append(w_idx)

In [None]:
out_cuda = out.to(device="cuda")
main1_cuda = to_cuda_list(in1.T.contiguous(),in2.T.contiguous())
main2_cuda = to_cuda_list(uni_w3j,weight.T.contiguous(),path_weight)
meta_cuda = to_cuda_list(
                    per_blk_w_idx,
                    per_blk_u_in1_idx_range,
                    per_blk_u_in2_idx_range,
                    per_blk_fiber_start_range,
                    
                    blk_u_in1_idx,
                    blk_u_in2_idx,

                    pattern_len_array,
                    stride_mul_array,
                    rem_cum_idx_array,
                    rem_cumval_array,
                    fiber_cnt_array, 
                    
                    per_fiber_local_idx,
                    input_dtype=torch.uint16)
size_input = [max_u_in1_dim,max_u_in2_dim,max_fiber_cnt,uni_cg_cnt]
size_input = [int(x) for x in size_input]

In [None]:
sptp.sptp_linear_v1(*main1_cuda, out_cuda, *main2_cuda,*meta_cuda,*size_input)

In [None]:
out_cuda

tensor([[ 0.0628,  0.0963,  0.0538,  ...,  0.0593,  0.0466, -0.0062],
        [ 0.1053,  0.1894,  0.1876,  ..., -0.0967,  0.0249, -0.1779],
        [ 0.1389,  0.0715,  0.3133,  ...,  0.0856, -0.0235,  0.0522],
        ...,
        [ 0.6284,  0.5477,  0.4419,  ...,  0.1437,  0.0976, -0.0147],
        [ 0.1207,  0.3705,  0.0709,  ...,  0.1003, -0.0310,  0.0856],
        [ 0.3683,  0.0145,  0.4971,  ...,  0.0489, -0.2779, -0.0830]],
       device='cuda:0')

In [None]:
correct_out = tp(in1,in2,weight)

In [None]:
correct_out

tensor([[ 0.0628,  0.0963,  0.0538,  ...,  0.0593,  0.0466, -0.0062],
        [ 0.1053,  0.1894,  0.1876,  ..., -0.0967,  0.0249, -0.1779],
        [ 0.1389,  0.0715,  0.3133,  ...,  0.0856, -0.0235,  0.0522],
        ...,
        [ 0.6284,  0.5477,  0.4419,  ...,  0.1437,  0.0976, -0.0147],
        [ 0.1207,  0.3705,  0.0709,  ...,  0.1003, -0.0310,  0.0856],
        [ 0.3683,  0.0145,  0.4971,  ...,  0.0489, -0.2779, -0.0830]])

In [None]:
index = 4033
torch.isclose(correct_out[index], out_cuda.cpu()[index]).all()
for a,b in zip(correct_out[index], out_cuda.cpu()[index]):
    print(torch.isclose(a,b),a,b)

tensor(True) tensor(0.0449) tensor(0.0449)
tensor(True) tensor(0.0436) tensor(0.0436)
tensor(True) tensor(0.0244) tensor(0.0244)
tensor(True) tensor(0.0281) tensor(0.0281)
tensor(True) tensor(0.0289) tensor(0.0289)
tensor(True) tensor(0.0088) tensor(0.0088)
tensor(True) tensor(0.0269) tensor(0.0269)
tensor(True) tensor(0.0476) tensor(0.0476)
tensor(True) tensor(0.0165) tensor(0.0165)
tensor(True) tensor(0.0056) tensor(0.0056)
tensor(True) tensor(0.0185) tensor(0.0185)
tensor(True) tensor(0.0059) tensor(0.0059)
tensor(True) tensor(0.0575) tensor(0.0575)
tensor(True) tensor(0.0567) tensor(0.0567)
tensor(True) tensor(0.0191) tensor(0.0191)
tensor(True) tensor(0.0153) tensor(0.0153)
tensor(True) tensor(0.0493) tensor(0.0493)
tensor(True) tensor(0.0405) tensor(0.0405)
tensor(True) tensor(0.0193) tensor(0.0193)
tensor(True) tensor(0.0102) tensor(0.0102)
tensor(True) tensor(0.0299) tensor(0.0299)
tensor(True) tensor(0.0065) tensor(0.0065)
tensor(True) tensor(0.0468) tensor(0.0468)
tensor(True

In [None]:
a_u_cnt = []
b_u_cnt = []
fiber_cnt = []
i_cnt = []
step = 32
i = 0
fiber_cnt_pattern = {}
while(i< uvuv_i_out.dim):
    u_in1 = cg_dummy_kij[i:i+step].nonzero()[:,1].unique()
    u_in2 = cg_dummy_kij[i:i+step].nonzero()[:,2].unique()


    l_a_u_cnt = (cg_dummy_kij[i:i+step].nonzero()[:,1].unique().numel())
    l_b_u_cnt = (cg_dummy_kij[i:i+step].nonzero()[:,2].unique().numel())
    l_fiber_cnt = (cg_dummy_kij[i:i+step].nonzero().shape[0])

    pattern = per_out_fiber[i:i+step].tolist()

    # print([pattern[0]]*len(pattern))
    

    # print(list(itertools.chain.from_iterable(([pattern[0:3]]*len(pattern))))[0:len(pattern)])
    # print(list(itertools.chain.from_iterable(([pattern[0:5]]*len(pattern))))[0:len(pattern)])
    
    if ([pattern[0]]*len(pattern) == pattern):
        pattern_tuple = (1, pattern[0])
    elif (list(itertools.chain.from_iterable(([pattern[0:3]]*len(pattern))))[0:len(pattern)] == pattern):
        pattern_tuple = (3, pattern[0:3])
    elif (list(itertools.chain.from_iterable(([pattern[0:5]]*len(pattern))))[0:len(pattern)] == pattern):
        pattern_tuple = (5, pattern[0:5])
    elif (list(itertools.chain.from_iterable(([pattern[0:7]]*len(pattern))))[0:len(pattern)] == pattern):
        pattern_tuple = (7, pattern[0:7])
    else:
        print("bad")
    

    if str(pattern_tuple) not in fiber_cnt_pattern:
        fiber_cnt_pattern[str(pattern_tuple)] = []
    fiber_cnt_pattern[str(pattern_tuple)].append(i//32)
    # # plt.scatter(cg_dummy_kij[i:i+step].nonzero()[:,0], cg_dummy_kij[i:i+32].nonzero()[:,1])
    # # plt.show()
    # # plt.close()
    # if(l_a_u_cnt > 60):
    #     l_a_u_cnt = (cg_dummy_kij[i:i+step//2].nonzero()[:,1].unique().numel())
    #     l_b_u_cnt = (cg_dummy_kij[i:i+step//2].nonzero()[:,2].unique().numel())
    #     l_fiber_cnt = (cg_dummy_kij[i:i+step//2].nonzero().shape[0])
    #     if(l_a_u_cnt > 60):
    #         l_a_u_cnt = (cg_dummy_kij[i:i+step//4].nonzero()[:,1].unique().numel())
    #         l_b_u_cnt = (cg_dummy_kij[i:i+step//4].nonzero()[:,2].unique().numel())
    #         l_fiber_cnt = (cg_dummy_kij[i:i+step//4].nonzero().shape[0])

    #         a_u_cnt.append(l_a_u_cnt)
    #         b_u_cnt.append(l_b_u_cnt)
    #         fiber_cnt.append(l_fiber_cnt)
    #         i_cnt.append(step//4)
    #         i += step//4
    #     else:
    #         a_u_cnt.append(l_a_u_cnt)
    #         b_u_cnt.append(l_b_u_cnt)
    #         fiber_cnt.append(l_fiber_cnt)
    #         i_cnt.append(step//2)
    #         i += step//2
    # else:
    a_u_cnt.append(l_a_u_cnt)
    b_u_cnt.append(l_b_u_cnt)
    fiber_cnt.append(l_fiber_cnt)
    i_cnt.append(step)
    i += step

In [None]:
smem_size = []
for a, b, f, i in zip(a_u_cnt, b_u_cnt, fiber_cnt, i_cnt):
    smem_size.append((32 * (a+b+i+i) +f) * 4 )
    print(a,b,f,i)

32 1 32 32
96 3 96 32
160 5 160 32
32 1 32 32
160 5 160 32
96 3 96 32
11 3 32 32
12 3 32 32
11 3 32 32
33 3 64 32
34 3 64 32
33 3 64 32
55 3 117 32
58 3 118 32
55 3 117 32
32 1 32 32
32 1 32 32
32 1 32 32
33 5 117 32
36 5 118 32
33 5 117 32
55 5 170 32
60 5 172 32
55 5 170 32
11 3 32 32
12 3 32 32
11 3 32 32
32 1 32 32
32 1 32 32
32 1 32 32
33 5 117 32
36 5 118 32
33 5 117 32
55 5 170 32
60 5 172 32
55 5 170 32
33 3 64 32
34 3 64 32
33 3 64 32
55 3 117 32
58 3 118 32
55 3 117 32
7 5 32 32
7 5 32 32
8 5 32 32
7 5 32 32
7 5 32 32
21 3 70 32
21 3 71 32
22 3 70 32
21 3 71 32
21 3 70 32
35 3 103 32
35 3 102 32
36 3 102 32
35 3 102 32
35 3 103 32
21 5 103 32
21 5 102 32
24 5 102 32
21 5 102 32
21 5 103 32
32 1 32 32
32 1 32 32
32 1 32 32
32 1 32 32
32 1 32 32
35 5 160 32
35 5 161 32
38 5 158 32
35 5 161 32
35 5 160 32
7 5 32 32
7 5 32 32
8 5 32 32
7 5 32 32
7 5 32 32
21 5 103 32
21 5 102 32
24 5 102 32
21 5 102 32
21 5 103 32
32 1 32 32
32 1 32 32
32 1 32 32
32 1 32 32
32 1 32 32
35 5 160 32

In [None]:
max(smem_size)

29952