In [13]:
from pyrtl import *
from enum import IntEnum
import numpy as np
from typing import List, Type, Callable, Optional
from hardware_accelerators import *
from hardware_accelerators.dtypes import BaseFloat, BF16
from hardware_accelerators.rtllib import float_adder
from hardware_accelerators.simulation.repr_funcs import *

# Address Generator

## Description

The `TiledAddressGenerator` is a sophisticated control unit designed to manage memory access patterns for a tiled matrix multiplication accelerator. It provides two independent interfaces - one for writing data and one for reading data - each with its own FSM to handle the sequential access patterns required for tile-based operations.

### Design Theory

The address generator is built around the concept of tiles - logical groupings of matrix data that can be processed independently. Each tile consists of `array_size` rows of data, where each row contains results from one column of the systolic array. The tiles are stored contiguously in memory, with each tile starting at a base address computed as `tile_number * array_size`.

### Key Features:
1. **Dual FSMs**: Separate state machines for read and write operations allow overlapped access
2. **Base Address ROM**: Pre-computed tile base addresses for fast lookup
3. **Mode Support**: Integrated write mode control for accumulate/overwrite operations
4. **Row Tracking**: Maintains current row position within tiles

### Write Interface

The write interface manages storing data from the systolic array into tiles:

- `tile_addr`: Selects destination tile
- `write_start`: Initiates write sequence
- `write_mode`: Controls accumulate (1) vs overwrite (0) behavior
- `write_valid`: Indicates valid data available to write

#### Write FSM Behavior:
1. **IDLE State**:
   - Waits for write_start signal
   - On start: Loads base address, latches mode, transitions to WRITING

2. **WRITING State**:
   - Processes one row per cycle when write_valid is high
   - Increments address and row counter
   - Returns to IDLE after processing array_size rows

### Read Interface

The read interface manages retrieving stored tile data:

- `read_tile_addr`: Selects source tile
- `read_start`: Initiates read sequence

#### Read FSM Behavior:
1. **IDLE State**:
   - Waits for read_start signal
   - On start: Loads base address, transitions to READING

2. **READING State**:
   - Outputs one row address per cycle
   - Increments address and row counter
   - Returns to IDLE after processing array_size rows

## Code

In [8]:
class TiledAccumulatorFSM(IntEnum):
    IDLE = 0
    WRITING = 1


class ReadAccumulatorFSM(IntEnum):
    IDLE = 0
    READING = 1


