In [1]:
import torch
import e3nn
from torch.utils.cpp_extension import load
import json
import matplotlib.pyplot as plt
import itertools
import os
import cuequivariance as cue
import cuequivariance_torch as cuet
import random
from typing import List

In [2]:
from torchviz import make_dot

In [3]:
os.environ['TORCH_CUDA_ARCH_LIST'] = "8.0"
sptp_bwd = load(name='sptp_linear_bwd', sources=['/home2/lsy/mdsim/fused_e3nn/fused_e3nn_kernel/sptp_linear_bwd.cpp', 
                                  '/home2/lsy/mdsim/fused_e3nn/fused_e3nn_kernel/bwd_sptp_linear.cu',
                                  '/home2/lsy/mdsim/fused_e3nn/fused_e3nn_kernel/bwd_sptp_linear_shared.cu',
                                  '/home2/lsy/mdsim/fused_e3nn/fused_e3nn_kernel/fwd_sptp_linear_v2.cu',
                                  '/home2/lsy/mdsim/fused_e3nn/fused_e3nn_kernel/fwd_sptp_linear_v2_shared.cu',
                                  '/home2/lsy/mdsim/fused_e3nn/fused_e3nn_kernel/bwd_bwd_sptp_linear_v2_shared.cu',
                                  ], 
                                  extra_cuda_cflags=["-lineinfo"], verbose=True)


Using /home2/lsy/.cache/torch_extensions/py311_cu124 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home2/lsy/.cache/torch_extensions/py311_cu124/sptp_linear_bwd/build.ninja...
Building extension module sptp_linear_bwd...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


[1/2] c++ -MMD -MF sptp_linear_bwd.o.d -DTORCH_EXTENSION_NAME=sptp_linear_bwd -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home2/lsy/miniconda3/envs/cueq/lib/python3.11/site-packages/torch/include -isystem /home2/lsy/miniconda3/envs/cueq/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /home2/lsy/miniconda3/envs/cueq/lib/python3.11/site-packages/torch/include/TH -isystem /home2/lsy/miniconda3/envs/cueq/lib/python3.11/site-packages/torch/include/THC -isystem /usr/local/cuda-12.4/include -isystem /home2/lsy/miniconda3/envs/cueq/include/python3.11 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -c /home2/lsy/mdsim/fused_e3nn/fused_e3nn_kernel/sptp_linear_bwd.cpp -o sptp_linear_bwd.o 
[2/2] c++ sptp_linear_bwd.o bwd_sptp_linear.cuda.o bwd_sptp_linear_shared.cuda.o fwd_sptp_linear_v2.cuda.o fwd_sptp_linear_v2_shared.cuda.o bwd_bwd_sptp_linear_v2_shared.cuda.o -shared -L

Loading extension module sptp_linear_bwd...


In [4]:
@torch.library.custom_op(
    "sptp_linear::sptp_linear_fwd_v2_shared",
    mutates_args=(),
    device_types="cuda",
)
def _(
    in1: torch.Tensor, 
    in2: torch.Tensor,
    weight: torch.Tensor,

    t_in1_idxing: torch.Tensor,
    t_in1_ival: torch.Tensor,
    t_in1_related_path_idx: torch.Tensor,

    t_path_array1: torch.Tensor,
    t_path_array2: torch.Tensor,
    t_per_upath_fiber_start: torch.Tensor,
    t_path_weight: torch.Tensor,
    t_per_path_weight_pos: torch.Tensor,

    t_per_upath_fiber_array: torch.Tensor,
    t_unique_cg_val: torch.Tensor,

    upath_cnt:int,
    per_block_batch:int,
    max_ir_dim:int,
    out_size:int ) -> torch.Tensor:
    batch_size = in1.shape[0]
    out = torch.empty((batch_size, out_size), device=in1.device, dtype=in1.dtype)

    sptp_bwd.sptp_linear_fwd_v2_shared(in1,in2,weight, out,
                                t_in1_idxing, t_in1_ival, t_in1_related_path_idx, t_path_array1,t_path_array2,t_per_upath_fiber_start, t_path_weight, t_per_path_weight_pos, t_per_upath_fiber_array,t_unique_cg_val, upath_cnt, per_block_batch, max_ir_dim*2+1
                                )
    return out

In [5]:
def fused_e3nn_setup_fwd_context(ctx, inputs, output):
    (
     in1, 
     in2,
     weight,

     t_in1_idxing,
     t_in1_ival,
     t_in1_related_path_idx,

     t_path_array1,
     t_path_array2,
     t_per_upath_fiber_start,
     t_path_weight,
     t_per_path_weight_pos,

     t_per_upath_fiber_array,
     t_unique_cg_val,

     upath_cnt,
     per_block_batch,
     max_ir_dim,
     out_size
    ) = inputs
    ctx.save_for_backward(
        in1,in2,weight,
        t_in1_idxing, t_in1_ival, t_in1_related_path_idx, t_path_array1,t_path_array2,t_per_upath_fiber_start, t_path_weight, t_per_path_weight_pos, t_per_upath_fiber_array,t_unique_cg_val
    )
    ctx.upath_cnt = upath_cnt
    ctx.per_block_batch = per_block_batch
    ctx.max_ir_dim = max_ir_dim



In [6]:
def fused_e3nn_bwd(ctx, grad_output):
    (
     in1, 
     in2,
     weight,

     t_in1_idxing,
     t_in1_ival,
     t_in1_related_path_idx,

     t_path_array1,
     t_path_array2,
     t_per_upath_fiber_start,
     t_path_weight,
     t_per_path_weight_pos,

     t_per_upath_fiber_array,
     t_unique_cg_val
    ) = ctx.saved_tensors


    grad_list = torch.ops.sptp_linear.sptp_linear_bwd_v2_shared(
        in1,
        in2,
        weight, 
        grad_output,
        t_in1_idxing, 
        t_in1_ival, 
        t_in1_related_path_idx, 
        t_path_array1,
        t_path_array2,
        t_per_upath_fiber_start,
        t_path_weight,
        t_per_path_weight_pos,
        t_per_upath_fiber_array,
        t_unique_cg_val,
        ctx.upath_cnt,
        ctx.per_block_batch,
        ctx.max_ir_dim
    )

    return (
        grad_list[0], # in1_grad
        grad_list[1], # in2_grad
        grad_list[2], # weight_grad
        
        None,
        None,
        None,
        
        None,
        None,
        None,
        None,
        None,
        
        None,
        None,

        None,
        None,
        None,
        None,

    )


In [7]:
@torch.library.custom_op(
    "sptp_linear::sptp_linear_bwd_v2_shared",
    mutates_args=(),
    device_types="cuda",
)
def _(
    in1: torch.Tensor,
    in2: torch.Tensor,
    weight: torch.Tensor, 
    grad_output: torch.Tensor,
    t_in1_idxing: torch.Tensor, 
    t_in1_ival: torch.Tensor, 
    t_in1_related_path_idx: torch.Tensor, 
    t_path_array1: torch.Tensor,
    t_path_array2: torch.Tensor,
    t_per_upath_fiber_start: torch.Tensor,
    t_path_weight: torch.Tensor,
    t_per_path_weight_pos: torch.Tensor,
    t_per_upath_fiber_array: torch.Tensor,
    t_unique_cg_val: torch.Tensor,
    upath_cnt:int,
    per_block_batch:int,
    max_ir_dim:int
    ) -> List[torch.Tensor]:
    
    batch_size = in1.shape[0]
    in2_size = in2.shape[1]

    mem_debug = torch.empty((1,1),device=in1.device)
    mem_dl_din1 = torch.empty_like(in1)
    mem_dl_din2 = torch.empty((batch_size, in2_size * upath_cnt) , device=in1.device, dtype=in1.dtype)
    mem_dl_dw = torch.empty_like(weight)


    sptp_bwd.sptp_linear_bwd_v1_shared(in1,in2,weight, grad_output.contiguous(), mem_dl_din1, mem_dl_din2, mem_dl_dw, mem_debug,
        t_in1_idxing, t_in1_ival, t_in1_related_path_idx, t_path_array1,t_path_array2,t_per_upath_fiber_start, t_path_weight, t_per_path_weight_pos, t_per_upath_fiber_array,t_unique_cg_val, upath_cnt, per_block_batch, max_ir_dim*2+1
    )
    mem_dl_din2_summed = mem_dl_din2.reshape((batch_size, upath_cnt, in2_size)).sum(dim=1)
    
    return [mem_dl_din1,mem_dl_din2_summed, mem_dl_dw]


