In [1]:
import torch
import e3nn
from torch.utils.cpp_extension import load
import json
import matplotlib.pyplot as plt
import itertools
import os


In [2]:
os.environ['TORCH_CUDA_ARCH_LIST'] = "8.0"
# os.environ["MAX_JOBS"] = "16"
sptp = load(name='sptp_linear', sources=['/home2/lsy/mdsim/fused_e3nn/fused_e3nn_kernel/sptp_linear.cpp','/home2/lsy/mdsim/fused_e3nn/fused_e3nn_kernel/sptp_linear.cu'], extra_cuda_cflags=["-lineinfo"], verbose=True)


Using /home2/lsy/.cache/torch_extensions/py311_cu124 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home2/lsy/.cache/torch_extensions/py311_cu124/sptp_linear/build.ninja...
Building extension module sptp_linear...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


ninja: no work to do.


Loading extension module sptp_linear...


In [3]:
def mul_Irreps(mul, i_in):
    dd = []
    for ori_mul, ir in i_in:
        dd.append((ori_mul*mul, (ir.l, ir.p)))
    return e3nn.o3.Irreps(dd)
def compare(a, b):
    isclose = torch.isclose(a, b)
    diff_pos = torch.argwhere(isclose == False)
    for pos in diff_pos:
        pos_t = [x for x in pos]
        if(abs(a[pos_t] - b[pos_t]) > 1e-7):
            print(pos)
            print(a[pos_t] - b[pos_t] )
            
IR_IN1_IDX = 0
IR_IN2_IDX = 1
IR_OUT_IDX = 2
INST_IDX = 3

def load_nequip_config(h, l_max, layer_idx):
    filename = f"/home2/lsy/mdsim/nequip/benchmark_config/4_{h}_{l_max}_p_sc.txt"
    with open(filename, "r") as f:
        f_in = f.read().split("\n")

    per_layer_dict = dict()
    for l_idx, d in enumerate(f_in):
        if(d == "") : continue
        dd = json.loads(d)
        per_layer_dict[l_idx] = dd
    tp_list = per_layer_dict[layer_idx]["tp"]
    i_in1 = e3nn.o3.Irreps(tp_list[IR_IN1_IDX])
    i_in2 = e3nn.o3.Irreps(tp_list[IR_IN2_IDX])
    i_out = e3nn.o3.Irreps(tp_list[IR_OUT_IDX])
    inst_tuple = [tuple(x) for x in tp_list[INST_IDX]]

    return i_in1, i_in2, i_out, inst_tuple


In [4]:
def to_cuda_list(*arg, input_dtype = torch.float32):
    return_list = []
    for item in arg:
        if(type(item) == torch.Tensor):
            return_list.append(item.to(device="cuda"))
        else:
            return_list.append(torch.tensor(item,device="cuda", dtype=input_dtype))
    return return_list

def to_cuda_dict(strname_list, *arg):
    return_dict = {}
    for item,name in zip(arg,strname_list):
        if(type(item) == torch.Tensor):
            return_dict[name] = item.to("cuda")
        else:
            return_dict[name] = torch.tensor(item,device="cuda")
    return return_dict

def cumsum_list(s, np1 = True):
    new_s = []
    current = 0
    for e in s:
        new_s.append(current)
        current += e
    if(np1):
        new_s.append(current)
    return new_s

In [5]:
torch.manual_seed(0)

h = 32
l_max = 2
layer_idx = 3
batch_size = 4096
i_in1, i_in2, i_out, inst_tuple = load_nequip_config(h,l_max,layer_idx)
# i_in2 = mul_Irreps(3, i_in2)

uvuv_tp = e3nn.o3.FullTensorProduct(i_in1,i_in2, filter_ir_out=i_out, path_normalization="none", normalization="none")
uvuv_i_out = uvuv_tp.irreps_out

