# Matrix Acceleration Engine

### Imports

In [7]:
from hardware_accelerators.rtllib import (
    SystolicArrayDiP,
    BufferMemory,
    TiledAccumulatorMemoryBank,
    float_adder,
    float_multiplier,
    lmul_fast,
)
from hardware_accelerators.simulation import *
from hardware_accelerators.dtypes import *
import pyrtl
from pyrtl import *
import numpy as np
from dataclasses import dataclass
from typing import Type, Callable, List

## Configuration Helper Class

In [15]:
@dataclass
class AcceleratorConfig:
    """Configuration class for a systolic array accelerator.

    This class defines the parameters and specifications for a systolic array
    accelerator including array dimensions, data types, arithmetic operations,
    and memory configuration.
    """

    array_size: int
    """Dimension of systolic array (always square)"""

    data_type: Type[BaseFloat]
    """Floating point format of input data to systolic array"""

    weight_type: Type[BaseFloat]
    """Floating point format of weight inputs"""

    accum_type: Type[BaseFloat]
    """Floating point format to accumulate values in"""

    pe_adder: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector]
    """Function to generate adder hardware for the processing elements"""

    accum_adder: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector]
    """Function to generate adder hardware for the accumulator buffer"""

    pe_multiplier: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector]
    """Function to generate multiplier hardware for the processing elements"""

    pipeline: bool
    """Whether to add a pipeline stage in processing elements between multiplier and adder"""

    accumulator_tiles: int
    """Number of tiles in the accumulator memory, each tile is equal to the size of the systolic array"""

    @property
    def accum_addr_width(self):
        """Get the width of the accumulator address bus in bits"""
        return (self.accumulator_tiles - 1).bit_length()

## Putting everything together

In [16]:
config = AcceleratorConfig(
    array_size=4,
    data_type=BF16,
    weight_type=BF16,
    accum_type=BF16,
    pe_adder=float_adder,
    pe_multiplier=lmul_fast,
    accum_adder=float_adder,
    pipeline=False,
    accumulator_tiles=4,
)

In [17]:
getattr(config, "array_size")

4

In [7]:
# Input Signals
data_start = Input(1, "data_start")  # Start data streaming
data_bank = Input(1, "data_bank")  # Select data memory bank
weight_start = Input(1, "weight_start")  # Start weight streaming
weight_bank = Input(1, "weight_bank")  # Select weight memory bank

accum_start = Input(1, "accum_start")
accum_tile_addr = Input(config.accum_addr_width, "accum_tile_addr")
accum_mode = Input(1, "accum_mode")  # 0 = overwrite, 1 = accumulate

accum_read_start = Input(1, "accum_read_start")
accum_read_tile_addr = Input(config.accum_addr_width, "accum_read_tile_addr")

buffer = BufferMemory(
    array_size=config.array_size,
    data_type=config.data_type,
    weight_type=config.weight_type,
)

buffer.connect_inputs(
    data_start=data_start,
    data_bank=data_bank,
    weight_start=weight_start,
    weight_bank=weight_bank,
)

buffer_outputs = buffer.get_outputs()

systolic_array = SystolicArrayDiP(
    size=config.array_size,
    data_type=config.data_type,
    accum_type=config.accum_type,
    multiplier=config.pe_multiplier,
    adder=config.pe_adder,
    pipeline=config.pipeline,
)

systolic_array.connect_inputs(
    weight_enable=buffer_outputs.weight_valid,
    weight_inputs=buffer_outputs.weights_out,
    data_inputs=buffer_outputs.datas_out,
    enable_input=buffer_outputs.data_valid,
)

accumulator = TiledAccumulatorMemoryBank(
    tile_addr_width=config.accum_addr_width,
    array_size=config.array_size,
    data_type=config.accum_type,
    adder=config.accum_adder,
)

accumulator.connect_inputs(
    data_in=systolic_array.results_out,
    write_start=accum_start,
    write_tile_addr=accum_tile_addr,
    write_mode=accum_mode,
    write_valid=systolic_array.control_out,
    read_start=accum_read_start,
    read_tile_addr=accum_read_tile_addr,
)

# Matrix Engine

In [9]:
class MatrixEngine:
    """Hardware implementation of the matrix engine accelerator"""

    def __init__(self, config: AcceleratorConfig):
        """Initialize matrix engine hardware with given configuration

        Args:
            config: Configuration parameters for the accelerator
        """
        self.config = config

        # Create input control signals
        self.data_start = Input(1, "data_start")
        self.data_bank = Input(1, "data_bank")
        self.weight_start = Input(1, "weight_start")
        self.weight_bank = Input(1, "weight_bank")

        self.accum_start = Input(1, "accum_start")
        self.accum_tile_addr = Input(config.accum_addr_width, "accum_tile_addr")
        self.accum_mode = Input(1, "accum_mode")

        self.accum_read_start = Input(1, "accum_read_start")
        self.accum_read_tile_addr = Input(
            config.accum_addr_width, "accum_read_tile_addr"
        )

        # Initialize hardware components
        self.buffer = BufferMemory(
            array_size=config.array_size,
            data_type=config.data_type,
            weight_type=config.weight_type,
        )

        self.systolic_array = SystolicArrayDiP(
            size=config.array_size,
            data_type=config.data_type,
            accum_type=config.accum_type,
            multiplier=config.pe_multiplier,
            adder=config.pe_adder,
            pipeline=config.pipeline,
        )

        self.accumulator = TiledAccumulatorMemoryBank(
            tile_addr_width=config.accum_addr_width,
            array_size=config.array_size,
            data_type=config.accum_type,
            adder=config.accum_adder,
        )

        # Connect components
        self._connect_components()

    def _connect_components(self):
        """Connect all hardware components together"""
        # Connect buffer control signals
        self.buffer.connect_inputs(
            data_start=self.data_start,
            data_bank=self.data_bank,
            weight_start=self.weight_start,
            weight_bank=self.weight_bank,
        )

        # Get buffer outputs
        buffer_outputs = self.buffer.get_outputs()

        # Connect systolic array to buffer
        self.systolic_array.connect_inputs(
            weight_enable=buffer_outputs.weight_valid,
            weight_inputs=buffer_outputs.weights_out,
            data_inputs=buffer_outputs.datas_out,
            enable_input=buffer_outputs.data_valid,
        )

        # Connect accumulator
        self.accumulator.connect_inputs(
            data_in=self.systolic_array.results_out,
            write_start=self.accum_start,
            write_tile_addr=self.accum_tile_addr,
            write_mode=self.accum_mode,
            write_valid=self.systolic_array.control_out,
            read_start=self.accum_read_start,
            read_tile_addr=self.accum_read_tile_addr,
        )

## Simulation Class

In [14]:
@dataclass
class SimulationState:
    """Stores the state of the matrix engine at a given simulation step"""

    step: int
    inputs: dict
    buffer_state: dict
    systolic_state: dict
    accumulator_state: dict

    def __repr__(self) -> str:
        width = 60
        sep = "-" * width
        return (
            f"\nMatrix Engine State - Step {self.step}\n{sep}\n"
            f"Inputs:\n{self.inputs}\n\n"
            f"Buffer State:\n{self.buffer_state}\n\n"
            f"Systolic Array State:\n{self.systolic_state}\n\n"
            f"Accumulator State:\n{self.accumulator_state}\n"
            f"{sep}\n"
        )

