In [103]:
import torch
import e3nn
from torch.utils.cpp_extension import load
import json
import matplotlib.pyplot as plt
import itertools
import os


In [104]:
# os.environ['TORCH_CUDA_ARCH_LIST'] = "8.0"
# sptp = load(name='sptp_linear', sources=['/home2/lsy/fused_e3nn/sptp_linear/sptp_linear.cpp', 
#                                   '/home2/lsy/fused_e3nn/sptp_linear/sptp_linear.cu',
#                                   ], 
#                                   extra_cuda_cflags=["-lineinfo"], verbose=True)


In [105]:
def mul_Irreps(mul, i_in):
    dd = []
    for ori_mul, ir in i_in:
        dd.append((ori_mul*mul, (ir.l, ir.p)))
    return e3nn.o3.Irreps(dd)
def compare(a, b):
    isclose = torch.isclose(a, b)
    diff_pos = torch.argwhere(isclose == False)
    for pos in diff_pos:
        pos_t = [x for x in pos]
        if(abs(a[pos_t] - b[pos_t]) > 1e-7):
            print(pos)
            print(a[pos_t] - b[pos_t] )
            
IR_IN1_IDX = 0
IR_IN2_IDX = 1
IR_OUT_IDX = 2
INST_IDX = 3

def load_nequip_config(h, l_max, layer_idx):
    filename = f"/home2/lsy/fused_e3nn/benchmark_config/4_{h}_{l_max}_p_sc.txt"
    with open(filename, "r") as f:
        f_in = f.read().split("\n")

    per_layer_dict = dict()
    for l_idx, d in enumerate(f_in):
        if(d == "") : continue
        dd = json.loads(d)
        per_layer_dict[l_idx] = dd
    tp_list = per_layer_dict[layer_idx]["tp"]
    i_in1 = e3nn.o3.Irreps(tp_list[IR_IN1_IDX])
    i_in2 = e3nn.o3.Irreps(tp_list[IR_IN2_IDX])
    i_out = e3nn.o3.Irreps(tp_list[IR_OUT_IDX])
    inst_tuple = [tuple(x) for x in tp_list[INST_IDX]]

    return i_in1, i_in2, i_out, inst_tuple


In [106]:
def to_cuda_list(*arg, input_dtype = torch.float32):
    return_list = []
    for item in arg:
        if(type(item) == torch.Tensor):
            return_list.append(item.to(device="cuda"))
        else:
            return_list.append(torch.tensor(item,device="cuda", dtype=input_dtype))
    return return_list

def to_cuda_dict(strname_list, *arg):
    return_dict = {}
    for item,name in zip(arg,strname_list):
        if(type(item) == torch.Tensor):
            return_dict[name] = item.to("cuda")
        else:
            return_dict[name] = torch.tensor(item,device="cuda")
    return return_dict

def cumsum_list(s, np1 = True):
    new_s = []
    current = 0
    for e in s:
        new_s.append(current)
        current += e
    if(np1):
        new_s.append(current)
    return new_s

In [193]:
torch.manual_seed(0)

h = 32
l_max = 5
layer_idx = 3
batch_size = 4096
i_in1, i_in2, i_out, inst_tuple = load_nequip_config(h,l_max,layer_idx)
# i_in2 = mul_Irreps(3, i_in2)

uvuv_tp = e3nn.o3.FullTensorProduct(i_in1,i_in2, filter_ir_out=i_out, path_normalization="none", normalization="none")
uvuv_i_out = uvuv_tp.irreps_out
# split_size = []
# reshape_size = []
# for inst in uvuv_tp.instructions:
#     split_size.append(uvuv_i_out[inst.i_out].dim)
#     reshape_size.append([inst.path_shape[0],inst.path_shape[1],uvuv_i_out[inst.i_out][1].dim])
# weight_mul = e3nn.o3.experimental.FullTensorProduct_uvu_weight_only(i_in1, i_in2, split_size, reshape_size, filter_ir_out=i_out, irrep_normalization=None, regroup_output=False)
# uvw
# i_out = e3nn.o3.Irreps(tp_list[IR_OUT_IDX])
# tp = e3nn.o3.FullyConnectedTensorProduct(i_in1,i_in2,i_out,shared_weights=False, internal_weights=False)

# # uvu
tp = e3nn.o3.TensorProduct(i_in1,i_in2,i_out,inst_tuple,shared_weights=False, internal_weights=False, path_normalization="none", normalization="none") # 



