In [1]:
import numpy as np
np.random.seed(2)

In [2]:
class CustomMatmul:
    def __init__(self, num_multipliers):
        self.num_multipliers = num_multipliers

    def __call__(self, a, b):
        a_shape = a.shape
        b_shape = b.shape
        o_shape = [a_shape[0], b_shape[1]]
        o = np.zeros(o_shape)
        num_chunks = a_shape[1] // self.num_multipliers

        a_reshaped = np.reshape(a, [-1, num_chunks, self.num_multipliers])
        b_reshaped = np.reshape(b, [num_chunks, self.num_multipliers, -1])

        for i in range(num_chunks):
            o += np.matmul(a_reshaped[:, i], b_reshaped[i])

        return o

# test custom matmul

In [3]:
inputs = np.random.randn(16, 256)
weights = np.random.randn(256, 10)
matmul_custom = CustomMatmul(16)

o_ref = np.matmul(inputs, weights)
o_dut = matmul_custom(inputs, weights)

np.testing.assert_equal(o_dut, o_ref)