In [33]:
from typing import List
import numpy as np
import math

In [34]:
def transpose(N, X):
    out = [0] * (N**2)
    for i in range(N):
        for j in range(N):
            original = i*N + j
            tranposed = j*N + i
            out[tranposed] = X[original]
    return out

In [35]:
def panel_panel_dgemm(
    rows: int,
    cols: int,
    panel_a_rows: int,
    panel_a_cols: int,
    panel_b_rows: int,
    panel_b_cols: int,
    start_row_a: int,
    start_col_a: int,
    start_row_b: int,
    start_col_b: int,
    write_row: int,
    write_col: int,
    A: List[float],
    B: List[float],
    C: List[float],
) -> None:
    """
    We perform panel-panel multiplication using the values in A and B and write the result to C.

    The matrices are assumed to be stored in row-major order in 1D arrays

    `start_row_a` tells us which row of A the panel starts at if we were to view A as a 2D matrix.
    `start_col_a` tells us which column of A the panel starts at if we were to view A as a 2D matrix.
    `panel_a_rows` tells us how many rows the panel has if we were to view A as a 2D matrix.
    `panel_a_cols` tells us how many columns the panel has if we were to view A as a 2D matrix.

    Similarly, `start_row_b`, `start_col_b`, `panel_b_rows`, and `panel_b_cols` tell us the same
    information about B.

    `write_row` tells us which row of C we should start writing to if we were to view C as a 2D matrix.
    `write_col` tells us which column of C we should start writing to if we were to view C as a 2D matrix.
    """
    for k in range(panel_a_rows):
        a_row = start_row_a + k
        for i in range(panel_a_cols):
            a_col = start_col_a + i
            b_row = start_row_b + i
            for j in range(panel_b_cols):
                b_col = start_col_b + j
                a_flat = a_row * rows + a_col
                b_flat = b_row * rows + b_col
                c_flat = (write_row + k) * rows + j + write_col
                C[c_flat] += A[a_flat] * B[b_flat]

def panel_panel_dgemm_bt(
    rows: int,
    cols: int,
    panel_a_rows: int,
    panel_a_cols: int,
    panel_b_rows: int,
    panel_b_cols: int,
    start_row_a: int,
    start_col_a: int,
    start_row_b: int,
    start_col_b: int,
    write_row: int,
    write_col: int,
    A: List[float],
    B: List[float],
    C: List[float],
) -> None:
    """
    We perform panel-panel multiplication using the values in A and B and write the result to C

    The matrices are assumed to be stored in row-major order in 1D arrays with the exception of B

    `start_row_a` tells us which row of A the panel starts at if we were to view A as a 2D matrix.
    `start_col_a` tells us which column of A the panel starts at if we were to view A as a 2D matrix.
    `panel_a_rows` tells us how many rows the panel has if we were to view A as a 2D matrix.
    `panel_a_cols` tells us how many columns the panel has if we were to view A as a 2D matrix.

    Similarly, `start_row_b`, `start_col_b`, `panel_b_rows`, and `panel_b_cols` tell us the same
    information about B.

    `write_row` tells us which row of C we should start writing to if we were to view C as a 2D matrix.
    `write_col` tells us which column of C we should start writing to if we were to view C as a 2D matrix.
    """
    for k in range(panel_a_rows):
        a_row = start_row_a + k
        for j in range(panel_b_cols):
            b_col = start_col_b + j
            for i in range(panel_a_cols):
                a_col = start_col_a + i
                b_row = start_row_b + i
                a_flat = a_row * rows + a_col
                b_flat = b_col * rows + b_row
                c_flat = (write_row + k) * rows + j + write_col
                C[c_flat] += A[a_flat] * B[b_flat]

In [36]:
rows = 4
cols = 4
A = np.array([i for i in range(rows*cols)], dtype=np.float32)
B = A.copy() * -1

In [37]:
A_matrix = A.reshape((rows, cols))
B_matrix = B.reshape((rows, cols))

In [38]:
A_panel_1 = A_matrix[0:2, :]
B_panel_1 = B_matrix[:, 0:2]

In [39]:
A_panel_2 = A_matrix[2:4, :]
B_panel_2 = B_matrix[:, 2:4]

In [40]:
A_panel_1, B_panel_1, A_panel_2, B_panel_2

(array([[0., 1., 2., 3.],
        [4., 5., 6., 7.]], dtype=float32),
 array([[ -0.,  -1.],
        [ -4.,  -5.],
        [ -8.,  -9.],
        [-12., -13.]], dtype=float32),
 array([[ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]], dtype=float32),
 array([[ -2.,  -3.],
        [ -6.,  -7.],
        [-10., -11.],
        [-14., -15.]], dtype=float32))

In [41]:
upper_left = A_panel_1 @ B_panel_1
lower_left = A_panel_2 @ B_panel_1
upper_right = A_panel_1 @ B_panel_2
lower_right = A_panel_2 @ B_panel_2

In [42]:
combined = np.zeros((rows, cols), dtype=np.float32)
combined[0:2, 0:2] = upper_left
combined[2:4, 0:2] = lower_left
combined[0:2, 2:4] = upper_right
combined[2:4, 2:4] = lower_right

In [43]:
combined

array([[ -56.,  -62.,  -68.,  -74.],
       [-152., -174., -196., -218.],
       [-248., -286., -324., -362.],
       [-344., -398., -452., -506.]], dtype=float32)

In [44]:
flat_a = A.flatten().tolist()
flat_b = B.flatten().tolist()
flat_c = [0] * (rows * cols)

In [45]:
flat_b_t = transpose(rows, flat_b)

In [46]:
panel_a_cols = cols
panel_b_rows = rows
start_col_a = 0
start_row_b = 0
panel_block_size = 3
loop_cap = math.ceil(rows / panel_block_size) if panel_block_size < rows else 0 + 1 #remove the +1 when C++
print(loop_cap)
for i in range(loop_cap):
    start_row_a = i * panel_block_size
    panel_a_rows = min(panel_block_size, rows - start_row_a)
    for j in range(loop_cap):
        start_col_b = j * panel_block_size
        panel_b_cols = min(panel_block_size, cols - start_col_b)
        write_row = start_row_a
        write_col = start_col_b
        panel_panel_dgemm(
            rows, cols, panel_a_rows, panel_a_cols, panel_b_rows, panel_b_cols, start_row_a, start_col_a, start_row_b, start_col_b, write_row, write_col, flat_a, flat_b, flat_c
        )

2


In [47]:
matrix_c = np.array(flat_c).reshape((rows, cols))

In [48]:
combined, matrix_c

(array([[ -56.,  -62.,  -68.,  -74.],
        [-152., -174., -196., -218.],
        [-248., -286., -324., -362.],
        [-344., -398., -452., -506.]], dtype=float32),
 array([[ -56.,  -62.,  -68.,  -74.],
        [-152., -174., -196., -218.],
        [-248., -286., -324., -362.],
        [-344., -398., -452., -506.]]))