class TiledAddressGenerator:
    """Enhanced address generator with write mode support"""

    def __init__(self, tile_addr_width: int, array_size: int):
        self.array_size = array_size
        self.num_tiles = 2**tile_addr_width
        self.internal_addr_width = (self.num_tiles * array_size - 1).bit_length()

        # Base address ROM
        base_addrs = [i * array_size for i in range(self.num_tiles)]
        self.base_addr_rom = RomBlock(
            bitwidth=self.internal_addr_width,
            addrwidth=tile_addr_width,
            romdata=base_addrs,
        )

        # ================== Write Interface ==================
        self._tile_addr = WireVector(tile_addr_width)
        self._write_start = WireVector(1)
        self._write_mode = WireVector(1)  # 0=overwrite, 1=accumulate
        self._write_valid = WireVector(1)

        # Write state registers
        self.write_state = Register(1)
        self.write_addr = Register(self.internal_addr_width)
        self.write_row = Register(array_size.bit_length())
        self.write_mode_reg = Register(1)  # Stores mode for current operation

        # ================== Read Interface ==================
        self._read_tile_addr = WireVector(tile_addr_width)
        self._read_start = WireVector(1)

        # Read state registers
        self.read_state = Register(1)
        self.read_addr = Register(self.internal_addr_width)
        self.read_row = Register(array_size.bit_length())

        # Outputs
        self.write_addr_out = WireVector(self.internal_addr_width)
        self.write_enable = WireVector(1)
        self.write_busy = WireVector(1)
        self.write_done = WireVector(1)
        self.write_mode_out = WireVector(1)

        self.read_addr_out = WireVector(self.internal_addr_width)
        self.read_busy = WireVector(1)
        self.read_done = WireVector(1)

        self._implement_write_fsm()
        self._implement_read_fsm()

    def _implement_write_fsm(self):
        write_base = self.base_addr_rom[self._tile_addr]

        # Combinational outputs
        self.write_addr_out <<= self.write_addr
        self.write_enable <<= (
            self.write_state == TiledAccumulatorFSM.WRITING
        ) & self._write_valid
        self.write_busy <<= self.write_state == TiledAccumulatorFSM.WRITING
        self.write_done <<= (self.write_state == TiledAccumulatorFSM.WRITING) & (
            self.write_row == self.array_size
        )
        self.write_mode_out <<= self.write_mode_reg

        with conditional_assignment:
            # IDLE State
            with self.write_state == TiledAccumulatorFSM.IDLE:
                with self._write_start:
                    self.write_state.next |= TiledAccumulatorFSM.WRITING
                    self.write_addr.next |= write_base
                    self.write_row.next |= 0
                    self.write_mode_reg.next |= self._write_mode  # Latch mode

            # WRITING State
            with self.write_state == TiledAccumulatorFSM.WRITING:
                with self._write_valid:
                    with self.write_row == self.array_size - 1:
                        self.write_state.next |= TiledAccumulatorFSM.IDLE
                        self.write_row.next |= 0
                    with otherwise:
                        self.write_addr.next |= self.write_addr + 1
                        self.write_row.next |= self.write_row + 1

    def _implement_read_fsm(self):
        read_base = self.base_addr_rom[self._read_tile_addr]

        self.read_addr_out <<= self.read_addr
        self.read_busy <<= self.read_state == ReadAccumulatorFSM.READING
        self.read_done <<= (self.read_state == ReadAccumulatorFSM.READING) & (
            self.read_row == self.array_size - 1
        )

        with conditional_assignment:
            with self.read_state == ReadAccumulatorFSM.IDLE:
                with self._read_start:
                    self.read_state.next |= ReadAccumulatorFSM.READING
                    self.read_addr.next |= read_base
                    self.read_row.next |= 0

            with self.read_state == ReadAccumulatorFSM.READING:
                with self.read_row == self.array_size - 1:
                    self.read_state.next |= ReadAccumulatorFSM.IDLE
                with otherwise:
                    self.read_addr.next |= self.read_addr + 1
                    self.read_row.next |= self.read_row + 1

    # Write interface methods
    def connect_tile_addr(self, addr: WireVector) -> None:
        self._tile_addr <<= addr

    def connect_write_start(self, start: WireVector) -> None:
        self._write_start <<= start

    def connect_write_mode(self, mode: WireVector) -> None:
        self._write_mode <<= mode

    def connect_write_valid(self, valid: WireVector) -> None:
        self._write_valid <<= valid

    # Read interface methods
    def connect_read_tile_addr(self, addr: WireVector) -> None:
        self._read_tile_addr <<= addr

    def connect_read_start(self, start: WireVector) -> None:
        self._read_start <<= start

# Memory Bank

## Description

The `AccumulatorMemoryBank` integrates the address generator with actual storage elements to create a complete memory subsystem for the matrix accelerator. It provides parallel memory banks that can either accumulate or overwrite incoming data based on the current mode.

### Design Theory

The memory bank is organized as N parallel memories (where N = array_size), each storing partial sums from one column of the systolic array. This organization allows:
1. Parallel write/accumulate of all columns
2. Independent access to each column's data
3. Efficient tiled storage and retrieval

#### Key Components:
1. **Address Generator**: Controls memory access patterns
2. **Memory Banks**: One per systolic array column
3. **Accumulator Logic**: Per-bank adders for accumulation
4. **Mode Control**: Integrated accumulate/overwrite switching

### Operation Modes

#### Write Operations:
1. **Overwrite Mode** (`write_mode = 0`):
   - New data directly replaces existing values
   - Used for initial tile writes

2. **Accumulate Mode** (`write_mode = 1`):
   - New data is added to existing values
   - Used for partial sum accumulation

#### Read Operations:
- Reads occur independently of writes
- Returns data from all banks in parallel
- Addressed by tile and automatically sequences through rows

### Interface Design

The memory bank exposes clean interfaces for both control and data:

#### Write Interface:
- Control signals for tile selection and mode
- Data inputs matching systolic array width
- Status signals (busy, done)

#### Read Interface:
- Tile selection and control
- Parallel data outputs
- Status signals

## Code

In [9]:
class AccumulatorMemoryBank:
    """Integrated memory bank with address generator and accumulation control"""

    def __init__(
        self,
        tile_addr_width: int,
        array_size: int,
        data_type: Type[BaseFloat],
        adder: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector],
    ):
        self.array_size = array_size
        self.tile_addr_width = tile_addr_width
        self.data_width = data_type.bitwidth()
        self.data_type = data_type
        self.adder = adder

        # Instantiate address generator
        self.addr_gen = TiledAddressGenerator(
            tile_addr_width=tile_addr_width, array_size=array_size
        )

        # Input ports
        self._write_tile_addr = WireVector(self.tile_addr_width)
        self._write_start = WireVector(1)
        self._write_mode = WireVector(1)
        self._write_valid = WireVector(1)
        self._read_tile_addr = WireVector(self.tile_addr_width)
        self._read_start = WireVector(1)
        self._data_in = [WireVector(self.data_width) for i in range(array_size)]

        # Connect address generator
        self.addr_gen.connect_tile_addr(self._write_tile_addr)
        self.addr_gen.connect_write_start(self._write_start)
        self.addr_gen.connect_write_mode(self._write_mode)
        self.addr_gen.connect_write_valid(self._write_valid)
        self.addr_gen.connect_read_tile_addr(self._read_tile_addr)
        self.addr_gen.connect_read_start(self._read_start)

        # Create memory banks
        self.memory_banks = [
            MemBlock(
                bitwidth=self.data_width,
                addrwidth=self.addr_gen.internal_addr_width,
                name=f"bank_{i}",
            )
            for i in range(array_size)
        ]

        # Output ports
        self._data_out = [WireVector(self.data_width) for _ in range(array_size)]
        self.write_busy = self.addr_gen.write_busy
        self.write_done = self.addr_gen.write_done
        self.read_busy = self.addr_gen.read_busy
        self.read_done = self.addr_gen.read_done

        self._implement_memory_logic()

    def _implement_memory_logic(self):
        # Write logic
        for i, mem in enumerate(self.memory_banks):
            current_val = mem[self.addr_gen.write_addr_out]
            sum_result = self.adder(self._data_in[i], current_val, self.data_type)

            with conditional_assignment:
                with self.addr_gen.write_enable:
                    with self.addr_gen.write_mode_out:  # Accumulate mode
                        mem[self.addr_gen.write_addr_out] |= sum_result
                    with otherwise:  # Overwrite mode
                        mem[self.addr_gen.write_addr_out] |= self._data_in[i]

        # Read logic
        for i, mem in enumerate(self.memory_banks):
            self._data_out[i] <<= mem[self.addr_gen.read_addr_out]

    def connect_inputs(
        self,
        write_tile_addr: WireVector | None = None,
        write_start: WireVector | None = None,
        write_mode: WireVector | None = None,
        write_valid: WireVector | None = None,
        read_tile_addr: WireVector | None = None,
        read_start: WireVector | None = None,
        data_in: list[WireVector] | None = None,
    ) -> None:
        """Connect all input control and data wires to the accumulator bank.

        Args:
            write_tile_addr: Address of tile to write to (tile_addr_width bits)
                Used to select which tile receives the incoming data.

            write_start: Start signal for write operation (1 bit)
                Pulses high for one cycle to initiate a new write sequence.

            write_mode: Mode selection for write operation (1 bit)
                0 = overwrite mode: new data replaces existing values
                1 = accumulate mode: new data is added to existing values

            write_valid: Data valid signal for write operation (1 bit)
                High when input data is valid and should be written/accumulated

            read_tile_addr: Address of tile to read from (tile_addr_width bits)
                Used to select which tile's data to output.

            read_start: Start signal for read operation (1 bit)
                Pulses high for one cycle to initiate a new read sequence.

            data_in: List of data input wires (data_width bits each)
                Input data from systolic array, one wire per column.
                Length must match array_size.

        Raises:
            AssertionError: If input wire widths don't match expected widths or
                        if data_in length doesn't match array_size.
        """
        if write_tile_addr is not None:
            assert len(write_tile_addr) == self.tile_addr_width
            self._write_tile_addr <<= write_tile_addr

        if write_start is not None:
            assert len(write_start) == 1
            self._write_start <<= write_start

        if write_mode is not None:
            assert len(write_mode) == 1
            self._write_mode <<= write_mode

        if write_valid is not None:
            assert len(write_valid) == 1
            self._write_valid <<= write_valid

        if read_tile_addr is not None:
            assert len(read_tile_addr) == self.tile_addr_width
            self._read_tile_addr <<= read_tile_addr

        if read_start is not None:
            assert len(read_start) == 1
            self._read_start <<= read_start

        if data_in is not None:
            assert (
                len(data_in) == self.array_size
            ), f"Expected {self.array_size} data inputs, got {len(data_in)}"
            for i, wire in enumerate(data_in):
                assert (
                    len(wire) == self.data_width
                ), f"Data input {i} width mismatch. Expected {self.data_width}, got {len(wire)}"
                self._data_in[i] <<= wire

    @property
    def write_interface(self) -> dict:
        return {
            "tile_addr": self._write_tile_addr,
            "start": self._write_start,
            "mode": self._write_mode,
            "valid": self._write_valid,
            "data": self._data_in,
        }

    @property
    def read_interface(self) -> dict:
        return {
            "tile_addr": self._read_tile_addr,
            "start": self._read_start,
            "data": self._data_out,
        }

    def get_output(self, bank: int) -> WireVector:
        return self._data_out[bank]