In [194]:
i_in2.dim

36

In [138]:
grad_uvu = e3nn.o3.experimental.FullTensorProduct_grad_uvu(i_in1,i_in2, filter_ir_out=i_out, irrep_normalization=None, regroup_output=False)

In [139]:
in1 = torch.rand(batch_size, i_in1.dim,requires_grad=True)
in2 = torch.rand(batch_size, i_in2.dim,requires_grad=True)
weight = torch.rand(batch_size,tp.weight_numel,requires_grad=True)

correct_out = tp(in1,in2,weight)

In [140]:
out = grad_uvu(in1,in2,weight)

In [141]:
out.retain_grad()
y = out.sum()
y.retain_grad()

In [142]:
y.backward()

In [143]:
y.grad

tensor(1.)

In [144]:
back_w = grad_uvu.used_w_list[0]
dl_dO = grad_uvu.w_result_list[0].grad

In [145]:
# v = 1
# uvuv can easily retrievd by dividing it with path_weight * weight as v is 1

In [146]:
path_idx = 4
dL_dO = grad_uvu.w_result_list[path_idx].grad
W = grad_uvu.used_w_list[path_idx]
O = grad_uvu.w_result_list[path_idx]
uvuv = grad_uvu.cg_result_list[path_idx]
uvuv_gen = (O / W).reshape(uvuv.shape)
cg = grad_uvu.used_cg_list[path_idx]

dL_duvuv = torch.einsum("zuk,zuv -> zuvk", dL_dO, W)
dL_dW = torch.einsum("zuk,zuvk -> zuv", dL_dO, uvuv_gen)
# only place with sparsity
dL_dOuter = torch.einsum("zuvk, ijk -> zuvij", dL_duvuv, cg)
# dL_dA = torch.einsum("zuvij,zvj -> zui", dL_dOuter, in2)
# dL_dB = torch.einsum("zuvij,zui -> zvj", dL_dOuter, in1)

In [147]:
W.shape

torch.Size([4096, 32, 1])

In [148]:
dL_duvuv = torch.einsum("zuk,zuv -> zuvk", dl_dO, back_w)

In [149]:
dL_dw = torch.einsum("zuk,zuvk -> zuv", dl_dO, grad_uvu.cg_result_list[0])

In [150]:
out[0]

tensor([ 0.0056,  0.0788,  0.0625,  ...,  0.0286,  0.0494, -0.0057],
       grad_fn=<SelectBackward0>)

In [151]:
y = torch.rand(correct_out.shape)
loss = (correct_out-y)[0].sum()
loss.backward(retain_graph=True)

In [152]:
# full tp -> linear
# i_out = full_tp.irreps_out
unique_cg = []
nnz_cg_cnt = 0
all_cg_cnt = 0
cg_dummy = torch.zeros(i_in1.dim, i_in2.dim, uvuv_i_out.dim)
cg_dummy_coverage = torch.zeros(i_in1.dim, i_in2.dim, uvuv_i_out.dim)
for inst in uvuv_tp.instructions:
    i = inst.i_in1
    j = inst.i_in2
    k = inst.i_out

    mul_in1, ir_in1 = i_in1[i]
    mul_in2, ir_in2 = i_in2[j]
    mul_out, ir_out = uvuv_i_out[k]

    cg = e3nn.o3.wigner_3j(ir_in1.l, ir_in2.l, ir_out.l)
    unique_cg += list(cg.unique())
    all_cg_cnt+= cg.numel()
    nnz_cg_cnt += cg.count_nonzero()

    partial_mat_cg = torch.zeros(i_in1[i].dim, i_in2[j].dim, uvuv_i_out[k].dim)
    print(cg)
    
    ## uvuv
    for u,v in itertools.product(range(mul_in1), range(mul_in2)):
        partial_mat_cg [u*ir_in1.dim:(u+1)*ir_in1.dim,
        v*ir_in2.dim:(v+1)*ir_in2.dim,
        (u*mul_in2+v)*ir_out.dim:(u*mul_in2+v+1)*ir_out.dim] = cg 

    cg_dummy[i_in1.slices()[i], i_in2.slices()[j], uvuv_i_out.slices()[k]] = partial_mat_cg
    cg_dummy_coverage[i_in1.slices()[i], i_in2.slices()[j], uvuv_i_out.slices()[k]] = 1