In [17]:
class MatrixEngineSimulator:
    """Simulator for the matrix engine accelerator"""

    def __init__(self, config: AcceleratorConfig):
        """Initialize matrix engine simulator

        Args:
            config: Configuration parameters for the accelerator
        """
        self.config = config
        self.history: List[SimulationState] = []

        # Create hardware
        reset_working_block()
        self.engine = MatrixEngine(config)
        self.sim = Simulation()

    def _get_default_inputs(self, updates: dict = {}) -> dict:
        """Get dictionary of default input values with optional updates"""
        defaults = {
            "data_start": 0,
            "data_bank": 0,
            "weight_start": 0,
            "weight_bank": 0,
            "accum_start": 0,
            "accum_tile_addr": 0,
            "accum_mode": 0,
            "accum_read_start": 0,
            "accum_read_tile_addr": 0,
        }
        defaults.update(updates)
        # Convert all values to integers for PyRTL
        return {k: int(v) for k, v in defaults.items()}

    def _step(self, inputs: dict = None):
        """Advance simulation one step and record state"""
        sim_inputs = self._get_default_inputs(inputs or {})

        # Create the simulation input dictionary with proper wire names
        wire_inputs = {
            self.engine.data_start.name: sim_inputs["data_start"],
            self.engine.data_bank.name: sim_inputs["data_bank"],
            self.engine.weight_start.name: sim_inputs["weight_start"],
            self.engine.weight_bank.name: sim_inputs["weight_bank"],
            self.engine.accum_start.name: sim_inputs["accum_start"],
            self.engine.accum_tile_addr.name: sim_inputs["accum_tile_addr"],
            self.engine.accum_mode.name: sim_inputs["accum_mode"],
            self.engine.accum_read_start.name: sim_inputs["accum_read_start"],
            self.engine.accum_read_tile_addr.name: sim_inputs["accum_read_tile_addr"],
        }

        # Step simulation with wire name mapped inputs
        self.sim.step(wire_inputs)

        # Record simulation state
        state = SimulationState(
            step=len(self.history),
            inputs=sim_inputs,
            buffer_state={
                "data": self._inspect_buffer_data(),
                "weights": self._inspect_buffer_weights(),
            },
            systolic_state={
                "weights": self.engine.systolic_array.inspect_weights(self.sim, False),
                "data": self.engine.systolic_array.inspect_data(self.sim, False),
                "accumulators": self.engine.systolic_array.inspect_accumulators(
                    self.sim, False
                ),
                "outputs": self.engine.systolic_array.inspect_outputs(self.sim, False),
            },
            accumulator_state=self._inspect_accumulator_state(),
        )
        self.history.append(state)

    # def _get_default_inputs(self, updates: dict = {}) -> dict:
    #     """Get dictionary of default input values with optional updates"""
    #     defaults = {
    #         "data_start": 0,
    #         "data_bank": 0,
    #         "weight_start": 0,
    #         "weight_bank": 0,
    #         "accum_start": 0,
    #         "accum_tile_addr": 0,
    #         "accum_mode": 0,
    #         "accum_read_start": 0,
    #         "accum_read_tile_addr": 0
    #     }
    #     defaults.update(updates)
    #     return defaults

    def load_weights(self, weights: np.ndarray, bank: int) -> None:
        """Load weight matrix into specified buffer bank

        Args:
            weights: Weight matrix to load
            bank: Buffer bank to load into (0 or 1)
        """
        if weights.shape != (self.config.array_size, self.config.array_size):
            raise ValueError(
                f"Weights must be {self.config.array_size}x{self.config.array_size}"
            )

        # Load weights into buffer memory
        self._step({"weight_start": 1, "weight_bank": bank})

        # Convert weights to binary and load into memory
        binary_weights = self._convert_to_binary(weights, self.config.weight_type)
        weight_mem = self.sim.inspect_mem(self.engine.buffer.weight_mems[bank])
        for i, row in enumerate(binary_weights[::-1]):
            weight_mem[i] = row

    def load_activations(self, activations: np.ndarray, bank: int) -> None:
        """Load activation matrix into specified buffer bank

        Args:
            activations: Activation matrix to load
            bank: Buffer bank to load into (0 or 1)
        """
        if activations.shape != (self.config.array_size, self.config.array_size):
            raise ValueError(
                f"Activations must be {self.config.array_size}x{self.config.array_size}"
            )

        # Load activations into buffer memory
        self._step({"data_start": 1, "data_bank": bank})

        # Convert activations to binary and load into memory
        binary_acts = self._convert_to_binary(activations, self.config.data_type)
        data_mem = self.sim.inspect_mem(self.engine.buffer.data_mems[bank])
        for i, row in enumerate(binary_acts[::-1]):
            data_mem[i] = row

    def matrix_multiply(
        self,
        data_tile: int,
        weight_tile: int,
        accum_tile: int,
        accumulate: bool = False,
    ) -> None:
        """Perform matrix multiplication using specified tiles

        Args:
            data_tile: Buffer bank containing input activations
            weight_tile: Buffer bank containing weights
            accum_tile: Accumulator tile to store results
            accumulate: Whether to accumulate with existing values
        """
        # Start data and weight streaming
        self._step(
            {
                "data_start": 1,
                "data_bank": data_tile,
                "weight_start": 1,
                "weight_bank": weight_tile,
                "accum_start": 1,
                "accum_tile_addr": accum_tile,
                "accum_mode": int(accumulate),
            }
        )

        # Run for required cycles
        cycles = self.config.array_size * 2 + int(self.config.pipeline)
        for _ in range(cycles):
            self._step(
                {
                    "data_bank": data_tile,
                    "weight_bank": weight_tile,
                    "accum_tile_addr": accum_tile,
                    "accum_mode": int(accumulate),
                }
            )

    def read_accumulator_tile(self, tile: int) -> None:
        """Initiate read operation for specified accumulator tile

        Args:
            tile: Accumulator tile to read
        """
        # Start read operation
        self._step({"accum_read_start": 1, "accum_read_tile_addr": tile})

        # Step simulation until read is complete
        while not self.sim.inspect(self.engine.accumulator.read_done.name):
            self._step({"accum_read_tile_addr": tile})

    def _convert_to_binary(
        self, matrix: np.ndarray, dtype: Type[BaseFloat]
    ) -> np.ndarray:
        """Convert matrix to binary representation"""
        return np.array([[dtype(val).binint for val in row] for row in matrix])

    # def _step(self, inputs: dict = None):
    #     """Advance simulation one step and record state"""
    #     sim_inputs = self._get_default_inputs(inputs or {})
    #     self.sim.step(sim_inputs)

    #     # Record simulation state
    #     state = SimulationState(
    #         step=len(self.history),
    #         inputs=sim_inputs,
    #         buffer_state={
    #             "data": self._inspect_buffer_data(),
    #             "weights": self._inspect_buffer_weights(),
    #         },
    #         systolic_state={
    #             "weights": self.engine.systolic_array.inspect_weights(self.sim, False),
    #             "data": self.engine.systolic_array.inspect_data(self.sim, False),
    #             "accumulators": self.engine.systolic_array.inspect_accumulators(self.sim, False),
    #             "outputs": self.engine.systolic_array.inspect_outputs(self.sim, False)
    #         },
    #         accumulator_state=self._inspect_accumulator_state()
    #     )
    #     self.history.append(state)

    def _inspect_buffer_data(self) -> dict:
        """Inspect current state of buffer data memories"""
        return {
            bank: self.sim.inspect_mem(self.engine.buffer.data_mems[bank])
            for bank in range(2)
        }

    def _inspect_buffer_weights(self) -> dict:
        """Inspect current state of buffer weight memories"""
        return {
            bank: self.sim.inspect_mem(self.engine.buffer.weight_mems[bank])
            for bank in range(2)
        }

    def _inspect_accumulator_state(self) -> dict:
        """Inspect current state of accumulator memory banks"""
        mems = self.engine.accumulator.memory_banks
        result = {}

        # Initialize empty lists for each address
        for addr in range(self.config.array_size * self.config.accumulator_tiles):
            result[addr] = []

        # Collect memory contents
        for mem in mems:
            d = self.sim.inspect_mem(mem)
            for addr in range(self.config.array_size * self.config.accumulator_tiles):
                result[addr].append(d.get(addr, 0))

        # Convert to numpy array and reshape into tiles
        tiles = [
            [float(self.config.accum_type(binint=x)) for x in tile]
            for tile in result.values()
        ]
        tiles = np.array(tiles)

        # Reshape into tile matrices
        result_3d = []
        for i in range(self.config.accumulator_tiles):
            start_idx = i * self.config.array_size
            end_idx = start_idx + self.config.array_size
            result_3d.append(tiles[start_idx:end_idx])

        return {"tiles": result_3d}

    def get_accumulator_outputs(self) -> np.ndarray:
        """Get current values on accumulator output ports"""
        return np.array(
            [
                float(
                    self.config.accum_type(
                        binint=self.sim.inspect(
                            self.engine.accumulator.get_output(i).name
                        )
                    )
                )
                for i in range(self.config.array_size)
            ]
        )

    def print_state(self, step: int = -1):
        """Print simulation state at specified step"""
        if not self.history:
            print("No simulation history available")
            return

        state = self.history[step]
        print(state)

In [18]:
config = AcceleratorConfig(
    array_size=3,
    data_type=BF16,
    weight_type=BF16,
    accum_type=BF16,
    pe_adder=float_adder,
    pe_multiplier=lmul_fast,
    accum_adder=float_adder,
    pipeline=False,
    accumulator_tiles=4,
)

sim = MatrixEngineSimulator(config)

weights_matrix = np.identity(config.array_size)
activations_matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])


# Load matrices
sim.load_weights(weights_matrix, bank=0)
sim.load_activations(activations_matrix, bank=0)

# Perform matrix multiplication
result = sim.matrix_multiply(data_tile=0, weight_tile=0, accum_tile=0, accumulate=False)

# Inspect simulation history
sim.print_state()  # Print latest state
sim.print_state(step=0)  # Print initial state

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

# Prompt

## Code