# Basic Testing

In [10]:
def simulate_accumulator():
    """Comprehensive test of accumulator bank with integrated address generator"""
    reset_working_block()

    # Configuration
    ARRAY_SIZE = 3
    NUM_TILES = 4
    TILE_ADDR_WIDTH = (NUM_TILES - 1).bit_length()
    DATA_TYPE = BF16

    # Instantiate accumulator bank
    acc_bank = AccumulatorMemoryBank(
        tile_addr_width=TILE_ADDR_WIDTH,
        array_size=ARRAY_SIZE,
        data_type=DATA_TYPE,
        adder=float_adder,
    )
    write_tile_addr = Input(TILE_ADDR_WIDTH, "write_tile_addr")
    write_start = Input(1, "write_start")
    write_mode = Input(1, "write_mode")
    write_valid = Input(1, "write_valid")
    read_tile_addr = Input(TILE_ADDR_WIDTH, "read_tile_addr")
    read_start = Input(1, "read_start")
    data_in = [Input(DATA_TYPE.bitwidth(), f"data_in_{i}") for i in range(ARRAY_SIZE)]

    acc_bank.connect_inputs(
        write_tile_addr,
        write_start,
        write_mode,
        write_valid,
        read_tile_addr,
        read_start,
        data_in,  # type: ignore
    )

    # Create simulation
    sim = Simulation()

    def get_inputs(update: dict[str, int] = {}) -> dict[str, int]:
        defaults = {
            "write_tile_addr": 0,
            "write_start": 0,
            "write_mode": 0,
            "write_valid": 0,
            "read_tile_addr": 0,
            "read_start": 0,
            **{f"data_in_{i}": 0 for i in range(ARRAY_SIZE)},
        }
        defaults.update(update)
        return defaults

    def write_tile(tile_num: int, data: List[List[float]], accumulate: bool = False):
        """Write data to specified tile"""
        binary_data = [[DATA_TYPE(val).binint for val in row] for row in data]

        # Start write operation
        sim.step(
            {
                "write_tile_addr": tile_num,
                "write_start": 1,
                "write_mode": int(accumulate),
                "write_valid": 0,
                "read_tile_addr": 0,
                "read_start": 0,
                **{f"data_in_{i}": 0 for i in range(ARRAY_SIZE)},
            }
        )

        # Write rows
        for row in binary_data:
            sim.step(
                {
                    "write_tile_addr": tile_num,
                    "write_start": 0,
                    "write_mode": int(accumulate),
                    "write_valid": 1,
                    "read_tile_addr": 0,
                    "read_start": 0,
                    **{f"data_in_{i}": row[i] for i in range(ARRAY_SIZE)},
                }
            )

    def read_tile(tile_num: int) -> np.ndarray:
        """Read data from specified tile"""
        results = []

        # Start read operation
        sim.step(
            get_inputs(
                {
                    "read_tile_addr": tile_num,
                    "read_start": 1,
                    "write_tile_addr": 0,
                    "write_start": 0,
                    "write_valid": 0,
                    "write_mode": 0,
                }
            )
        )

        # Read rows
        while True:
            sim.step(
                get_inputs(
                    {
                        "read_tile_addr": tile_num,
                        "read_start": 0,
                        "write_tile_addr": 0,
                        "write_start": 0,
                        "write_valid": 0,
                        "write_mode": 0,
                    }
                )
            )

            # Capture outputs
            row = [
                float(DATA_TYPE(binint=sim.inspect(acc_bank.get_output(i).name)))
                for i in range(ARRAY_SIZE)
            ]
            results.append(row)

            # Check completion
            if sim.inspect(acc_bank.read_done.name):
                break

        return np.array(results[-ARRAY_SIZE:])  # Return only valid data

    def inspect_memories():
        """Read all tile memories and return as list of numpy arrays"""
        mems = acc_bank.memory_banks
        result = {}

        # Initialize empty lists for each key
        for addr in range(ARRAY_SIZE * NUM_TILES):
            result[addr] = []
            # Add values from each dictionary
            for mem in mems:
                d = sim.inspect_mem(mem)
                result[addr].append(d.get(addr, 0))

        # Convert to list of lists if needed
        tiles = list(result.values())
        tiles = np.array([[float(BF16(binint=x)) for x in tile] for tile in tiles])

        # Reshape into 3D array of square matrices
        result_3d = []
        for i in range(NUM_TILES):
            start_idx = i * ARRAY_SIZE
            end_idx = start_idx + ARRAY_SIZE
            matrix = tiles[start_idx:end_idx]
            result_3d.append(matrix)

        result_3d = np.array(result_3d)
        return result_3d

    def print_memory_state(*args):
        """Print current state of all tiles"""
        tiles = inspect_memories()
        print(*args)
        print("\nCurrent Tile States:")
        print("-" * 50)
        for i, tile in enumerate(tiles):
            print(f"Tile {i}:")
            print(np.array2string(tile, precision=2, suppress_small=True))
        print("-" * 50)

    # Test data
    test_data = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]

    # Initial state
    print_memory_state("Initial State")

    # Test sequence
    # 1. Write to tile 0 (overwrite mode)
    write_tile(0, test_data)
    print_memory_state("After First Write (Tile 0)")

    # 2. Write to tile 2 (overwrite mode)
    write_tile(2, test_data)
    print_memory_state("After Second Write (Tile 2)")

    # 3. Accumulate into tile 0
    write_tile(0, test_data, accumulate=True)
    print_memory_state("After Accumulation (Tile 0)")

    # 4. Validate final read outputs
    print("\nFinal Validation:")
    expected_tile0 = np.array([[2, 4, 6], [8, 10, 12], [14, 16, 18]])
    expected_tile2 = np.array(test_data)

    # Read tile 0
    tile0_data = read_tile(0)
    assert np.allclose(
        tile0_data, expected_tile0
    ), f"Tile 0 mismatch:\nExpected:\n{expected_tile0}\nGot:\n{tile0_data}"

    # Read tile 2
    tile2_data = read_tile(2)
    assert np.allclose(
        tile2_data, expected_tile2
    ), f"Tile 2 mismatch:\nExpected:\n{expected_tile2}\nGot:\n{tile2_data}"

    print("All assertions passed!")


