# Systolic Array Design


In [1]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import pyrtl
from pyrtl import *

from hardware_accelerators import *
from hardware_accelerators.simulation import *
from hardware_accelerators.dtypes import BaseFloat

from IPython.display import *
import numpy as np
from dataclasses import dataclass
from typing import Callable, Type, Self, List, Type

## Configuration Class


In [6]:
from pydantic import BaseModel, Field, validator
from typing import Callable, Type, Annotated
from enum import Enum
from pyrtl import WireVector


class ProcessingElementType(Enum):
    STANDARD = "standard"
    SIMPLE = "simple"


class SystolicArrayConfig(BaseModel):
    # Array configuration
    array_size: Annotated[
        int, Field(gt=0, description="Size N of the NxN systolic array matrix")
    ]

    # Data types
    data_type: Annotated[
        Type[BaseFloat],
        Field(
            description="Floating point format for input data and weights (e.g. Float8, BF16)"
        ),
    ]

    accum_type: Annotated[
        Type[BaseFloat],
        Field(
            description="Floating point format for accumulation, typically wider than data_type"
        ),
    ]

    # Arithmetic operations
    multiplier: Annotated[
        Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector],
        Field(description="Floating point multiplier implementation function"),
    ]

    adder: Annotated[
        Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector],
        Field(description="Floating point adder implementation function"),
    ]

    # Processing element configuration
    pe_type: Annotated[
        ProcessingElementType,
        Field(
            default=ProcessingElementType.STANDARD,
            description="Type of processing element to use in the array",
        ),
    ]

    pipeline_mult: Annotated[
        bool,
        Field(
            default=False,
            description="Whether to add a pipeline register after multiplication",
        ),
    ]

    class Config:
        arbitrary_types_allowed = True
        validate_assignment = True

    @property
    def data_width(self) -> int:
        """Bit width of data/weight values"""
        return self.data_type.bitwidth()

    @property
    def accum_width(self) -> int:
        """Bit width of accumulator values"""
        return self.accum_type.bitwidth()

    @classmethod
    def standard_config(
        cls,
        array_size: int,
        data_type: Type[BaseFloat],
        accum_type: Type[BaseFloat],
        multiplier: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector],
        adder: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector],
        pipeline_mult: bool = False,
    ) -> "SystolicArrayConfig":
        """Create a standard configuration with the STANDARD processing element type

        Args:
            array_size: Size N of the NxN systolic array
            data_type: Floating point format for input data
            accum_type: Floating point format for accumulation
            multiplier: Multiplier implementation function
            adder: Adder implementation function
            pipeline_mult: Whether to pipeline multiplication

        Returns:
            Configured SystolicArrayConfig instance
        """
        return cls(
            array_size=array_size,
            data_type=data_type,
            accum_type=accum_type,
            multiplier=multiplier,
            adder=adder,
            pipeline_mult=pipeline_mult,
            pe_type=ProcessingElementType.STANDARD,
        )

    def model_post_init(self) -> None:
        """Additional validation after initialization"""
        if self.accum_width < self.data_width:
            raise ValueError(
                f"Accumulator width ({self.accum_width}) must be >= "
                f"data width ({self.data_width})"
            )


SystolicArrayConfig(
    array_size=4,
)