# ex_uvu_weight = e3nn.o3.experimental.FullTensorProductv2(i_in1, i_in2, filter_ir_out=i_out, irrep_normalization=None, regroup_output=False)
split_size = []
reshape_size = []
for inst in uvuv_tp.instructions:
    split_size.append(uvuv_i_out[inst.i_out].dim)
    reshape_size.append([inst.path_shape[0],inst.path_shape[1],uvuv_i_out[inst.i_out][1].dim])
weight_mul = e3nn.o3.experimental.FullTensorProduct_uvu_weight_only(i_in1, i_in2, split_size, reshape_size, filter_ir_out=i_out, irrep_normalization=None, regroup_output=False)

# uvw
# i_out = e3nn.o3.Irreps(tp_list[IR_OUT_IDX])
# tp = e3nn.o3.FullyConnectedTensorProduct(i_in1,i_in2,i_out,shared_weights=False, internal_weights=False)

# # uvu
tp = e3nn.o3.TensorProduct(i_in1,i_in2,i_out,inst_tuple,shared_weights=False, internal_weights=False, path_normalization="none", normalization="none") # 



In [6]:
# full tp -> linear
# i_out = full_tp.irreps_out
unique_cg = []
nnz_cg_cnt = 0
all_cg_cnt = 0
cg_dummy = torch.zeros(i_in1.dim, i_in2.dim, uvuv_i_out.dim)
cg_dummy_coverage = torch.zeros(i_in1.dim, i_in2.dim, uvuv_i_out.dim)
for inst in uvuv_tp.instructions:
    i = inst.i_in1
    j = inst.i_in2
    k = inst.i_out

    mul_in1, ir_in1 = i_in1[i]
    mul_in2, ir_in2 = i_in2[j]
    mul_out, ir_out = uvuv_i_out[k]

    cg = e3nn.o3.wigner_3j(ir_in1.l, ir_in2.l, ir_out.l)
    unique_cg += list(cg.unique())
    all_cg_cnt+= cg.numel()
    nnz_cg_cnt += cg.count_nonzero()

    partial_mat_cg = torch.zeros(i_in1[i].dim, i_in2[j].dim, uvuv_i_out[k].dim)
    
    ## uvuv
    for u,v in itertools.product(range(mul_in1), range(mul_in2)):
        partial_mat_cg [u*ir_in1.dim:(u+1)*ir_in1.dim,
        v*ir_in2.dim:(v+1)*ir_in2.dim,
        (u*mul_in2+v)*ir_out.dim:(u*mul_in2+v+1)*ir_out.dim] = cg 

    cg_dummy[i_in1.slices()[i], i_in2.slices()[j], uvuv_i_out.slices()[k]] = partial_mat_cg
    cg_dummy_coverage[i_in1.slices()[i], i_in2.slices()[j], uvuv_i_out.slices()[k]] = 1
print("compute density", nnz_cg_cnt / all_cg_cnt)

compute density tensor(0.2228)


In [7]:
cg_dummy_kij = cg_dummy.permute(2,0,1)
cg_dummy_kij_flat = cg_dummy_kij.reshape(uvuv_i_out.dim, -1)
# cg_kij_pos_tuple = cg_dummy_kij.nonzero(as_tuple=True)
# cg_kij_pos_feed = cg_dummy_kij.nonzero().to(dtype=torch.int32).contiguous()
# cg_kij_val_feed = cg_dummy_kij[cg_kij_pos_tuple]
per_out_fiber = cg_dummy_kij_flat.count_nonzero(dim=1)
unique_cg_lookup = list(set([x.item() for x in unique_cg]))

uni_cg_cnt = len(unique_cg_lookup)
out_size = uvuv_i_out.dim

In [8]:
in1 = torch.rand(batch_size, i_in1.dim)
in2 = torch.rand(batch_size, i_in2.dim)
out = torch.zeros((batch_size,uvuv_i_out.dim))

uni_w3j = torch.tensor(unique_cg_lookup)

if(tp.shared_weights):
    weight = torch.rand(tp.weight_numel)