print("compute density", nnz_cg_cnt / all_cg_cnt)

tensor([[[1.]]])
tensor([[[0.5774, 0.0000, 0.0000],
         [0.0000, 0.5774, 0.0000],
         [0.0000, 0.0000, 0.5774]]])
tensor([[[0.4472, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.4472, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.4472, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.4472, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.4472]]])
tensor([[[0.3780, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.3780, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.3780, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.3780, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.3780, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3780, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3780]]])
tensor([[[1.]]])
tensor([[[0.5774, 0.0000, 0.0000],
         [0.0000, 0.5774, 0.0000],
         [0.0000, 0.0000, 0.5774]]])
tensor([[[0.4472, 0.0000

In [153]:
cg_dummy_kij = cg_dummy.permute(2,0,1)
cg_dummy_kij_flat = cg_dummy_kij.reshape(uvuv_i_out.dim, -1)
cg_kij_pos_tuple = cg_dummy_kij.nonzero(as_tuple=True)
cg_kij_pos_feed = cg_dummy_kij.nonzero().to(dtype=torch.int32).contiguous()
cg_kij_val_feed = cg_dummy_kij[cg_kij_pos_tuple]
per_out_fiber = cg_dummy_kij_flat.count_nonzero(dim=1)
unique_cg_lookup = list(set([x.item() for x in unique_cg]))

uni_cg_cnt = len(unique_cg_lookup)
out_size = uvuv_i_out.dim

In [154]:
in1_occurance = torch.bincount(cg_kij_pos_tuple[1])

In [155]:
in2_occurance = torch.bincount(cg_kij_pos_tuple[2])

In [156]:
per_threadblock_data = {}
i_out_slice = i_out.slices()
for inst in tp.instructions:
    if inst.i_in1 not in per_threadblock_data.keys():
        per_threadblock_data[inst.i_in1] = {}
        per_threadblock_data[inst.i_in1]["out_slice"] = []
        per_threadblock_data[inst.i_in1][""] = 0
    per_threadblock_data[inst.i_in1]["out_slice"].append(i_out_slice[inst.i_out])

In [None]:
cg_dummy_kij

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [179]:
mac2nnz_ij = torch.bincount(cg_dummy_kij_flat.nonzero()[:,1])

In [199]:
mac2nnz_ij.to(float).mean()

tensor(2.3867, dtype=torch.float64)

In [191]:
nnz_cnt_for_each_path = []
for inst in tp.instructions:
    nnz_cnt_for_each_path.append((inst.i_in1, i_in1[inst.i_in1].dim//32, i_in2[inst.i_in2].dim, cg_dummy_kij_flat[i_out_slice[inst.i_out]].count_nonzero()//32))

In [192]:
nnz_cnt_for_each_path

[(0, 1, 1, tensor(1)),
 (0, 1, 3, tensor(3)),
 (0, 1, 5, tensor(5)),
 (0, 1, 7, tensor(7)),
 (1, 1, 1, tensor(1)),
 (1, 1, 3, tensor(3)),
 (1, 1, 5, tensor(5)),
 (1, 1, 7, tensor(7)),
 (2, 3, 1, tensor(3)),
 (2, 3, 3, tensor(3)),
 (2, 3, 3, tensor(6)),
 (2, 3, 3, tensor(11)),
 (2, 3, 5, tensor(11)),
 (2, 3, 5, tensor(16)),
 (2, 3, 5, tensor(21)),
 (2, 3, 7, tensor(21)),
 (2, 3, 7, tensor(26)),
 (3, 5, 1, tensor(5)),
 (3, 5, 3, tensor(11)),
 (3, 5, 3, tensor(16)),
 (3, 5, 3, tensor(21)),
 (3, 5, 5, tensor(5)),
 (3, 5, 5, tensor(16)),
 (3, 5, 5, tensor(25)),
 (3, 5, 5, tensor(28)),
 (3, 5, 7, tensor(21)),
 (3, 5, 7, tensor(28)),
 (3, 5, 7, tensor(41)),
 (4, 7, 1, tensor(7)),
 (4, 7, 3, tensor(21)),
 (4, 7, 3, tensor(26)),
 (4, 7, 5, tensor(21)),
 (4, 7, 5, tensor(28)),
 (4, 7, 5, tensor(41)),
 (4, 7, 7, tensor(7)),
 (4, 7, 7, tensor(26)),
 (4, 7, 7, tensor(41)),
 (4, 7, 7, tensor(42)),
 (5, 3, 1, tensor(3)),
 (5, 3, 3, tensor(3)),
 (5, 3, 3, tensor(6)),
 (5, 3, 3, tensor(11)),
 (5, 3, 5,

In [168]:
in1_0_ij_nnz = torch.cat([cg_dummy_kij_flat[s] for s in per_threadblock_data[5]["out_slice"]], dim=0).nonzero()

In [173]:
len(in1_0_ij_nnz[:,1].unique())

1536

In [170]:
for x in in1_0_ij_nnz:
    print(x[:,1].unique())

IndexError: too many indices for tensor of dimension 1

In [82]:
in2_occurance

tensor([ 576, 1152,  896, 1152,  896, 1152,  896, 1152,  896])

In [None]:
in1_occurance

tensor([ 9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,
         9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,
         9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,
         9,  9,  9,  9,  9,  9,  9,  9,  9,  9, 18, 14, 18, 18, 14, 18, 18, 14,
        18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14,
        18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14,
        18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14,
        18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14,
        18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 18, 14, 18, 14, 18,
        14, 18, 14, 14, 18, 14, 18, 14, 14, 18, 14, 18, 14, 14, 18, 14, 18, 14,
        14, 18, 14, 18, 14, 14, 18, 14, 18, 14, 14, 18, 14, 18, 14, 14, 18, 14,
        18, 14, 14, 18, 14, 18, 14, 14, 18, 14, 18, 14, 14, 18, 14, 18, 14, 14,
        18, 14, 18, 14, 14, 18, 14, 18, 

In [None]:
per_in1_fiber = cg_dummy_kij_flat.count_nonzero(dim=1)


In [13]:
in1 = torch.rand(batch_size, i_in1.dim)
in2 = torch.rand(batch_size, i_in2.dim)
out = torch.zeros((batch_size,uvuv_i_out.dim))

uni_w3j = torch.tensor(unique_cg_lookup)

if(tp.shared_weights):
    weight = torch.rand(tp.weight_numel)
else:
    # weight = torch.rand(batch_size,tp.weight_numel)
    weight = torch.ones(batch_size,tp.weight_numel)
    # weight[:,928:960] = 2

path_weight = torch.zeros(i_out.dim) 
for inst in tp.instructions:
    k = inst.i_out
    path_weight[i_out.slices()[k]] = inst.path_weight

In [9]:
blk_u_in1_idx = []
blk_u_in2_idx = []

per_blk_u_in1_idx_range = []
per_blk_u_in2_idx_range = []
per_blk_fiber_start_range = []

per_fiber_local_idx = []
per_fiber_global_idx = []

max_u_in1_dim = 0
max_u_in2_dim = 0 
max_fiber_cnt = 0

pattern_len_array = []
stride_mul_array = []
rem_cum_idx_array = []
rem_cumval_array = []
fiber_cnt_array = []

step = 32
i = 0
while(i< uvuv_i_out.dim):
    local_k_idx = cg_dummy_kij[i:i+step].nonzero()[:,0]
    blk_in1_idx = cg_dummy_kij[i:i+step].nonzero()[:,1]
    blk_in2_idx = cg_dummy_kij[i:i+step].nonzero()[:,2]

    u_in1_list = blk_in1_idx.unique().tolist()
    u_in2_list = blk_in2_idx.unique().tolist()

    max_u_in1_dim = max(max_u_in1_dim, len(u_in1_list))
    max_u_in2_dim = max(max_u_in2_dim, len(u_in2_list))
    max_fiber_cnt = max(max_fiber_cnt, len(blk_in1_idx))

    blk_u_in1_idx.append(u_in1_list)
    blk_u_in2_idx.append(u_in2_list)

    per_blk_u_in1_idx_range.append(len(u_in1_list))
    per_blk_u_in2_idx_range.append(len(u_in2_list))

    per_blk_fiber_start_range.append(len(blk_in1_idx))

    local_in1_idx = [u_in1_list.index(x.item()) for x in blk_in1_idx]
    local_in2_idx = [u_in2_list.index(x.item()) for x in blk_in2_idx]

    for l_a,l_b,l_k,g_a,g_b in zip(local_in1_idx,local_in2_idx, local_k_idx,blk_in1_idx,blk_in2_idx):
        real_k = i + l_k
        cg_idx = unique_cg_lookup.index(cg_dummy_kij[real_k,g_a,g_b])

        per_fiber_local_idx.append(l_a)
        per_fiber_local_idx.append(l_b)
        per_fiber_local_idx.append(cg_idx)
        per_fiber_local_idx.append(l_k)

        per_fiber_global_idx.append(g_a)
        per_fiber_global_idx.append(g_b)
        per_fiber_global_idx.append(cg_idx)
        per_fiber_global_idx.append(real_k)


    pattern = per_out_fiber[i:i+step].tolist()
    pattern_length = 0
    if ([pattern[0]]*len(pattern) == pattern):
        pattern_length=1
    elif (list(itertools.chain.from_iterable(([pattern[0:3]]*len(pattern))))[0:len(pattern)] == pattern):
        pattern_length=3
    elif (list(itertools.chain.from_iterable(([pattern[0:5]]*len(pattern))))[0:len(pattern)] == pattern):
        pattern_length=5
    elif (list(itertools.chain.from_iterable(([pattern[0:7]]*len(pattern))))[0:len(pattern)] == pattern):
        pattern_length=7
    elif (list(itertools.chain.from_iterable(([pattern[0:9]]*len(pattern))))[0:len(pattern)] == pattern):
        pattern_length=9
    elif (list(itertools.chain.from_iterable(([pattern[0:11]]*len(pattern))))[0:len(pattern)] == pattern):
        pattern_length=11
    else:
        print("bad")

    unit_pattern = pattern[0:pattern_length]

    pattern_len_array.append(pattern_length)
    stride_mul_array.append(sum(unit_pattern))
    rem_cum_idx_array.append(pattern_length)

    rem_cumval_array.append(cumsum_list(unit_pattern, False))
    fiber_cnt_array.append(unit_pattern)

    i += step

rem_cum_idx_array = cumsum_list(rem_cum_idx_array,False)
per_blk_u_in1_idx_range  = cumsum_list(per_blk_u_in1_idx_range)
per_blk_u_in2_idx_range  = cumsum_list(per_blk_u_in2_idx_range)
per_blk_fiber_start_range = cumsum_list(per_blk_fiber_start_range)

rem_cumval_array = list(itertools.chain.from_iterable(rem_cumval_array))
fiber_cnt_array = list(itertools.chain.from_iterable(fiber_cnt_array))
blk_u_in1_idx = list(itertools.chain.from_iterable(blk_u_in1_idx))
blk_u_in2_idx = list(itertools.chain.from_iterable(blk_u_in2_idx))

per_fiber_local_idx = torch.tensor(per_fiber_local_idx).to(torch.uint8)
per_fiber_global_idx = torch.tensor(per_fiber_global_idx)

In [11]:
max_u_in1_dim

288

In [10]:
per_blk_w_idx = []
for w_idx, i in enumerate(uvuv_i_out):
    for j in range(i.ir.dim):
        per_blk_w_idx.append(w_idx)

In [11]:
out_cuda = out.to(device="cuda")
main1_cuda = to_cuda_list(in1.T.contiguous(),in2.T.contiguous())
main2_cuda = to_cuda_list(uni_w3j,weight.T.contiguous(),path_weight)
meta_cuda = to_cuda_list(
                    per_blk_w_idx,
                    per_blk_u_in1_idx_range,
                    per_blk_u_in2_idx_range,
                    per_blk_fiber_start_range,
                    
                    blk_u_in1_idx,
                    blk_u_in2_idx,

                    pattern_len_array,
                    stride_mul_array,
                    rem_cum_idx_array,
                    rem_cumval_array,
                    fiber_cnt_array, 
                    
                    per_fiber_local_idx,
                    input_dtype=torch.uint16)
size_input = [max_u_in1_dim,max_u_in2_dim,max_fiber_cnt,uni_cg_cnt]
size_input = [int(x) for x in size_input]

In [None]:
sptp.sptp_linear_v1(*main1_cuda, out_cuda, *main2_cuda,*meta_cuda,*size_input)

In [17]:
out_cuda

tensor([[ 0.0628,  0.0963,  0.0538,  ...,  0.0593,  0.0466, -0.0062],
        [ 0.1053,  0.1894,  0.1876,  ..., -0.0967,  0.0249, -0.1779],
        [ 0.1389,  0.0715,  0.3133,  ...,  0.0856, -0.0235,  0.0522],
        ...,
        [ 0.6284,  0.5477,  0.4419,  ...,  0.1437,  0.0976, -0.0147],
        [ 0.1207,  0.3705,  0.0709,  ...,  0.1003, -0.0310,  0.0856],
        [ 0.3683,  0.0145,  0.4971,  ...,  0.0489, -0.2779, -0.0830]],
       device='cuda:0')

In [18]:
correct_out = tp(in1,in2,weight)

In [19]:
correct_out

tensor([[ 0.0628,  0.0963,  0.0538,  ...,  0.0593,  0.0466, -0.0062],
        [ 0.1053,  0.1894,  0.1876,  ..., -0.0967,  0.0249, -0.1779],
        [ 0.1389,  0.0715,  0.3133,  ...,  0.0856, -0.0235,  0.0522],
        ...,
        [ 0.6284,  0.5477,  0.4419,  ...,  0.1437,  0.0976, -0.0147],
        [ 0.1207,  0.3705,  0.0709,  ...,  0.1003, -0.0310,  0.0856],
        [ 0.3683,  0.0145,  0.4971,  ...,  0.0489, -0.2779, -0.0830]])

In [20]:
index = 4033
torch.isclose(correct_out[index], out_cuda.cpu()[index]).all()
for a,b in zip(correct_out[index], out_cuda.cpu()[index]):
    print(torch.isclose(a,b),a,b)

tensor(True) tensor(0.0449) tensor(0.0449)
tensor(True) tensor(0.0436) tensor(0.0436)
tensor(True) tensor(0.0244) tensor(0.0244)
tensor(True) tensor(0.0281) tensor(0.0281)
tensor(True) tensor(0.0289) tensor(0.0289)
tensor(True) tensor(0.0088) tensor(0.0088)
tensor(True) tensor(0.0269) tensor(0.0269)
tensor(True) tensor(0.0476) tensor(0.0476)
tensor(True) tensor(0.0165) tensor(0.0165)
tensor(True) tensor(0.0056) tensor(0.0056)
tensor(True) tensor(0.0185) tensor(0.0185)
tensor(True) tensor(0.0059) tensor(0.0059)
tensor(True) tensor(0.0575) tensor(0.0575)
tensor(True) tensor(0.0567) tensor(0.0567)
tensor(True) tensor(0.0191) tensor(0.0191)
tensor(True) tensor(0.0153) tensor(0.0153)
tensor(True) tensor(0.0493) tensor(0.0493)
tensor(True) tensor(0.0405) tensor(0.0405)
tensor(True) tensor(0.0193) tensor(0.0193)
tensor(True) tensor(0.0102) tensor(0.0102)
tensor(True) tensor(0.0299) tensor(0.0299)
tensor(True) tensor(0.0065) tensor(0.0065)
tensor(True) tensor(0.0468) tensor(0.0468)
tensor(True

In [21]:
a_u_cnt = []
b_u_cnt = []
fiber_cnt = []
i_cnt = []
step = 32
i = 0
fiber_cnt_pattern = {}
while(i< uvuv_i_out.dim):
    u_in1 = cg_dummy_kij[i:i+step].nonzero()[:,1].unique()
    u_in2 = cg_dummy_kij[i:i+step].nonzero()[:,2].unique()


    l_a_u_cnt = (cg_dummy_kij[i:i+step].nonzero()[:,1].unique().numel())
    l_b_u_cnt = (cg_dummy_kij[i:i+step].nonzero()[:,2].unique().numel())
    l_fiber_cnt = (cg_dummy_kij[i:i+step].nonzero().shape[0])

    pattern = per_out_fiber[i:i+step].tolist()

    # print([pattern[0]]*len(pattern))
    

    # print(list(itertools.chain.from_iterable(([pattern[0:3]]*len(pattern))))[0:len(pattern)])
    # print(list(itertools.chain.from_iterable(([pattern[0:5]]*len(pattern))))[0:len(pattern)])
    
    if ([pattern[0]]*len(pattern) == pattern):
        pattern_tuple = (1, pattern[0])
    elif (list(itertools.chain.from_iterable(([pattern[0:3]]*len(pattern))))[0:len(pattern)] == pattern):
        pattern_tuple = (3, pattern[0:3])
    elif (list(itertools.chain.from_iterable(([pattern[0:5]]*len(pattern))))[0:len(pattern)] == pattern):
        pattern_tuple = (5, pattern[0:5])
    elif (list(itertools.chain.from_iterable(([pattern[0:7]]*len(pattern))))[0:len(pattern)] == pattern):
        pattern_tuple = (7, pattern[0:7])
    else:
        print("bad")
    

    if str(pattern_tuple) not in fiber_cnt_pattern:
        fiber_cnt_pattern[str(pattern_tuple)] = []
    fiber_cnt_pattern[str(pattern_tuple)].append(i//32)
    # # plt.scatter(cg_dummy_kij[i:i+step].nonzero()[:,0], cg_dummy_kij[i:i+32].nonzero()[:,1])
    # # plt.show()
    # # plt.close()
    # if(l_a_u_cnt > 60):
    #     l_a_u_cnt = (cg_dummy_kij[i:i+step//2].nonzero()[:,1].unique().numel())
    #     l_b_u_cnt = (cg_dummy_kij[i:i+step//2].nonzero()[:,2].unique().numel())
    #     l_fiber_cnt = (cg_dummy_kij[i:i+step//2].nonzero().shape[0])
    #     if(l_a_u_cnt > 60):
    #         l_a_u_cnt = (cg_dummy_kij[i:i+step//4].nonzero()[:,1].unique().numel())
    #         l_b_u_cnt = (cg_dummy_kij[i:i+step//4].nonzero()[:,2].unique().numel())
    #         l_fiber_cnt = (cg_dummy_kij[i:i+step//4].nonzero().shape[0])

    #         a_u_cnt.append(l_a_u_cnt)
    #         b_u_cnt.append(l_b_u_cnt)
    #         fiber_cnt.append(l_fiber_cnt)
    #         i_cnt.append(step//4)
    #         i += step//4
    #     else:
    #         a_u_cnt.append(l_a_u_cnt)
    #         b_u_cnt.append(l_b_u_cnt)
    #         fiber_cnt.append(l_fiber_cnt)
    #         i_cnt.append(step//2)
    #         i += step//2
    # else:
    a_u_cnt.append(l_a_u_cnt)
    b_u_cnt.append(l_b_u_cnt)
    fiber_cnt.append(l_fiber_cnt)
    i_cnt.append(step)
    i += step

In [22]:
smem_size = []
for a, b, f, i in zip(a_u_cnt, b_u_cnt, fiber_cnt, i_cnt):
    smem_size.append((32 * (a+b+i+i) +f) * 4 )
    print(a,b,f,i)

32 1 32 32
96 3 96 32
160 5 160 32
32 1 32 32
160 5 160 32
96 3 96 32
11 3 32 32
12 3 32 32
11 3 32 32
33 3 64 32
34 3 64 32
33 3 64 32
55 3 117 32
58 3 118 32
55 3 117 32
32 1 32 32
32 1 32 32
32 1 32 32
33 5 117 32
36 5 118 32
33 5 117 32
55 5 170 32
60 5 172 32
55 5 170 32
11 3 32 32
12 3 32 32
11 3 32 32
32 1 32 32
32 1 32 32
32 1 32 32
33 5 117 32
36 5 118 32
33 5 117 32
55 5 170 32
60 5 172 32
55 5 170 32
33 3 64 32
34 3 64 32
33 3 64 32
55 3 117 32
58 3 118 32
55 3 117 32
7 5 32 32
7 5 32 32
8 5 32 32
7 5 32 32
7 5 32 32
21 3 70 32
21 3 71 32
22 3 70 32
21 3 71 32
21 3 70 32
35 3 103 32
35 3 102 32
36 3 102 32
35 3 102 32
35 3 103 32
21 5 103 32
21 5 102 32
24 5 102 32
21 5 102 32
21 5 103 32
32 1 32 32
32 1 32 32
32 1 32 32
32 1 32 32
32 1 32 32
35 5 160 32
35 5 161 32
38 5 158 32
35 5 161 32
35 5 160 32
7 5 32 32
7 5 32 32
8 5 32 32
7 5 32 32
7 5 32 32
21 5 103 32
21 5 102 32
24 5 102 32
21 5 102 32
21 5 103 32
32 1 32 32
32 1 32 32
32 1 32 32
32 1 32 32
32 1 32 32
35 5 160 32

In [23]:
max(smem_size)

29952