In [8]:
@torch.library.custom_op(
    "sptp_linear::sptp_linear_bwd_bwd_v2_shared",
    mutates_args=(),
    device_types="cuda",
)
def _(
    dF_in1: torch.Tensor,
    dF_in2: torch.Tensor,
    dF_dw: torch.Tensor,
    dE_dout: torch.Tensor,
    
    in1: torch.Tensor, 
    in2: torch.Tensor,
    weight: torch.Tensor,

    t_in1_idxing: torch.Tensor,
    t_in1_ival: torch.Tensor,
    t_in1_related_path_idx: torch.Tensor,

    t_path_array1: torch.Tensor,
    t_path_array2: torch.Tensor,
    t_per_upath_fiber_start: torch.Tensor,
    t_path_weight: torch.Tensor,
    t_per_path_weight_pos: torch.Tensor,

    t_per_upath_fiber_array: torch.Tensor,
    t_unique_cg_val: torch.Tensor,

    upath_cnt:int,
    per_block_batch:int,
    max_ir_dim:int) -> List[torch.Tensor]:

    batch_size = in2.shape[0]
    in2_size = in2.shape[1]

    dF_dout = torch.empty_like(dE_dout)
    dL_din1 = torch.empty_like(in1)
    dL_din2_duplicate = torch.empty((batch_size, in2_size * upath_cnt) , device=in2.device, dtype=in2.dtype)
    dL_dw = torch.empty_like(weight)
    mem_debug = torch.empty((1,1),device=in1.device)

    sptp_bwd.sptp_linear_bwd_bwd_v2_shared(
                                dF_in1, dF_in2, dF_dw, dE_dout,
                                in1, in2, weight,
                                dF_dout, dL_dw, dL_din1, dL_din2_duplicate, mem_debug,
                                t_in1_idxing, t_in1_ival, t_in1_related_path_idx, t_path_array1,t_path_array2,t_per_upath_fiber_start, t_path_weight, t_per_path_weight_pos, t_per_upath_fiber_array,t_unique_cg_val, upath_cnt, per_block_batch, max_ir_dim*2+1
                                )
    
    dL_din2 = dL_din2_duplicate.reshape((batch_size, upath_cnt, in2_size)).sum(dim=1)
    
    return [dL_din1, dL_din2, dL_dw, dF_dout]

In [9]:
def fused_e3nn_setup_bwd_context(ctx, inputs, output):
    (
     in1, 
     in2,
     weight,
     dE_dout,

     t_in1_idxing,
     t_in1_ival,
     t_in1_related_path_idx,

     t_path_array1,
     t_path_array2,
     t_per_upath_fiber_start,
     t_path_weight,
     t_per_path_weight_pos,

     t_per_upath_fiber_array,
     t_unique_cg_val,

     upath_cnt,
     per_block_batch,
     max_ir_dim
    ) = inputs   

    ctx.save_for_backward(
        dE_dout,
        in1, in2, weight,
        t_in1_idxing, t_in1_ival, t_in1_related_path_idx, t_path_array1,t_path_array2,t_per_upath_fiber_start, t_path_weight, t_per_path_weight_pos, t_per_upath_fiber_array,t_unique_cg_val
    )
    ctx.upath_cnt = upath_cnt
    ctx.per_block_batch = per_block_batch
    ctx.max_ir_dim = max_ir_dim



In [10]:
def fused_e3nn_bwd_bwd(ctx, grad_output):
    (dE_dout,
        in1, in2, weight,
        t_in1_idxing, t_in1_ival, t_in1_related_path_idx, t_path_array1,t_path_array2,t_per_upath_fiber_start, t_path_weight, t_per_path_weight_pos, t_per_upath_fiber_array,t_unique_cg_val
    ) = ctx.saved_tensors

    dF_in1 =  grad_output[0]
    dF_in2 =  grad_output[1]
    dF_w =  grad_output[2]

    grad_list = torch.ops.sptp_linear.sptp_linear_bwd_bwd_v2_shared(
        dF_in1,
        dF_in2,
        dF_w,
        dE_dout.detach(),

        in1,
        in2,
        weight, 

        t_in1_idxing, 
        t_in1_ival, 
        t_in1_related_path_idx, 

        t_path_array1,
        t_path_array2,
        t_per_upath_fiber_start,
        t_path_weight,
        t_per_path_weight_pos,

        t_per_upath_fiber_array,
        t_unique_cg_val,
        
        ctx.upath_cnt,
        ctx.per_block_batch,
        ctx.max_ir_dim
    )

    return (
        grad_list[0],
        grad_list[1],
        grad_list[2], # weight_grad
        grad_list[3], # mem_dL_dO_grad
        
        None,
        None,
        None,
        
        None,
        None,
        None,
        None,
        None,
        
        None,
        None,

        None,
        None,
        None,
    )


In [11]:
torch.library.register_autograd(
    "sptp_linear::sptp_linear_fwd_v2_shared",
    fused_e3nn_bwd,
    setup_context=fused_e3nn_setup_fwd_context
)

In [12]:
torch.library.register_autograd(
    "sptp_linear::sptp_linear_bwd_v2_shared",
    fused_e3nn_bwd_bwd,
    setup_context=fused_e3nn_setup_bwd_context
)

In [13]:
def mul_Irreps(mul, i_in):
    dd = []
    for ori_mul, ir in i_in:
        dd.append((ori_mul*mul, (ir.l, ir.p)))
    return e3nn.o3.Irreps(dd)
def compare(a, b):
    isclose = torch.isclose(a, b)
    diff_pos = torch.argwhere(isclose == False)
    anything_bad = False
    for pos in diff_pos:
        pos_t = [x for x in pos]
        if(abs(a[pos_t] - b[pos_t]) > 1e-6):
            anything_bad = True
            print(pos)
            print(a[pos_t] - b[pos_t] )
    if(not anything_bad):
        print("All Good")
            
IR_IN1_IDX = 0
IR_IN2_IDX = 1
IR_OUT_IDX = 2
INST_IDX = 3

def load_nequip_config(h, l_max, layer_idx):
    filename = f"/home2/lsy/mdsim/nequip/benchmark_config/4_{h}_{l_max}_p_sc.txt"
    with open(filename, "r") as f:
        f_in = f.read().split("\n")

    per_layer_dict = dict()
    for l_idx, d in enumerate(f_in):
        if(d == "") : continue
        dd = json.loads(d)
        per_layer_dict[l_idx] = dd
    tp_list = per_layer_dict[layer_idx]["tp"]
    i_in1 = e3nn.o3.Irreps(tp_list[IR_IN1_IDX])
    i_in2 = e3nn.o3.Irreps(tp_list[IR_IN2_IDX])
    i_out = e3nn.o3.Irreps(tp_list[IR_OUT_IDX])
    inst_tuple = [tuple(x) for x in tp_list[INST_IDX]]

    return i_in1, i_in2, i_out, inst_tuple