else:
    weight = torch.rand(batch_size,tp.weight_numel)
    # weight = torch.ones(batch_size,tp.weight_numel)

    # wrong_weight = torch.ones(batch_size,tp.weight_numel)
    # # weight[0,:] = torch.arange(weight.shape[1])
    # wrong_weight[:,834] = 2.45
    # weight[:,834] = 2.45

    # weight[:,500:532] = 4.45

path_weight = torch.zeros(i_out.dim) 
for inst in tp.instructions:
    k = inst.i_out
    path_weight[i_out.slices()[k]] = inst.path_weight

In [9]:
w_idx = 0
cnt = 0
blk_size = 32
perout_w_idx = []
per_blk_k_dim = []

for mul,ir in i_out:
    for _ in range(mul):
        for i in range(ir.dim):
            cnt+=1
            perout_w_idx.append(w_idx)
            if(cnt%32==0):
                per_blk_k_dim.append(ir.dim)
        w_idx+=1

In [None]:
weight_sptp = []
for inst_idx in weight_mul.inv:
    u,v, slice = weight_mul.weight_uv_pair_sorted_slice[inst_idx]
    weight_sptp.append(weight[:,slice])
weight_sptp = torch.cat(weight_sptp,dim=1)

In [11]:
blk_u_in1_idx = []
blk_u_in2_idx = []

per_blk_w_start_idx = []
per_blk_w_end_idx = []
per_blk_w_precnt = []

per_blk_u_in1_idx_range = []
per_blk_u_in2_idx_range = []
per_blk_fiber_start_range = []

per_fiber_local_idx = []
per_fiber_global_idx = []

max_u_in1_dim = 0
max_u_in2_dim = 0 
max_fiber_cnt = 0

pattern_len_array = []
stride_mul_array = []
rem_cum_idx_array = []
rem_cumval_array = []
fiber_cnt_array = []

step = 32
i = 0
prev_w_idx_cnt = 0
prev_idx = -1

while(i< uvuv_i_out.dim):
    local_k_idx = cg_dummy_kij[i:i+step].nonzero()[:,0]
    blk_in1_idx = cg_dummy_kij[i:i+step].nonzero()[:,1]
    blk_in2_idx = cg_dummy_kij[i:i+step].nonzero()[:,2]

    used_weight_idx = perout_w_idx[i:i+step]

    per_blk_w_start_idx.append(used_weight_idx[0])
    per_blk_w_end_idx.append(used_weight_idx[-1]+1)
    
    if (prev_idx == used_weight_idx[0]):
        per_blk_w_precnt.append(prev_w_idx_cnt)
    else:
        per_blk_w_precnt.append(0)
    prev_idx = used_weight_idx[-1]

    prev_w_idx_cnt=0
    for w_idx_val in used_weight_idx:
        if w_idx_val == used_weight_idx[-1]: prev_w_idx_cnt+=1

    u_in1_list = blk_in1_idx.unique().tolist()
    u_in2_list = blk_in2_idx.unique().tolist()

    max_u_in1_dim = max(max_u_in1_dim, len(u_in1_list))
    max_u_in2_dim = max(max_u_in2_dim, len(u_in2_list))
    max_fiber_cnt = max(max_fiber_cnt, len(blk_in1_idx))

    blk_u_in1_idx.append(u_in1_list)
    blk_u_in2_idx.append(u_in2_list)

    per_blk_u_in1_idx_range.append(len(u_in1_list))
    per_blk_u_in2_idx_range.append(len(u_in2_list))

    per_blk_fiber_start_range.append(len(blk_in1_idx))

    local_in1_idx = [u_in1_list.index(x.item()) for x in blk_in1_idx]
    local_in2_idx = [u_in2_list.index(x.item()) for x in blk_in2_idx]

    for l_a,l_b,l_k,g_a,g_b in zip(local_in1_idx,local_in2_idx, local_k_idx,blk_in1_idx,blk_in2_idx):
        real_k = i + l_k
        cg_idx = unique_cg_lookup.index(cg_dummy_kij[real_k,g_a,g_b])

        per_fiber_local_idx.append(l_a)
        per_fiber_local_idx.append(l_b)
        per_fiber_local_idx.append(cg_idx)
        per_fiber_local_idx.append(l_k)

        per_fiber_global_idx.append(g_a)
        per_fiber_global_idx.append(g_b)
        per_fiber_global_idx.append(cg_idx)
        per_fiber_global_idx.append(real_k)


    pattern = per_out_fiber[i:i+step].tolist()
    pattern_length = 0
    if ([pattern[0]]*len(pattern) == pattern):
        pattern_length=1
    elif (list(itertools.chain.from_iterable(([pattern[0:3]]*len(pattern))))[0:len(pattern)] == pattern):
        pattern_length=3
    elif (list(itertools.chain.from_iterable(([pattern[0:5]]*len(pattern))))[0:len(pattern)] == pattern):
        pattern_length=5
    elif (list(itertools.chain.from_iterable(([pattern[0:7]]*len(pattern))))[0:len(pattern)] == pattern):
        pattern_length=7
    elif (list(itertools.chain.from_iterable(([pattern[0:9]]*len(pattern))))[0:len(pattern)] == pattern):
        pattern_length=9
    elif (list(itertools.chain.from_iterable(([pattern[0:11]]*len(pattern))))[0:len(pattern)] == pattern):
        pattern_length=11
    else:
        print(pattern, i)
        print("bad")

    unit_pattern = pattern[0:pattern_length]

    pattern_len_array.append(pattern_length)
    stride_mul_array.append(sum(unit_pattern))
    rem_cum_idx_array.append(pattern_length)

    rem_cumval_array.append(cumsum_list(unit_pattern, False))
    fiber_cnt_array.append(unit_pattern)

    i += step