```python
# systolic.py
from abc import ABC, abstractmethod
from typing import Callable, List, Type

import numpy as np
from pyrtl import Const, Register, Simulation, WireVector, chop

from hardware_accelerators import *
from hardware_accelerators.simulation import *

from ..dtypes.base import BaseFloat
from .processing_element import ProcessingElement

# TODO: Add float type conversion logic to pass different bitwidths to the accumulator
# TODO: specify different dtypes for weights and activations


class BaseSystolicArray(ABC):
    def __init__(
        self,
        size: int,
        data_type: Type[BaseFloat],
        accum_type: Type[BaseFloat],
        multiplier: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector],
        adder: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector],
    ):
        """Base class for implementing systolic array hardware structures

        Args:
            size: N for NxN array
            data_type: Number format for inputs (Float8, BF16)
            accum_type: Number format for accumulation
            multiplier: Multiplier implementation to use
            adder: Adder implementation to use
        """
        # Set configuration attributes
        self.size = size
        self.data_type = data_type
        self.accum_type = accum_type
        data_width = data_type.bitwidth()
        accum_width = accum_type.bitwidth()
        self.multiplier = multiplier
        self.adder = adder

        # Input wires
        self.data_in = [WireVector(data_width) for _ in range(size)]
        self.weights_in = [WireVector(data_width) for _ in range(size)]
        self.results_out = [WireVector(accum_width) for _ in range(size)]

        # Control wires
        self.weight_enable = WireVector(1)

        # Create PE array
        self.pe_array = self._create_pe_array()

        # Connect PEs in systolic pattern based on dataflow type
        self._connect_array()

    @abstractmethod
    def _create_pe_array(self) -> List[List[ProcessingElement]]:
        return [
            [
                ProcessingElement(
                    self.data_type,
                    self.accum_type,
                    self.multiplier,
                    self.adder,
                )
                for _ in range(self.size)
            ]
            for _ in range(self.size)
        ]

    @abstractmethod
    def _connect_array(self):
        pass

    # -----------------------------------------------------------------------------
    # Connection functions return their inputs to allow for more concise simulation
    # -----------------------------------------------------------------------------
    def connect_weight_enable(self, source: WireVector):
        """Connect weight load enable signal"""
        self.weight_enable <<= source
        return source

    def connect_data_input(self, row: int, source: WireVector):
        """Connect data input for specified row"""
        assert 0 <= row < self.size
        self.data_in[row] <<= source
        return source

    def connect_weight_input(self, col: int, source: WireVector):
        """Connect weight input for specified column"""
        assert 0 <= col < self.size
        self.weights_in[col] <<= source
        return source

    def connect_result_output(self, col: int, dest: WireVector):
        """Connect result output from specified column"""
        assert 0 <= col < self.size
        dest <<= self.results_out[col]
        return dest

    # -----------------------------------------------------------------------------
    # Simulation helper methods
    # -----------------------------------------------------------------------------
    def inspect_weights(self, sim: Simulation, verbose: bool = True):
        weights = np.zeros((self.size, self.size))
        enabled = sim.inspect(self.weight_enable.name) == 1
        for row in range(self.size):
            for col in range(self.size):
                w = sim.inspect(self.pe_array[row][col].weight_reg.name)
                weights[row][col] = self.data_type(binint=w).decimal_approx
        if verbose:
            print(f"Weights: {enabled=}")
            print(np.array_str(weights, precision=3, suppress_small=True), "\n")
        return weights

    def inspect_data(self, sim: Simulation, verbose: bool = True):
        data = np.zeros((self.size, self.size))
        for row in range(self.size):
            for col in range(self.size):
                d = sim.inspect(self.pe_array[row][col].data_reg.name)
                data[row][col] = self.data_type(binint=d).decimal_approx
        if verbose:
            print("Data:")
            print(np.array_str(data, precision=4, suppress_small=True), "\n")
        return data

    def inspect_accumulators(self, sim: Simulation, verbose: bool = True):
        acc_regs = np.zeros((self.size, self.size))
        for row in range(self.size):
            for col in range(self.size):
                d = sim.inspect(self.pe_array[row][col].accum_reg.name)
                acc_regs[row][col] = self.accum_type(binint=d).decimal_approx
        if verbose:
            print("Data:")
            print(np.array_str(acc_regs, precision=4, suppress_small=True), "\n")
        return acc_regs

    def inspect_outputs(self, sim: Simulation, verbose: bool = True):
        current_results = np.zeros(self.size)
        for i in range(self.size):
            r = sim.inspect(self.results_out[i].name)
            current_results[i] = self.accum_type(binint=r).decimal_approx
        if verbose:
            print("Output:")
            print(np.array_str(current_results, precision=4, suppress_small=True), "\n")
        return current_results


class SystolicArrayDiP(BaseSystolicArray):
    def __init__(
        self,
        size: int,
        data_type: Type[BaseFloat],
        accum_type: Type[BaseFloat],
        multiplier: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector],
        adder: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector],
        pipeline: bool = False,
    ):
        """Initialize systolic array hardware structure.

        This class uses Ḏiagonal-I̱nput and P̱ermutated weight-stationary (DiP) dataflow.

        Args:
            size: N for NxN array
            data_type: Number format for inputs (Float8, BF16)
            accum_type: Number format for accumulation
            multiplier: Multiplier implementation to use
            adder: Adder implementation to use
            pipeline: Add pipeline register after multiplication in processing element
        """
        self.pipeline = pipeline

        # Control signal registers to propogate signal down the array
        self.enable_in = WireVector(1)
        self.control_registers = [Register(1) for _ in range(size)]

        # Control signal output
        self.control_out = WireVector(1)
        self.control_out <<= self.control_registers[-1]

        # If PEs contain an extra pipeline stage, 1 additional control reg is needed
        if self.pipeline:
            self.control_registers.append(Register(1))

        super().__init__(size, data_type, accum_type, multiplier, adder)

    def _create_pe_array(self) -> List[List[ProcessingElement]]:
        # Create PE array
        return [
            [
                ProcessingElement(
                    self.data_type,
                    self.accum_type,
                    self.multiplier,
                    self.adder,
                    self.pipeline,
                )
                for _ in range(self.size)
            ]
            for _ in range(self.size)
        ]

    def _connect_array(self):
        """Connect processing elements in DiP configuration
        - All data flows top to bottom diagonally shifted
        - Weights are loaded simultaneously across array
        - Data inputs arrive synchronously
        """
        # Connect control signal registers that propagate down the array
        self.control_registers[0].next <<= self.enable_in
        for i in range(1, len(self.control_registers)):
            self.control_registers[i].next <<= self.control_registers[i - 1]

        for row in range(self.size):
            for col in range(self.size):
                pe = self.pe_array[row][col]

                # Connect PE inputs:
                # First row gets external input, others connect to PE above
                if row == 0:
                    pe.connect_data_enable(self.enable_in)
                    pe.connect_data(self.data_in[col])
                    pe.connect_weight(self.weights_in[col])
                    pe.connect_accum(Const(0))

                # DiP config: PEs connected to previous row and data diagonally shifted by 1
                else:
                    pe.connect_data_enable(self.control_registers[row - 1])
                    pe.connect_data(self.pe_array[row - 1][col - self.size + 1])
                    pe.connect_weight(self.pe_array[row - 1][col])
                    pe.connect_accum(self.pe_array[row - 1][col])

                # Delay the control signal for accumulator by num pipeline stages
                if self.pipeline:
                    pe.connect_mul_enable(self.control_registers[row])
                    pe.connect_adder_enable(self.control_registers[row + 1])
                else:
                    pe.connect_adder_enable(self.control_registers[row])

                # Connect weight enable signal (shared by all PEs)
                pe.connect_weight_enable(self.weight_enable)

                # Connect bottom row results to output ports
                if row == self.size - 1:
                    self.results_out[col] <<= pe.outputs.accum

    def connect_enable_input(self, source: WireVector):
        """Connect PE enable signal. Controls writing to the data input register"""
        self.enable_in <<= source
        return source

    def connect_inputs(
        self,
        data_inputs: list[WireVector] | None = None,
        weight_inputs: list[WireVector] | None = None,
        enable_input: WireVector | None = None,
        weight_enable: WireVector | None = None,
    ) -> None:
        """Connect input control and data wires to the systolic array.

        Args:
            data_inputs: List of data input wires (data_width bits each)
                Input data for each row of the systolic array.
                Length must match array size.

            weight_inputs: List of weight input wires (data_width bits each)
                Input weights for each column of the systolic array.
                Length must match array size.

            enable_input: Enable signal for data streaming (1 bit)
                Controls writing to the data input register.

            weight_enable: Weight load enable signal (1 bit)
                Controls writing to the weight registers.

        Raises:
            AssertionError: If input wire widths don't match expected widths or
                            if input list lengths don't match array size.
        """
        if data_inputs is not None:
            assert (
                len(data_inputs) == self.size
            ), f"Expected {self.size} data inputs, got {len(data_inputs)}"
            for row, data_input in enumerate(data_inputs):
                assert (
                    len(data_input) == self.data_type.bitwidth()
                ), f"Data input {row} width mismatch. Expected {self.data_type.bitwidth()}, got {len(data_input)}"
                self.connect_data_input(row, data_input)

        if weight_inputs is not None:
            assert (
                len(weight_inputs) == self.size
            ), f"Expected {self.size} weight inputs, got {len(weight_inputs)}"
            for col, weight_input in enumerate(weight_inputs):
                assert (
                    len(weight_input) == self.data_type.bitwidth()
                ), f"Weight input {col} width mismatch. Expected {self.data_type.bitwidth()}, got {len(weight_input)}"
                self.connect_weight_input(col, weight_input)

        if enable_input is not None:
            assert len(enable_input) == 1, "Enable input must be 1 bit wide"
            self.connect_enable_input(enable_input)

        if weight_enable is not None:
            assert len(weight_enable) == 1, "Weight enable must be 1 bit wide"
            self.connect_weight_enable(weight_enable)

import copy
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Type

import numpy as np
from pyrtl import Input, Output, Simulation, WireVector, reset_working_block

from ..dtypes import *
from ..rtllib import *
from .utils import *


@dataclass
class SimulationState:
    """Stores the state of the systolic array at a given simulation step"""

    inputs: dict[str, Any]
    weights: np.ndarray
    data: np.ndarray
    outputs: np.ndarray
    accumulators: np.ndarray
    step: int

    def __repr__(self) -> str:
        """Pretty print the simulation state at this step"""
        width = 40
        sep = "-" * width

        return (
            f"\nSimulation State - Step {self.step}\n{sep}\n"
            f"Inputs:\n"
            f"  w_en: {self.inputs['w_en']}\n"
            f"  enable: {self.inputs['enable']}\n"
            f"  weights: {np.array2string(self.inputs['weights'], precision=4, suppress_small=True)}\n"
            f"  data: {np.array2string(self.inputs['data'], precision=4, suppress_small=True)}\n"
            f"\nWeights Matrix:\n{np.array2string(self.weights, precision=4, suppress_small=True)}\n"
            f"\nData Matrix:\n{np.array2string(self.data, precision=4, suppress_small=True)}\n"
            f"\nAccumulators:\n{np.array2string(self.accumulators, precision=4, suppress_small=True)}\n"
            f"\nOutputs:\n{np.array2string(self.outputs, precision=4, suppress_small=True)}\n"
            f"{sep}\n"
        )


class SystolicArraySimulator:
    def __init__(
        self,
        size: int,
        data_type: Type[BaseFloat] = BF16,
        accum_type: Type[BaseFloat] = BF16,
        multiplier: Callable[
            [WireVector, WireVector, Type[BaseFloat]], WireVector
        ] = lmul_fast,
        adder: Callable[
            [WireVector, WireVector, Type[BaseFloat]], WireVector
        ] = float_adder,
        pipeline: bool = False,
    ):
        """Initialize systolic array simulator

        Args:
            size: Dimension of systolic array (NxN)
            dtype: Number format to use (e.g. BF16)
            pipeline: Whether to use pipelined PEs
            multiplier: Multiplication implementation
            adder: Addition implementation
        """
        self.size = size
        self.dtype = data_type
        self.accum_type = accum_type
        self.pipeline = pipeline
        self.dwidth = data_type.bitwidth()
        self.accwidth = accum_type.bitwidth()
        self.multiplier = multiplier
        self.adder = adder
        self.history: List[SimulationState] = []

    def _setup(self):
        # Setup PyRTL simulation
        reset_working_block()

        # Initialize hardware
        self.array = SystolicArrayDiP(
            size=self.size,
            data_type=self.dtype,
            accum_type=self.accum_type,
            multiplier=self.multiplier,
            adder=self.adder,
            pipeline=self.pipeline,
        )

        self.w_en = self.array.connect_weight_enable(Input(1, "w_en"))
        self.enable = self.array.connect_enable_input(Input(1, "enable"))

        self.w_ins = [Input(self.dwidth, f"weight_{i}") for i in range(self.size)]
        self.d_ins = [Input(self.dwidth, f"data_{i}") for i in range(self.size)]
        self.acc_outs = [Output(self.dwidth, f"result_{i}") for i in range(self.size)]

        for i in range(self.size):
            self.array.connect_weight_input(i, self.w_ins[i])
            self.array.connect_data_input(i, self.d_ins[i])
            self.array.connect_result_output(i, self.acc_outs[i])

        self.sim = Simulation()
        self.history = []
        self.sim_inputs = {
            w.name: 0 for w in [self.w_en, self.enable, *self.w_ins, *self.d_ins]
        }

    @classmethod
    def matrix_multiply(
        cls,
        weights: np.ndarray,
        activations: np.ndarray,
        dtype: Optional[Type[BaseFloat]] = None,
    ) -> np.ndarray:
        """Perform matrix multiplication using systolic array

        Args:
            weights: Weight matrix
            activations: Activation matrix
            dtype: Optional number format override
            pipeline: Whether to use pipelined PEs

        Returns:
            Tuple of (result matrix, simulation history)
        """
        # Validate inputs
        if weights.shape != activations.shape:
            raise ValueError("Weight and activation matrices must have same shape")
        if weights.shape[0] != weights.shape[1]:
            raise ValueError("Only square matrices supported")

        # Create simulator instance
        size = weights.shape[0]
        dtype = dtype or BF16
        sim = cls(size=size, data_type=dtype)

        return sim.simulate(weights, activations)

    def simulate(self, weights: np.ndarray, activations: np.ndarray) -> np.ndarray:
        """Instance method to run simulation

        Args:
            weights: Weight matrix
            activations: Activation matrix

        Returns:
            Tuple of (result matrix, simulation history)
        """
        # Validate inputs
        if weights.shape != (self.size, self.size) or activations.shape != (
            self.size,
            self.size,
        ):
            raise ValueError(f"Input matrices must be {self.size}x{self.size}")

        # Convert and permutate matrices
        weights = convert_array_dtype(permutate_weight_matrix(weights), self.dtype)
        activations = convert_array_dtype(activations, self.dtype)

        self._setup()

        # Load weights except top row
        self.sim_inputs["w_en"] = 1
        for row in range(self.size - 1):
            for col in range(self.size):
                self.sim_inputs[self.w_ins[col].name] = weights[-row - 1][col]
            self._step()

        # Load top row weights and first data row
        self.sim_inputs["enable"] = 1
        for col in range(self.size):
            self.sim_inputs[self.w_ins[col].name] = weights[0][col]
            self.sim_inputs[self.d_ins[col].name] = activations[-1][col]
        self._step()

        # Disable weight loading
        self.sim_inputs["w_en"] = 0

        # Feed remaining data
        for row in range(1, self.size):
            for col in range(self.size):
                self.sim_inputs[self.d_ins[col].name] = activations[-row - 1][col]
            self._step()

        # Additional step to flush pipeline
        self._step()
        self._reset_inputs()

        # Collect results
        results = []
        for _ in range(self.size + int(self.pipeline)):
            self._step()
            results.insert(0, self.array.inspect_outputs(self.sim, False))

        return np.array(results)

    def _reset_inputs(self):
        """Reset all simulation inputs to 0"""
        for k in self.sim_inputs:
            self.sim_inputs[k] = 0

    def _step(self):
        """Advance simulation one step and record state"""
        self.sim.step(self.sim_inputs)

        # Record simulation state
        state = SimulationState(
            inputs=self.get_readable_inputs(),
            weights=self.array.inspect_weights(self.sim, False),
            data=self.array.inspect_data(self.sim, False),
            outputs=self.array.inspect_outputs(self.sim, False),
            accumulators=self.array.inspect_accumulators(self.sim, False),
            step=len(self.history),
        )
        self.history.append(state)

    def get_readable_inputs(self) -> dict[str, Any]:
        """Convert binary simulation inputs to human-readable floating point values.

        Returns:
            Dictionary with consolidated inputs:
                - 'w_en', 'enable': Binary control signals
                - 'weights': Array of weight input values in floating point
                - 'data': Array of data input values in floating point
        """
        readable = {
            "w_en": self.sim_inputs["w_en"],
            "enable": self.sim_inputs["enable"],
            "weights": np.zeros(len([k for k in self.sim_inputs if "weight_" in k])),
            "data": np.zeros(len([k for k in self.sim_inputs if "data_" in k])),
        }

        # Convert weight inputs to array
        for i, key in enumerate(sorted([k for k in self.sim_inputs if "weight_" in k])):
            binary = self.sim_inputs[key]
            readable["weights"][i] = self.dtype(binint=binary).decimal_approx

        # Convert data inputs to array
        for i, key in enumerate(sorted([k for k in self.sim_inputs if "data_" in k])):
            binary = self.sim_inputs[key]
            readable["data"][i] = self.dtype(binint=binary).decimal_approx

        return readable

    def __repr__(self) -> str:
        """Detailed representation of simulator configuration and state"""
        config = (
            f"SystolicArrayDiPSimulator(\n"
            f"  size: {self.size}x{self.size}\n"
            f"  data_type: {self.dtype.__name__}\n"
            f"  accum_type: {self.accum_type.__name__}\n"
            f"  pipeline: {self.pipeline}\n"
            f"  steps_simulated: {len(self.history)}\n"
        )

        if not self.history:
            return config + "  history: empty\n)"

        # Add summary of simulation history
        last_state = self.history[-1]
        history = (
            f"  Latest State (Step {last_state.step}):\n"
            f"    Weights Shape: {last_state.weights.shape}\n"
            f"    Data Shape: {last_state.data.shape}\n"
            f"    Latest Outputs: {np.array2string(last_state.outputs, precision=3)}\n"
        )

        return config + history + ")"
```

