In [1]:
from typing import List
import numpy as np
import math
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def panel_panel_dgemm_a_t(
    rows: int,
    cols: int,
    panel_a_rows: int,
    panel_a_cols: int,
    panel_b_rows: int,
    panel_b_cols: int,
    start_row_a: int,
    start_col_a: int,
    start_row_b: int,
    start_col_b: int,
    write_row: int,
    write_col: int,
    A: List[float],
    B: List[float],
    C: List[float],
) -> None:
    """
    We perform panel-panel multiplication using the values in A and B and write the result to C

    The matrices are assumed to be stored in row-major order in 1D arrays with the exception of A

    `start_row_a` tells us which row of A the panel starts at if we were to view A as a 2D matrix.
    `start_col_a` tells us which column of A the panel starts at if we were to view A as a 2D matrix.
    `panel_a_rows` tells us how many rows the panel has if we were to view A as a 2D matrix.
    `panel_a_cols` tells us how many columns the panel has if we were to view A as a 2D matrix.

    Similarly, `start_row_b`, `start_col_b`, `panel_b_rows`, and `panel_b_cols` tell us the same
    information about B.

    `write_row` tells us which row of C we should start writing to if we were to view C as a 2D matrix.
    `write_col` tells us which column of C we should start writing to if we were to view C as a 2D matrix.
    """
    for i in range(panel_a_cols):
        a_c = start_col_a + i
        b_r = start_row_b + i
        for k in range(panel_a_rows):
            a_r = start_row_a + k
            for j in range(panel_b_cols):
                b_c = start_col_b + j
                a_flat = a_c * cols + a_r
                b_flat = b_r * rows + b_c
                out_flat = (write_row + k) * rows + j + write_col
                #print(f"writing to {out_flat} from a: {a_flat} and b: {b_flat}")
                #print(f"current value: {C[out_flat]}, product: {A[a_flat]} * {B[b_flat]} = {A[a_flat] * B[b_flat]}")
                C[out_flat] += A[a_flat] * B[b_flat]
                #print(f"new value: {C[out_flat]}")

In [3]:
def transpose(N, X):
    out = [0] * (N**2)
    for i in range(N):
        for j in range(N):
            original = i*N + j
            tranposed = j*N + i
            out[tranposed] = X[original]
    return out

In [4]:
def panel_panel_dgemm(
    rows: int,
    cols: int,
    panel_a_rows: int,
    panel_a_cols: int,
    panel_b_rows: int,
    panel_b_cols: int,
    start_row_a: int,
    start_col_a: int,
    start_row_b: int,
    start_col_b: int,
    write_row: int,
    write_col: int,
    A: List[float],
    B: List[float],
    C: List[float],
) -> None:
    """
    We perform panel-panel multiplication using the values in A and B and write the result to C.

    The matrices are assumed to be stored in row-major order in 1D arrays

    `start_row_a` tells us which row of A the panel starts at if we were to view A as a 2D matrix.
    `start_col_a` tells us which column of A the panel starts at if we were to view A as a 2D matrix.
    `panel_a_rows` tells us how many rows the panel has if we were to view A as a 2D matrix.
    `panel_a_cols` tells us how many columns the panel has if we were to view A as a 2D matrix.

    Similarly, `start_row_b`, `start_col_b`, `panel_b_rows`, and `panel_b_cols` tell us the same
    information about B.

    `write_row` tells us which row of C we should start writing to if we were to view C as a 2D matrix.
    `write_col` tells us which column of C we should start writing to if we were to view C as a 2D matrix.
    """
    for k in range(panel_a_rows):
        a_row = start_row_a + k
        for i in range(panel_a_cols):
            a_col = start_col_a + i
            b_row = start_row_b + i
            for j in range(panel_b_cols):
                b_col = start_col_b + j
                a_flat = a_row * rows + a_col
                b_flat = b_row * rows + b_col
                c_flat = (write_row + k) * rows + j + write_col
                C[c_flat] += A[a_flat] * B[b_flat]

def panel_panel_dgemm_bt(
    rows: int,
    cols: int,
    panel_a_rows: int,
    panel_a_cols: int,
    panel_b_rows: int,
    panel_b_cols: int,
    start_row_a: int,
    start_col_a: int,
    start_row_b: int,
    start_col_b: int,
    write_row: int,
    write_col: int,
    A: List[float],
    B: List[float],
    C: List[float],
) -> None:
    """
    We perform panel-panel multiplication using the values in A and B and write the result to C

    The matrices are assumed to be stored in row-major order in 1D arrays with the exception of B

    `start_row_a` tells us which row of A the panel starts at if we were to view A as a 2D matrix.
    `start_col_a` tells us which column of A the panel starts at if we were to view A as a 2D matrix.
    `panel_a_rows` tells us how many rows the panel has if we were to view A as a 2D matrix.
    `panel_a_cols` tells us how many columns the panel has if we were to view A as a 2D matrix.

    Similarly, `start_row_b`, `start_col_b`, `panel_b_rows`, and `panel_b_cols` tell us the same
    information about B.

    `write_row` tells us which row of C we should start writing to if we were to view C as a 2D matrix.
    `write_col` tells us which column of C we should start writing to if we were to view C as a 2D matrix.
    """
    for k in range(panel_a_rows):
        a_row = start_row_a + k
        for j in range(panel_b_cols):
            b_col = start_col_b + j
            for i in range(panel_a_cols):
                a_col = start_col_a + i
                b_row = start_row_b + i
                a_flat = a_row * rows + a_col
                b_flat = b_col * rows + b_row
                c_flat = (write_row + k) * rows + j + write_col
                C[c_flat] += A[a_flat] * B[b_flat]

