## Implement Broadcasting

1. flat index -> multidimensional index
2. set strides of broadcasting dimensions to 0
2. multidimensional index -> flat index

In [145]:
import numpy as np
from functools import reduce
from operator import mul

In [146]:
numbElements = lambda shape: reduce(mul, shape)

In [147]:
def getStrides(shape: tuple) -> tuple:
    #build strides 
    strides = [1]
    for i in reversed(shape[1:]):
       strides.append(strides[-1] * i)
    strides.reverse()
    return tuple(strides)

getStrides((2, 3, 4))

(12, 4, 1)

In [148]:
def multiToFlat(shape: tuple, idxs: tuple, strides: tuple) -> int:
    # check if shape is compatible to array
    assert len(idxs) == len(shape)
    assert len(strides) == len(shape)

    for i in range(len(shape)):
        # A stride of 0 means broadcasting that dimension
        # The index can be larger than the shape's dimension size (which will be 1).
        # Only check bounds for non-broadcasted dimensions:
        if strides[i] != 0:
            assert idxs[i] < shape[i], f"Index {idxs[i]} out of bounds for shape {shape[i]} at dimension {i}"
        assert idxs[i] >= 0, "No negative index allowed"

    return np.dot(strides, idxs).item()

In [149]:
multiToFlat((2, 2, 3), (0, 0, 1), getStrides((2, 2, 3)))

1

In [150]:
a = (1, 2, 3)
b = (2, 3, 0)
np.dot(a, b)

np.int64(8)