if __name__ == "__main__":
    simulate_accumulator()

Initial State

Current Tile States:
--------------------------------------------------
Tile 0:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
Tile 1:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
Tile 2:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
Tile 3:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
--------------------------------------------------
After First Write (Tile 0)

Current Tile States:
--------------------------------------------------
Tile 0:
[[1. 2. 3.]
 [4. 5. 6.]
 [7. 8. 9.]]
Tile 1:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
Tile 2:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
Tile 3:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
--------------------------------------------------
After Second Write (Tile 2)

Current Tile States:
--------------------------------------------------
Tile 0:
[[1. 2. 3.]
 [4. 5. 6.]
 [7. 8. 9.]]
Tile 1:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
Tile 2:
[[1. 2. 3.]
 [4. 5. 6.]
 [7. 8. 9.]]
Tile 3:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
--------------------------------------------------
After Accumulation 

# Simulation Handler Class

In [11]:
class AccumulatorBankSimulator:
    """Simulator for AccumulatorMemoryBank with integrated address generator"""

    def __init__(
        self,
        array_size: int,
        num_tiles: int,
        data_type: Type[BaseFloat] = BF16,
        adder: Callable = float_adder,
    ):
        """Initialize simulator configuration

        Args:
            array_size: Dimension of systolic array (NxN)
            num_tiles: Number of tiles to support
            data_type: Number format for data (default: BF16)
            adder: Floating point adder implementation
        """
        self.array_size = array_size
        self.num_tiles = num_tiles
        self.data_type = data_type
        self.tile_addr_width = (num_tiles - 1).bit_length()

        # Store configuration for setup
        self.config = {
            "array_size": array_size,
            "tile_addr_width": self.tile_addr_width,
            "data_type": data_type,
            "adder": adder,
        }
        self.sim = None

    def setup(self):
        """Initialize PyRTL simulation environment"""
        reset_working_block()

        # Input ports
        self._write_tile_addr = Input(self.tile_addr_width, "write_tile_addr")
        self._write_start = Input(1, "write_start")
        self._write_mode = Input(1, "write_mode")
        self._write_valid = Input(1, "write_valid")
        self._read_tile_addr = Input(self.tile_addr_width, "read_tile_addr")
        self._read_start = Input(1, "read_start")
        self._data_in = [
            Input(self.data_type.bitwidth(), f"data_in_{i}")
            for i in range(self.array_size)
        ]

        # Create accumulator bank
        self.acc_bank = AccumulatorMemoryBank(**self.config)
        self.acc_bank.connect_inputs(
            self._write_tile_addr,
            self._write_start,
            self._write_mode,
            self._write_valid,
            self._read_tile_addr,
            self._read_start,
            self._data_in,  # type: ignore
        )

        # Create simulation
        self.sim = Simulation()

        return self

    def _get_default_inputs(self, updates: dict = {}) -> dict:
        """Get dictionary of default input values with optional updates"""
        defaults = {
            "write_tile_addr": 0,
            "write_start": 0,
            "write_mode": 0,
            "write_valid": 0,
            "read_tile_addr": 0,
            "read_start": 0,
            **{f"data_in_{i}": 0 for i in range(self.array_size)},
        }
        defaults.update(updates)
        return defaults

    def write_tile(
        self,
        tile_addr: int,
        data: np.ndarray,
        accumulate: bool = False,
        check_bounds: bool = True,
    ) -> None:
        """Write data to specified tile

        Args:
            tile_addr: Destination tile address
            data: Input data array (array_size x array_size)
            accumulate: If True, accumulate with existing values
            check_bounds: If True, validate input dimensions
        """
        if self.sim is None:
            raise RuntimeError("Simulator not initialized. Call setup() first")

        if check_bounds:
            if tile_addr >= self.num_tiles or tile_addr < 0:
                raise ValueError(f"Tile address {tile_addr} out of range")
            if data.shape != (self.array_size, self.array_size):
                raise ValueError(f"Data must be {self.array_size}x{self.array_size}")

        # Convert data to binary format
        binary_data = [[self.data_type(val).binint for val in row] for row in data]

        # Start write operation
        self.sim.step(
            self._get_default_inputs(
                {
                    "write_tile_addr": tile_addr,
                    "write_start": 1,
                    "write_mode": int(accumulate),
                }
            )
        )

        # Write each row
        for row in binary_data:
            self.sim.step(
                self._get_default_inputs(
                    {
                        "write_tile_addr": tile_addr,
                        "write_mode": int(accumulate),
                        "write_valid": 1,
                        **{f"data_in_{i}": row[i] for i in range(self.array_size)},
                    }
                )
            )

    def read_tile(self, tile_addr: int) -> np.ndarray:
        """Read data from specified tile

        Args:
            tile_addr: Tile address to read from

        Returns:
            Array containing tile data
        """
        if self.sim is None:
            raise RuntimeError("Simulator not initialized. Call setup() first")

        results = []

        # Start read operation
        self.sim.step(
            self._get_default_inputs({"read_tile_addr": tile_addr, "read_start": 1})
        )

        # Read rows until done
        while True:
            self.sim.step(self._get_default_inputs({"read_tile_addr": tile_addr}))

            # Capture outputs
            row = [
                float(
                    self.data_type(
                        binint=self.sim.inspect(self.acc_bank.get_output(i).name)
                    )
                )
                for i in range(self.array_size)
            ]
            results.append(row)

            if self.sim.inspect(self.acc_bank.read_done.name):
                break

        return np.array(results[-self.array_size :])

    def get_all_tiles(self) -> np.ndarray:
        """Read all tile memories

        Returns:
            3D array of shape (num_tiles, array_size, array_size)
        """
        if self.sim is None:
            raise RuntimeError("Simulator not initialized. Call setup() first")

        mems = self.acc_bank.memory_banks
        result = {}

        # Initialize empty lists for each address
        for addr in range(self.array_size * self.num_tiles):
            result[addr] = []

        # Collect memory contents
        for mem in mems:
            d = self.sim.inspect_mem(mem)
            for addr in range(self.array_size * self.num_tiles):
                result[addr].append(d.get(addr, 0))

        # Convert to numpy array and reshape
        tiles = [
            [float(self.data_type(binint=x)) for x in tile] for tile in result.values()
        ]
        tiles = np.array(tiles)

        # Reshape into tile matrices
        result_3d = []
        for i in range(self.num_tiles):
            start_idx = i * self.array_size
            end_idx = start_idx + self.array_size
            result_3d.append(tiles[start_idx:end_idx])

        return np.array(result_3d)

    def print_state(self, message: Optional[str] = None):
        """Print current state of all tiles"""
        if message:
            print(f"\n{message}")

        tiles = self.get_all_tiles()
        print("\nTile States:")
        print("-" * 50)
        for i, tile in enumerate(tiles):
            print(f"Tile {i}:")
            print(np.array2string(tile, precision=2, suppress_small=True))
        print("-" * 50)