```python
# accumulator.py
from enum import IntEnum
from typing import Callable, Type

from pyrtl import (
    MemBlock,
    Register,
    RomBlock,
    WireVector,
    conditional_assignment,
    otherwise,
)

from ..dtypes.base import BaseFloat


class TiledAccumulatorFSM(IntEnum):
    IDLE = 0
    WRITING = 1


class ReadAccumulatorFSM(IntEnum):
    IDLE = 0
    READING = 1


class TiledAddressGenerator:
    """Hardware control unit for managing tiled memory access patterns.

    Provides dual finite state machines for independent read and write operations,
    with integrated support for accumulate/overwrite modes. Generates addresses
    for accessing data organized in tiles, where each tile contains array_size
    rows of data.

    Features:
    - Separate read/write FSMs for overlapped operation
    - Base address ROM for fast tile address computation
    - Mode control for accumulate vs overwrite operations
    - Row tracking within tiles
    - Status signals for external coordination
    """

    def __init__(self, tile_addr_width: int, array_size: int):
        """Initialize the address generator.

        Args:
            tile_addr_width: Number of bits for addressing tiles. Determines
                number of tiles as 2^tile_addr_width.
            array_size: Number of rows per tile, matching systolic array
                dimension. Also determines address increment pattern.

        The internal address width is computed to accommodate all required
        addresses (num_tiles * array_size locations).
        """
        self.array_size = array_size
        self.num_tiles = 2**tile_addr_width
        self.internal_addr_width = (self.num_tiles * array_size - 1).bit_length()

        # Base address ROM
        base_addrs = [i * array_size for i in range(self.num_tiles)]
        self.base_addr_rom = RomBlock(
            bitwidth=self.internal_addr_width,
            addrwidth=tile_addr_width,
            romdata=base_addrs,
        )

        # ================== Write Interface ==================
        self._tile_addr = WireVector(tile_addr_width)
        self._write_start = WireVector(1)
        self._write_mode = WireVector(1)  # 0=overwrite, 1=accumulate
        self._write_valid = WireVector(1)

        # Write state registers
        self.write_state = Register(1)
        self.write_addr = Register(self.internal_addr_width)
        self.write_row = Register(array_size.bit_length())
        self.write_mode_reg = Register(1)  # Stores mode for current operation

        # ================== Read Interface ==================
        self._read_tile_addr = WireVector(tile_addr_width)
        self._read_start = WireVector(1)

        # Read state registers
        self.read_state = Register(1)
        self.read_addr = Register(self.internal_addr_width)
        self.read_row = Register(array_size.bit_length())

        # Outputs
        self.write_addr_out = WireVector(self.internal_addr_width)
        self.write_enable = WireVector(1)
        self.write_busy = WireVector(1)
        self.write_done = WireVector(1)
        self.write_mode_out = WireVector(1)

        self.read_addr_out = WireVector(self.internal_addr_width)
        self.read_busy = WireVector(1)
        self.read_done = WireVector(1)

        self._implement_write_fsm()
        self._implement_read_fsm()

    def _implement_write_fsm(self):
        write_base = self.base_addr_rom[self._tile_addr]

        # Combinational outputs
        self.write_addr_out <<= self.write_addr
        self.write_enable <<= (
            self.write_state == TiledAccumulatorFSM.WRITING
        ) & self._write_valid
        self.write_busy <<= self.write_state == TiledAccumulatorFSM.WRITING
        self.write_done <<= (self.write_state == TiledAccumulatorFSM.WRITING) & (
            self.write_row == self.array_size
        )
        self.write_mode_out <<= self.write_mode_reg

        with conditional_assignment:
            # IDLE State
            with self.write_state == TiledAccumulatorFSM.IDLE:
                with self._write_start:
                    self.write_state.next |= TiledAccumulatorFSM.WRITING
                    self.write_addr.next |= write_base
                    self.write_row.next |= 0
                    self.write_mode_reg.next |= self._write_mode  # Latch mode

            # WRITING State
            with self.write_state == TiledAccumulatorFSM.WRITING:
                with self._write_valid:
                    with self.write_row == self.array_size - 1:
                        self.write_state.next |= TiledAccumulatorFSM.IDLE
                        self.write_row.next |= 0
                    with otherwise:
                        self.write_addr.next |= self.write_addr + 1
                        self.write_row.next |= self.write_row + 1

    def _implement_read_fsm(self):
        read_base = self.base_addr_rom[self._read_tile_addr]

        self.read_addr_out <<= self.read_addr
        self.read_busy <<= self.read_state == ReadAccumulatorFSM.READING
        self.read_done <<= (self.read_state == ReadAccumulatorFSM.READING) & (
            self.read_row == self.array_size - 1
        )

        with conditional_assignment:
            with self.read_state == ReadAccumulatorFSM.IDLE:
                with self._read_start:
                    self.read_state.next |= ReadAccumulatorFSM.READING
                    self.read_addr.next |= read_base
                    self.read_row.next |= 0

            with self.read_state == ReadAccumulatorFSM.READING:
                with self.read_row == self.array_size - 1:
                    self.read_state.next |= ReadAccumulatorFSM.IDLE
                with otherwise:
                    self.read_addr.next |= self.read_addr + 1
                    self.read_row.next |= self.read_row + 1

    # Write interface methods
    def connect_tile_addr(self, addr: WireVector) -> None:
        self._tile_addr <<= addr

    def connect_write_start(self, start: WireVector) -> None:
        self._write_start <<= start

    def connect_write_mode(self, mode: WireVector) -> None:
        self._write_mode <<= mode

    def connect_write_valid(self, valid: WireVector) -> None:
        self._write_valid <<= valid

    # Read interface methods
    def connect_read_tile_addr(self, addr: WireVector) -> None:
        self._read_tile_addr <<= addr

    def connect_read_start(self, start: WireVector) -> None:
        self._read_start <<= start


class AccumulatorMemoryBank:
    """Integrated memory system for storing and accumulating systolic array outputs.

    Combines an address generator with parallel memory banks to provide a complete
    storage subsystem for matrix multiplication results. Supports both direct
    overwrite and accumulation modes, with independent read and write operations.

    Features:
    - Parallel memory banks (one per systolic array column)
    - Integrated address generation
    - Accumulate/overwrite modes
    - Independent read/write interfaces
    - Status signals for external coordination
    """

    def __init__(
        self,
        tile_addr_width: int,
        array_size: int,
        data_type: Type[BaseFloat],
        adder: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector],
    ):
        """Initialize the memory bank system.

        Args:
            tile_addr_width: Number of bits for addressing tiles. Determines
                number of tiles as 2^tile_addr_width.
            array_size: Number of parallel memory banks, matching systolic
                array dimension.
            data_type: Number format for stored data (e.g. BF16, Float8).
                Determines memory word width.
            adder: Function implementing addition for the specified data_type.
                Used for accumulation mode.
        """
        self.array_size = array_size
        self.tile_addr_width = tile_addr_width
        self.data_width = data_type.bitwidth()
        self.data_type = data_type
        self.adder = adder

        # Instantiate address generator
        self.addr_gen = TiledAddressGenerator(
            tile_addr_width=tile_addr_width, array_size=array_size
        )

        # Input ports
        self._write_tile_addr = WireVector(self.tile_addr_width)
        self._write_start = WireVector(1)
        self._write_mode = WireVector(1)
        self._write_valid = WireVector(1)
        self._read_tile_addr = WireVector(self.tile_addr_width)
        self._read_start = WireVector(1)
        self._data_in = [WireVector(self.data_width) for i in range(array_size)]

        # Connect address generator
        self.addr_gen.connect_tile_addr(self._write_tile_addr)
        self.addr_gen.connect_write_start(self._write_start)
        self.addr_gen.connect_write_mode(self._write_mode)
        self.addr_gen.connect_write_valid(self._write_valid)
        self.addr_gen.connect_read_tile_addr(self._read_tile_addr)
        self.addr_gen.connect_read_start(self._read_start)

        # Create memory banks
        self.memory_banks = [
            MemBlock(
                bitwidth=self.data_width,
                addrwidth=self.addr_gen.internal_addr_width,
                name=f"bank_{i}",
            )
            for i in range(array_size)
        ]

        # Output ports
        self._data_out = [WireVector(self.data_width) for _ in range(array_size)]
        self.write_busy = self.addr_gen.write_busy
        self.write_done = self.addr_gen.write_done
        self.read_busy = self.addr_gen.read_busy
        self.read_done = self.addr_gen.read_done

        self._implement_memory_logic()

    def _implement_memory_logic(self):
        # Write logic
        for i, mem in enumerate(self.memory_banks):
            current_val = mem[self.addr_gen.write_addr_out]
            sum_result = self.adder(self._data_in[i], current_val, self.data_type)

            with conditional_assignment:
                with self.addr_gen.write_enable:
                    with self.addr_gen.write_mode_out:  # Accumulate mode
                        mem[self.addr_gen.write_addr_out] |= sum_result
                    with otherwise:  # Overwrite mode
                        mem[self.addr_gen.write_addr_out] |= self._data_in[i]

        # Read logic
        for i, mem in enumerate(self.memory_banks):
            self._data_out[i] <<= mem[self.addr_gen.read_addr_out]

    def connect_inputs(
        self,
        write_tile_addr: WireVector | None = None,
        write_start: WireVector | None = None,
        write_mode: WireVector | None = None,
        write_valid: WireVector | None = None,
        read_tile_addr: WireVector | None = None,
        read_start: WireVector | None = None,
        data_in: list[WireVector] | None = None,
    ) -> None:
        """Connect all input control and data wires to the accumulator bank.

        Args:
            write_tile_addr: Address of tile to write to (tile_addr_width bits)
                Used to select which tile receives the incoming data.

            write_start: Start signal for write operation (1 bit)
                Pulses high for one cycle to initiate a new write sequence.

            write_mode: Mode selection for write operation (1 bit)
                0 = overwrite mode: new data replaces existing values
                1 = accumulate mode: new data is added to existing values

            write_valid: Data valid signal for write operation (1 bit)
                High when input data is valid and should be written/accumulated

            read_tile_addr: Address of tile to read from (tile_addr_width bits)
                Used to select which tile's data to output.

            read_start: Start signal for read operation (1 bit)
                Pulses high for one cycle to initiate a new read sequence.

            data_in: List of data input wires (data_width bits each)
                Input data from systolic array, one wire per column.
                Length must match array_size.

        Raises:
            AssertionError: If input wire widths don't match expected widths or
                        if data_in length doesn't match array_size.
        """
        if write_tile_addr is not None:
            assert len(write_tile_addr) == self.tile_addr_width
            self._write_tile_addr <<= write_tile_addr

        if write_start is not None:
            assert len(write_start) == 1
            self._write_start <<= write_start

        if write_mode is not None:
            assert len(write_mode) == 1
            self._write_mode <<= write_mode

        if write_valid is not None:
            assert len(write_valid) == 1
            self._write_valid <<= write_valid

        if read_tile_addr is not None:
            assert len(read_tile_addr) == self.tile_addr_width
            self._read_tile_addr <<= read_tile_addr

        if read_start is not None:
            assert len(read_start) == 1
            self._read_start <<= read_start

        if data_in is not None:
            assert (
                len(data_in) == self.array_size
            ), f"Expected {self.array_size} data inputs, got {len(data_in)}"
            for i, wire in enumerate(data_in):
                assert (
                    len(wire) == self.data_width
                ), f"Data input {i} width mismatch. Expected {self.data_width}, got {len(wire)}"
                self._data_in[i] <<= wire

    @property
    def write_interface(self) -> dict:
        return {
            "tile_addr": self._write_tile_addr,
            "start": self._write_start,
            "mode": self._write_mode,
            "valid": self._write_valid,
            "data": self._data_in,
        }

    @property
    def read_interface(self) -> dict:
        return {
            "tile_addr": self._read_tile_addr,
            "start": self._read_start,
            "data": self._data_out,
        }

    def get_output(self, bank: int) -> WireVector:
        return self._data_out[bank]


from typing import Callable, Optional, Type

import numpy as np
from pyrtl import Input, Simulation, reset_working_block

from ..dtypes import BF16, BaseFloat
from ..rtllib.accumulators import AccumulatorMemoryBank
from ..rtllib.adders import float_adder


class AccumulatorBankSimulator:
    """Simulator for AccumulatorMemoryBank with integrated address generator"""

    def __init__(
        self,
        array_size: int,
        num_tiles: int,
        data_type: Type[BaseFloat] = BF16,
        adder: Callable = float_adder,
    ):
        """Initialize simulator configuration

        Args:
            array_size: Dimension of systolic array (NxN)
            num_tiles: Number of tiles to support
            data_type: Number format for data (default: BF16)
            adder: Floating point adder implementation
        """
        self.array_size = array_size
        self.num_tiles = num_tiles
        self.data_type = data_type
        self.tile_addr_width = (num_tiles - 1).bit_length()

        # Store configuration for setup
        self.config = {
            "array_size": array_size,
            "tile_addr_width": self.tile_addr_width,
            "data_type": data_type,
            "adder": adder,
        }
        self.sim = None

    def setup(self):
        """Initialize PyRTL simulation environment"""
        reset_working_block()

        # Input ports
        self._write_tile_addr = Input(self.tile_addr_width, "write_tile_addr")
        self._write_start = Input(1, "write_start")
        self._write_mode = Input(1, "write_mode")
        self._write_valid = Input(1, "write_valid")
        self._read_tile_addr = Input(self.tile_addr_width, "read_tile_addr")
        self._read_start = Input(1, "read_start")
        self._data_in = [
            Input(self.data_type.bitwidth(), f"data_in_{i}")
            for i in range(self.array_size)
        ]

        # Create accumulator bank
        self.acc_bank = AccumulatorMemoryBank(**self.config)
        self.acc_bank.connect_inputs(
            self._write_tile_addr,
            self._write_start,
            self._write_mode,
            self._write_valid,
            self._read_tile_addr,
            self._read_start,
            self._data_in,  # type: ignore
        )

        # Create simulation
        self.sim = Simulation()

        return self

    def _get_default_inputs(self, updates: dict = {}) -> dict:
        """Get dictionary of default input values with optional updates"""
        defaults = {
            "write_tile_addr": 0,
            "write_start": 0,
            "write_mode": 0,
            "write_valid": 0,
            "read_tile_addr": 0,
            "read_start": 0,
            **{f"data_in_{i}": 0 for i in range(self.array_size)},
        }
        defaults.update(updates)
        return defaults

    def write_tile(
        self,
        tile_addr: int,
        data: np.ndarray,
        accumulate: bool = False,
        check_bounds: bool = True,
    ) -> None:
        """Write data to specified tile

        Args:
            tile_addr: Destination tile address
            data: Input data array (array_size x array_size)
            accumulate: If True, accumulate with existing values
            check_bounds: If True, validate input dimensions
        """
        if self.sim is None:
            raise RuntimeError("Simulator not initialized. Call setup() first")

        if check_bounds:
            if tile_addr >= self.num_tiles or tile_addr < 0:
                raise ValueError(f"Tile address {tile_addr} out of range")
            if data.shape != (self.array_size, self.array_size):
                raise ValueError(f"Data must be {self.array_size}x{self.array_size}")

        # Convert data to binary format
        binary_data = [[self.data_type(val).binint for val in row] for row in data]

        # Start write operation
        self.sim.step(
            self._get_default_inputs(
                {
                    "write_tile_addr": tile_addr,
                    "write_start": 1,
                    "write_mode": int(accumulate),
                }
            )
        )

        # Write each row
        for row in binary_data:
            self.sim.step(
                self._get_default_inputs(
                    {
                        "write_tile_addr": tile_addr,
                        "write_mode": int(accumulate),
                        "write_valid": 1,
                        **{f"data_in_{i}": row[i] for i in range(self.array_size)},
                    }
                )
            )

    def read_tile(self, tile_addr: int) -> np.ndarray:
        """Read data from specified tile

        Args:
            tile_addr: Tile address to read from

        Returns:
            Array containing tile data
        """
        if self.sim is None:
            raise RuntimeError("Simulator not initialized. Call setup() first")

        results = []

        # Start read operation
        self.sim.step(
            self._get_default_inputs({"read_tile_addr": tile_addr, "read_start": 1})
        )

        # Read rows until done
        while True:
            self.sim.step(self._get_default_inputs({"read_tile_addr": tile_addr}))

            # Capture outputs
            row = [
                float(
                    self.data_type(
                        binint=self.sim.inspect(self.acc_bank.get_output(i).name)
                    )
                )
                for i in range(self.array_size)
            ]
            results.append(row)

            if self.sim.inspect(self.acc_bank.read_done.name):
                break

        return np.array(results[-self.array_size :])

    def get_all_tiles(self) -> np.ndarray:
        """Read all tile memories

        Returns:
            3D array of shape (num_tiles, array_size, array_size)
        """
        if self.sim is None:
            raise RuntimeError("Simulator not initialized. Call setup() first")

        mems = self.acc_bank.memory_banks
        result = {}

        # Initialize empty lists for each address
        for addr in range(self.array_size * self.num_tiles):
            result[addr] = []

        # Collect memory contents
        for mem in mems:
            d = self.sim.inspect_mem(mem)
            for addr in range(self.array_size * self.num_tiles):
                result[addr].append(d.get(addr, 0))

        # Convert to numpy array and reshape
        tiles = [
            [float(self.data_type(binint=x)) for x in tile] for tile in result.values()
        ]
        tiles = np.array(tiles)

        # Reshape into tile matrices
        result_3d = []
        for i in range(self.num_tiles):
            start_idx = i * self.array_size
            end_idx = start_idx + self.array_size
            result_3d.append(tiles[start_idx:end_idx])

        return np.array(result_3d)

    def print_state(self, message: Optional[str] = None):
        """Print current state of all tiles"""
        if message:
            print(f"\n{message}")

        tiles = self.get_all_tiles()
        print("\nTile States:")
        print("-" * 50)
        for i, tile in enumerate(tiles):
            print(f"Tile {i}:")
            print(np.array2string(tile, precision=2, suppress_small=True))
        print("-" * 50)
```