In [5]:
def panel_panel_dgemm_recurse_a_t(
    rows: int,
    cols: int,
    panel_a_rows: int,
    panel_a_cols: int,
    panel_b_rows: int,
    panel_b_cols: int,
    start_row_a: int,
    start_col_a: int,
    start_row_b: int,
    start_col_b: int,
    write_row: int,
    write_col: int,
    A: List[float],
    B: List[float],
    C: List[float],
    block_size: int
) -> None:
    original_start_col_a = start_col_a
    loop_cap = math.ceil(cols / block_size) if block_size < cols else 0 + 1 # remove +1 for C++
    for i in range(loop_cap):
        start_col_a = original_start_col_a + i * block_size
        start_row_b = start_col_a
        panel_a_cols = min(block_size, cols - start_col_a)
        panel_b_rows = panel_a_cols
        #print(f"looking at start_row_a: {start_row_a}, start_col_a: {start_col_a}, start_row_b: {start_row_b}, start_col_b: {start_col_b}")
        #print(f"panel_a_rows: {panel_a_rows}, panel_a_cols: {panel_a_cols}, panel_b_rows: {panel_b_rows}, panel_b_cols: {panel_b_cols}")
        panel_panel_dgemm_a_t(
            rows,
            cols,
            panel_a_rows,
            panel_a_cols,
            panel_b_rows,
            panel_b_cols,
            start_row_a,
            start_col_a,
            start_row_b,
            start_col_b,
            write_row,
            write_col,
            A,
            B,
            C
        )

In [6]:
rows = 10
cols = 10
A = np.array([i for i in range(rows*cols)], dtype=np.float32)
B = A.copy() * -1

In [7]:
A_matrix = A.reshape((rows, cols))
B_matrix = B.reshape((rows, cols))

In [8]:
matrix_c = A_matrix @ B_matrix

In [9]:
flat_a = A.flatten().tolist()
flat_b = B.flatten().tolist()
flat_c = [0] * (rows * cols)

In [10]:
flat_b_t = transpose(rows, flat_b)
flat_a_t = transpose(cols, flat_a)

In [11]:
flat_c = [0] * (rows * cols)

In [12]:
def square_dgemm(N, A, B, C, panel_block_size, subpartition_block_size):
    panel_a_cols = N
    panel_b_rows = N
    start_col_a = 0
    start_row_b = 0
    AT = transpose(N, A)
    rows = cols = N
    loop_cap = math.ceil(N / panel_block_size) if panel_block_size < rows else 0 + 1 #remove the +1 when C++
    for i in range(loop_cap):
        start_row_a = i * panel_block_size
        panel_a_rows = min(panel_block_size, rows - start_row_a)
        for j in range(loop_cap):
            start_col_b = j * panel_block_size
            panel_b_cols = min(panel_block_size, cols - start_col_b)
            write_row = start_row_a
            write_col = start_col_b
            #print(f"outer loop has start_row_a: {start_row_a}, start_col_a: {start_col_a}, start_row_b: {start_row_b}, start_col_b: {start_col_b}")
            #print(f"outer loop has panel_a_rows: {panel_a_rows}, panel_a_cols: {panel_a_cols}, panel_b_rows: {panel_b_rows}, panel_b_cols: {panel_b_cols}")
            #print(f"writing to {write_row}, {write_col}")
            panel_panel_dgemm_recurse_a_t(
                rows, cols, panel_a_rows, panel_a_cols, panel_b_rows, panel_b_cols, start_row_a, start_col_a, start_row_b, start_col_b, write_row, write_col, AT, B, C, subpartition_block_size
            )

In [13]:
for n in tqdm(range(1, 100)):
    rows = cols = n
    A = np.array([i for i in range(rows*cols)], dtype=np.float32)
    B = A.copy() * -1
    A_matrix = A.reshape((rows, cols))
    B_matrix = B.reshape((rows, cols))
    matrix_c = A_matrix @ B_matrix
    flat_a = A.flatten().tolist()
    flat_b = B.flatten().tolist()
    for panel_block_size in range(1, 120):
        for subpartition_block_size in range(1, 120):
            flat_c = [0] * (rows * cols)
            square_dgemm(rows, flat_a, flat_b, flat_c, panel_block_size, subpartition_block_size)
            if not np.allclose(matrix_c, np.array(flat_c).reshape((rows, cols))):
                print(f"n: {n}, panel_block_size: {panel_block_size}, subpartition_block_size: {subpartition_block_size} failed")

 13%|█▎        | 13/99 [00:31<06:56,  4.84s/it]