rem_cum_idx_array = cumsum_list(rem_cum_idx_array,False)
per_blk_u_in1_idx_range  = cumsum_list(per_blk_u_in1_idx_range)
per_blk_u_in2_idx_range  = cumsum_list(per_blk_u_in2_idx_range)
per_blk_fiber_start_range = cumsum_list(per_blk_fiber_start_range)

rem_cumval_array = list(itertools.chain.from_iterable(rem_cumval_array))
fiber_cnt_array = list(itertools.chain.from_iterable(fiber_cnt_array))
blk_u_in1_idx = list(itertools.chain.from_iterable(blk_u_in1_idx))
blk_u_in2_idx = list(itertools.chain.from_iterable(blk_u_in2_idx))

per_fiber_local_idx = torch.tensor(per_fiber_local_idx).to(torch.uint8)
per_fiber_global_idx = torch.tensor(per_fiber_global_idx)

In [12]:
READ_IDX = 0
WRITE_IDX = 1
MUL_FLOPS_IDX = 2
ADD_FLOPS_IDX = 3
def total_flops(input):
    return input[MUL_FLOPS_IDX] + input[ADD_FLOPS_IDX]
def total_MEM_traffic(input):
    return input[READ_IDX] + input[WRITE_IDX]

def compute_ai(AI_dict_item):
    return total_flops(AI_dict_item)/ total_MEM_traffic(AI_dict_item)

def time_estimate(AI_dict_item, FLOPS=20, MEM_BW=1.5):
    # in u sec
    ai = compute_ai(AI_dict_item)
    time = (total_flops(AI_dict_item)/10**12) / min(FLOPS, MEM_BW*ai)
    return time*(10**6)