```python
# buffer.py
from dataclasses import dataclass
from ..dtypes import *

from typing import List, Type
from pyrtl import (
    MemBlock,
    WireVector,
    Register,
    conditional_assignment,
    otherwise,
    chop,
)


@dataclass
class BufferOutputs:
    """Container for buffer memory output wires.

    Attributes:
        datas_out: List of data output wires, one per array column
        weights_out: List of weight output wires, one per array column
        data_valid: Wire indicating valid data on datas_out
        weight_valid: Wire indicating valid weights on weights_out
    """

    datas_out: List[WireVector]
    """List of data output wires, one per array column"""
    weights_out: List[WireVector]
    """List of weight output wires, one per array column"""
    data_valid: WireVector
    """Wire indicating valid data on datas_out"""
    weight_valid: WireVector
    """Wire indicating valid weights on weights_out"""


class BufferMemory:
    """Dual-bank memory buffer for streaming data and weights to a systolic array.

    This class implements a memory buffer with separate banks for both data and weights,
    designed to feed a systolic array for matrix multiplication. It features:

    - Dual memory banks for both data and weights enabling ping-pong buffering
    - Configurable data widths for both data and weight values
    - Controlled streaming of rows to systolic array
    - Status signals for monitoring buffer operations

    The buffer stores each row of the matrix as a single concatenated entry in memory,
    with the bitwidth scaled by the array size. This enables efficient reading of full
    rows during streaming operations.

    Input Control Wires:
        - data_start_in (1 bit): Initiates data streaming operation
        - data_select_in (1 bit): Selects which data memory bank to read from (0 or 1)
        - weight_start_in (1 bit): Initiates weight streaming operation
        - weight_select_in (1 bit): Selects which weight memory bank to read from (0 or 1)

    Output Status Wires:
        - data_load_busy (1 bit): Indicates data streaming is in progress
        - data_load_done (1 bit): Indicates data streaming has completed
        - weight_load_busy (1 bit): Indicates weight streaming is in progress
        - weight_load_done (1 bit): Indicates weight streaming has completed

    Data Output Wires:
        - datas_out: List of data output wires (data_type.bitwidth() each)
        - weights_out: List of weight output wires (weight_type.bitwidth() each)

    Usage Example:
        buffer = BufferMemory(
            array_size=4,
            data_type=BF16,
            weight_type=BF16
        )

        # Connect control signals
        buffer.connect_inputs(
            data_start=control.data_start,
            data_select=control.data_select,
            weight_start=control.weight_start,
            weight_select=control.weight_select
        )

        # Access outputs
        outputs = buffer.get_outputs()
        systolic_array.connect_data_inputs(outputs.datas_out)
        systolic_array.connect_weight_inputs(outputs.weights_out)
    """

    def __init__(
        self, array_size: int, data_type: Type[BaseFloat], weight_type: Type[BaseFloat]
    ):
        """Initialize the buffer memory with specified dimensions and data types.

        Args:
            array_size: Size N of the NxN systolic array this buffer will feed.
                       Determines the number of parallel output wires and memory organization.

            data_type: Float data type for activation/data values (e.g., BF16, Float8).
                      Determines the bitwidth of data storage and output wires.

            weight_type: Float data type for weight values (e.g., BF16, Float8).
                        Determines the bitwidth of weight storage and output wires.

        Memory Organization:
            - Each memory bank contains array_size entries
            - Each entry stores one full row of the matrix
            - Entry bitwidth = type.bitwidth() * array_size

        The class automatically calculates:
            - Address width based on array_size
            - Memory entry width based on data types and array_size
            - Required control register sizes
        """
        # Configuration parameters
        self.array_size = array_size
        self.addr_width = (array_size - 1).bit_length()
        self.d_width = data_type.bitwidth()
        self.w_width = weight_type.bitwidth()
        self.data_mem_width = self.d_width * array_size
        self.weight_mem_width = self.w_width * array_size

        # Memory Banks
        self.data_mems = [
            MemBlock(bitwidth=self.data_mem_width, addrwidth=self.addr_width)
            for _ in range(2)
        ]
        self.weight_mems = [
            MemBlock(bitwidth=self.weight_mem_width, addrwidth=self.addr_width)
            for _ in range(2)
        ]

        # Control Inputs
        self.data_start = WireVector(1)  # Start data streaming
        self.data_bank = WireVector(1)  # Select data memory bank
        self.weight_start = WireVector(1)  # Start weight streaming
        self.weight_bank = WireVector(1)  # Select weight memory bank

        # State Registers
        self.data_active = Register(1)  # Data streaming in progress
        self.data_addr = Register(self.addr_width)
        self.weight_active = Register(1)  # Weight streaming in progress
        self.weight_addr = Register(self.addr_width)

        # Status Outputs
        self.data_valid = WireVector(1)  # Data output is valid
        self.weight_valid = WireVector(1)  # Weight output is valid

        # Data Outputs
        self.datas_out = [WireVector(self.d_width) for _ in range(array_size)]
        self.weights_out = [WireVector(self.w_width) for _ in range(array_size)]

        # Control Logic
        self._implement_control_logic()

    def _implement_control_logic(self):
        """Implement the control and datapath logic."""
        with conditional_assignment:
            # Data streaming control
            with self.data_start & ~self.data_active:
                self.data_active.next |= 1
                self.data_addr.next |= 0

            with self.data_active:
                # Generate valid signal
                self.data_valid |= 1

                # Stream data from selected memory bank
                with self.data_bank == 0:
                    for out, data in zip(
                        self.datas_out,
                        chop(
                            self.data_mems[0][self.data_addr],
                            *[self.d_width] * self.array_size,
                        ),
                    ):
                        out |= data
                with otherwise:
                    for out, data in zip(
                        self.datas_out,
                        chop(
                            self.data_mems[1][self.data_addr],
                            *[self.d_width] * self.array_size,
                        ),
                    ):
                        out |= data

                # Address counter and completion logic
                with self.data_addr == self.array_size - 1:
                    self.data_active.next |= 0
                with otherwise:
                    self.data_addr.next |= self.data_addr + 1

        with conditional_assignment:
            # Weight streaming control (mirror of data control)
            with self.weight_start & ~self.weight_active:
                self.weight_active.next |= 1
                self.weight_addr.next |= 0

            with self.weight_active:
                self.weight_valid |= 1

                with self.weight_bank == 0:
                    for out, weight in zip(
                        self.weights_out,
                        chop(
                            self.weight_mems[0][self.weight_addr],
                            *[self.w_width] * self.array_size,
                        ),
                    ):
                        out |= weight
                with otherwise:
                    for out, weight in zip(
                        self.weights_out,
                        chop(
                            self.weight_mems[1][self.weight_addr],
                            *[self.w_width] * self.array_size,
                        ),
                    ):
                        out |= weight

                with self.weight_addr == self.array_size - 1:
                    self.weight_active.next |= 0
                with otherwise:
                    self.weight_addr.next |= self.weight_addr + 1

    def connect_inputs(self, data_start, data_bank, weight_start, weight_bank):
        """Connect control signals for the buffer memory.

        Args:
            data_start: Start signal for data streaming (1 bit)
            data_bank: Data memory bank selection (1 bit)
            weight_start: Start signal for weight streaming (1 bit)
            weight_bank: Weight memory bank selection (1 bit)
        """
        if data_start is not None:
            assert len(data_start) == 1
            self.data_start <<= data_start

        if data_bank is not None:
            assert len(data_bank) == 1
            self.data_bank <<= data_bank

        if weight_start is not None:
            assert len(weight_start) == 1
            self.weight_start <<= weight_start

        if weight_bank is not None:
            assert len(weight_bank) == 1
            self.weight_bank <<= weight_bank

    def get_outputs(self) -> BufferOutputs:
        """Get all output wires from the buffer memory.

        Returns:
            BufferOutputs containing:
                - datas_out: List of data output wires [array_size]
                - weights_out: List of weight output wires [array_size]
                - data_valid: Indicates valid data on outputs
                - weight_valid: Indicates valid weights on outputs

        The valid signals should be used to enable downstream components:
        - weight_valid connects to systolic array's weight_enable
        - data_valid indicates when data values are ready to be consumed
        """
        return BufferOutputs(
            datas_out=self.datas_out,
            weights_out=self.weights_out,
            data_valid=self.data_valid,
            weight_valid=self.weight_valid,
        )


from dataclasses import dataclass
from ..dtypes import *

from typing import List, Type
from pyrtl import (
    MemBlock,
    WireVector,
    Register,
    conditional_assignment,
    otherwise,
    chop,
)


@dataclass
class BufferOutputs:
    """Container for buffer memory output wires.

    Attributes:
        datas_out: List of data output wires, one per array column
        weights_out: List of weight output wires, one per array column
        data_valid: Wire indicating valid data on datas_out
        weight_valid: Wire indicating valid weights on weights_out
    """

    datas_out: List[WireVector]
    """List of data output wires, one per array column"""
    weights_out: List[WireVector]
    """List of weight output wires, one per array column"""
    data_valid: WireVector
    """Wire indicating valid data on datas_out"""
    weight_valid: WireVector
    """Wire indicating valid weights on weights_out"""


class BufferMemory:
    """Dual-bank memory buffer for streaming data and weights to a systolic array.

    This class implements a memory buffer with separate banks for both data and weights,
    designed to feed a systolic array for matrix multiplication. It features:

    - Dual memory banks for both data and weights enabling ping-pong buffering
    - Configurable data widths for both data and weight values
    - Controlled streaming of rows to systolic array
    - Status signals for monitoring buffer operations

    The buffer stores each row of the matrix as a single concatenated entry in memory,
    with the bitwidth scaled by the array size. This enables efficient reading of full
    rows during streaming operations.

    Input Control Wires:
        - data_start_in (1 bit): Initiates data streaming operation
        - data_select_in (1 bit): Selects which data memory bank to read from (0 or 1)
        - weight_start_in (1 bit): Initiates weight streaming operation
        - weight_select_in (1 bit): Selects which weight memory bank to read from (0 or 1)

    Output Status Wires:
        - data_load_busy (1 bit): Indicates data streaming is in progress
        - data_load_done (1 bit): Indicates data streaming has completed
        - weight_load_busy (1 bit): Indicates weight streaming is in progress
        - weight_load_done (1 bit): Indicates weight streaming has completed

    Data Output Wires:
        - datas_out: List of data output wires (data_type.bitwidth() each)
        - weights_out: List of weight output wires (weight_type.bitwidth() each)

    Usage Example:
        buffer = BufferMemory(
            array_size=4,
            data_type=BF16,
            weight_type=BF16
        )

        # Connect control signals
        buffer.connect_inputs(
            data_start=control.data_start,
            data_select=control.data_select,
            weight_start=control.weight_start,
            weight_select=control.weight_select
        )

        # Access outputs
        outputs = buffer.get_outputs()
        systolic_array.connect_data_inputs(outputs.datas_out)
        systolic_array.connect_weight_inputs(outputs.weights_out)
    """

    def __init__(
        self, array_size: int, data_type: Type[BaseFloat], weight_type: Type[BaseFloat]
    ):
        """Initialize the buffer memory with specified dimensions and data types.

        Args:
            array_size: Size N of the NxN systolic array this buffer will feed.
                       Determines the number of parallel output wires and memory organization.

            data_type: Float data type for activation/data values (e.g., BF16, Float8).
                      Determines the bitwidth of data storage and output wires.

            weight_type: Float data type for weight values (e.g., BF16, Float8).
                        Determines the bitwidth of weight storage and output wires.

        Memory Organization:
            - Each memory bank contains array_size entries
            - Each entry stores one full row of the matrix
            - Entry bitwidth = type.bitwidth() * array_size

        The class automatically calculates:
            - Address width based on array_size
            - Memory entry width based on data types and array_size
            - Required control register sizes
        """
        # Configuration parameters
        self.array_size = array_size
        self.addr_width = (array_size - 1).bit_length()
        self.d_width = data_type.bitwidth()
        self.w_width = weight_type.bitwidth()
        self.data_mem_width = self.d_width * array_size
        self.weight_mem_width = self.w_width * array_size

        # Memory Banks
        self.data_mems = [
            MemBlock(bitwidth=self.data_mem_width, addrwidth=self.addr_width)
            for _ in range(2)
        ]
        self.weight_mems = [
            MemBlock(bitwidth=self.weight_mem_width, addrwidth=self.addr_width)
            for _ in range(2)
        ]

        # Control Inputs
        self.data_start = WireVector(1)  # Start data streaming
        self.data_bank = WireVector(1)  # Select data memory bank
        self.weight_start = WireVector(1)  # Start weight streaming
        self.weight_bank = WireVector(1)  # Select weight memory bank

        # State Registers
        self.data_active = Register(1)  # Data streaming in progress
        self.data_addr = Register(self.addr_width)
        self.weight_active = Register(1)  # Weight streaming in progress
        self.weight_addr = Register(self.addr_width)

        # Status Outputs
        self.data_valid = WireVector(1)  # Data output is valid
        self.weight_valid = WireVector(1)  # Weight output is valid

        # Data Outputs
        self.datas_out = [WireVector(self.d_width) for _ in range(array_size)]
        self.weights_out = [WireVector(self.w_width) for _ in range(array_size)]

        # Control Logic
        self._implement_control_logic()

    def _implement_control_logic(self):
        """Implement the control and datapath logic."""
        with conditional_assignment:
            # Data streaming control
            with self.data_start & ~self.data_active:
                self.data_active.next |= 1
                self.data_addr.next |= 0

            with self.data_active:
                # Generate valid signal
                self.data_valid |= 1

                # Stream data from selected memory bank
                with self.data_bank == 0:
                    for out, data in zip(
                        self.datas_out,
                        chop(
                            self.data_mems[0][self.data_addr],
                            *[self.d_width] * self.array_size,
                        ),
                    ):
                        out |= data
                with otherwise:
                    for out, data in zip(
                        self.datas_out,
                        chop(
                            self.data_mems[1][self.data_addr],
                            *[self.d_width] * self.array_size,
                        ),
                    ):
                        out |= data

                # Address counter and completion logic
                with self.data_addr == self.array_size - 1:
                    self.data_active.next |= 0
                with otherwise:
                    self.data_addr.next |= self.data_addr + 1

        with conditional_assignment:
            # Weight streaming control (mirror of data control)
            with self.weight_start & ~self.weight_active:
                self.weight_active.next |= 1
                self.weight_addr.next |= 0

            with self.weight_active:
                self.weight_valid |= 1

                with self.weight_bank == 0:
                    for out, weight in zip(
                        self.weights_out,
                        chop(
                            self.weight_mems[0][self.weight_addr],
                            *[self.w_width] * self.array_size,
                        ),
                    ):
                        out |= weight
                with otherwise:
                    for out, weight in zip(
                        self.weights_out,
                        chop(
                            self.weight_mems[1][self.weight_addr],
                            *[self.w_width] * self.array_size,
                        ),
                    ):
                        out |= weight

                with self.weight_addr == self.array_size - 1:
                    self.weight_active.next |= 0
                with otherwise:
                    self.weight_addr.next |= self.weight_addr + 1

    def connect_inputs(self, data_start, data_bank, weight_start, weight_bank):
        """Connect control signals for the buffer memory.

        Args:
            data_start: Start signal for data streaming (1 bit)
            data_bank: Data memory bank selection (1 bit)
            weight_start: Start signal for weight streaming (1 bit)
            weight_bank: Weight memory bank selection (1 bit)
        """
        if data_start is not None:
            assert len(data_start) == 1
            self.data_start <<= data_start

        if data_bank is not None:
            assert len(data_bank) == 1
            self.data_bank <<= data_bank

        if weight_start is not None:
            assert len(weight_start) == 1
            self.weight_start <<= weight_start

        if weight_bank is not None:
            assert len(weight_bank) == 1
            self.weight_bank <<= weight_bank

    def get_outputs(self) -> BufferOutputs:
        """Get all output wires from the buffer memory.

        Returns:
            BufferOutputs containing:
                - datas_out: List of data output wires [array_size]
                - weights_out: List of weight output wires [array_size]
                - data_valid: Indicates valid data on outputs
                - weight_valid: Indicates valid weights on outputs

        The valid signals should be used to enable downstream components:
        - weight_valid connects to systolic array's weight_enable
        - data_valid indicates when data values are ready to be consumed
        """
        return BufferOutputs(
            datas_out=self.datas_out,
            weights_out=self.weights_out,
            data_valid=self.data_valid,
            weight_valid=self.weight_valid,
        )
```