In [12]:
def test_accumulator():
    # Create and setup simulator
    sim = AccumulatorBankSimulator(array_size=3, num_tiles=4).setup()

    # Test data
    test_data = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])

    # Initial state
    sim.print_state("Initial State")

    # Write to tile 0
    sim.write_tile(0, test_data)
    sim.print_state("After Write to Tile 0")

    # Write to tile 2
    sim.write_tile(2, test_data)
    sim.print_state("After Write to Tile 2")

    # Accumulate into tile 0
    sim.write_tile(0, test_data, accumulate=True)
    sim.print_state("After Accumulation to Tile 0")

    # Validate results
    tile0_data = sim.read_tile(0)
    tile2_data = sim.read_tile(2)

    expected_tile0 = np.array([[2, 4, 6], [8, 10, 12], [14, 16, 18]])
    expected_tile2 = test_data

    np.testing.assert_allclose(tile0_data, expected_tile0)
    np.testing.assert_allclose(tile2_data, expected_tile2)
    print("All tests passed!")


if __name__ == "__main__":
    test_accumulator()


Initial State

Tile States:
--------------------------------------------------
Tile 0:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
Tile 1:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
Tile 2:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
Tile 3:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
--------------------------------------------------

After Write to Tile 0

Tile States:
--------------------------------------------------
Tile 0:
[[1. 2. 3.]
 [4. 5. 6.]
 [7. 8. 9.]]
Tile 1:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
Tile 2:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
Tile 3:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
--------------------------------------------------

After Write to Tile 2

Tile States:
--------------------------------------------------
Tile 0:
[[1. 2. 3.]
 [4. 5. 6.]
 [7. 8. 9.]]
Tile 1:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
Tile 2:
[[1. 2. 3.]
 [4. 5. 6.]
 [7. 8. 9.]]
Tile 3:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
--------------------------------------------------

After Accumulation to Tile 0

Tile States:
-------

---