ValidationError: 4 validation errors for SystolicArrayConfig
data_type
  Field required [type=missing, input_value={'array_size': 4}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.10/v/missing
accum_type
  Field required [type=missing, input_value={'array_size': 4}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.10/v/missing
multiplier
  Field required [type=missing, input_value={'array_size': 4}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.10/v/missing
adder
  Field required [type=missing, input_value={'array_size': 4}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.10/v/missing

# Weight Stationary Dataflow

like TPU architecture


## Simple Processing Element


We will use a weight stationary dataflow in our systolic array. Since weights need to be updated less frequently than the inputs or activations flowing through them, this will reduce the amount of memory IO and therefore the power requirements.

To start, lets focus on building the most simple version of the processing element. Each element in the array will have the following inputs and outputs:

- activation value
- weight value
- partial sum

Activations will flow from left to right, weights will flow from top to bottom, and both are passed through unchanged.  
The accumulated value output will be the result of adding the accumulation input with the product of the current activation and weight.


In [3]:
@dataclass
class PEOutputs:
    """Container for PE outputs to make connections clear"""

    data: Register
    weight: Register
    accum: Register

In [4]:
class SimpleProcessingElement:
    def __init__(
        self,
        data_type: Type[BaseFloat],
        accum_type: Type[BaseFloat],
        multiplier: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector],
        adder: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector],
        *args,
        **kwargs
    ):
        """Initialize processing element hardware

        Args:
            data_type: Float type for data/weight (Float8, BF16 etc)
            accum_type: Float type for accumulation
            multiplier_type: Floating point multiplier implementation
            adder_type: Floating point adder implementation
            pipeline_mult: If True, register multiplication output before passing to accumulator
        """
        # Get bit widths from format specs
        data_width = data_type.bitwidth()
        accum_width = accum_type.bitwidth()

        # Input/output registers
        self.data_reg = Register(bitwidth=data_width)
        self.weight_reg = Register(bitwidth=data_width)
        self.accum_in = WireVector(bitwidth=accum_width)
        self.accum_reg = Register(bitwidth=accum_width)

        # Control signals
        self.weight_we = WireVector(bitwidth=1)  # Weight write enable

        # Multiply-accumulate logic
        product = multiplier(self.data_reg, self.weight_reg, data_type)

        # TODO: Add float type conversion logic to pass different bitwidths to the accumulator

        sum_result = adder(product, self.accum_in, accum_type)

        self.accum_reg.next <<= sum_result

        # Store registers in output container
        self.outputs = PEOutputs(
            data=self.data_reg, weight=self.weight_reg, accum=self.accum_reg
        )

    def connect_data(self, source: Self | WireVector):
        """Connect data input from source (PE or external input)"""
        if isinstance(source, SimpleProcessingElement):
            self.data_reg.next <<= source.outputs.data
        else:
            self.data_reg.next <<= source

    def connect_weight(self, source: Self | WireVector):
        """Connect weight input from source (PE or external input)"""
        if isinstance(source, SimpleProcessingElement):
            weight_in = source.outputs.weight
        else:
            weight_in = source

        # Conditional weight update based on enable signal
        with pyrtl.conditional_assignment:
            with self.weight_we:
                self.weight_reg.next |= weight_in

    def connect_accum(self, source: Self | WireVector):
        """Connect accumulator input from source (PE or external input)"""
        if isinstance(source, SimpleProcessingElement):
            self.accum_in <<= source.outputs.accum
        else:
            self.accum_in <<= source

    def connect_weight_enable(self, enable: WireVector):
        """Connect weight write enable signal"""
        self.weight_we <<= enable

## Processing Element with Control Signals
(currently in `rtllib`)


In [28]:
import warnings


class ProcessingElement:
    def __init__(
        self,
        data_type: Type[BaseFloat],
        accum_type: Type[BaseFloat],
        multiplier: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector],
        adder: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector],
        pipeline_mult: bool = False,
    ):
        """Initialize processing element hardware

        Args:
            data_type: Float type for data/weight (Float8, BF16 etc)
            accum_type: Float type for accumulation
            multiplier_type: Floating point multiplier implementation
            adder_type: Floating point adder implementation
            pipeline_mult: If True, register multiplication output before passing to accumulator
        """
        self.pipeline = pipeline_mult

        # Get bit widths from format specs
        data_width = data_type.bitwidth()
        acc_width = accum_type.bitwidth()

        # Input wires
        self.data_in = WireVector(data_width)
        self.weight_in = WireVector(data_width)
        self.accum_in = WireVector(acc_width)

        # Registers
        self.data_reg = Register(data_width)
        self.weight_reg = Register(data_width)
        self.accum_reg = Register(acc_width)

        # Control signals
        self.weight_en = WireVector(1)  # Weight write enable
        self.data_en = WireVector(1)  # Enable writing to the data input register
        self.adder_en = WireVector(1)  # Enable writing to the accumulator register

        # Multiply logic
        product = multiplier(self.data_reg, self.weight_reg, data_type)
        self.adder_input = product

        # TODO: Add float type conversion logic to pass different bitwidths to the accumulator

        # Optionally build a pipeline register to hold the multiplier result
        if self.pipeline:
            self.mul_en = WireVector(1)
            product_reg = Register(data_width)
            self.adder_input = product_reg
            with conditional_assignment:
                with self.mul_en:  # Enable writing to product register
                    product_reg.next |= product

        # Add the product and previous accumulation value to get partial sum
        sum_result = adder(self.adder_input, self.accum_in, accum_type)

        # Enable writing to data input register
        with conditional_assignment:
            with self.data_en:
                self.data_reg.next |= self.data_in
            with otherwise:
                self.data_reg.next |= 0

        # Enable writing to weight input register
        with conditional_assignment:
            with self.weight_en:
                self.weight_reg.next |= self.weight_in

        # Enable writing to accumulator register
        with conditional_assignment:
            with self.adder_en:
                self.accum_reg.next |= sum_result

        # Store registers in output container
        self.outputs = PEOutputs(
            data=self.data_reg, weight=self.weight_reg, accum=self.accum_reg
        )

    def connect_data(self, source: Self | WireVector):
        """Connect data input from source (PE or external input)"""
        if isinstance(source, ProcessingElement):
            self.data_in <<= source.outputs.data
        else:
            self.data_in <<= source

    def connect_weight(self, source: Self | WireVector):
        """Connect weight input from source (PE or external input)"""
        if isinstance(source, ProcessingElement):
            self.weight_in <<= source.outputs.weight
        else:
            self.weight_in <<= source

    def connect_accum(self, source: Self | WireVector):
        """Connect accumulator input from source (PE or external input)"""
        if isinstance(source, ProcessingElement):
            self.accum_in <<= source.outputs.accum
        else:
            self.accum_in <<= source

    def connect_weight_enable(self, enable: WireVector):
        """Connect weight write enable signal"""
        self.weight_en <<= enable

    def connect_data_enable(self, enable: WireVector):
        """Connect PE enable signal. Controls writing to the data input register"""
        self.data_en <<= enable

    def connect_mul_enable(self, enable: WireVector):
        """Connect multiplier enable signal. Controls writing to the product register"""
        if self.pipeline:
            self.mul_en <<= enable
        else:
            warnings.warn(
                "Pipelining is disabled. There is no product register to enable. Skipping."
            )

    def connect_adder_enable(self, enable: WireVector):
        """Connect adder enable signal. Controls writing to the accumulator register"""
        self.adder_en <<= enable

    def connect_control_signals(
        self,
        weight_en: WireVector | None = None,
        data_en: WireVector | None = None,
        mul_en: WireVector | None = None,
        adder_en: WireVector | None = None,
    ):
        """Connect control signals to the processing element

        Args:
            weight_en (WireVector): Weight write enable signal. Controls writing to the weight register
            data_en (WireVector): PE enable signal. Controls writing to the data input register
            mul_en (WireVector): Multiplier enable signal. Controls writing to the product register
            adder_en (WireVector): Adder enable signal. Controls writing to the accumulator register
        """
        if data_en is not None:
            self.connect_data_enable(data_en)
        if weight_en is not None:
            self.connect_weight_enable(weight_en)
        if mul_en is not None:
            self.connect_mul_enable(mul_en)
        if adder_en is not None:
            self.connect_adder_enable(adder_en)

        return self