```python
# matrix_engine.py
@dataclass
class AcceleratorConfig:
    """Configuration class for a systolic array accelerator.

    This class defines the parameters and specifications for a systolic array
    accelerator including array dimensions, data types, arithmetic operations,
    and memory configuration.
    """

    array_size: int
    """Dimension of systolic array (always square)"""

    data_type: Type[BaseFloat]
    """Floating point format of input data to systolic array"""

    weight_type: Type[BaseFloat]
    """Floating point format of weight inputs"""

    accum_type: Type[BaseFloat]
    """Floating point format to accumulate values in"""

    pe_adder: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector]
    """Function to generate adder hardware for the processing elements"""

    accum_adder: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector]
    """Function to generate adder hardware for the accumulator buffer"""

    pe_multiplier: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector]
    """Function to generate multiplier hardware for the processing elements"""

    pipeline: bool
    """Whether to add a pipeline stage in processing elements between multiplier and adder"""

    accumulator_tiles: int
    """Number of tiles in the accumulator memory, each tile is equal to the size of the systolic array"""

    @property
    def accum_addr_width(self):
        """Get the width of the accumulator address bus in bits"""
        return (self.accumulator_tiles - 1).bit_length()
    

class MatrixEngine:
    """Hardware implementation of the matrix engine accelerator"""
    
    def __init__(self, config: AcceleratorConfig):
        """Initialize matrix engine hardware with given configuration
        
        Args:
            config: Configuration parameters for the accelerator
        """
        self.config = config
        
        # Create input control signals
        self.data_start = Input(1, "data_start")
        self.data_bank = Input(1, "data_bank") 
        self.weight_start = Input(1, "weight_start")
        self.weight_bank = Input(1, "weight_bank")
        
        self.accum_start = Input(1, "accum_start")
        self.accum_tile_addr = Input(config.accum_addr_width, "accum_tile_addr")
        self.accum_mode = Input(1, "accum_mode")
        
        self.accum_read_start = Input(1, "accum_read_start")
        self.accum_read_tile_addr = Input(config.accum_addr_width, "accum_read_tile_addr")
        
        # Initialize hardware components
        self.buffer = BufferMemory(
            array_size=config.array_size,
            data_type=config.data_type,
            weight_type=config.weight_type
        )
        
        self.systolic_array = SystolicArrayDiP(
            size=config.array_size,
            data_type=config.data_type,
            accum_type=config.accum_type,
            multiplier=config.pe_multiplier,
            adder=config.pe_adder,
            pipeline=config.pipeline
        )
        
        self.accumulator = AccumulatorMemoryBank(
            tile_addr_width=config.accum_addr_width,
            array_size=config.array_size,
            data_type=config.accum_type,
            adder=config.accum_adder
        )
        
        # Connect components
        self._connect_components()
    
    def _connect_components(self):
        """Connect all hardware components together"""
        # Connect buffer control signals
        self.buffer.connect_inputs(
            data_start=self.data_start,
            data_bank=self.data_bank,
            weight_start=self.weight_start,
            weight_bank=self.weight_bank
        )
        
        # Get buffer outputs
        buffer_outputs = self.buffer.get_outputs()
        
        # Connect systolic array to buffer
        self.systolic_array.connect_inputs(
            weight_enable=buffer_outputs.weight_valid,
            weight_inputs=buffer_outputs.weights_out,
            data_inputs=buffer_outputs.datas_out,
            enable_input=buffer_outputs.data_valid
        )
        
        # Connect accumulator
        self.accumulator.connect_inputs(
            data_in=self.systolic_array.results_out,
            write_start=self.accum_start,
            write_tile_addr=self.accum_tile_addr,
            write_mode=self.accum_mode,
            write_valid=self.systolic_array.control_out,
            read_start=self.accum_read_start,
            read_tile_addr=self.accum_read_tile_addr
        )
```