In [13]:
def print_ai(batch_size , i_in1, i_in2, i_out, uvuv_tp, FLOPS=20, MEM_BW=1.5):
    # e3nn outer product
    # scale with z
    AI_dict= {"op":[0,0,0,0], "cg":[0,0,0,0], "wmul":[0,0,0,0]}

    z = batch_size
    outer_product_dict = {}

    for inst in uvuv_tp.instructions:
        u = i_in1[inst.i_in1].mul
        v = i_in2[inst.i_in2].mul

        i = i_in1[inst.i_in1][1].dim
        j = i_in2[inst.i_in2][1].dim
        k = i_out[inst.i_out][1].dim

        op_str = f"{inst.i_in1}_{inst.i_in2}"
        if(op_str not in outer_product_dict):
            outer_product_dict[op_str] = 1
            AI_dict["op"][READ_IDX] += (u*i + v*j) * z *4
            AI_dict["op"][WRITE_IDX] += u*i*v*j * z*4
            AI_dict["op"][MUL_FLOPS_IDX] += u*i*v*j * z            

        AI_dict["cg"][READ_IDX] += u*i*v*j * z*4
        AI_dict["cg"][READ_IDX] += i*j*k*4
        AI_dict["cg"][WRITE_IDX] += u*v*k * z*4
        AI_dict["cg"][MUL_FLOPS_IDX] += z*u*v*i*j*k
        AI_dict["cg"][ADD_FLOPS_IDX] += z*u*v*(i*j-1)*k

        AI_dict["wmul"][READ_IDX] += u*v*k * z*4
        AI_dict["wmul"][READ_IDX] += z*u*v*4
        AI_dict["wmul"][WRITE_IDX] += z* u*k*4
        
        AI_dict["wmul"][MUL_FLOPS_IDX] += z*u*v*k
        AI_dict["wmul"][ADD_FLOPS_IDX] += z*u*k*(v-1)

    op_ai = compute_ai(AI_dict["op"])
    cg_ai = compute_ai(AI_dict["cg"])
    w_ai = compute_ai(AI_dict["wmul"])
    return (op_ai, cg_ai, w_ai), (
        time_estimate(AI_dict["op"], FLOPS,MEM_BW), 
        time_estimate(AI_dict["cg"], FLOPS,MEM_BW), 
        time_estimate(AI_dict["wmul"],FLOPS,MEM_BW)), AI_dict


In [14]:
outs = print_ai(batch_size, i_in1, i_in2, i_out, uvuv_tp)

In [15]:
naive_read = [x[READ_IDX]/10**6 for x in outs[2].values()]
naive_write = [x[WRITE_IDX]/10**6 for x in outs[2].values()]

In [16]:
print(naive_read)
print(naive_write)
naive_total = sum(naive_write) + sum(naive_read)

[29.196288, 187.700024, 69.206016]
[84.934656, 53.477376, 53.477376]


In [17]:
ideal_mem_traffic = [0,0,0,0]
real_mem_traffic = [0,0,0,0]
fp32_size = 4

# read in1 in2
ideal_mem_traffic[0] += (i_in1.dim + i_in2.dim) * batch_size * fp32_size
real_mem_traffic[0] += (len(blk_u_in1_idx) + len(blk_u_in2_idx)) * batch_size * fp32_size
# read weight
ideal_mem_traffic[1] += weight.numel() * fp32_size
real_mem_traffic[1] += sum(torch.tensor(per_blk_w_end_idx) - torch.tensor(per_blk_w_start_idx)).item() * batch_size * fp32_size

# metadata time
real_mem_traffic[2] += len(per_fiber_local_idx) * (batch_size /32)
# write traffic
real_mem_traffic[3] += out.numel()*fp32_size
ideal_mem_traffic[3] += out.numel()*fp32_size

In [18]:
print("ideal_mem_traffic ", sum(ideal_mem_traffic)/10**6 , "MB")
print("ideal_mem_traffic ", [x / 10**6 for x in ideal_mem_traffic] , "MB")

print("real_mem_traffic ", sum(real_mem_traffic)/10**6 , "MB")
print("real_mem_traffic ", [x / 10**6 for x in real_mem_traffic] , "MB")

print("naive_mem_traffic ", naive_total , "MB")


