In [2]:
import torch

In [53]:
# done
def calculateBroadcastShape(shape1, shape2):
    out_dims = max(len(shape1), len(shape2))
    shape_f = [0] * out_dims

    for i in range(out_dims - 1, -1, -1):
        offset = out_dims - 1 - i
        dim1 = len(shape1) - 1 - offset
        dim2 = len(shape2) - 1 - offset

        size1 = shape1[dim1] if dim1 >= 0 else 1
        size2 = shape2[dim2] if dim2 >= 0 else 1

        assert size1 == 1 or size2 == 1 or size1 == size2
        shape_f[i] = size1 if size1 != 1 else size2
    return shape_f


# done
def calculateStrides(shape_f):
    outstrides = [-1] * len(shape_f)
    outstrides[-1] = 1  # last is one
    for i in range(len(shape_f) - 1, 0, -1):
        outstrides[i - 1] = outstrides[i] * shape_f[i]
    return outstrides


def threadToIndices(thread, outstrides):
    outArr = [-1] * len(outstrides)
    outArr[0] = int(thread / outstrides[0])
    for st in range(1, len(outstrides)):
        outArr[st] = int((thread % outstrides[st - 1]) / outstrides[st])
    return outArr


def indicesToThread(idx, outstrides):
    prod = 0
    for i, j in zip(idx, outstrides):
        prod += i * j
    return prod


def execute(shape1, shape2, arr_1, arr_2):
    f_shape = calculateBroadcastShape(shape1, shape2)
    f_strides = calculateStrides(f_shape)

    outsize = 1
    for i in f_shape:
        outsize *= i

    while len(shape2) < len(f_shape):
        shape2.insert(0, 1)
    while len(shape1) < len(f_shape):
        shape1.insert(0, 1)

    strides1 = calculateStrides(shape1)
    strides2 = calculateStrides(shape2)

    for i in range(len(f_shape)):
        if shape1[i] == 1:
            strides1[i] = 0
        if shape2[i] == 1:
            strides2[i] = 0

    out_arr = [0] * outsize
    for f_thread in range(outsize):

        idx = f_thread
        thread1 = 0
        thread2 = 0
        for i in range(len(f_shape)):
            this_idx = idx // f_strides[i]
            idx = idx % f_strides[i]
            thread1 += this_idx * strides1[i]
            thread2 += this_idx * strides2[i]

        out_arr[f_thread] = arr_1[thread1] + arr_2[thread2]

    return out_arr

In [None]:
print(execute([3, 1], [2], [1, 2, 3], [4, 5]))

In [None]:
import torch

shape1 = [3, 1]
shape2 = [2]
data1 = [1, 2, 3]
data2 = [4, 5]

tensor1 = torch.tensor(data1).reshape(shape1)
tensor2 = torch.tensor(data2).reshape(shape2)

print(tensor1)
print(tensor2)
print(tensor1 + tensor2)