## Instructions


I am now combining all of these components together to construct the full accelerator matrix engine. You can see how I tie all the inputs and outputs together above, as well as the simulation handler classes for all of the individual hardware modules. I would like you to help me design a matrix engine simulation class that exposes methods for performing various actions on the hardware, this should include load weights to a buffer tile, load activations to buffer, matrix multiply (should take inputs data tile num, weight tile num, accumulate mode, and accumulator tile). You should also provide helper methods to inspect the important values stored in the simulation at each step and keep track of the history throughout the simulation. For the matrix engine class itself (not simulation) that constructs the hardware, we should initialize it using a single parameter called accelerator config which is also shown above. 

Use the code examples to build these two classes, be sure to be completely thorough with your simulation implementation and define all the required methods described here. 

All inspect functions should NOT step the simulation ever, only read values, never interact with the simulation. Reading the accumulator tile and inspecting the accumulator memory are two different things. You should provide separate methods for each, and each method that interacts with the simulation should generally avoid calling other methods that interact with the simulation.

Please define these inspect methods for the matrix engine class, they should accept the pyrtl simulation as an argument and use that to inspect memories, weights, registers, etc. This will keep the simulation class more lightweight.

You should design these methods, and the simulation class, so that it makes tracking values over the course of a simulation easy which will help with debugging anything that breaks. Additionally, the load weights method of the sim class should call the permutate weight matrix util function before loading weights into the buffer memory in simulation. This is due to how the systolic array dataflow works internally. 