In [151]:
def flatToMulti(index: int, shape: tuple) -> tuple:
    assert index < reduce(lambda x, y: x*y, shape), "Index out of bounds"
    assert index >= 0, "No negative index allowed"
    
    strides = getStrides(shape)

    res = []
    for i in strides:
        res.append(index // i)
        index %= i
    return tuple(res)

In [152]:
flatToMulti(0, (2, 2, 3))

(0, 0, 0)

In [153]:
def outputShape(a: tuple, b: tuple) -> tuple:
    smaller, larger = (b, a) if len(a) > len(b) else (a, b)
    smaller = (1,) * abs(len(a)-len(b)) + smaller
    res = []
    for x, y in zip(smaller, larger):
        if x == y or x == 1 or y == 1:
            res += [max(x,y)]
        else:
            raise RuntimeError("shapes are not compatible.")
    return tuple(res)

outputShape((3, 2), (1, 2, 3, 1))

(1, 2, 3, 2)

In [154]:
def AddOperator(a: np.ndarray, b: np.ndarray):
    shapeA = a.shape
    shapeB = b.shape
    a = a.flatten()
    b = b.flatten()

    outShape = outputShape(shapeA, shapeB)
    print("Output shape:", outShape)
    out = np.zeros(outShape).flatten()

    out_ndim = len(outShape)

    padded_shapeA = (1,) * (out_ndim - len(shapeA)) + shapeA
    padded_shapeB = (1,) * (out_ndim - len(shapeB)) + shapeB

    stridesA = getStrides(padded_shapeA) 
    stridesB = getStrides(padded_shapeB)

    bcast_stridesA = tuple((0 if x == 1 else y for x, y in zip(padded_shapeA, stridesA)))
    bcast_stridesB = tuple((0 if x == 1 else y for x, y in zip(padded_shapeB, stridesB)))

    for i in range(numbElements(outShape)):
        multi_idx = flatToMulti(i, outShape)
        idxA = multiToFlat(padded_shapeA, multi_idx, bcast_stridesA)
        idxB = multiToFlat(padded_shapeB, multi_idx, bcast_stridesB)
        out[i] = a[idxA] + b[idxB]

    
    return out.reshape(outShape)
        
a = np.linspace(1, 10, 10).reshape(2, 1, 5)
b= np.linspace(1, 5, 5).reshape(1, 5)
AddOperator(a, b) 

Output shape: (2, 1, 5)


array([[[ 2.,  4.,  6.,  8., 10.]],

       [[ 7.,  9., 11., 13., 15.]]])

In [155]:
# broadcasting matmul
# 1. split in batch_dims and matrix_dims (m, n)
# 2. check matrix multiplication compatibility (k, n) x (n, m) -> (k, m)
# 3. check batch_dims with broadcasting rules
# 4. output shape is (broadcasted_batch_dims, k, m)
# 5. perform matrix multiplication with the broadcasted batch dimensions
a =np.ones((1, 3, 2))
b =np.ones((2, 4, 2, 3))
out = a @ b
out.shape

(2, 4, 3, 3)

In [156]:
def getMatmulOutputShape(shapeA: tuple, shapeB: tuple) -> tuple:
    # 1. split in batch_dims and matrix_dims (m, n)
    assert len(shapeA) >= 2 and len(shapeB) >= 2
    matrix_dims_A = shapeA[-2:]
    matrix_dims_B = shapeB[-2:]
    batch_dims_A = shapeA[:-2]
    batch_dims_B = shapeB[:-2]
    # 2. check matrix multiplication compatibility (k, n) x (n, m) -> (k, m)
    assert matrix_dims_A[1] == matrix_dims_B[0], f"matrix-multiplication shape mismatch: {matrix_dims_A} x {matrix_dims_B}"
    # 3. check batch_dims with broadcasting rules
    batch_dims_brcast = outputShape(batch_dims_A, batch_dims_B)
    # 4. output shape is (broadcasted_batch_dims, k, m)
    out_shape = batch_dims_brcast + (matrix_dims_A[0], matrix_dims_B[1])
    return out_shape

getMatmulOutputShape((1, 3, 2), (2, 4, 2, 3))

(2, 4, 3, 3)

In [157]:
a_shape = (4,5)
b_shape = (5,3)
a = np.ones(a_shape)
b = np.ones(b_shape)
print(a)
print(b)
a @ b

[[1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]]
[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]


array([[5., 5., 5.],
       [5., 5., 5.],
       [5., 5., 5.],
       [5., 5., 5.]])

In [158]:
# naive matmul
a_shape = (4,5)
b_shape = (5,3)
a = np.ones(a_shape).flatten()
b = np.ones(b_shape).flatten()
out_shape = getMatmulOutputShape(a_shape, b_shape)
out = np.zeros(numbElements(out_shape)).flatten()

# dot product first row of A with first col of B
for row in range(a_shape[1]):
    a_idx = multiToFlat(a_shape, (0, row), getStrides(a_shape)) 
    b_idx = multiToFlat(b_shape, (row, 0), getStrides(b_shape)) 
    out[0] += a[a_idx] * b[b_idx]

# dot product first row of A with second col of B
for row in range(a_shape[1]):
    a_idx = multiToFlat(a_shape, (0, row), getStrides(a_shape)) 
    b_idx = multiToFlat(b_shape, (row, 1), getStrides(b_shape)) 
    out[1] += a[a_idx] * b[b_idx]

# dot product first row of A with third col of B
for row in range(a_shape[1]):
    a_idx = multiToFlat(a_shape, (0, row), getStrides(a_shape)) 
    b_idx = multiToFlat(b_shape, (row, 2), getStrides(b_shape)) 
    out[2] += a[a_idx] * b[b_idx]

# dot product second row of A with first col of B
for row in range(a_shape[1]):
    a_idx = multiToFlat(a_shape, (1, row), getStrides(a_shape)) 
    b_idx = multiToFlat(b_shape, (row, 0), getStrides(b_shape)) 
    out[3] += a[a_idx] * b[b_idx]

# ...

out

array([5., 5., 5., 5., 0., 0., 0., 0., 0., 0., 0., 0.])

In [207]:
# naive matmul
a = np.ones((4, 3))
b = np.ones((3, 3))

def naiveMatmul(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    a_shape = a.shape
    b_shape = b.shape
    out_shape = getMatmulOutputShape(a_shape, b_shape)
    a = a.flatten()
    b = b.flatten()
    out = np.zeros(numbElements(out_shape))

    for i in range(a_shape[0]):
        for j in range(b_shape[1]):
            for k in range(a_shape[1]):
                a_idx = multiToFlat(a_shape, (i, k), getStrides(a_shape)) 
                b_idx = multiToFlat(b_shape, (k, j), getStrides(b_shape))
                out_idx = multiToFlat(out_shape, (i, j), getStrides(out_shape))
                out[out_idx] += a[a_idx] * b[b_idx]
    
    return out.reshape(out_shape)

naiveMatmul(a, b)

array([[3., 3., 3.],
       [3., 3., 3.],
       [3., 3., 3.],
       [3., 3., 3.]])

In [213]:
def naiveMatmul(a: np.ndarray, b: np.ndarray, out: np.ndarray):
    a_shape = a.shape
    b_shape = b.shape
    out_shape = getMatmulOutputShape(a_shape, b_shape)
    a = a.flatten()
    b = b.flatten()

    for i in range(a_shape[0]):
        for j in range(b_shape[1]):
            for k in range(a_shape[1]):
                a_idx = multiToFlat(a_shape, (i, k), getStrides(a_shape)) 
                b_idx = multiToFlat(b_shape, (k, j), getStrides(b_shape))
                out_idx = multiToFlat(out_shape, (i, j), getStrides(out_shape))
                out[out_idx] += a[a_idx] * b[b_idx]

In [209]:
def getTensor(shape: tuple, idx: tuple, tensor_strides: tuple) -> tuple[int, tuple]: 
   assert len(idx) <= len(shape), f"Too many indices for tensor with shape {shape}"
   assert len(tensor_strides) == len(shape)
   for i in range(len(idx)):
      if tensor_strides[i] != 0:  # Broadcasting is enabled for that dimension
         assert idx[i] < shape[i], f"Index {idx[i]} is out of bounds for axis {i} with size {shape[i]}"
      assert 0 <= idx[i]

   # calculate offset
   idx_padd = idx + (0, ) * (len(shape) - len(idx))
   offset = np.dot(idx_padd, tensor_strides) 
   # new shape
   new_shape = shape[len(idx):]

   return (offset.item(), new_shape)

 
t = np.zeros((3, 2, 2, 3))
getTensor(t.shape, (2,0, 0), getStrides(t.shape))

(24, (3,))

In [245]:
def MatmulOperator(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    out_shape = getMatmulOutputShape(a.shape,  b.shape)
    # naive batch matmul
    out = np.zeros(numbElements(out_shape))
    batch_dims = out_shape[:-2]

    padded_shapeA = (1,) * (len(out_shape) - len(a.shape)) + a.shape
    padded_shapeB = (1,) * (len(out_shape) - len(b.shape)) + b.shape

    stridesA = getStrides(padded_shapeA) 
    stridesB = getStrides(padded_shapeB)

    bcast_stridesA = tuple((0 if x == 1 else y for x, y in zip(padded_shapeA, stridesA)))
    bcast_stridesB = tuple((0 if x == 1 else y for x, y in zip(padded_shapeB, stridesB)))

    a = a.flatten()
    b = b.flatten()

    stridesOut = getStrides(out_shape)
    # flat batch iteration
    for i in range(numbElements(batch_dims)):
        batch_idx = flatToMulti(i, batch_dims)

        matrixA_offset, matrixA_shape = getTensor(padded_shapeA, batch_idx, bcast_stridesA)
        matrixB_offset, matrixB_shape = getTensor(padded_shapeB, batch_idx, bcast_stridesB)

        matA = a[matrixA_offset : matrixA_offset + numbElements(matrixA_shape)]
        matB = b[matrixB_offset : matrixB_offset + numbElements(matrixB_shape)]
        
        out_ref_offset, out_ref_shape = getTensor(out_shape, batch_idx, stridesOut)
        out_ref = out[out_ref_offset : out_ref_offset + numbElements(out_ref_shape)]
        
        naiveMatmul(matA.reshape(matrixA_shape), matB.reshape(matrixB_shape), out_ref) # only reshape to conveniently pass the shape
    
    return out.reshape(out_shape)        

shapeA = (2, 4, 3)
shapeB = (3, 2)
A = np.arange(1, np.prod(shapeA) + 1).reshape(shapeA) 
B = np.arange(1, np.prod(shapeB) + 1).reshape(shapeB) 
MatmulOperator(A, B)

array([[[ 22.,  28.],
        [ 49.,  64.],
        [ 76., 100.],
        [103., 136.]],

       [[130., 172.],
        [157., 208.],
        [184., 244.],
        [211., 280.]]])

In [247]:
import time

shapeA = (1, 10, 200)
shapeB = (5, 1, 200, 100)
A = np.arange(1, np.prod(shapeA) + 1).reshape(shapeA) 
B = np.arange(1, np.prod(shapeB) + 1).reshape(shapeB) 

start = time.perf_counter()
MatmulOperator(A, B)
end = time.perf_counter()
print(f"Own implementation: {end - start:.6f} seconds")

start = time.perf_counter()
res = A @ B
end = time.perf_counter()
print(f"Numpy implementation: {end - start:.6f} seconds")

Own implementation: 7.035489 seconds
Numpy implementation: 0.000707 seconds