ideal_mem_traffic  78.790656 MB
ideal_mem_traffic  [9.58464, 15.72864, 0.0, 53.477376] MB
real_mem_traffic  137.166848 MB
real_mem_traffic  [62.291968, 16.908288, 4.489216, 53.477376] MB
naive_mem_traffic  477.99173600000006 MB


In [19]:
32 K
ideal_mem_traffic  1730.1504 MB
ideal_mem_traffic  [136.31488, 285.212672, 0.0, 1308.622848] MB
real_mem_traffic  3409.18272 MB
real_mem_traffic  [1623.195648, 317.19424, 160.169984, 1308.622848] MB
naive_mem_traffic  13103.033184 MB


SyntaxError: invalid syntax (1980152098.py, line 1)

In [20]:
other_meta = sum([len(x) for x in [per_blk_w_start_idx,
                    per_blk_w_end_idx,
                    per_blk_w_precnt,
                    per_blk_k_dim,

                    per_blk_u_in1_idx_range,
                    per_blk_u_in2_idx_range,
                    per_blk_fiber_start_range,
                    
                    pattern_len_array,
                    stride_mul_array,
                    rem_cum_idx_array,
                    rem_cumval_array,
                    fiber_cnt_array]]) * (batch_size/32) / 10**6


In [21]:
print(i_in1.dim, i_in2.dim)
print(len(blk_u_in1_idx))
print(len(blk_u_in2_idx))

576 9
3444
358


In [22]:
per_fiber_local_idx

tensor([ 0,  0,  1,  ...,  0,  7, 31], dtype=torch.uint8)

In [23]:
print(len(per_fiber_local_idx)) 

35072


In [24]:
# num_ z block 
half_size = 2
(len(per_blk_w_start_idx) + len(per_blk_w_end_idx) + len(per_blk_w_precnt) + len(per_blk_k_dim) + len(per_blk_u_in1_idx_range))* half_size 

# per block


1022

In [25]:
# # Need to fix for u != 32
# per_blk_w_idx = []
# for w_idx, i in enumerate(uvuv_i_out):
#     for j in range(i.ir.dim):
#         # need to add u//32 time
#         # will need seperate w_idx counter 
#         per_blk_w_idx.append(w_idx)

In [26]:
out_cuda = out.to(device="cuda")
main1_cuda = to_cuda_list(in1.T.contiguous(),in2.T.contiguous())
main2_cuda = to_cuda_list(uni_w3j,weight_sptp.T.contiguous(),path_weight)
meta_cuda = to_cuda_list(
                    per_blk_w_start_idx,
                    per_blk_w_end_idx,
                    per_blk_w_precnt,
                    per_blk_k_dim,

                    per_blk_u_in1_idx_range,
                    per_blk_u_in2_idx_range,
                    per_blk_fiber_start_range,
                    
                    blk_u_in1_idx,
                    blk_u_in2_idx,

                    pattern_len_array,
                    stride_mul_array,
                    rem_cum_idx_array,
                    rem_cumval_array,
                    fiber_cnt_array, 
                    
                    per_fiber_local_idx,
                    input_dtype=torch.uint16)
size_input = [max_u_in1_dim,max_u_in2_dim,max_fiber_cnt,uni_cg_cnt]
size_input = [int(x) for x in size_input]

In [27]:
sptp.sptp_linear_v1(*main1_cuda, out_cuda, *main2_cuda,*meta_cuda,*size_input)

In [28]:
out_cuda

tensor([[ 0.0353,  0.0048,  0.0091,  ...,  0.0575,  0.0452, -0.0060],
        [ 0.0990,  0.1063,  0.0881,  ..., -0.0040,  0.0010, -0.0073],
        [ 0.1247,  0.0288,  0.0212,  ...,  0.0492, -0.0135,  0.0300],
        ...,
        [ 0.0707,  0.1817,  0.4029,  ...,  0.1425,  0.0968, -0.0146],
        [ 0.1039,  0.3051,  0.0421,  ...,  0.0953, -0.0294,  0.0814],
        [ 0.2821,  0.0105,  0.1127,  ...,  0.0288, -0.1635, -0.0489]],
       device='cuda:0')