Additionally, the matrix engine class should not create Input wires on initialization, it should create unnamed normal wirevectors. Do NOT add a name to these wires, those are only for Input and Output wires, naming the regular wirevectors might cause duplicate naming issues. Since there are a lot of wires to manage here, what you should do is create a helper method called connect inputs similar to the one in the systolic array class, each wire should have a docstring for its respective input parameter. You should then make a utility method called create_sim_inputs or something similar that calls the connect inputs method and passes Input() wires with appropriate names. This method doesn't have to return anything if you dont want, but it could be useful to return the names of the wires created for use in the simulation. You could return nothing, a list/tuple of names, or a dict with names as keys and 0 for values. 

If you want to refactor anything else about the matrix engine class to make it work better or optimize it, feel free to do so if it will help in your implementation of the simulation class.

I should also mention that the weight/data loading functions of the simulation should not actually step the simulation, we are going to directly modify the values stored in the simulation 


Here is a high level specification to help you design the matrix engine additional methods and simulation class:

```markdown
## Matrix Engine Simulator Specification

## Overview
The MatrixEngineSimulator class provides a high-level interface for simulating a systolic array-based matrix multiplication accelerator. It manages the simulation of three main hardware components: buffer memory, systolic array, and accumulator memory bank.

## Key Requirements
1. Initialize with an AcceleratorConfig object that defines hardware parameters
2. Provide methods to load and manipulate data without directly exposing hardware signals
3. Record simulation history for debugging and analysis
4. Separate simulation control from state inspection
5. Use inspection methods defined in the MatrixEngine class

## Core Functionality

### Initialization
- Accept AcceleratorConfig parameter
- Create MatrixEngine instance
- Initialize PyRTL simulation
- Set up history tracking

### Data Loading Operations
1. `load_weights(weights: np.ndarray, bank: int)`
   - Load weight matrix into specified buffer bank (0 or 1)
   - Validate matrix dimensions match array size
   - Convert floating-point values to binary representation
   - Handle memory bank loading

2. `load_activations(activations: np.ndarray, bank: int)`
   - Load activation matrix into specified buffer bank (0 or 1)
   - Validate matrix dimensions match array size
   - Convert floating-point values to binary representation
   - Handle memory bank loading

### Computation Operations
1. `matrix_multiply(data_tile: int, weight_tile: int, accum_tile: int, accumulate: bool)`
   - Initiate matrix multiplication using specified tiles
   - Control data streaming from buffer banks
   - Manage accumulator writing
   - Handle proper timing for systolic array operation

2. `read_accumulator_tile(tile: int)`
   - Initiate read operation for specified accumulator tile
   - Handle timing for accumulator read sequence

### Simulation Control
1. `_step(inputs: dict = None)`
   - Advance simulation by one cycle
   - Apply input signals
   - Record simulation state
   - Use MatrixEngine inspection methods to capture state

2. `_get_default_inputs(updates: dict = {})`
   - Provide default values for all control signals
   - Allow selective updates of signals

### State Management
1. `SimulationState` dataclass to capture:
   - Current simulation step
   - Active input signals
   - Buffer memory state
   - Systolic array state
   - Accumulator memory state

2. History tracking:
   - Maintain list of SimulationState objects
   - Allow access to previous states
   - Support debugging and analysis

### Utility Methods
1. `_convert_to_binary(matrix: np.ndarray, dtype: Type[BaseFloat])`
   - Convert floating-point matrices to binary representation
   - Handle proper number format conversion

2. `print_state(step: int = -1)`
   - Display formatted simulation state
   - Support viewing historical states

## MatrixEngine Inspection Methods
The following inspection methods should be defined in the MatrixEngine class:

1. `inspect_buffer_state(sim: Simulation)`
   - Read buffer memory contents
   - Return data and weight bank states

2. `inspect_systolic_array_state(sim: Simulation)`
   - Read weights matrix
   - Read data matrix
   - Read accumulator values
   - Read output values

3. `inspect_accumulator_state(sim: Simulation)`
   - Read all accumulator tile contents
   - Return formatted tile data

4. `get_accumulator_outputs(sim: Simulation)`
   - Read current values on accumulator output ports

## Usage Pattern
```python
# Initialize
sim = MatrixEngineSimulator(config)

# Load data
sim.load_weights(weights, bank=0)
sim.load_activations(activations, bank=0)

# Perform computation
sim.matrix_multiply(data_tile=0, weight_tile=0, accum_tile=0)
sim.read_accumulator_tile(0)

# Access results using engine inspection methods
results = sim.engine.get_accumulator_outputs(sim.sim)

# Debug/analyze
sim.print_state()
```

## Key Design Principles
1. Clear separation between simulation control and state inspection
2. No recursive simulation steps in inspection methods
3. Proper validation of inputs and configurations
4. Comprehensive state tracking for debugging
5. Clean interface hiding hardware complexity
6. Reusable inspection methods in MatrixEngine class

This specification provides a framework for implementing a clean, maintainable simulator that properly manages hardware simulation while providing useful debugging and analysis capabilities.
```

# Testing

In [8]:
from hardware_accelerators.simulation import TiledMatrixEngineSimulator
import hardware_accelerators

In [9]:
# Initialize with accelerator configuration
config = hardware_accelerators.rtllib.TiledAcceleratorConfig(
    array_size=4,
    data_type=BF16,
    weight_type=BF16,
    accum_type=BF16,
    pe_adder=float_adder,
    accum_adder=float_adder,
    pe_multiplier=lmul_fast,
    pipeline=True,
    accumulator_tiles=8,
)

set_debug_mode(False)

sim = TiledMatrixEngineSimulator(config)

# Load data
weights = np.identity(config.array_size)
activations = np.random.randn(config.array_size, config.array_size)
print(f"Original Weights:\n{weights}")
print(f"Original Activations:\n{activations}\n")

sim.load_weights(weights, bank=0)
sim.load_activations(activations, bank=0)

print(f"Weights:\n{sim.weight_banks}")
print(f"Activations:\n{sim.data_banks}")

Original Weights:
[[1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]]
Original Activations:
[[ 0.55018066 -0.5175061   0.24944585  0.16355582]
 [ 1.94820671  0.41952214 -0.26568817  0.45657394]
 [ 0.70380903 -0.05924119 -0.92877048  1.52801445]
 [ 2.18712271  1.69032911 -0.73782007 -0.69169502]]

Weights:
[[[1. 1. 1. 1.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]

 [[0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]]
Activations:
[[[ 0.546875   -0.515625    0.24902344  0.16308594]
  [ 1.9453125   0.41796875 -0.265625    0.45507812]
  [ 0.703125   -0.05908203 -0.92578125  1.5234375 ]
  [ 2.171875    1.6875     -0.734375   -0.69140625]]

 [[ 0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.        ]]]


In [10]:
# Perform computation
sim.matmul(data_bank=0, weight_bank=0, accum_tile=0, accumulate=False)
# Read results
result = sim.read_accumulator_tile(0)
print("Result:\n", result)

# Debug
sim.print_history()

Result:
 [[ 2.296875    1.75       -0.765625   -0.72265625]
 [ 0.734375   -0.06103516 -0.95703125  1.5859375 ]
 [ 2.015625    0.43359375 -0.28125     0.47070312]
 [ 0.578125   -0.546875    0.26367188  0.17089844]]

Simulation Step 0

Input Signals:
--------------------------------------------------------------------------------
  data_start: 0
  data_bank: 0
  weight_start: 1
  weight_bank: 0
  accum_start: 0
  accum_tile_addr: 0
  accum_mode: 0
  accum_read_start: 0
  accum_read_tile_addr: 0

Buffer Memory State:
--------------------------------------------------------------------------------
Data Bank 0:
[[ 0.5469 -0.5156  0.249   0.1631]
 [ 1.9453  0.418  -0.2656  0.4551]
 [ 0.7031 -0.0591 -0.9258  1.5234]
 [ 2.1719  1.6875 -0.7344 -0.6914]]
Data Bank 1:
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]
Weight Bank 0:
[[1. 1. 1. 1.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]
Weight Bank 1:
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]

Systolic Array Sta

In [6]:
2**8

256