### Test PE

In [30]:
reset_working_block()
set_debug_mode(False)


pe = ProcessingElement(BF16, BF16, lmul_fast, float_adder, pipeline_mult=False)


w_en, d_en, acc_en, mul_en = input_list(["w_en", "d_en", "acc_en", "mul_en"], 1)

data_in, weight_in, acc_in = input_list(["data_in", "weight_in", "acc_in"], 16)
data_reg, weight_reg, acc_reg = output_list(["data_reg", "weight_reg", "acc_reg"], 16)

pe.connect_control_signals(weight_en=w_en, data_en=d_en, adder_en=acc_en, mul_en=mul_en)
pe.connect_data(data_in)

pe.connect_weight(weight_in)
pe.connect_accum(acc_in)

product = probe(pe.adder_input)
product.name = "product"

data_reg <<= pe.outputs.data
weight_reg <<= pe.outputs.weight
acc_reg <<= pe.outputs.accum

wires = [
    w_en,
    d_en,
    mul_en,
    acc_en,
    data_in,
    weight_in,
    acc_in,
    product,
    data_reg,
    weight_reg,
    acc_reg,
]
trace = SimulationTrace(wires)

repr_map = {
    wire.name: lambda x: str(BF16(binint=x).decimal_approx)
    for wire in wires
    if len(wire) == 16
}

sim = Simulation(tracer=trace)


inputs = {
    "w_en": [1, 0, 0, 0, 0],
    "weight_in": [BF16(2).binint, 0, 0, 0, 0],
    "d_en": [1, 0, 0, 0, 0],
    "data_in": [BF16(2).binint, 0, 0, 0, 0],
    "mul_en": [0, 0, 0, 0, 0],
    "acc_en": [0, 1, 0, 0, 0],
    "acc_in": [0, BF16(2).binint, 0, 0, 0],
}


sim.step_multiple(inputs)
render_waveform(
    sim,
    repr_func=bin,
    repr_per_name=repr_map,
)



<IPython.core.display.Javascript object>

# Systolic Array Class


