In [3]:
import numpy as np
import tvm 
from tvm import te

def broadcast_add(shape1, shape2):
    assert len(shape1) == 2 and len(shape2) == 2, \
        "broadcast tensors should both be 2-dimsion"
    for i in range(len(shape1)):
        assert shape1[i] == shape2[i] or shape1[i] == 1 or shape2[i] == 1, \
            "tensor shapes do not fit for broadcasting"
    A = te.placeholder(shape1, name='A')
    B = te.placeholder(shape2, name='B')
    m = shape1[0] if shape2[0] == 1 else shape2[0]
    n = shape1[1] if shape2[1] == 1 else shape2[1]
    f = lambda x, y: A[0 if shape1[0]==1 else x, 0 if shape1[1]==1 else y] + \
        B[0 if shape2[0]==1 else x, 0 if shape2[1]==1 else y]
    C = te.compute((m, n), f, name='C')
    return A, B, C


In [9]:
m = 3
n = 4
shape1 = (m, 1)
shape2 = (m, n)
A, B, C = broadcast_add(shape1, shape2)
s = te.create_schedule(C.op)
#print(tvm.lower(s, [A, B], simple_mode=True))
mod = tvm.build(s, [A, B, C])

In [11]:
def get_bcast_data(shape1, shape2, constructor=None):
    np.random.seed(0)
    a = np.random.normal(size=shape1).astype('float32')
    b = np.random.normal(size=shape2).astype("float32")
    out_shape = (shape1[0] if shape2[0] == 1 else shape2[0],
                shape1[1] if shape2[1] == 1 else shape2[1])
    c = np.empty(out_shape, dtype='float32')
    if constructor:
        a,b,c = [constructor(x) for x in (a,b,c)]
    return a,b,c

a,b,c = get_bcast_data(shape1, shape2, tvm.nd.array)
mod(a,b,c)
np.testing.assert_allclose(np.add(a.asnumpy(), b.asnumpy()), c.asnumpy(), atol=1e-5)

In [12]:
shape3 = (m, 1)
shape4 = (1, n)
A1, B1, C1 = broadcast_add(shape3, shape4)
s = te.create_schedule(C1.op)
#print(tvm.lower(s, [A, B], simple_mode=True))
mod = tvm.build(s, [A1, B1, C1])
a1,b1,c1 = get_bcast_data(shape3, shape4, tvm.nd.array)
mod(a1,b1,c1)
np.testing.assert_allclose(np.add(a1.asnumpy(), b1.asnumpy()), c1.asnumpy(), atol=1e-5)
print(a1.shape, b1.shape, c1.shape)


(3, 1) (1, 4) (3, 4)
