In [1]:
# Import necessary Library
import numpy as np
import random
import os
import sys
from FPU import cmul, cadd
from utils import read_binary, binary_to_fp16, store_binary, generate_fp16

def generate_matrix(rows, cols, complex=False, seed=123):
    """
    Generate an (rows x cols) matrix using generate_fp16().
    """
    vec = generate_fp16(size=rows * cols, complex=complex, seed=seed)
    mat = []
    idx = 0
    for _ in range(rows):
        row = []
        for _ in range(cols):
            row.append(vec[idx])
            idx += 1
        mat.append(row)
    return mat

def matmul(A_mat, B_mat, use_complex=True):
    """
    Multiply A(MxK) by B(KxN) column-by-column.
    For each column of B, produce the corresponding column of C.
    Complex multiply/add uses cmul and cadd.

    Args:
        A_mat: list of list of elements (real or [re, im])
        B_mat: list of list of elements (real or [re, im])
        use_complex: True if A and B are complex-valued

    Returns:
        C_mat: list of list, same format as A/B (real or complex),
               with shape MxN
    """
    M = len(A_mat)
    K = len(A_mat[0])
    N = len(B_mat[0])

    print(M, K, N)

    C_mat = []
    for i in range(M):
        # initialize each row as empty list
        C_mat.append([])

    for col in range(N):
        # process one column of B at a time
        for i in range(M):
            if use_complex:
                acc = [np.float16(0.0), np.float16(0.0)]
                for k in range(K):
                    prod = cmul(A_mat[i][k], B_mat[k][col])
                    acc = cadd(acc, prod)
                C_mat[i].append([np.float16(acc[0]), np.float16(acc[1])])
            else:
                acc = np.float16(0.0)
                for k in range(K):
                    acc = np.float16(acc + np.float16(A_mat[i][k] * B_mat[k][col]))
                C_mat[i].append(np.float16(acc))
    return C_mat

def flatten_matrix_row_major(mat, complex=False):
    """
    Row-major flatten:
        index(A[i][j]) = i*cols + j
    For complex=True, each element is [re, im] (np.float16).
    """
    flat = []
    for row in mat:
        for x in row:
            flat.append(x)
    return flat

def flatten_matrix_col_major(mat, complex=False):
    """
    Column-major flatten:
        index(A[i][j]) = j*rows + i
    For complex=True, each element is [re, im] (np.float16).
    """
    flat = []
    rows = len(mat)
    cols = len(mat[0])

    for c in range(cols):
        for r in range(rows):
            flat.append(mat[r][c])
    return flat

## Vector & Two column Multiplication in the single PE

In [None]:
M = 1   # Rows in A
K = 128   # Cols in A / Rows in B
N = 2   # Cols in B
use_complex = True  # Set False for real-valued matmul

A_mat = generate_matrix(M, K, complex=use_complex, seed=101)
B_mat = generate_matrix(K, N, complex=use_complex, seed=202)

C_mat = matmul(A_mat, B_mat, use_complex=use_complex)

A_flat_row = flatten_matrix_row_major(A_mat, complex=use_complex)
B_flat_col = flatten_matrix_col_major(B_mat, complex=use_complex)
C_flat_col = flatten_matrix_row_major(C_mat, complex=use_complex)

1 128 2


In [3]:
store_binary(A_flat_row, "../PE/DATA/mm_pg_input.txt", zero_padding=1)
store_binary(B_flat_col, "../PE/DATA/mm_pg_weight.txt")
store_binary(C_flat_col, "../PE/DATA/mm_pg_output.txt")

In [4]:
B_flat_col

[[1.292, 0.9346],
 [1.03, 1.908],
 [-0.8916, 0.5835],
 [0.834, 1.567],
 [-0.01967, -0.6147],
 [1.124, -0.8896],
 [0.4148, 1.585],
 [1.694, 1.607],
 [-0.2179, -0.4697],
 [-0.6177, -0.3381],
 [1.95, -0.07556],
 [-0.12384, 1.597],
 [0.0485, -0.8276],
 [0.5635, 0.7183],
 [0.667, 0.04675],
 [0.7617, -0.9546],
 [0.1698, 1.938],
 [-0.4934, 1.592],
 [-0.09314, 0.962],
 [-0.4446, 0.6714],
 [-0.9536, 0.0409],
 [0.8125, -0.921],
 [-0.53, -0.9614],
 [1.88, 0.489],
 [1.044, 1.378],
 [-0.891, 0.6206],
 [1.355, -0.2974],
 [0.8105, -0.7817],
 [-0.2974, 0.6616],
 [1.823, 0.941],
 [1.081, 0.527],
 [0.693, 0.4626],
 [-0.4285, 0.8076],
 [0.3474, 0.1536],
 [1.885, 0.906],
 [1.522, 1.853],
 [1.42, 0.102],
 [-0.8438, 0.7744],
 [0.01967, 1.932],
 [-0.3171, 1.802],
 [-0.4187, -0.4055],
 [0.3289, 1.104],
 [-0.1334, -0.3943],
 [1.999, 0.1884],
 [1.299, 0.2161],
 [-0.852, 0.716],
 [-0.6226, -0.1476],
 [1.082, -0.2179],
 [0.2375, 1.113],
 [1.225, -0.1432],
 [1.592, 1.13],
 [0.0335, 1.206],
 [1.148, 0.815],
 [0.869