In [8]:
class SystolicArray:
    def __init__(
        self,
        size: int,
        data_type: Type[BaseFloat],
        accum_type: Type[BaseFloat],
        multiplier: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector],
        adder: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector],
        dip_dataflow: bool = True,
        pipeline_mult: bool = False,
    ):
        """Initialize systolic array hardware structure

        Args:
            size: N for NxN array
            data_type: Number format for inputs (Float8, BF16)
            accum_type: Number format for accumulation
            multiplier_type: Multiplier implementation to use
            adder_type: Adder implementation to use
            dip_dataflow: Ḏiagonal-I̱nput and P̱ermutated weight-stationary (DiP) dataflow
            pipeline_mult: Add pipeline register after multiplication
        """
        # Set configuration attributes
        self.size = size
        self.dip = dip_dataflow
        self.pipeline_mult = pipeline_mult
        self.data_type = data_type
        self.accum_type = accum_type
        data_width = data_type.bitwidth()
        accum_width = accum_type.bitwidth()

        # Input wires
        self.data_in = [WireVector(data_width) for _ in range(size)]
        self.weights_in = [WireVector(data_width) for _ in range(size)]
        self.results_out = [WireVector(accum_width) for _ in range(size)]

        # Control wires
        self.weight_enable = WireVector(1)  # Enable writing weights, shared by all PEs
        self.enable_in = WireVector(1)  # Enable writing to the data input register

        # Control signal registers to propogate signal down the array
        self.control_registers = [Register(1) for _ in range(size)]

        # If PEs contain an extra pipeline stage, 1 additional control reg is needed
        if self.pipeline_mult:
            self.control_registers.append(Register(1))

        # Create PE array
        self.pe_array = [
            [
                ProcessingElement(
                    data_type,
                    accum_type,
                    multiplier,
                    adder,
                    pipeline_mult,
                )
                for _ in range(size)
            ]
            for _ in range(size)
        ]

        # Connect PEs in systolic pattern based on dataflow type
        if self.dip:
            self._connect_dip_array()
        else:
            self._connect_ws_array()  # WS = Weight Stationary (original TPU-style)

    def _connect_dip_array(self):
        """Connect processing elements in DiP configuration
        - All data flows top to bottom diagonally shifted
        - Weights are loaded simultaneously across array
        - Data inputs arrive synchronously
        """
        # Connect control signal registers that propagate down the array
        self.control_registers[0].next <<= self.enable_in
        for i in range(1, len(self.control_registers)):
            self.control_registers[i].next <<= self.control_registers[i - 1]

        for row in range(self.size):
            for col in range(self.size):
                pe = self.pe_array[row][col]

                # Connect PE inputs:
                # First row gets external input, others connect to PE above
                if row == 0:
                    pe.connect_data_enable(self.enable_in)
                    pe.connect_data(self.data_in[col])
                    pe.connect_weight(self.weights_in[col])
                    pe.connect_accum(Const(0))

                # DiP config: PEs connected to previous row and data diagonally shifted by 1
                else:
                    pe.connect_data_enable(self.control_registers[row - 1])
                    pe.connect_data(self.pe_array[row - 1][col - self.size + 1])
                    pe.connect_weight(self.pe_array[row - 1][col])
                    pe.connect_accum(self.pe_array[row - 1][col])

                # Delay the control signal for accumulator by num pipeline stages
                if self.pipeline_mult:
                    pe.connect_mul_enable(self.control_registers[row])
                    pe.connect_adder_enable(self.control_registers[row + 1])
                else:
                    pe.connect_adder_enable(self.control_registers[row])

                # Connect weight enable signal (shared by all PEs)
                pe.connect_weight_enable(self.weight_enable)

                # Connect bottom row results to output ports
                if row == self.size - 1:
                    self.results_out[col] <<= pe.outputs.accum

    # TODO: Connect control registers properly for weight stationary
    def _connect_ws_array(self):
        """Connect processing elements in systolic pattern

        Data flow patterns:
        - Activations flow left to right, must be diagonally buffered with FIFO
        - Weights flow top to bottom
        - Partial sums flow top to bottom
        """

        # self._connect_control_registers()

        for row in range(self.size):
            for col in range(self.size):
                pe = self.pe_array[row][col]

                # Connect activation input:
                if col == 0:
                    # First column gets external input
                    pe.connect_data(self.data_in[row])
                else:
                    # PE data comes from PE to the left
                    pe.connect_data(self.pe_array[row][col - 1])

                # Connect weight input:
                # First row gets external input, others connect to PE above
                if row == 0:
                    pe.connect_weight(self.weights_in[col])
                    pe.connect_accum(Const(0))
                else:
                    pe.connect_weight(self.pe_array[row - 1][col])
                    pe.connect_accum(self.pe_array[row - 1][col])

                # Connect bottom row results to output ports
                if row == self.size - 1:
                    self.results_out[col] <<= pe.outputs.accum

    # -----------------------------------------------------------------------------
    # Connection functions return their inputs to allow for more concise simulation
    # -----------------------------------------------------------------------------
    def connect_weight_enable(self, source: WireVector):
        """Connect weight load enable signal"""
        self.weight_enable <<= source
        return source

    def connect_enable_input(self, source: WireVector):
        """Connect PE enable signal. Controls writing to the data input register"""
        self.enable_in <<= source
        return source

    def connect_data_input(self, row: int, source: WireVector):
        """Connect data input for specified row"""
        assert 0 <= row < self.size
        self.data_in[row] <<= source
        return source

    def connect_weight_input(self, col: int, source: WireVector):
        """Connect weight input for specified column"""
        assert 0 <= col < self.size
        self.weights_in[col] <<= source
        return source

    def connect_result_output(self, col: int, dest: WireVector):
        """Connect result output from specified column"""
        assert 0 <= col < self.size
        dest <<= self.results_out[col]
        return dest

    # -----------------------------------------------------------------------------
    # Simulation helper methods
    # -----------------------------------------------------------------------------
    def inspect_weights(self, sim: Simulation, verbose: bool = True):
        weights = np.zeros((self.size, self.size))
        enabled = sim.inspect(self.weight_enable.name) == 1
        for row in range(self.size):
            for col in range(self.size):
                w = sim.inspect(self.pe_array[row][col].weight_reg.name)
                weights[row][col] = self.data_type(binint=w).decimal_approx
        if verbose:
            print(f"Weights: {enabled=}")
            print(np.array_str(weights, precision=3, suppress_small=True), "\n")
        return weights

    def inspect_data(self, sim: Simulation, verbose: bool = True):
        data = np.zeros((self.size, self.size))
        for row in range(self.size):
            for col in range(self.size):
                d = sim.inspect(self.pe_array[row][col].data_reg.name)
                data[row][col] = self.data_type(binint=d).decimal_approx
        if verbose:
            print("Data:")
            print(np.array_str(data, precision=4, suppress_small=True), "\n")
        return data

    def inspect_outputs(self, sim: Simulation, verbose: bool = True):
        current_results = np.zeros(self.size)
        for i in range(self.size):
            r = sim.inspect(self.results_out[i].name)
            current_results[i] = self.accum_type(binint=r).decimal_approx
        if verbose:
            print("Output:")
            print(np.array_str(current_results, precision=4, suppress_small=True), "\n")
        return current_results

## Use the simulator class to test DiP
code imported from `hardware_accelerators`

In [None]:
from hardware_accelerators.simulation.systolic import SystolicArraySimulator

SIZE = 3

weights = np.identity(SIZE)
activations = np.random.rand(SIZE, SIZE)
print(f"Weights:\n{weights}")
print("Data:\n", activations)

systolic_sim = SystolicArraySimulator(size=SIZE)

result = systolic_sim.simulate(weights, activations)

print(f"Result:\n{result}")

Weights:
[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]
Data:
 [[0.96938999 0.4226538  0.76405986]
 [0.04827431 0.25570676 0.26495031]
 [0.05079308 0.81155553 0.91206032]]
Result:
[[1.         0.4375     0.79296875]
 [0.05004883 0.26953125 0.27929688]
 [0.05273438 0.83984375 0.94140625]]


In [139]:
systolic_sim.history