In [14]:
def load_nequip_config_e3nn_cueq(h, l_max, layer_idx):
    filename = f"/home2/lsy/mdsim/nequip/benchmark_config/4_{h}_{l_max}_p_sc.txt"
    with open(filename, "r") as f:
        f_in = f.read().split("\n")

    per_layer_dict = dict()
    for l_idx, d in enumerate(f_in):
        if(d == "") : continue
        dd = json.loads(d)
        per_layer_dict[l_idx] = dd
    tp_list = per_layer_dict[layer_idx]["tp"]

    ei_in1 = e3nn.o3.Irreps(tp_list[IR_IN1_IDX])
    ei_in2 = e3nn.o3.Irreps(tp_list[IR_IN2_IDX])
    ei_out = e3nn.o3.Irreps(tp_list[IR_OUT_IDX])
    inst_tuple = [tuple(x) for x in tp_list[INST_IDX]]


    # changing mul for each ir.l
    new_in1_list = []
    new_out_list = []
    changed_idx = [[],[]]
    # mul_list = {}
    mul_list = {0:128, 1:64}

    for idx, (mul,ir) in enumerate(ei_in1):
        if (ir.l in mul_list):
            new_in1_list.append((mul_list[ir.l], ir))
            for inst in inst_tuple:
                if(idx == inst[0]):
                    changed_idx[0].append(inst[2])
                    changed_idx[1].append(mul_list[ir.l])
        else:
            new_in1_list.append((mul, ir))

    for idx, (mul,ir) in enumerate(ei_out):
        if (idx in changed_idx[0]):
            new_out_list.append((changed_idx[1][changed_idx[0].index(idx)], ir))
        else:
            new_out_list.append((mul, ir))

    ei_in1 = e3nn.o3.Irreps(new_in1_list)
    ei_out = e3nn.o3.Irreps(new_out_list)

    ci_in1 = cue.Irreps("O3", str(ei_in1))
    ci_in2 = cue.Irreps("O3", tp_list[IR_IN2_IDX])
    ci_out = cue.Irreps("O3", str(ei_out))


    return [ei_in1,ei_in2,ei_out,inst_tuple] , [ci_in1,ci_in2,ci_out,inst_tuple]


In [15]:
def to_cuda_list(*arg, input_dtype = torch.float32):
    return_list = []
    for item in arg:
        if(type(item) == torch.Tensor):
            return_list.append(item.to(device="cuda"))
        else:
            return_list.append(torch.tensor(item,device="cuda", dtype=input_dtype))
    return return_list

def to_cuda_dict(strname_list, *arg):
    return_dict = {}
    for item,name in zip(arg,strname_list):
        if(type(item) == torch.Tensor):
            return_dict[name] = item.to("cuda")
        else:
            return_dict[name] = torch.tensor(item,device="cuda")
    return return_dict

def cumsum_list(s, np1 = True):
    new_s = []
    current = 0
    for e in s:
        new_s.append(current)
        current += e
    if(np1):
        new_s.append(current)
    return new_s

In [16]:
h = 32
l_max = 1
layer_idx = 2
batch_size = 4096
# i_in1, i_in2, i_out, inst_tuple = load_nequip_config(h,l_max,layer_idx)
# i_in2 = mul_Irreps(3, i_in2)
e3nn_config, cueq_config = load_nequip_config_e3nn_cueq(h,l_max,layer_idx)
i_in1, i_in2, i_out, inst_tuple = e3nn_config

# not really needed for v=1 
uvuv_tp = e3nn.o3.FullTensorProduct(i_in1,i_in2, filter_ir_out=i_out)
uvuv_i_out = uvuv_tp.irreps_out

# split_size = []
# reshape_size = []
# for inst in uvuv_tp.instructions:
#     split_size.append(uvuv_i_out[inst.i_out].dim)
#     reshape_size.append([inst.path_shape[0],inst.path_shape[1],uvuv_i_out[inst.i_out][1].dim])
# weight_mul = e3nn.o3.experimental.FullTensorProduct_uvu_weight_only(i_in1, i_in2, split_size, reshape_size, filter_ir_out=i_out, irrep_normalization=None, regroup_output=False)
# uvw
# i_out = e3nn.o3.Irreps(tp_list[IR_OUT_IDX])
# tp = e3nn.o3.FullyConnectedTensorProduct(i_in1,i_in2,i_out,shared_weights=False, internal_weights=False)

# # uvu
tp = e3nn.o3.TensorProduct(i_in1,i_in2,i_out,inst_tuple,shared_weights=False, internal_weights=False)



In [57]:
max([x[1].l for x in i_out])

1

In [17]:
# bwd_uvu = e3nn.o3.experimental.FullTensorProduct_bwd_uvu(i_in1,i_in2, filter_ir_out=i_out, irrep_normalization="component", regroup_output=False)

In [18]:
# grad_uvu = e3nn.o3.experimental.FullTensorProduct_grad_uvu(i_in1,i_in2, filter_ir_out=i_out, irrep_normalization=None, regroup_output=False)

In [19]:
# full tp -> linear
# i_out = full_tp.irreps_out
unique_cg = []
unique_cg_mat = {}
nnz_cg_cnt = 0
all_cg_cnt = 0
cg_dummy = torch.zeros(i_in1.dim, i_in2.dim, uvuv_i_out.dim)
cg_dummy_coverage = torch.zeros(i_in1.dim, i_in2.dim, uvuv_i_out.dim)
for inst in uvuv_tp.instructions:
    i = inst.i_in1
    j = inst.i_in2
    k = inst.i_out

    mul_in1, ir_in1 = i_in1[i]
    mul_in2, ir_in2 = i_in2[j]
    mul_out, ir_out = uvuv_i_out[k]

    cg = e3nn.o3.wigner_3j(ir_in1.l, ir_in2.l, ir_out.l)
    unique_cg += list(cg.unique())
    all_cg_cnt+= cg.numel()
    nnz_cg_cnt += cg.count_nonzero()

    partial_mat_cg = torch.zeros(i_in1[i].dim, i_in2[j].dim, uvuv_i_out[k].dim)
    # print(cg)
    unique_cg_mat[f"{ir_in1.l}_{ir_in2.l}_{ir_out.l}"] = cg
    
    ## uvuv
    for u,v in itertools.product(range(mul_in1), range(mul_in2)):
        partial_mat_cg [u*ir_in1.dim:(u+1)*ir_in1.dim,
        v*ir_in2.dim:(v+1)*ir_in2.dim,
        (u*mul_in2+v)*ir_out.dim:(u*mul_in2+v+1)*ir_out.dim] = cg 

    cg_dummy[i_in1.slices()[i], i_in2.slices()[j], uvuv_i_out.slices()[k]] = partial_mat_cg
    cg_dummy_coverage[i_in1.slices()[i], i_in2.slices()[j], uvuv_i_out.slices()[k]] = 1
print("compute density", nnz_cg_cnt / all_cg_cnt)

compute density tensor(0.2800)


In [20]:
len(unique_cg_mat)

5

In [21]:
unique_cg_val = list(set([x.item() for x in unique_cg]))

In [22]:
tp_inst_outorder = sorted(tp.instructions, key=lambda x : x.i_out)

In [23]:
# already duplicate as num_path > num unique cg matrix
per_path_fiber_start = [0]
per_path_fiber_array = []

for inst in tp_inst_outorder:
    path_cg = unique_cg_mat[f"{i_in1[inst.i_in1][1].l}_{i_in2[inst.i_in2][1].l}_{i_out[inst.i_out][1].l}"]
    for i,j,k in path_cg.nonzero():
        cg_idx = unique_cg_val.index(path_cg[i,j,k])
        per_path_fiber_array.append([i.item(),j.item(),k.item(),cg_idx])
    per_path_fiber_start.append(len(per_path_fiber_array))

