In [51]:
import numpy as np

In [52]:
kernel_size = 4
n = 1
k = 2

In [53]:
rows = cols = kernel_size * n + k

In [54]:
A = np.random.rand(rows, cols)
B = np.random.rand(rows, cols)
Cstar = A @ B

In [55]:
first4nrows = np.zeros((4*n, cols))

In [56]:
lastkrrows = np.zeros((k, cols))

In [57]:
# `a` is first 4n rows and 4n columns of A (upper left corner)
a = A[:4*n, :4*n]
# `b` is first 4n rows and last k columns of A (upper right corner)
b = A[:4*n, -k:]
# `c` is last k rows and 4n columns of A (lower left corner)
c = A[-k:, :4*n]
# `d` is last k rows and last k columns of A (lower right corner)
d = A[-k:, -k:]

In [58]:
# `a_prime` is first 4n rows and 4n columns of B (upper left corner)
a_prime = B[:4*n, :4*n]
# `b_prime` is first 4n rows and last k columns of B (upper right corner)
b_prime = B[:4*n, -k:]
# `c_prime` is last k rows and 4n columns of B (lower left corner)
c_prime = B[-k:, :4*n]
# `d_prime` is last k rows and last k columns of B (lower right corner)
d_prime = B[-k:, -k:]

In [59]:
# there are two contributions to the `first4nrows` matrix
# one is [a] * [a_prime][b_prime] where [a] is 4n x 4n and [a_prime][b_prime] is 4n x (4n + k)
# other is [b] * [c_prime][d_prime] where [b] is 4n x k and [c_prime][d_prime] is k x (4n + k)
alpha = a @ np.concatenate((a_prime, b_prime), axis=1)
first4nrows += alpha
beta = b @ np.concatenate((c_prime, d_prime), axis=1)
first4nrows += beta

In [60]:
# there are also two contributions to the `lastkrrows` matrix
# one is [c] * [a_prime][b_prime] where [c] is k x 4n and [a_prime][b_prime] is 4n x (4n + k)
# other is [d] * [c_prime][d_prime] where [d] is k x k and [c_prime][d_prime] is k x (4n + k)
gamma = c @ np.concatenate((a_prime, b_prime), axis=1)
lastkrrows += gamma
delta = d @ np.concatenate((c_prime, d_prime), axis=1)
lastkrrows += delta

In [61]:
combined = np.concatenate((first4nrows, lastkrrows), axis=0)

In [62]:
assert np.allclose(Cstar, combined)

In [63]:
# alpha it self can be split into two parts
alpha_first4n = np.zeros((4*n, 4*n))
alpha_lastk = np.zeros((4*n, k))

In [64]:
alpha_first4n = a @ a_prime
alpha_lastk = a @ b_prime
alpha_combined = np.concatenate((alpha_first4n, alpha_lastk), axis=1)
assert np.allclose(alpha, alpha_combined)

In [65]:
# ear 4x4_4x4n

In [66]:
n = 2
a_4x4 = np.asarray(
    np.arange(0, 16).reshape(4, 4)
)
b_4x4n = np.asarray(
    np.arange(0, 4 * 4 * n).reshape(4, 4 * n)
) * -1

In [67]:
a_4x4, b_4x4n

(array([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]]),
 array([[  0,  -1,  -2,  -3,  -4,  -5,  -6,  -7],
        [ -8,  -9, -10, -11, -12, -13, -14, -15],
        [-16, -17, -18, -19, -20, -21, -22, -23],
        [-24, -25, -26, -27, -28, -29, -30, -31]]))

In [68]:
b_4x4n_1 = b_4x4n[:, :n]
b_4x4n_2 = b_4x4n[:, n:]

In [69]:
expected = a_4x4 @ b_4x4n

In [70]:
part1 = a_4x4 @ b_4x4n_1
part2 = a_4x4 @ b_4x4n_2
combined = np.concatenate((part1, part2), axis=1)

In [71]:
assert np.allclose(expected, combined)

In [72]:
expected

array([[ -112,  -118,  -124,  -130,  -136,  -142,  -148,  -154],
       [ -304,  -326,  -348,  -370,  -392,  -414,  -436,  -458],
       [ -496,  -534,  -572,  -610,  -648,  -686,  -724,  -762],
       [ -688,  -742,  -796,  -850,  -904,  -958, -1012, -1066]])

In [73]:
# field 4x4n_4nx4n

In [76]:
n = 2
a_4x4n = np.asarray(
    np.arange(0, 4 * 4 * n).reshape(4, 4 * n)
)
b_4nx4n = np.asarray(
    np.arange(0, 4 * 4 * n * n).reshape(4 * n, 4 * n)
) * -1

In [77]:
expected = a_4x4n @ b_4nx4n

In [78]:
expected

array([[-1120, -1148, -1176, -1204, -1232, -1260, -1288, -1316],
       [-2912, -3004, -3096, -3188, -3280, -3372, -3464, -3556],
       [-4704, -4860, -5016, -5172, -5328, -5484, -5640, -5796],
       [-6496, -6716, -6936, -7156, -7376, -7596, -7816, -8036]])

In [81]:
# farm 4nx4n_4nx4n

In [82]:
n = 2
a_4nx4n = np.asarray(
    np.arange(0, 4 * n * 4 * n).reshape(4 * n, 4 * n)
)
b_4nx4n = np.asarray(
    np.arange(0, 4 * n * 4 * n).reshape(4 * n, 4 * n)
) * -1

In [83]:
expected = a_4nx4n @ b_4nx4n

In [84]:
expected

array([[ -1120,  -1148,  -1176,  -1204,  -1232,  -1260,  -1288,  -1316],
       [ -2912,  -3004,  -3096,  -3188,  -3280,  -3372,  -3464,  -3556],
       [ -4704,  -4860,  -5016,  -5172,  -5328,  -5484,  -5640,  -5796],
       [ -6496,  -6716,  -6936,  -7156,  -7376,  -7596,  -7816,  -8036],
       [ -8288,  -8572,  -8856,  -9140,  -9424,  -9708,  -9992, -10276],
       [-10080, -10428, -10776, -11124, -11472, -11820, -12168, -12516],
       [-11872, -12284, -12696, -13108, -13520, -13932, -14344, -14756],
       [-13664, -14140, -14616, -15092, -15568, -16044, -16520, -16996]])