[
 Simulation State - Step 0
 ----------------------------------------
 Inputs:
   w_en: 1
   enable: 0
   weights: [0. 0. 0.]
   data: [0. 0. 0.]
 
 Weights Matrix:
 [[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]
 
 Data Matrix:
 [[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]
 
 Accumulators:
 [[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]
 
 Outputs:
 [0. 0. 0.]
 ----------------------------------------,
 
 Simulation State - Step 1
 ----------------------------------------
 Inputs:
   w_en: 1
   enable: 0
   weights: [0. 0. 0.]
   data: [0. 0. 0.]
 
 Weights Matrix:
 [[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]
 
 Data Matrix:
 [[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]
 
 Accumulators:
 [[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]
 
 Outputs:
 [0. 0. 0.]
 ----------------------------------------,
 
 Simulation State - Step 2
 ----------------------------------------
 Inputs:
   w_en: 1
   enable: 1
   weights: [1. 1. 1.]
   data: [0.0508 0.8086 0.9102]
 
 Weights Matrix:
 [[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]
 
 Data Ma

### Testing wire concatenation order

In [None]:
reset_working_block()

product = Output(16, "product")
product <<= float_multiplier(Input(16, "a"), Input(16, "b"), BF16)

sim = Simulation()

a, b = 4, 5

sim.step({"a": BF16(a).binint, "b": BF16(b).binint})
sim.step({"a": 0, "b": 0})

render_waveform(sim, repr_func=repr_bf16_binary)

<IPython.core.display.Javascript object>

In [46]:
reset_working_block()

ins = input_list(["a_in", "b_in", "c_in"], 8)
wires = wirevector_list(["a", "b", "c"], 8, WireVector)

for i, w in zip(ins, wires):
    w <<= i

concat_wires = Output(8 * len(wires), "concat")
concat_wires <<= concat(*wires)

concat_list_wires = Output(8 * len(wires), "concat_list")
concat_list_wires <<= concat_list(wires)

sim = Simulation(SimulationTrace([*ins, concat_wires, concat_list_wires]))

sim.step({wire.name: i + 1 for i, wire in enumerate(ins)})


def fmt_bigwire(x):
    bin = format(x, "024b")
    return "-".join(list(map("".join, zip(*[iter(bin)] * 8))))


render_waveform(
    sim,
    repr_func=str,
    repr_per_name={"concat": fmt_bigwire, "concat_list": fmt_bigwire},
)

<IPython.core.display.Javascript object>

## Systolic Data Setup


In [59]:
class SystolicSetup:
    """Creates diagonal delay pattern for systolic array I/O

    For a 3x3 array, creates following pattern of registers:
    (R = register, -> = connection)

    Row 0:  [R] ------->
    Row 1:  [R]->[R] -->
    Row 2:  [R]->[R]->[R]

    - Each row i contains i+1 registers
    - Input connects to leftmost register
    - Output reads from rightmost register
    - Can be used for both input and output buffering
    """

    def __init__(self, size: int, dtype: Type[BaseFloat]):
        """Initialize delay register network

        Args:
            size: Number of rows in network
            bitwidth: Bit width of data values
        """
        self.size = size
        self.dtype = dtype
        self.bitwidth = dtype.bitwidth()

        # Create input wires for each row
        self.inputs = [WireVector(self.bitwidth) for _ in range(size)]

        # Create delay register network - more delays for lower rows
        self.delay_regs = []
        self.outputs = [WireVector(self.bitwidth) for _ in range(size)]

        for i in range(size):  # Create num rows equal to the size of systolic array
            row: List[Register] = []
            # Number of buffer registers equals row index for lower triangular config
            for j in range(i + 1):
                row.append(Register(self.bitwidth))
                if j != 0:
                    # Left most register connects to inputs, others connect to previous reg
                    row[j].next <<= row[j - 1]

            # Connect row input and output
            row[0].next <<= self.inputs[i]
            self.outputs[i] <<= row[-1]
            self.delay_regs.append(row)

    def connect_input(self, row: int, source: WireVector):
        """Connect input for specified row"""
        assert (
            len(source) == self.bitwidth
        ), f"Source bitwidth ({len(source)}) must match configured bitwidth ({self.bitwidth})"
        self.inputs[row] <<= source

    def connect_output(self, row, dest: WireVector):
        """Connect final register in a buffer row to an output destination"""
        assert (
            len(dest) == self.bitwidth
        ), f"Destination bitwidth ({len(dest)}) must match configured bitwidth ({self.bitwidth})"
        dest <<= self.outputs[row]

## Matrix Multiply Unit Top Level


In [60]:
class MatrixMultiplier:
    """Top level systolic array matrix multiplier hardware"""

    def __init__(
        self,
        size: int,
        data_type: Type[BaseFloat],
        accum_type: Type[BaseFloat],
        multiplier: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector],
        adder: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector],
        pipeline_mult: bool = False,
    ):
        self.size = size
        self.data_type = data_type
        self.accum_type = accum_type
        self.data_width = data_type.bitwidth()
        self.accum_width = accum_type.bitwidth()

        # Create hardware components
        self.systolic_array = SystolicArray(
            size, data_type, accum_type, multiplier, adder, pipeline_mult
        )
        self.systolic_setup = SystolicSetup(size, data_type)
        self.result_buffer = SystolicSetup(size, accum_type)

        # Connect internal components
        self._connect_internal_components()

    def _connect_internal_components(self):
        """Connect systolic array to input/output buffers"""
        for i in range(self.size):
            self.systolic_array.connect_data_input(i, self.systolic_setup.outputs[i])
            self.systolic_array.connect_result_output(
                i, self.result_buffer.inputs[-i - 1]
            )

    def _validate_wire_list(
        self, wires: List[WireVector], expected_width: int, purpose: str
    ):
        """Validate a list of wires meets requirements"""
        if len(wires) != self.size:
            raise ValueError(f"{purpose} requires {self.size} wires, got {len(wires)}")
        if not all(isinstance(w, WireVector) for w in wires):
            raise TypeError(f"All {purpose} must be WireVector instances")
        if not all(w.bitwidth == expected_width for w in wires):
            raise ValueError(f"All {purpose} must have bitwidth {expected_width}")

    def _split_wide_wire(
        self, wire: WireVector, width_per_slice: int
    ) -> List[WireVector]:
        """Split a wide wire into equal slices"""
        expected_width = width_per_slice * self.size
        if wire.bitwidth != expected_width:
            raise ValueError(
                f"Wide wire must have bitwidth {expected_width}, got {wire.bitwidth}"
            )
        # Use chop instead of manual slicing
        return chop(wire, *([width_per_slice] * self.size))

    def connect_weight_enable(self, enable: WireVector):
        """Connect weight enable signal"""
        if not isinstance(enable, WireVector) or enable.bitwidth != 1:
            raise ValueError("Weight enable must be 1-bit WireVector")
        self.systolic_array.connect_weight_load(enable)

    def connect_weight(self, index: int, weight: WireVector):
        """Connect an individual weight wire to the systolic array"""
        assert len(weight) == self.data_width
        self.systolic_array.connect_weight_input(index, weight)

    def connect_data(self, index: int, data: WireVector):
        """Connect an individual data wire to the systolic setup buffer"""
        assert len(data) == self.data_width
        self.systolic_setup.connect_input(index, data)

    def connect_output(self, index: int, output: WireVector):
        """Connect an individual output wire to the result buffer"""
        assert len(output) == self.accum_width
        self.result_buffer.connect_output(-index - 1, output)

    def connect_all_weights(self, weights: WireVector | List[WireVector]):
        """Connect weight inputs either as list of wires or single wide wire"""
        if isinstance(weights, list):
            self._validate_wire_list(weights, self.data_width, "weight inputs")
            weight_wires = weights
        else:
            # Split wide wire into individual weight wires
            weight_wires = chop(weights, *([self.data_width] * self.size))

        for i, wire in enumerate(weight_wires):
            self.systolic_array.connect_weight_input(i, wire)

    def connect_all_data(self, data: WireVector | List[WireVector]):
        """Connect data inputs either as list of wires or single wide wire"""
        if isinstance(data, list):
            self._validate_wire_list(data, self.data_width, "data inputs")
            data_wires = data
        else:
            # Split wide wire into individual data wires
            data_wires = chop(data, *([self.data_width] * self.size))

        for i, wire in enumerate(data_wires):
            self.systolic_setup.connect_input(i, wire)

    def connect_all_outputs(self, results: WireVector | List[WireVector]):
        """Connect result outputs either as list of wires or single wide wire"""
        if isinstance(results, list):
            self._validate_wire_list(results, self.accum_width, "result outputs")
            result_wires = results
        else:
            # Split wide wire into individual result wires
            result_wires = chop(results, *([self.accum_width] * self.size))

        for i, wire in enumerate(result_wires):
            self.result_buffer.connect_output(-i - 1, wire)

## Simulation Class


In [61]:
class MatrixMultiplierSimulator:
    def __init__(self, matrix_multiplier: MatrixMultiplier):
        self.mmu = matrix_multiplier
        self.size = matrix_multiplier.size
        self.data_type = matrix_multiplier.data_type
        self.accum_type = matrix_multiplier.accum_type

        # Create I/O ports
        self.mmu.connect_weight_enable(Input(1, "weight_enable"))

        for i in range(self.size):
            self.mmu.connect_weight(i, Input(self.mmu.data_width, f"weight_{i}"))
            self.mmu.connect_data(i, Input(self.mmu.data_width, f"data_{i}"))
            self.mmu.connect_output(i, Output(self.mmu.accum_width, f"result_{i}"))

        # Initialize simulation
        self.sim = Simulation()
        self.sim_inputs = {
            "weight_enable": 0,
            **{f"weight_{i}": 0 for i in range(self.size)},
            **{f"data_{i}": 0 for i in range(self.size)},
        }

        self._iter_state = None
        self.result_matrix = np.zeros((self.size, self.size))

    def set_matrices(self, matrix_a: np.ndarray, matrix_b: np.ndarray):
        """Set input matrices and prepare simulation state"""
        # Verify dimensions
        assert (
            matrix_a.shape == matrix_b.shape == (self.size, self.size)
        ), f"Matrices must be {self.size}x{self.size}"

        # Convert matrices to specified datatype
        self.matrix_a = self._convert_matrix(matrix_a)
        self.matrix_b = self._convert_matrix(matrix_b)

        # Load weights into PEs
        self._load_weights()

    def calculate(self):
        while next(self):
            continue
        return self.result_matrix

    def matmul(self, matrix_a: np.ndarray, matrix_b: np.ndarray):
        self.set_matrices(matrix_a, matrix_b)
        return self.calculate()

    def _convert_matrix(self, matrix: np.ndarray) -> List[List[int]]:
        """Convert numpy matrix to list of binary values in specified datatype"""
        return [[self.data_type(x).binint for x in row] for row in matrix]

    def _load_weights(self):
        """Load weights into processing elements in reverse row order"""
        for row in reversed(range(self.size)):
            for col in range(self.size):
                self.sim_inputs[f"weight_{col}"] = self.matrix_b[row][col]
            self.sim_inputs["weight_enable"] = 1
            self.sim.step(self.sim_inputs)

        # Reset weight inputs
        for i in range(self.size):
            self.sim_inputs[f"weight_{i}"] = 0
        self.sim_inputs["weight_enable"] = 0
        self.sim.step(self.sim_inputs)

    def __iter__(self):
        """Initialize iterator state"""
        if self.matrix_a is None or self.matrix_b is None:
            raise RuntimeError("Matrices must be set before iteration")

        self._iter_state = {
            "row": self.size - 1,  # Start from last row
            "extra_cycles": self.size * 3 - 1,  # Cycles needed to flush results
            "phase": "input",  # 'input' or 'flush' phase
        }
        return self

    def __next__(self):
        """Return next simulation step results"""
        if self._iter_state is None:
            raise RuntimeError("Iterator not initialized")

        # If we're done with both phases, stop iteration
        if (
            self._iter_state["phase"] == "flush"
            and self._iter_state["extra_cycles"] == 0
        ):
            raise StopIteration

        # Handle input phase
        if self._iter_state["phase"] == "input":
            if self._iter_state["row"] < 0:
                # Transition to flush phase
                self._iter_state["phase"] = "flush"
                # Clear inputs
                for i in range(self.size):
                    self.sim_inputs[f"data_{i}"] = 0
            else:
                # Load next row of input data
                for col in range(self.size):
                    self.sim_inputs[f"data_{col}"] = self.matrix_a[
                        self._iter_state["row"]
                    ][col]
                self._iter_state["row"] -= 1

        # Handle flush phase
        if self._iter_state["phase"] == "flush":
            self._iter_state["extra_cycles"] -= 1

        # Step simulation
        self.sim.step(self.sim_inputs)

        # Return current results
        current_outputs = self.get_current_results()

        # Shift previous results down and insert new results at top
        self.result_matrix[1:] = self.result_matrix[:-1]
        self.result_matrix[0] = current_outputs

        return current_outputs

    def get_current_results(self) -> List[BaseFloat]:
        """Get current values from result output ports

        Returns:
            List of values currently present on the result output ports.
            Length will equal systolic array size (one value per column).
        """
        return [
            self.accum_type(binint=self.sim.inspect(f"result_{i}"))
            for i in range(self.size)
        ]

    def inspect_pe_array(self) -> dict[str, np.ndarray]:
        """Get current state of processing element array as dictionary of matrices

        Returns:
            Dictionary with keys 'data', 'weights', 'accum', where each value is
            a matrix showing the current values in the corresponding registers
            across the PE array
        """
        # Initialize matrices to store PE values
        data_matrix = np.zeros((self.size, self.size))
        weight_matrix = np.zeros((self.size, self.size))
        accum_matrix = np.zeros((self.size, self.size))

        # Populate matrices with current PE values
        for row in range(self.size):
            for col in range(self.size):
                pe = self.mmu.systolic_array.pe_array[row][col]

                # Convert binary values to float using appropriate data types
                data_matrix[row, col] = self.data_type(
                    binint=self.sim.inspect(pe.outputs.data.name)
                )
                weight_matrix[row, col] = self.data_type(
                    binint=self.sim.inspect(pe.outputs.weight.name)
                )
                accum_matrix[row, col] = self.accum_type(
                    binint=self.sim.inspect(pe.outputs.accum.name)
                )

        return {"data": data_matrix, "weights": weight_matrix, "accum": accum_matrix}

    def inspect_systolic_setup(self) -> str:
        """Visualize current state of systolic setup registers"""
        repr_str = ""
        for row in range(self.size):
            input_val = self.data_type(binint=self.sim.inspect(f"data_{row}"))
            repr_str += f"(input={input_val}) => "

            for reg in self.mmu.systolic_setup.delay_regs[row]:
                val = self.data_type(binint=self.sim.inspect(reg.name))
                repr_str += f"{val} -> "
            repr_str += "\n"
        return repr_str

In [64]:
reset_working_block()
set_debug_mode()

# Create hardware and simulator
systolic = MatrixMultiplier(
    size=4,
    data_type=BF16,
    accum_type=BF16,
    multiplier=lmul_simple,
    adder=float_adder,
)
simulator = MatrixMultiplierSimulator(systolic)

# Set input matrices
data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
weights = np.identity(4)
simulator.set_matrices(weights, weights)

# Iterate over simulation steps
for step, results in enumerate(simulator):
    print(f"Step {step} ({simulator._iter_state}):")

    print("Systolic Setup State:")
    print(simulator.inspect_systolic_setup())

    # Get PE array state
    print("PE Array State:")
    pe_state = simulator.inspect_pe_array()
    print("Data Values:")
    print(pe_state["data"])
    print("\nWeight Values:")
    print(pe_state["weights"])
    print("\nAccumulator Values:")
    print(pe_state["accum"])

    print("Current Results:")
    print([f"{float(x):.3f}" for x in results])

    print("-" * 80 + "\n")

PyrtlError: Wires used but never driven: ['tmp54994', 'tmp54993'] 

 tmp54994/1W:
Wire Traceback, most recent call last 
  File "<frozen runpy>", line 198, in _run_module_as_main
   File "<frozen runpy>", line 88, in _run_code
   File "/home/vscode/.local/lib/python3.12/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
   File "/home/vscode/.local/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
   File "/home/vscode/.local/lib/python3.12/site-packages/ipykernel/kernelapp.py", line 739, in start
    self.io_loop.start()
   File "/home/vscode/.local/lib/python3.12/site-packages/tornado/platform/asyncio.py", line 205, in start
    self.asyncio_loop.run_forever()
   File "/usr/local/lib/python3.12/asyncio/base_events.py", line 640, in run_forever
    self._run_once()
   File "/usr/local/lib/python3.12/asyncio/base_events.py", line 1992, in _run_once
    handle._run()
   File "/usr/local/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
   File "/home/vscode/.local/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue
    await self.process_one()
   File "/home/vscode/.local/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 534, in process_one
    await dispatch(*args)
   File "/home/vscode/.local/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
    await result
   File "/home/vscode/.local/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 362, in execute_request
    await super().execute_request(stream, ident, parent)
   File "/home/vscode/.local/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 778, in execute_request
    reply_content = await reply_content
   File "/home/vscode/.local/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 449, in do_execute
    res = shell.run_cell(
   File "/home/vscode/.local/lib/python3.12/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
    return super().run_cell(*args, **kwargs)
   File "/home/vscode/.local/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3009, in run_cell
    result = self._run_cell(
   File "/home/vscode/.local/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3064, in _run_cell
    result = runner(coro)
   File "/home/vscode/.local/lib/python3.12/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
    coro.send(None)
   File "/home/vscode/.local/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3269, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
   File "/home/vscode/.local/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3448, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
   File "/home/vscode/.local/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
   File "/tmp/ipykernel_86117/814662598.py", line 5, in <module>
    mmu = MatrixMultiplier(
   File "/tmp/ipykernel_86117/1758568244.py", line 20, in __init__
    self.systolic_array = SystolicArray(
   File "/tmp/ipykernel_86117/2341231167.py", line 54, in __init__
    self.adder_enable_in = WireVector(1)  # Enable writing to the accum register


tmp54993/1W:
Wire Traceback, most recent call last 
  File "<frozen runpy>", line 198, in _run_module_as_main
   File "<frozen runpy>", line 88, in _run_code
   File "/home/vscode/.local/lib/python3.12/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
   File "/home/vscode/.local/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
   File "/home/vscode/.local/lib/python3.12/site-packages/ipykernel/kernelapp.py", line 739, in start
    self.io_loop.start()
   File "/home/vscode/.local/lib/python3.12/site-packages/tornado/platform/asyncio.py", line 205, in start
    self.asyncio_loop.run_forever()
   File "/usr/local/lib/python3.12/asyncio/base_events.py", line 640, in run_forever
    self._run_once()
   File "/usr/local/lib/python3.12/asyncio/base_events.py", line 1992, in _run_once
    handle._run()
   File "/usr/local/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
   File "/home/vscode/.local/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue
    await self.process_one()
   File "/home/vscode/.local/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 534, in process_one
    await dispatch(*args)
   File "/home/vscode/.local/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
    await result
   File "/home/vscode/.local/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 362, in execute_request
    await super().execute_request(stream, ident, parent)
   File "/home/vscode/.local/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 778, in execute_request
    reply_content = await reply_content
   File "/home/vscode/.local/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 449, in do_execute
    res = shell.run_cell(
   File "/home/vscode/.local/lib/python3.12/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
    return super().run_cell(*args, **kwargs)
   File "/home/vscode/.local/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3009, in run_cell
    result = self._run_cell(
   File "/home/vscode/.local/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3064, in _run_cell
    result = runner(coro)
   File "/home/vscode/.local/lib/python3.12/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
    coro.send(None)
   File "/home/vscode/.local/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3269, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
   File "/home/vscode/.local/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3448, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
   File "/home/vscode/.local/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
   File "/tmp/ipykernel_86117/814662598.py", line 5, in <module>
    mmu = MatrixMultiplier(
   File "/tmp/ipykernel_86117/1758568244.py", line 20, in __init__
    self.systolic_array = SystolicArray(
   File "/tmp/ipykernel_86117/2341231167.py", line 53, in __init__
    self.data_enable_in = WireVector(1)  # Enable writing to the data input register



In [63]:
[[x.decimal_approx for x in result] for result in simulator][::-1]

[[12.75, 12.75, 12.75, 12.75],
 [13.8125, 8.5, 8.5, 8.5],
 [8.5, 13.8125, 8.5, 8.5],
 [8.5, 8.5, 13.8125, 8.5],
 [8.5, 8.5, 8.5, 13.8125],
 [12.75, 12.75, 12.75, 12.75],
 [12.75, 12.75, 12.75, 12.75],
 [12.75, 12.75, 12.75, 12.75],
 [12.75, 12.75, 12.75, 12.75],
 [12.75, 12.75, 12.75, 12.75],
 [12.75, 12.75, 12.75, 12.75],
 [12.75, 12.75, 12.75, 12.75],
 [12.75, 12.75, 12.75, 12.75],
 [12.75, 12.75, 12.75, 12.75],
 [12.75, 12.75, 12.75, 12.75]]

In [36]:
simulator.result_matrix

array([[1.0625, 2.125 , 3.125 ],
       [4.25  , 5.25  , 6.25  ],
       [7.25  , 8.5   , 9.5   ]])

---


# DiP: (diagonal input and permutated weight stationary)


Implementing the design found in the recent paper **[DiP: A Scalable, Energy-Efficient Systolic Array for Matrix Multiplication Acceleration](https://arc.net/l/quote/wllrwuvk)**


## Weight Matrix Permutation


In [316]:
def permutate_weight_matrix(arr: np.ndarray):
    # verify matrix is square
    rows, cols = arr.shape
    permutated = np.zeros((rows, cols))
    for i in range(cols):
        for j in range(rows):
            permutated[j][i] = arr[(j + i) % rows][i]
    return permutated

In [None]:
x = np.array([[1, 2, 3], [3, 4, 5], [5, 6, 7]])
x, permutate_weight_matrix(x)

(array([[1, 2, 3],
        [3, 4, 5],
        [5, 6, 7]]),
 array([[1., 4., 7.],
        [3., 6., 3.],
        [5., 2., 5.]]))