In [24]:
per_in1_ir_pathinfo = {}
for inst in tp_inst_outorder:
    if(inst.i_in1 not in per_in1_ir_pathinfo):
        per_in1_ir_pathinfo[inst.i_in1] = []
    per_in1_ir_pathinfo[inst.i_in1].append([inst.i_out, inst.i_in2, inst.path_weight])

In [25]:
weight_uv_pair = []
weight_uv_pair_sorted_chunk = []
out_order = []
current = 0

for inst in tp.instructions:
    weight_uv_pair.append((i_in1[inst.i_in1][0], i_in2[inst.i_in2][0] ))
    out_order.append(inst.i_out)
for u,v in weight_uv_pair:
    weight_uv_pair_sorted_chunk.append(slice(current,current+u*v))
    current+=u*v
out2weight_order = torch.tensor(out_order).argsort()

In [26]:
WARPSIZE = 32
in1_idxing = [0]
in1_ival = []
in1_related_path_idx = [0]

path_array1 = []
path_array2 = []
path_weight = []
per_path_weight_pos = []

per_upath_fiber_start = []
per_upath_fiber_array = []

in1_slices = i_in1.slices()
in2_slices = i_in2.slices()
out_slices = i_out.slices()

for in1_ir_idx, (mul,ir) in enumerate(i_in1):
    assert (mul%WARPSIZE ==0)
    in1_idx_start = in1_slices[in1_ir_idx].start
    i_val = ir.dim
    
    if mul >= WARPSIZE:
        for i in range(mul//WARPSIZE):
            in1_idxing.append(in1_idx_start + WARPSIZE*i_val*(i+1))
            in1_ival.append(i_val)
            in1_related_path_idx.append(in1_related_path_idx[-1] + len(per_in1_ir_pathinfo[in1_ir_idx]))
            
            dummy_list = []
            dummy_list2 = []
            # Bug? TODO:
            for out_ir_idx, in2_ir_idx, pw in per_in1_ir_pathinfo[in1_ir_idx]:
                # should be in order
                fiber_start = per_path_fiber_start[out_ir_idx]
                fiber_end = per_path_fiber_start[out_ir_idx+1]
                
                upath_fiber_start = len(per_upath_fiber_array)
                upath_fiber_end = upath_fiber_start + fiber_end - fiber_start

                per_upath_fiber_start.append([upath_fiber_start, upath_fiber_end])
                # print(fiber_array_orignal[1:4])
                # print(fiber_start,fiber_end)
                # print(fiber_array_orignal[fiber_start:fiber_end])
                
                per_upath_fiber_array += per_path_fiber_array[fiber_start:fiber_end]

                dummy_list.append([out_slices[out_ir_idx].start + WARPSIZE*i_out[out_ir_idx].ir.dim * i,
                                   out_slices[out_ir_idx].start + WARPSIZE*i_out[out_ir_idx].ir.dim * (i+1)
                                   ])
                dummy_list2.append([i_out[out_ir_idx].ir.dim,
                                    in2_slices[in2_ir_idx].start,
                                    i_in2[in2_ir_idx].ir.dim,
                                    in2_slices[in2_ir_idx].stop])
                path_weight.append(pw)
                
                # TODO:??
                per_path_weight_pos.append(weight_uv_pair_sorted_chunk[out2weight_order[out_ir_idx]].start + WARPSIZE*i)

            path_array1.append(dummy_list)
            path_array2.append(dummy_list2)

In [27]:
t_in1_idxing = torch.tensor(in1_idxing, dtype=torch.int32, device="cuda")
t_in1_ival = torch.tensor(in1_ival, dtype=torch.int32, device="cuda")
t_in1_related_path_idx = torch.tensor(in1_related_path_idx, dtype=torch.int32, device="cuda")

t_path_array1 = torch.tensor(list(itertools.chain.from_iterable(path_array1)), dtype=torch.uint16, device="cuda")
t_path_array2 = torch.tensor(list(itertools.chain.from_iterable(path_array2)), dtype=torch.uint8, device="cuda")
t_path_weight = torch.tensor(path_weight, dtype=torch.float32, device="cuda")
t_per_path_weight_pos = torch.tensor(per_path_weight_pos, dtype=torch.int32, device="cuda")

t_per_upath_fiber_start = torch.tensor(per_upath_fiber_start, dtype=torch.uint16, device="cuda")
t_per_upath_fiber_array = torch.tensor(per_upath_fiber_array, dtype=torch.uint8, device="cuda")

t_unique_cg_val = torch.tensor(unique_cg_val, dtype=torch.float32, device="cuda")
upath_cnt = len(in1_idxing)-1

In [28]:
tp = tp.cuda()
# bwd_uvu = bwd_uvu.cuda()

In [29]:
def zero_grad(grad):
    return grad * 0  # Scale the gradient by 2

In [30]:
torch.manual_seed(0)

in1 = torch.rand(batch_size, i_in1.dim, device="cuda", requires_grad=True)
in2 = torch.rand(batch_size, i_in2.dim, device="cuda", requires_grad=True)
weight = torch.ones(batch_size,tp.weight_numel, device="cuda", requires_grad=True)
# weight = torch.rand(batch_size,tp.weight_numel, device="cuda", requires_grad=True)

in1_c = in1.clone().detach()
in2_c = in2.clone().detach()
weight_c = weight.clone().detach()
in1_c.requires_grad =True
in2_c.requires_grad =True
weight_c.requires_grad =True

in1_b = in1.clone().detach()
in2_b = in2.clone().detach()
weight_b = weight.clone().detach()
in1_b.requires_grad =True
in2_b.requires_grad =True
weight_b.requires_grad =True

In [31]:
out_exp = tp(in1,in2,weight)
out_exp.retain_grad()
y = torch.nn.functional.gelu(out_exp).sum()
y.retain_grad()
# y.backward()
f_in1, f_in2, f_weight = torch.autograd.grad(y, [in1,in2,weight], create_graph=True)
f_in1.retain_grad()
f_in2.retain_grad()
f_weight.retain_grad()
print(out_exp.grad)
dE_dout =  out_exp.grad.detach().clone()

tensor([[0.7704, 0.6867, 0.8563,  ..., 0.5528, 0.2079, 0.5731],
        [0.9822, 0.9533, 0.6543,  ..., 0.5984, 0.8178, 0.3355],
        [0.8418, 0.5885, 0.6154,  ..., 0.3496, 0.6252, 0.6307],
        ...,
        [0.9611, 0.6230, 0.9658,  ..., 0.5955, 0.2656, 0.7359],
        [0.9286, 0.9267, 0.7335,  ..., 0.7273, 0.5112, 0.1011],
        [0.7893, 0.7277, 0.6409,  ..., 0.4744, 0.4633, 0.5597]],
       device='cuda:0', grad_fn=<CloneBackward0>)


In [32]:
out_exp.grad = None

In [33]:
f_in1_gelu = torch.nn.functional.gelu(f_in1)
f_in2_gelu = torch.nn.functional.gelu(f_in2)
f_weight_gelu = torch.nn.functional.gelu(f_weight)

fake_loss = f_in1_gelu.sum() + f_in2_gelu.sum() + f_weight_gelu.sum()
print(fake_loss)
fake_loss.backward()

tensor(4094213., device='cuda:0', grad_fn=<AddBackward0>)


In [34]:
dL_dout =  out_exp.grad.detach().clone()

In [35]:
# out_exp.register_hook(zero_grad)
# f_weight.register_hook(zero_grad)
# f_in1.register_hook(zero_grad)
# f_in2.register_hook(zero_grad)

In [36]:
mem_dL_dW_2 =  torch.zeros_like(weight_c)
mem_dL_din1_2 =  torch.zeros_like(in1_c)
mem_dL_din2_2 = torch.zeros((batch_size, i_in2.dim * upath_cnt), device="cuda")
mem_debug =  torch.zeros_like(out_exp)

In [37]:
sptp_bwd.sptp_linear_bwd_v1_shared(in1, in2, weight, dL_dout, 
                                           mem_dL_din1_2, mem_dL_din2_2,mem_dL_dW_2, mem_debug,
                                           t_in1_idxing, t_in1_ival, t_in1_related_path_idx, 
                                           
                                           t_path_array1,t_path_array2,t_per_upath_fiber_start, t_path_weight,t_per_path_weight_pos, 
                                           t_per_upath_fiber_array,t_unique_cg_val, 
                                           upath_cnt, 1, l_max*2+1
                                           )
mem_dl_din2_2_summed = mem_dL_din2_2.reshape((batch_size, upath_cnt, i_in2.dim)).sum(dim=1)

In [38]:
mem_dF_dO = torch.zeros_like(out_exp)
mem_dL_dW =  torch.zeros_like(weight_c)
mem_dL_din1 =  torch.zeros_like(in1_c)
mem_dL_din2 = torch.zeros((batch_size, i_in2.dim * upath_cnt), device="cuda")
mem_debug =  torch.zeros_like(out_exp)

In [39]:
sptp_bwd.sptp_linear_bwd_bwd_v2_shared(f_in1.grad, f_in2.grad, f_weight.grad, dE_dout, in1, in2, weight, 
                                           mem_dF_dO, mem_dL_dW, mem_dL_din1, mem_dL_din2, mem_debug,
                                           t_in1_idxing, t_in1_ival, t_in1_related_path_idx, 
                                           
                                           t_path_array1,t_path_array2,t_per_upath_fiber_start, t_path_weight,t_per_path_weight_pos, 
                                           t_per_upath_fiber_array,t_unique_cg_val, 
                                           upath_cnt, 1, l_max*2+1
                                           )
mem_dl_din2_summed = mem_dL_din2.reshape((batch_size, upath_cnt, i_in2.dim)).sum(dim=1)

In [40]:
in1.grad

tensor([[5.5773, 6.0905, 3.7802,  ..., 3.9337, 3.0472, 3.8047],
        [7.8987, 6.1085, 6.3980,  ..., 1.6678, 3.6508, 4.9597],
        [5.3784, 5.7791, 3.7292,  ..., 3.0424, 1.6832, 2.6684],
        ...,
        [5.8595, 7.0677, 7.4968,  ..., 2.5093, 2.4028, 1.2549],
        [8.5208, 8.2731, 7.2394,  ..., 3.7490, 4.9139, 3.5372],
        [5.6741, 5.3575, 5.9812,  ..., 2.3310, 2.0964, 2.1986]],
       device='cuda:0')

In [41]:
mem_dL_din1 + mem_dL_din1_2

tensor([[5.5773, 6.0905, 3.7802,  ..., 3.9337, 3.0472, 3.8047],
        [7.8987, 6.1085, 6.3980,  ..., 1.6678, 3.6508, 4.9597],
        [5.3784, 5.7791, 3.7292,  ..., 3.0424, 1.6832, 2.6684],
        ...,
        [5.8595, 7.0677, 7.4968,  ..., 2.5093, 2.4028, 1.2549],
        [8.5208, 8.2731, 7.2394,  ..., 3.7490, 4.9139, 3.5372],
        [5.6741, 5.3575, 5.9812,  ..., 2.3310, 2.0964, 2.1986]],
       device='cuda:0')

In [42]:
out_ours = torch.ops.sptp_linear.sptp_linear_fwd_v2_shared(
    in1_c, in2_c, weight_c,
    t_in1_idxing, t_in1_ival, t_in1_related_path_idx,                                           
    t_path_array1,t_path_array2,t_per_upath_fiber_start, t_path_weight,t_per_path_weight_pos, 
    t_per_upath_fiber_array,t_unique_cg_val, 
    upath_cnt, 1, l_max, i_out.dim
)
out_ours.retain_grad()
y_ours = torch.nn.functional.gelu(out_ours).sum()
# y_ours.backward()
f_in1_c, f_in2_c, f_weight_c = torch.autograd.grad(y_ours, [in1_c,in2_c,weight_c], create_graph=True)
# print(out_ours.grad)
dE_dout_c = out_ours.grad.detach().clone()
# out_ours.grad = None
f_in1_c.retain_grad()
f_in2_c.retain_grad() 
f_weight_c.retain_grad()

In [43]:
f_in1_c_gelu = torch.nn.functional.gelu(f_in1_c)
f_in2_c_gelu = torch.nn.functional.gelu(f_in2_c)
f_weight_c_gelu = torch.nn.functional.gelu(f_weight_c)
fake_loss_c = f_in1_c_gelu.sum() + f_in2_c_gelu.sum() + f_weight_c_gelu.sum()

In [None]:
fake_loss_c.backward()

In [45]:
mem_dL_din1_c =  torch.zeros_like(in1_c)

In [46]:
sptp_bwd.sptp_linear_bwd_bwd_v2_shared(f_in1_c.grad, f_in2_c.grad, f_weight_c.grad, dE_dout_c, in1, in2, weight, 
                                           mem_dF_dO, mem_dL_dW, mem_dL_din1_c, mem_dL_din2, mem_debug,
                                           t_in1_idxing, t_in1_ival, t_in1_related_path_idx, 
                                           
                                           t_path_array1,t_path_array2,t_per_upath_fiber_start, t_path_weight,t_per_path_weight_pos, 
                                           t_per_upath_fiber_array,t_unique_cg_val, 
                                           upath_cnt, 1, l_max*2+1
                                           )

In [49]:
in1.grad

tensor([[5.5773, 6.0905, 3.7802,  ..., 3.9337, 3.0472, 3.8047],
        [7.8987, 6.1085, 6.3980,  ..., 1.6678, 3.6508, 4.9597],
        [5.3784, 5.7791, 3.7292,  ..., 3.0424, 1.6832, 2.6684],
        ...,
        [5.8595, 7.0677, 7.4968,  ..., 2.5093, 2.4028, 1.2549],
        [8.5208, 8.2731, 7.2394,  ..., 3.7490, 4.9139, 3.5372],
        [5.6741, 5.3575, 5.9812,  ..., 2.3310, 2.0964, 2.1986]],
       device='cuda:0')

In [50]:
in1_c.grad

tensor([[5.5773, 6.0905, 3.7802,  ..., 3.9337, 3.0472, 3.8047],
        [7.8987, 6.1085, 6.3980,  ..., 1.6678, 3.6508, 4.9597],
        [5.3784, 5.7791, 3.7292,  ..., 3.0424, 1.6832, 2.6684],
        ...,
        [5.8595, 7.0677, 7.4968,  ..., 2.5093, 2.4028, 1.2549],
        [8.5208, 8.2731, 7.2394,  ..., 3.7490, 4.9139, 3.5372],
        [5.6741, 5.3575, 5.9812,  ..., 2.3310, 2.0964, 2.1986]],
       device='cuda:0')

In [47]:
in2.shape

torch.Size([4096, 4])

In [48]:
mem_dF_dO.shape

torch.Size([4096, 1408])

In [49]:
print("f_in1",f_in1.grad)
print("f_in2",f_in2.grad)
print("f_weight",f_weight.grad)
print("dE_dout",dE_dout)
print("in1", in1)
print("in2", in2)
print("weight", weight)
print("dL_din1", mem_dL_din1)


print("t_in1_idxing")
print(t_in1_idxing)
print("t_in1_ival")
print(t_in1_ival)
print("t_in1_related_path_idx")
print(t_in1_related_path_idx)
print("t_path_array1")
print(t_path_array1)
print("t_path_array2")
print(t_path_array2)
print("t_per_upath_fiber_start")
print(t_per_upath_fiber_start)
print("t_path_weight")
print(t_path_weight)
print("t_per_path_weight_pos")
print(t_per_path_weight_pos)
print("t_per_upath_fiber_array")
print(t_per_upath_fiber_array)
print("t_unique_cg_val")
print(t_unique_cg_val)
print("upath_cnt")
print(upath_cnt)
print("per_block_batch")
print(1)
print("lmax*2+1")
print(l_max*2+1)


f_in1 tensor([[1.0835, 1.0992, 0.9990,  ..., 0.7717, 0.8860, 0.9795],
        [1.1251, 1.1014, 1.1067,  ..., 0.8032, 0.8488, 1.0750],
        [0.9974, 1.0096, 0.9414,  ..., 0.8907, 0.7220, 0.7496],
        ...,
        [0.9995, 1.0388, 1.0745,  ..., 0.6675, 0.9395, 0.6983],
        [1.1185, 1.1241, 1.1287,  ..., 0.9539, 0.9254, 0.9626],
        [0.9864, 0.9759, 0.9965,  ..., 0.7334, 0.7592, 0.7311]],
       device='cuda:0')
f_in2 tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        ...,
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]], device='cuda:0')
f_weight tensor([[0.6828, 0.7529, 0.5081,  ..., 0.5357, 0.4965, 0.4815],
        [0.7294, 0.5938, 0.6124,  ..., 0.5583, 0.5552, 0.6262],
        [0.6096, 0.6350, 0.5247,  ..., 0.5151, 0.5374, 0.5493],
        ...,
        [0.5171, 0.5307, 0.5492,  ..., 0.4736, 0.5697, 0.5699],
        [1.0690, 1.0005, 0.8217,  ..., 0.5710, 0.5248, 0.5242],
        [0.7259, 0.6936, 0.7597,  ...

In [56]:
mem_dL_din1_2

tensor([[1.2085, 1.3191, 0.7220,  ..., 0.5664, 0.4250, 0.9886],
        [1.5657, 1.2067, 1.2670,  ..., 0.5222, 0.5092, 1.1629],
        [0.7433, 0.8249, 0.4065,  ..., 0.6783, 0.3948, 0.3597],
        ...,
        [0.7110, 0.9278, 1.1433,  ..., 0.3238, 0.7436, 0.5749],
        [1.6389, 1.7620, 1.7631,  ..., 1.0391, 0.8856, 0.7171],
        [0.8572, 0.7910, 0.9210,  ..., 0.3631, 0.3887, 0.3929]],
       device='cuda:0')

In [55]:
mem_dL_din1

tensor([[2.3939, 2.6076, 1.7698,  ..., 1.3619, 1.2394, 1.4429],
        [2.9043, 2.3094, 2.3960,  ..., 1.1597, 1.3672, 1.8853],
        [2.2201, 2.3311, 1.8035,  ..., 1.3438, 1.0559, 1.1976],
        ...,
        [2.0722, 2.4082, 2.8295,  ..., 1.0296, 1.1911, 0.9216],
        [3.9204, 3.6235, 2.9922,  ..., 1.2741, 1.6628, 1.2556],
        [2.4327, 2.3346, 2.5331,  ..., 1.0610, 1.0325, 1.0366]],
       device='cuda:0')

In [47]:
out_exp.grad[0]

tensor([ 1.3374,  0.9280,  1.2032,  ...,  0.0454, -0.2542,  0.1906],
       device='cuda:0')

In [48]:
out_ours.grad[0]

tensor([ 1.3374,  0.9280,  1.2032,  ...,  0.0454, -0.2542,  0.1906],
       device='cuda:0')

In [51]:
in1.grad

tensor([[3.6023, 3.9266, 2.4918,  ..., 1.9283, 1.6644, 2.4315],
        [4.4700, 3.5160, 3.6630,  ..., 1.6819, 1.8764, 3.0482],
        [2.9633, 3.1560, 2.2101,  ..., 2.0220, 1.4507, 1.5573],
        ...,
        [2.7832, 3.3361, 3.9729,  ..., 1.3533, 1.9347, 1.4965],
        [5.5593, 5.3854, 4.7553,  ..., 2.3132, 2.5484, 1.9727],
        [3.2900, 3.1256, 3.4540,  ..., 1.4241, 1.4212, 1.4295]],
       device='cuda:0')

In [54]:
in1_c.grad

tensor([[2.1383, 2.4349, 1.1404,  ..., 2.3862, 1.4359, 2.6144],
        [2.4970, 2.2493, 2.2964,  ..., 2.1748, 3.0055, 3.1140],
        [0.9395, 0.8798, 0.6203,  ..., 1.5471, 0.9014, 1.4099],
        ...,
        [1.8932, 1.6205, 2.1985,  ..., 0.3500, 0.8847, 1.1050],
        [2.5475, 2.3638, 2.3195,  ..., 1.0526, 0.9460, 0.8123],
        [1.1587, 1.1379, 1.8229,  ..., 1.5200, 1.1463, 0.8046]],
       device='cuda:0')

In [52]:
mem_dL_din1_2

tensor([[1.2085, 1.3191, 0.7220,  ..., 0.5664, 0.4250, 0.9886],
        [1.5657, 1.2067, 1.2670,  ..., 0.5222, 0.5092, 1.1629],
        [0.7433, 0.8249, 0.4065,  ..., 0.6783, 0.3948, 0.3597],
        ...,
        [0.7110, 0.9278, 1.1433,  ..., 0.3238, 0.7436, 0.5749],
        [1.6389, 1.7620, 1.7631,  ..., 1.0391, 0.8856, 0.7171],
        [0.8572, 0.7910, 0.9210,  ..., 0.3631, 0.3887, 0.3929]],
       device='cuda:0')

In [53]:
mem_dL_din1

tensor([[2.3939, 2.6076, 1.7698,  ..., 1.3619, 1.2394, 1.4429],
        [2.9043, 2.3094, 2.3960,  ..., 1.1597, 1.3672, 1.8853],
        [2.2201, 2.3311, 1.8035,  ..., 1.3438, 1.0559, 1.1976],
        ...,
        [2.0722, 2.4082, 2.8295,  ..., 1.0296, 1.1911, 0.9216],
        [3.9204, 3.6235, 2.9922,  ..., 1.2741, 1.6628, 1.2556],
        [2.4327, 2.3346, 2.5331,  ..., 1.0610, 1.0325, 1.0366]],
       device='cuda:0')

In [46]:
in1_c.grad

tensor([[2.1383, 2.4349, 1.1404,  ..., 2.3862, 1.4359, 2.6144],
        [2.4970, 2.2493, 2.2964,  ..., 2.1748, 3.0055, 3.1140],
        [0.9395, 0.8798, 0.6203,  ..., 1.5471, 0.9014, 1.4099],
        ...,
        [1.8932, 1.6205, 2.1985,  ..., 0.3500, 0.8847, 1.1050],
        [2.5475, 2.3638, 2.3195,  ..., 1.0526, 0.9460, 0.8123],
        [1.1587, 1.1379, 1.8229,  ..., 1.5200, 1.1463, 0.8046]],
       device='cuda:0')

In [268]:
# make_dot(f_in1, params={"in1":in1, "in2":in2, "weight":weight, "f_in1":f_in1}).render("attached", format="png")
# make_dot(f_in1, params={"in1":in1, "in2":in2, "weight":weight, "f_in1":f_in1}, show_attrs=True, show_saved=True)


In [271]:
mem_dF_dO

tensor([[1.7816, 2.0296, 1.0832,  ..., 1.4670, 1.0924, 0.9306],
        [0.9562, 1.3873, 0.8707,  ..., 0.7697, 1.1433, 1.0237],
        [0.1781, 0.6431, 0.9443,  ..., 0.8144, 0.6786, 0.9715],
        ...,
        [1.3204, 0.6390, 1.1996,  ..., 1.6154, 0.5895, 0.7243],
        [0.4803, 1.2466, 1.2268,  ..., 0.1781, 0.3921, 0.4132],
        [1.2405, 0.5305, 0.8399,  ..., 1.2313, 0.6930, 0.6711]],
       device='cuda:0')

In [258]:
weight.grad

tensor([[1.1880, 1.4026, 0.5561,  ..., 1.7543, 1.0620, 1.9509],
        [0.5432, 0.8636, 0.4849,  ..., 1.0886, 0.5023, 1.6023],
        [0.0880, 0.3345, 0.5114,  ..., 0.6527, 0.8226, 1.3222],
        ...,
        [0.8195, 0.3325, 0.7256,  ..., 0.4892, 0.7790, 1.6820],
        [0.2458, 0.7442, 0.7297,  ..., 0.6555, 1.4635, 0.4917],
        [0.7267, 0.2733, 0.4580,  ..., 0.4321, 0.6894, 1.4175]],
       device='cuda:0')

In [260]:
mem_dL_dW

tensor([[1.1880, 1.4026, 0.5561,  ..., 1.7543, 1.0620, 1.9509],
        [0.5432, 0.8636, 0.4849,  ..., 1.0886, 0.5023, 1.6023],
        [0.0880, 0.3345, 0.5114,  ..., 0.6527, 0.8226, 1.3222],
        ...,
        [0.8195, 0.3325, 0.7256,  ..., 0.4892, 0.7790, 1.6820],
        [0.2458, 0.7442, 0.7297,  ..., 0.6555, 1.4635, 0.4917],
        [0.7267, 0.2733, 0.4580,  ..., 0.4321, 0.6894, 1.4175]],
       device='cuda:0')

In [53]:
dL_dout =  out_exp.grad.detach().clone()

In [41]:
in1_b_grad, in2_b_grad, w_b_grad = bwd_uvu(in1_b, in2_b, dL_dout , weight_b)

In [None]:
bwd_bwd_dF_dOut = torch.cat([bwd_uvu.uvuv_result_list[i].reshape(batch_size,-1) for i in bwd_uvu.sort2path],dim=-1)

In [45]:
bwd_bwd_dF_dOut.reshape(batch_size,-1)

tensor([[0.3880, 0.3156, 0.1872,  ..., 0.7488, 0.4440, 0.3124],
        [0.1694, 0.1916, 0.3869,  ..., 0.2643, 0.5337, 0.4475],
        [0.0095, 0.0309, 0.0177,  ..., 0.3095, 0.1770, 0.4630],
        ...,
        [0.3021, 0.5303, 0.0098,  ..., 0.9175, 0.0170, 0.1353],
        [0.0589, 0.0205, 0.1172,  ..., 0.0187, 0.1065, 0.1152],
        [0.1767, 0.8879, 0.3235,  ..., 0.6278, 0.2288, 0.2125]],
       device='cuda:0', grad_fn=<ViewBackward0>)

In [133]:
mem_debug

tensor([[0.3880, 0.5023, 0.0242,  ..., 0.4323, 0.2564, 0.1804],
        [0.1694, 0.2809, 0.1469,  ..., 0.1526, 0.3082, 0.2583],
        [0.0095, 0.0575, 0.0885,  ..., 0.1787, 0.1022, 0.2673],
        ...,
        [0.3021, 0.0585, 0.2606,  ..., 0.5297, 0.0098, 0.0781],
        [0.0589, 0.2111, 0.2072,  ..., 0.0108, 0.0615, 0.0665],
        [0.1767, 0.0599, 0.1110,  ..., 0.3625, 0.1321, 0.1227]],
       device='cuda:0')

In [None]:
dL_dout

tensor([[0.5060, 0.7949, 0.5167,  ..., 0.8254, 0.7008, 0.6429],
        [0.6071, 0.6771, 0.5564,  ..., 0.5421, 0.5847, 0.5711],
        [0.5028, 0.5178, 0.5271,  ..., 0.6647, 0.5951, 0.7419],
        ...,
        [0.6042, 0.5466, 0.5999,  ..., 0.7856, 0.5055, 0.5441],
        [0.5322, 0.6620, 0.5791,  ..., 0.5134, 0.5761, 0.5823],
        [0.6344, 0.5033, 0.5818,  ..., 0.7029, 0.5754, 0.5701]],
       device='cuda:0')

In [None]:
mem_dF_dO

tensor([[0.2805, 1.5944, 0.9703,  ..., 1.8918, 1.3451, 1.1090],
        [0.7541, 1.0335, 0.4534,  ..., 0.4600, 0.7867, 0.6821],
        [0.0735, 0.2649, 0.3865,  ..., 1.0998, 0.8784, 1.3561],
        ...,
        [0.6070, 0.6866, 0.5809,  ..., 1.7051, 0.4276, 0.5955],
        [0.3539, 1.1995, 0.6624,  ..., 0.2852, 0.6560, 0.6927],
        [1.1623, 0.0623, 0.7820,  ..., 1.2202, 0.6200, 0.5955]],
       device='cuda:0')

In [38]:
in1_b_grad, in2_b_grad, w_b_grad = bwd_uvu(in1_b, in2_b, dL_dout , weight_b)

In [40]:
in1_c_grad, in2_c_grad, w_c_grad = bwd_uvu(in1_c, in2_c, dE_dout , weight_c)
f_in1_gelu = torch.nn.functional.gelu(in1_c_grad)
f_in2_gelu = torch.nn.functional.gelu(in2_c_grad)
f_weight_gelu = torch.nn.functional.gelu(w_c_grad)

fake_loss_f = f_in1_gelu.sum() + f_in2_gelu.sum() + f_weight_gelu.sum()
print(fake_loss_f)
fake_loss_f.backward()

tensor(887212.1250, device='cuda:0', grad_fn=<AddBackward0>)


In [38]:
in1.grad - in1_b_grad

tensor([[1.7423, 1.6745, 1.3295,  ..., 1.8285, 1.3473, 3.0079],
        [1.7320, 1.4259, 1.6523,  ..., 0.8184, 1.3094, 1.6805],
        [1.4283, 1.2062, 1.2244,  ..., 1.2111, 0.9035, 2.0908],
        ...,
        [1.0257, 1.6077, 0.9544,  ..., 1.2068, 1.8411, 1.8960],
        [1.6774, 1.6838, 2.2131,  ..., 1.5985, 1.0611, 1.5779],
        [1.9894, 0.9254, 2.0164,  ..., 0.8554, 1.0441, 1.5951]],
       device='cuda:0', grad_fn=<SubBackward0>)

In [39]:
in1_c.grad

tensor([[2.4746, 2.5673, 2.4915,  ..., 2.7547, 1.7052, 4.6607],
        [2.3667, 1.6227, 2.3653,  ..., 0.8481, 2.2329, 2.0880],
        [1.9892, 1.5336, 1.5013,  ..., 1.5270, 1.1403, 2.6657],
        ...,
        [1.2998, 2.4359, 1.2184,  ..., 1.7500, 2.7503, 2.5633],
        [2.2998, 2.1113, 2.9679,  ..., 2.1193, 1.2690, 2.1742],
        [2.4509, 1.1871, 2.7462,  ..., 1.2247, 1.3526, 2.0707]],
       device='cuda:0')

In [33]:
sptp_bwd.sptp_linear_bwd_v1_shared(in1, in2, weight, dL_dout, 
                                           mem_dL_din1_2, mem_dL_din2_2,mem_dL_dW_2, mem_debug,
                                           t_in1_idxing, t_in1_ival, t_in1_related_path_idx, 
                                           
                                           t_path_array1,t_path_array2,t_per_upath_fiber_start, t_path_weight,t_per_path_weight_pos, 
                                           t_per_upath_fiber_array,t_unique_cg_val, 
                                           upath_cnt, 1, l_max*2+1
                                           )

In [36]:
print(in1.grad)
print(mem_dL_din1 + mem_dL_din1_2)

tensor([[5.2646, 4.7706, 1.1848,  ..., 0.9037, 1.5067, 1.3808],
        [1.3026, 6.7545, 5.2440,  ..., 0.8022, 1.0752, 0.9668],
        [1.6280, 3.3079, 4.2651,  ..., 0.6249, 0.3846, 0.5027],
        ...,
        [5.1874, 1.7463, 1.9636,  ..., 0.6805, 0.9507, 0.5868],
        [1.1540, 2.9928, 1.3317,  ..., 0.5994, 0.5564, 0.8041],
        [2.4731, 2.5021, 2.0758,  ..., 0.9075, 1.2623, 1.0643]],
       device='cuda:0')
tensor([[6.5785, 5.9241, 1.4275,  ..., 1.1915, 1.8857, 1.7943],
        [1.5393, 8.3717, 6.3028,  ..., 0.8520, 1.2310, 1.1216],
        [1.9913, 4.0285, 5.1296,  ..., 0.7297, 0.4196, 0.5860],
        ...,
        [6.2440, 1.9335, 2.3568,  ..., 0.8560, 1.2979, 0.5248],
        [1.3531, 3.7285, 1.6639,  ..., 0.5955, 0.5359, 0.9069],
        [2.9506, 2.9886, 2.4973,  ..., 1.1122, 1.5329, 1.2736]],
       device='cuda:0')


tensor([[1.5027, 3.3319, 5.1228,  ..., 1.7868, 1.8400, 1.8842],
        [2.5770, 3.3867, 1.8557,  ..., 1.1394, 0.8838, 1.1472],
        [3.1812, 1.8583, 3.4825,  ..., 1.4393, 0.9468, 1.0517],
        ...,
        [2.5716, 0.9875, 2.1731,  ..., 1.3933, 1.4764, 1.8670],
        [1.9476, 2.6564, 2.9469,  ..., 1.0605, 0.7719, 1.0328],
        [2.7277, 1.7549, 1.5600,  ..., 2.0296, 1.0919, 1.8511]],
       device='cuda:0')

In [47]:
new_out = out_exp.grad.detach().clone()
new_out.requires_grad = True
in1_c_grad, in2_c_grad, w_c_grad = bwd_uvu(in1_c, in2_c, new_out , weight_c)
in1_c_grad.retain_grad()
in2_c_grad.retain_grad()
w_c_grad.retain_grad()


f_in1_gelu_c = torch.nn.functional.gelu(in1_c_grad)
f_in2_gelu_c = torch.nn.functional.gelu(in2_c_grad)
f_weight_gelu_c = torch.nn.functional.gelu(w_c_grad)
fake_loss_c = f_in1_gelu_c.sum() + f_in2_gelu_c.sum() + f_weight_gelu_c.sum()


In [48]:
fake_loss_c.backward()

In [55]:
in2_c.grad

tensor([[302.5776, 217.1613, 226.8174, 176.6435],
        [453.4338, 234.4316, 198.2770, 158.4990],
        [444.6406, 204.6053, 227.6118, 153.1021],
        ...,
        [335.7768, 134.1655, 155.7681, 230.7632],
        [236.0521, 182.1424, 220.3352, 208.7974],
        [296.4327, 177.2991, 186.4896, 233.1756]], device='cuda:0')

In [56]:
mem_dl_din2_summed

tensor([[302.5777, 217.1614, 226.8174, 176.6435],
        [453.4338, 234.4316, 198.2770, 158.4990],
        [444.6406, 204.6053, 227.6118, 153.1021],
        ...,
        [335.7768, 134.1655, 155.7681, 230.7632],
        [236.0521, 182.1424, 220.3352, 208.7974],
        [296.4327, 177.2991, 186.4896, 233.1756]], device='cuda:0')

In [45]:
mem_dL_dW

tensor([[2.2671, 2.4855, 0.2930,  ..., 0.1088, 0.9660, 0.0597],
        [0.3870, 1.8454, 3.6596,  ..., 0.2371, 0.0480, 0.4035],
        [0.4839, 1.2464, 1.1771,  ..., 0.3033, 0.1836, 0.0693],
        ...,
        [1.2005, 0.5062, 0.2942,  ..., 0.1568, 0.0677, 0.1113],
        [0.8240, 1.8094, 0.7419,  ..., 0.0719, 0.4216, 0.1555],
        [2.0162, 1.8861, 0.6361,  ..., 0.3202, 0.2585, 0.0748]],
       device='cuda:0')

In [37]:
print(f_in2)
print(f_in2_c)

tensor([[1.2376, 0.8852, 0.9344, 0.8920],
        [1.0128, 0.6753, 0.7480, 0.8928],
        [0.5494, 1.3505, 1.2065, 0.7359],
        ...,
        [3.5244, 2.2287, 2.2657, 2.8133],
        [2.8726, 4.8894, 3.0728, 4.4809],
        [2.0120, 4.9605, 1.9776, 1.4573]], device='cuda:0',
       grad_fn=<ViewBackward0>)
tensor([[0.7867, 0.6858, 0.5769, 0.7672],
        [0.4979, 0.4621, 0.5937, 0.8309],
        [0.5413, 0.6093, 0.4245, 0.8698],
        ...,
        [2.1065, 2.5849, 1.6669, 3.0937],
        [1.0955, 2.1882, 1.2061, 2.7810],
        [2.1567, 2.3180, 2.6381, 2.1826]], device='cuda:0',
       grad_fn=<GeneratedBackwardFor_sptp_linear_sptp_linear_bwd_v2_shared_defaultBackward>)


In [46]:
in2.grad

tensor([[1.4324, 4.2339, 3.3840, 1.5389],
        [0.9778, 1.3487, 2.5574, 4.8300],
        [2.3737, 1.1076, 1.0728, 3.2205],
        ...,
        [1.2971, 1.2289, 0.8152, 0.4706],
        [0.4982, 0.8714, 0.5937, 0.6644],
        [1.1654, 0.8855, 1.3845, 0.6539]], device='cuda:0')

In [None]:
in1_c.grad

tensor([[ 3.9973,  3.6291,  0.4055,  ...,  0.9197,  2.1953,  1.9050],
        [ 0.2527,  3.7130,  2.6742,  ...,  0.6850,  0.9896,  0.7729],
        [ 1.5004,  4.6002,  4.0388,  ...,  0.4224,  0.8066,  0.7461],
        ...,
        [ 5.3038,  2.3369,  2.4174,  ...,  0.5876,  1.8144,  0.8881],
        [ 0.9836,  1.5320,  0.5223,  ...,  0.0271, -0.0212,  0.0568],
        [ 1.9541,  2.1570,  0.9595,  ...,  1.1385,  2.7140,  1.6297]],
       device='cuda:0')