In [29]:
correct_out = tp(in1,in2,weight)

In [30]:
correct_out[0]

tensor([ 0.0353,  0.0048,  0.0091,  ...,  0.0575,  0.0452, -0.0060])

In [57]:
index = 0
torch.isclose(correct_out[index], out_cuda.cpu()[index]).all()
for a,b in zip(correct_out[index], out_cuda.cpu()[index]):
    print(torch.isclose(a,b),a,b)

tensor(True) tensor(0.0028) tensor(0.0028)
tensor(True) tensor(0.0122) tensor(0.0122)
tensor(True) tensor(0.0805) tensor(0.0805)
tensor(True) tensor(0.1163) tensor(0.1163)
tensor(True) tensor(0.1643) tensor(0.1643)
tensor(True) tensor(0.0852) tensor(0.0852)
tensor(True) tensor(0.0435) tensor(0.0435)
tensor(True) tensor(0.1736) tensor(0.1736)
tensor(True) tensor(0.2035) tensor(0.2035)
tensor(True) tensor(0.3153) tensor(0.3153)
tensor(True) tensor(0.0228) tensor(0.0228)
tensor(True) tensor(0.2755) tensor(0.2755)
tensor(True) tensor(0.1566) tensor(0.1566)
tensor(True) tensor(0.0192) tensor(0.0192)
tensor(True) tensor(0.1669) tensor(0.1669)
tensor(True) tensor(0.0688) tensor(0.0688)
tensor(True) tensor(0.0983) tensor(0.0983)
tensor(True) tensor(0.1652) tensor(0.1652)
tensor(True) tensor(0.0401) tensor(0.0401)
tensor(True) tensor(0.0971) tensor(0.0971)
tensor(True) tensor(0.3147) tensor(0.3147)
tensor(True) tensor(0.2743) tensor(0.2743)
tensor(True) tensor(0.0619) tensor(0.0619)
tensor(True

In [21]:
a_u_cnt = []
b_u_cnt = []
fiber_cnt = []
i_cnt = []
step = 32
i = 0
fiber_cnt_pattern = {}
while(i< uvuv_i_out.dim):
    u_in1 = cg_dummy_kij[i:i+step].nonzero()[:,1].unique()
    u_in2 = cg_dummy_kij[i:i+step].nonzero()[:,2].unique()


    l_a_u_cnt = (cg_dummy_kij[i:i+step].nonzero()[:,1].unique().numel())
    l_b_u_cnt = (cg_dummy_kij[i:i+step].nonzero()[:,2].unique().numel())
    l_fiber_cnt = (cg_dummy_kij[i:i+step].nonzero().shape[0])

    pattern = per_out_fiber[i:i+step].tolist()

    # print([pattern[0]]*len(pattern))
    

    # print(list(itertools.chain.from_iterable(([pattern[0:3]]*len(pattern))))[0:len(pattern)])
    # print(list(itertools.chain.from_iterable(([pattern[0:5]]*len(pattern))))[0:len(pattern)])
    
    if ([pattern[0]]*len(pattern) == pattern):
        pattern_tuple = (1, pattern[0])
    elif (list(itertools.chain.from_iterable(([pattern[0:3]]*len(pattern))))[0:len(pattern)] == pattern):
        pattern_tuple = (3, pattern[0:3])
    elif (list(itertools.chain.from_iterable(([pattern[0:5]]*len(pattern))))[0:len(pattern)] == pattern):
        pattern_tuple = (5, pattern[0:5])
    elif (list(itertools.chain.from_iterable(([pattern[0:7]]*len(pattern))))[0:len(pattern)] == pattern):
        pattern_tuple = (7, pattern[0:7])
    else:
        print("bad")
    

    if str(pattern_tuple) not in fiber_cnt_pattern:
        fiber_cnt_pattern[str(pattern_tuple)] = []
    fiber_cnt_pattern[str(pattern_tuple)].append(i//32)
    # # plt.scatter(cg_dummy_kij[i:i+step].nonzero()[:,0], cg_dummy_kij[i:i+32].nonzero()[:,1])
    # # plt.show()
    # # plt.close()
    # if(l_a_u_cnt > 60):
    #     l_a_u_cnt = (cg_dummy_kij[i:i+step//2].nonzero()[:,1].unique().numel())
    #     l_b_u_cnt = (cg_dummy_kij[i:i+step//2].nonzero()[:,2].unique().numel())
    #     l_fiber_cnt = (cg_dummy_kij[i:i+step//2].nonzero().shape[0])
    #     if(l_a_u_cnt > 60):
    #         l_a_u_cnt = (cg_dummy_kij[i:i+step//4].nonzero()[:,1].unique().numel())
    #         l_b_u_cnt = (cg_dummy_kij[i:i+step//4].nonzero()[:,2].unique().numel())
    #         l_fiber_cnt = (cg_dummy_kij[i:i+step//4].nonzero().shape[0])

    #         a_u_cnt.append(l_a_u_cnt)
    #         b_u_cnt.append(l_b_u_cnt)
    #         fiber_cnt.append(l_fiber_cnt)
    #         i_cnt.append(step//4)
    #         i += step//4
    #     else:
    #         a_u_cnt.append(l_a_u_cnt)
    #         b_u_cnt.append(l_b_u_cnt)
    #         fiber_cnt.append(l_fiber_cnt)
    #         i_cnt.append(step//2)
    #         i += step//2
    # else:
    a_u_cnt.append(l_a_u_cnt)
    b_u_cnt.append(l_b_u_cnt)
    fiber_cnt.append(l_fiber_cnt)
    i_cnt.append(step)
    i += step

In [22]:
smem_size = []
for a, b, f, i in zip(a_u_cnt, b_u_cnt, fiber_cnt, i_cnt):
    smem_size.append((32 * (a+b+i+i) +f) * 4 )
    print(a,b,f,i)

32 1 32 32
96 3 96 32
160 5 160 32
32 1 32 32
160 5 160 32
96 3 96 32
11 3 32 32
12 3 32 32
11 3 32 32
33 3 64 32
34 3 64 32
33 3 64 32
55 3 117 32
58 3 118 32
55 3 117 32
32 1 32 32
32 1 32 32
32 1 32 32
33 5 117 32
36 5 118 32
33 5 117 32
55 5 170 32
60 5 172 32
55 5 170 32
11 3 32 32
12 3 32 32
11 3 32 32
32 1 32 32
32 1 32 32
32 1 32 32
33 5 117 32
36 5 118 32
33 5 117 32
55 5 170 32
60 5 172 32
55 5 170 32
33 3 64 32
34 3 64 32
33 3 64 32
55 3 117 32
58 3 118 32
55 3 117 32
7 5 32 32
7 5 32 32
8 5 32 32
7 5 32 32
7 5 32 32
21 3 70 32
21 3 71 32
22 3 70 32
21 3 71 32
21 3 70 32
35 3 103 32
35 3 102 32
36 3 102 32
35 3 102 32
35 3 103 32
21 5 103 32
21 5 102 32
24 5 102 32
21 5 102 32
21 5 103 32
32 1 32 32
32 1 32 32
32 1 32 32
32 1 32 32
32 1 32 32
35 5 160 32
35 5 161 32
38 5 158 32
35 5 161 32
35 5 160 32
7 5 32 32
7 5 32 32
8 5 32 32
7 5 32 32
7 5 32 32
21 5 103 32
21 5 102 32
24 5 102 32
21 5 102 32
21 5 103 32
32 1 32 32
32 1 32 32
32 1 32 32
32 1 32 32
32 1 32 32
35 5 160 32

In [23]:
max(smem_size)

29952