# Accumulator Buffer

The accumulator buffer is needed to store partial results of large matrix multiplications from the outputs of the systolic array.

Each accumulator bank will connect to a column output from the systolic array. We can specify the address width which will determine the number of values we can store. This is equal to `2^len(mem_addr)`. 

In [1]:
from dataclasses import dataclass
import pyrtl
from IPython.display import SVG
from pyrtl import *
import numpy as np
from enum import IntEnum
from typing import List, Type, Callable
from hardware_accelerators import *
from hardware_accelerators.dtypes import BaseFloat
from hardware_accelerators.rtllib import float_adder
from hardware_accelerators.simulation import convert_array_dtype, render_waveform
from hardware_accelerators.simulation.repr_funcs import *

## TPU Accum Example

In [39]:
def accum(size, data_in, waddr, wen, wclear, raddr, lastvec):
    """A single 32-bit accumulator with 2^size 32-bit buffers.
    On wen, writes data_in to the specified address (waddr) if wclear is high;
    otherwise, it performs an accumulate at the specified address (buffer[waddr] += data_in).
    lastvec is a control signal indicating that the operation being stored now is the
    last vector of a matrix multiply instruction (at the final accumulator, this becomes
    a "done" signal).
    """

    mem = MemBlock(bitwidth=16, addrwidth=size, name="MEMORY")

    sum = float_adder(data_in, mem[waddr], BF16)[: mem.bitwidth]

    # Writes
    with conditional_assignment:
        with wen:
            with wclear:
                mem[waddr] |= data_in
            with otherwise:
                mem[waddr] |= sum
    # Read
    data_out = mem[raddr]

    # # Pipeline registers
    # waddrsave = Register(len(waddr))
    # waddrsave.next <<= waddr
    # wensave = Register(1)
    # wensave.next <<= wen
    # wclearsave = Register(1)
    # wclearsave.next <<= wclear
    # lastsave = Register(1)
    # lastsave.next <<= lastvec

    # return data_out, waddrsave, wensave, wclearsave, lastsave


reset_working_block()
# Output(32), Output(3), Output(1), Output(1), Output(1) =
accum(
    3,
    Input(16, "data_input"),
    Input(3, "write_address"),
    Input(1, "write_enable"),
    Input(1, "clear"),
    Input(3, "read_address"),
    Input(1, "last vector"),
)

# Single Accumulator Block

In [30]:
class AccumulatorBlock:
    def __init__(
        self,
        data_type: Type[BaseFloat],
        addr_width: int,
        adder: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector],
    ):
        """Single accumulator block with memory and floating point add capability

        Args:
            data_type: Number format for accumulation values
            addr_width: Number of address bits for memory
            adder: Floating point adder implementation
        """
        self.data_type = data_type
        self.data_width = data_type.bitwidth()

        # Input wires
        self.write_data = WireVector(self.data_width)
        self.write_addr = WireVector(addr_width)
        self.write_enable = WireVector(1)
        self.write_clear = WireVector(1)  # 1 for overwrite, 0 for accumulate
        self.read_addr = WireVector(addr_width)

        # Memory block
        self.memory = MemBlock(
            bitwidth=self.data_width,
            addrwidth=addr_width,
        )

        # Read data is direct from memory
        self.read_data = self.memory[self.read_addr]
        current = self.memory[self.write_addr]
        sum_result = adder(self.write_data, current, self.data_type)

        # Write logic with floating point add
        with conditional_assignment:
            with self.write_enable:
                with self.write_clear:
                    # Overwrite mode
                    self.memory[self.write_addr] |= self.write_data
                with ~self.write_clear:
                    # Accumulate mode - use floating point add
                    self.memory[self.write_addr] |= sum_result

    def connect_write_data(self, source: WireVector):
        """Connect data input"""
        self.write_data <<= source

    def connect_write_addr(self, addr: WireVector):
        """Connect write address input"""
        self.write_addr <<= addr

    def connect_write_enable(self, enable: WireVector):
        """Connect write enable signal"""
        self.write_enable <<= enable

    def connect_write_clear(self, clear: WireVector):
        """Connect write clear signal (overwrite vs accumulate)"""
        self.write_clear <<= clear

    def connect_read_addr(self, addr: WireVector):
        """Connect read address input"""
        self.read_addr <<= addr

In [31]:
def test_accumulator_block():
    """Test single accumulator block"""
    from pyrtl import reset_working_block, Input, Output, Simulation

    reset_working_block()

    # Test parameters
    ADDR_BITS = 3  # 8 addresses
    dtype = BF16  # Use BFloat16 format

    # Create input/output wires
    w_data = Input(dtype.bitwidth(), "write_data")
    w_addr = Input(ADDR_BITS, "write_addr")
    w_en = Input(1, "write_enable")
    w_clear = Input(1, "write_clear")
    r_addr = Input(ADDR_BITS, "read_addr")
    r_data = Output(dtype.bitwidth(), "read_data")

    # Create accumulator block
    accum = AccumulatorBlock(data_type=dtype, addr_width=ADDR_BITS, adder=float_adder)

    # Connect ports
    accum.connect_write_data(w_data)
    accum.connect_write_addr(w_addr)
    accum.connect_write_enable(w_en)
    accum.connect_write_clear(w_clear)
    accum.connect_read_addr(r_addr)
    r_data <<= accum.read_data

    # Create simulation
    sim = Simulation()

    # Test sequence
    sim_steps = [
        # 1. Write 1.5 to address 0
        {
            "write_data": dtype(1.5).binint,
            "write_addr": 0,
            "write_enable": 1,
            "write_clear": 1,
            "read_addr": 0,
        },
        # 2. Write 2.5 to address 1
        {
            "write_data": dtype(2.5).binint,
            "write_addr": 1,
            "write_enable": 1,
            "write_clear": 1,
            "read_addr": 0,
        },
        # 3. Accumulate 1.0 to address 0
        {
            "write_data": dtype(1.0).binint,
            "write_addr": 0,
            "write_enable": 1,
            "write_clear": 0,
            "read_addr": 0,
        },
        # 4. Read address 0 (should be 2.5)
        {
            "write_data": 0,
            "write_addr": 0,
            "write_enable": 0,
            "write_clear": 0,
            "read_addr": 0,
        },
        # 5. Read address 1 (should be 2.5)
        {
            "write_data": 0,
            "write_addr": 0,
            "write_enable": 0,
            "write_clear": 0,
            "read_addr": 1,
        },
    ]

    # Run simulation
    results = []
    for step in sim_steps:
        sim.step(step)
        results.append(dtype(binint=sim.inspect("read_data")).decimal_approx)

    print("\nSimulation Results:")
    print(f"Step 1 output: {results[0]}")
    print(f"Step 2 output: {results[1]}")
    print(f"Step 3 output: {results[2]}")
    print(f"Address 0 final value: {results[3]} (should be 2.5)")
    print(f"Address 1 final value: {results[4]} (should be 2.5)")

    return sim


# Run test
if __name__ == "__main__":
    sim = test_accumulator_block()


Simulation Results:
Step 1 output: 0.0
Step 2 output: 1.5
Step 3 output: 1.5
Address 0 final value: 2.5 (should be 2.5)
Address 1 final value: 2.5 (should be 2.5)


In [33]:
logic_area, mem_area = area_estimation()
logic_len, mem_len = logic_area**0.5, mem_area**0.5
logic_len, mem_len

(0.07729611633193481, 0.08170256081103433)

In [34]:
import tkinter as tk

root = tk.Tk()
dpi = root.winfo_screenwidth() / root.winfo_screenmmwidth()
dpi

TclError: no display name and no $DISPLAY environment variable

# Creating an array of accumulators

In [70]:
class AccumulatorArray:
    def __init__(
        self,
        size: int,
        data_type: Type[BaseFloat],
        addr_width: int,
        adder: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector],
    ):
        """Array of accumulator blocks that connect to systolic array outputs

        Args:
            size: Number of accumulator blocks (matches systolic array width)
            data_type: Number format for accumulation
            addr_width: Number of address bits for each accumulator memory
            adder: Floating point adder implementation
        """
        self.size = size
        self.data_type = data_type

        # Create array of accumulator blocks
        self.accumulators = [
            AccumulatorBlock(data_type, addr_width, adder) for _ in range(size)
        ]

        # Shared control signals
        self.write_addr = WireVector(addr_width)
        self.write_enable = WireVector(1)
        self.write_clear = WireVector(1)
        self.read_addr = WireVector(addr_width)

        # Connect shared controls to all accumulators
        for acc in self.accumulators:
            acc.connect_write_addr(self.write_addr)
            acc.connect_write_enable(self.write_enable)
            acc.connect_write_clear(self.write_clear)
            acc.connect_read_addr(self.read_addr)

    def connect_write_data(self, col: int, source: WireVector):
        """Connect data input for specified column"""
        assert 0 <= col < self.size
        self.accumulators[col].connect_write_data(source)

    def connect_write_addr(self, addr: WireVector):
        """Connect shared write address"""
        self.write_addr <<= addr

    def connect_write_enable(self, enable: WireVector):
        """Connect shared write enable"""
        self.write_enable <<= enable

    def connect_write_clear(self, clear: WireVector):
        """Connect shared write clear"""
        self.write_clear <<= clear

    def connect_read_addr(self, addr: WireVector):
        """Connect shared read address"""
        self.read_addr <<= addr

## Matrix Multiply Unit

In [71]:
from hardware_accelerators.rtllib.systolic import SystolicArrayDiP


class MatrixMultiplyUnit:
    def __init__(
        self,
        size: int,
        data_type: Type[BaseFloat],
        accum_type: Type[BaseFloat],
        accum_depth: int,
        multiplier: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector],
        adder: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector],
    ):
        """Top level matrix multiply unit combining systolic array and accumulator array

        Args:
            size: Dimension of systolic array (NxN)
            data_type: Number format for inputs
            accum_type: Number format for accumulation
            accum_depth: Number of addresses in accumulator memory (2^addr_width)
            multiplier: Multiplier implementation
            adder: Adder implementation
        """
        # Create systolic array
        self.array = SystolicArrayDiP(
            size=size,
            data_type=data_type,
            accum_type=accum_type,
            multiplier=multiplier,
            adder=adder,
        )

        # Create accumulator array
        self.accumulators = AccumulatorArray(
            size=size,
            data_type=accum_type,
            addr_width=accum_depth.bit_length(),
            adder=adder,
        )

        # Connect systolic outputs to accumulator inputs
        for i in range(size):
            self.accumulators.connect_write_data(i, self.array.results_out[i])

        # Control signals
        self.start = WireVector(1)  # Start new matrix multiply
        self.accum_addr = WireVector(
            accum_depth.bit_length()
        )  # Base accumulator address
        self.num_vectors = WireVector(16)  # Number of vectors to process
        self.overwrite = WireVector(1)  # Overwrite vs accumulate mode

        # Control state
        self.busy = Register(1)
        self.vector_count = Register(16)  # Counter for vectors processed
        self.current_addr = Register(accum_depth.bit_length())

        # Control logic
        with conditional_assignment:
            with self.start:
                # Start new operation
                self.busy.next |= 1
                self.vector_count.next |= self.num_vectors
                self.current_addr.next |= self.accum_addr
                # Enable systolic array
                self.array.connect_enable_input(1)

            with self.busy:
                # Continue processing vectors
                self.vector_count.next |= self.vector_count - 1
                self.current_addr.next |= self.current_addr + 1

                with self.vector_count == 1:
                    # Last vector
                    self.busy.next |= 0
                    self.array.connect_enable_input(Const(0))

        # Connect accumulator control signals
        self.accumulators.connect_write_addr(self.current_addr)
        self.accumulators.connect_write_enable(self.busy)
        self.accumulators.connect_write_clear(self.overwrite)

## Testing utils

In [72]:
def test_matrix_multiply():
    from pyrtl import reset_working_block, Input, Output, Simulation

    reset_working_block()

    # Parameters
    SIZE = 4
    ACCUM_DEPTH = 16
    dtype = BF16

    # Create MMU
    mmu = MatrixMultiplyUnit(
        size=SIZE,
        data_type=dtype,
        accum_type=dtype,
        accum_depth=ACCUM_DEPTH,
        multiplier=lmul_fast,
        adder=float_adder,
    )

    # Create control inputs
    start = Input(1, "start")
    accum_addr = Input(ACCUM_DEPTH.bit_length(), "accum_addr")
    num_vectors = Input(16, "num_vectors")
    overwrite = Input(1, "overwrite")
    read_addr = Input(ACCUM_DEPTH.bit_length(), "read_addr")

    # Create data inputs/outputs
    weight_en = Input(1, "weight_en")
    weights = [Input(dtype.bitwidth(), f"weight_{i}") for i in range(SIZE)]
    data = [Input(dtype.bitwidth(), f"data_{i}") for i in range(SIZE)]
    results = [Output(dtype.bitwidth(), f"result_{i}") for i in range(SIZE)]

    # Connect control signals
    mmu.start <<= start
    mmu.accum_addr <<= accum_addr
    mmu.num_vectors <<= num_vectors
    mmu.overwrite <<= overwrite
    mmu.accumulators.connect_read_addr(read_addr)

    # Connect array inputs/outputs
    mmu.array.connect_weight_enable(weight_en)
    for i in range(SIZE):
        mmu.array.connect_weight_input(i, weights[i])
        mmu.array.connect_data_input(i, data[i])
        results[i] <<= mmu.accumulators.accumulators[i].read_data

    # Create simulation
    sim = Simulation()

    # Test matrices
    W1 = np.array(
        [  # First weight tile
            [1.0, 0.5, 0.0, 0.0],
            [0.0, 1.0, 0.5, 0.0],
            [0.0, 0.0, 1.0, 0.5],
            [0.5, 0.0, 0.0, 1.0],
        ]
    )

    W2 = np.array(
        [  # Second weight tile
            [0.5, 0.0, 0.0, 0.0],
            [0.0, 0.5, 0.0, 0.0],
            [0.0, 0.0, 0.5, 0.0],
            [0.0, 0.0, 0.0, 0.5],
        ]
    )

    A = np.array(
        [  # Input matrix
            [1.0, 2.0, 3.0, 4.0],
            [2.0, 3.0, 4.0, 1.0],
            [3.0, 4.0, 1.0, 2.0],
            [4.0, 1.0, 2.0, 3.0],
        ]
    )

    def make_input_dict(
        w_en=0,
        weights=None,
        data=None,
        start=0,
        acc_addr=0,
        num_vecs=0,
        overwrite=0,
        read_addr=0,
    ):
        """Helper to create simulation input dictionary"""
        inputs = {
            "weight_en": w_en,
            "start": start,
            "accum_addr": acc_addr,
            "num_vectors": num_vecs,
            "overwrite": overwrite,
            "read_addr": read_addr,
        }

        if weights is not None:
            for i, w in enumerate(weights):
                inputs[f"weight_{i}"] = dtype(w).binint

        if data is not None:
            for i, d in enumerate(data):
                inputs[f"data_{i}"] = dtype(d).binint

        return inputs

    def read_results():
        """Helper to read current accumulator outputs"""
        return np.array(
            [
                dtype(binint=sim.inspect(f"result_{i}")).decimal_approx
                for i in range(SIZE)
            ]
        )

    # Simulation steps

    # Step 1: Load first weight tile
    print("\nLoading first weight tile...")
    sim.step(make_input_dict(w_en=1, weights=W1[0]))
    for row in W1[1:]:
        sim.step(make_input_dict(w_en=1, weights=row))

    # Step 2: Process first matrix multiply (overwrite mode)
    print("Processing first matrix multiply...")
    sim.step(
        make_input_dict(start=1, acc_addr=0, num_vecs=SIZE, overwrite=1, data=A[0])
    )

    for row in A[1:]:
        sim.step(make_input_dict(data=row))

    # Additional steps to flush pipeline
    for _ in range(SIZE):
        sim.step(make_input_dict())

    # Read results
    print("\nIntermediate results:")
    for i in range(SIZE):
        sim.step(make_input_dict(read_addr=i))
        print(f"Row {i}: {read_results()}")

    # Step 3: Load second weight tile
    print("\nLoading second weight tile...")
    sim.step(make_input_dict(w_en=1, weights=W2[0]))
    for row in W2[1:]:
        sim.step(make_input_dict(w_en=1, weights=row))

    # Step 4: Process second matrix multiply (accumulate mode)
    print("Processing second matrix multiply...")
    sim.step(
        make_input_dict(
            start=1,
            acc_addr=0,
            num_vecs=SIZE,
            overwrite=0,  # Accumulate mode
            data=A[0],
        )
    )

    for row in A[1:]:
        sim.step(make_input_dict(data=row))

    # Additional steps to flush pipeline
    for _ in range(SIZE):
        sim.step(make_input_dict())

    # Read final results
    print("\nFinal results:")
    for i in range(SIZE):
        sim.step(make_input_dict(read_addr=i))
        print(f"Row {i}: {read_results()}")

    return sim


# Run test
if __name__ == "__main__":
    sim = test_matrix_multiply()

PyrtlError: Wire "tmp9792/1W" has multiple drivers: [tmp9792/1W & \leftarrow w \, - & const\_3096\_1/1C  \\] and [tmp9792/1W & \leftarrow w \, - & const\_3104\_0/1C  \\] (check for multiple assignments with "<<=" or accidental mixing of "|=" and "<<=")

---

# Address Generator Design (FSM)

In [63]:
class TiledAccumulatorFSM(IntEnum):
    """FSM States for Accumulator Control

    State Transitions:
    IDLE -> WRITING: When start signal received
    WRITING -> WRITING: While processing rows within tile
    WRITING -> IDLE: Immediate transition when last row of tile processed
    """

    IDLE = 0
    """IDLE: Waiting for new tile operation"""
    WRITING = 1
    """WRITING: Processing rows of systolic array output"""


class TiledAddressGenerator:
    """Generates addresses and control signals for accumulator bank memory

    This module manages the storage of systolic array outputs into tile-organized
    memory. It automatically handles address generation and increments within tiles,
    abstracting away the internal memory organization from the higher-level control.
    """

    def __init__(self, tile_addr_width: int, array_size: int):
        """Initialize address generator

        Args:
            tile_addr_width: Number of bits for addressing tiles
            array_size: Dimension of systolic array (NxN)
        """
        # Configuration parameters
        self.array_size = array_size
        self.num_tiles = 2**tile_addr_width

        # Calculate required address width based on total storage needed
        # (num_tiles * rows_per_tile)
        self.internal_addr_width = (self.num_tiles * array_size - 1).bit_length()

        # Create base address lookup ROM
        # For example, with 4x4 array and 4 tiles:
        # tile 0 -> base addr 0
        # tile 1 -> base addr 4
        # tile 2 -> base addr 8
        # tile 3 -> base addr 12
        base_addrs = [i * array_size for i in range(self.num_tiles)]
        self.base_addr_rom = RomBlock(
            bitwidth=self.internal_addr_width,
            addrwidth=tile_addr_width,
            romdata=base_addrs,
        )

        # Input signals
        self._tile_addr = WireVector(tile_addr_width)
        self._start = WireVector(1)
        self._write_valid = WireVector(1)

        # State registers
        self.state = Register(1)  # Current FSM state
        self.internal_addr = Register(self.internal_addr_width)
        self.current_row = Register(array_size.bit_length())

        # Output signals
        self.internal_write_addr = WireVector(self.internal_addr_width)
        self.internal_write_enable = WireVector(1)
        self.busy = WireVector(1)
        self.tile_complete = WireVector(1)

        # Implement FSM logic
        self._implement_fsm()

    def _implement_fsm(self):
        """Implements the FSM logic using conditional assignments"""
        # Get base address from ROM using tile address
        tile_base = self.base_addr_rom[self._tile_addr]

        # Set output signals (combinational, outside conditional)
        self.internal_write_addr <<= self.internal_addr
        self.internal_write_enable <<= (
            self.state == TiledAccumulatorFSM.WRITING
        ) & self._write_valid
        self.busy <<= self.state == TiledAccumulatorFSM.WRITING
        self.tile_complete <<= (
            (self.state == TiledAccumulatorFSM.WRITING)
            & self._write_valid
            & (self.current_row == self.array_size)
        )

        # FSM Logic
        with conditional_assignment:
            # IDLE: Wait for start signal
            with self.state == TiledAccumulatorFSM.IDLE:
                with self._start:
                    self.state.next |= TiledAccumulatorFSM.WRITING
                    self.internal_addr.next |= tile_base
                    self.current_row.next |= 0

            # WRITING: Process rows until tile complete
            with self.state == TiledAccumulatorFSM.WRITING:
                with self._write_valid:
                    # Only return to IDLE after last row is processed
                    with self.current_row == self.array_size - 1:
                        self.state.next |= TiledAccumulatorFSM.IDLE
                        self.current_row.next |= 0
                    with otherwise:
                        self.internal_addr.next |= self.internal_addr + 1
                        self.current_row.next |= self.current_row + 1

    # Connection methods
    def connect_tile_addr(self, addr: WireVector) -> None:
        """Connect tile address input

        Args:
            addr: Address of tile to write to/read from (tile_addr_width bits)
        """
        self._tile_addr <<= addr

    def connect_start(self, start: WireVector) -> None:
        """Connect start signal

        Args:
            start: Signal to begin processing new tile (1 bit)
        """
        self._start <<= start

    def connect_write_valid(self, valid: WireVector) -> None:
        """Connect write valid signal from systolic array

        Args:
            valid: Signal indicating valid output from systolic array (1 bit)
        """
        self._write_valid <<= valid

    def get_write_addr(self) -> WireVector:
        """Get current write address output

        Returns:
            WireVector of internal_addr_width bits
        """
        return self.internal_write_addr

    def get_write_enable(self) -> WireVector:
        """Get write enable output

        Returns:
            1-bit WireVector, high when writing should occur
        """
        return self.internal_write_enable

    def get_busy(self) -> WireVector:
        """Get busy status

        Returns:
            1-bit WireVector, high when processing a tile
        """
        return self.busy

    def get_tile_complete(self) -> WireVector:
        """Get tile completion signal

        Returns:
            1-bit WireVector, pulses high when tile is complete
        """
        return self.tile_complete

## Address Generator Testing

### Lets start with a very basic test

In [85]:
def test_address_generator_simple():
    """Simple test to diagnose address generator behavior"""
    from pyrtl import reset_working_block, Input, Output, Simulation

    reset_working_block()

    # Test parameters
    ARRAY_SIZE = 2
    TILE_ADDR_WIDTH = 1

    # Create minimal test setup for 2x2 array, 2 tiles
    tile_addr = Input(TILE_ADDR_WIDTH, "tile_addr")
    start = Input(1, "start")
    write_valid = Input(1, "write_valid")

    # Create address generator
    addr_gen = TiledAddressGenerator(
        tile_addr_width=TILE_ADDR_WIDTH, array_size=ARRAY_SIZE
    )

    # Print ROM contents to verify initialization
    print("\nROM Contents:")
    print(f"ROM data: {addr_gen.base_addr_rom.data}")
    print(f"ROM width: {addr_gen.base_addr_rom.bitwidth}")
    print(f"ROM addr width: {addr_gen.base_addr_rom.addrwidth}")

    # Connect signals
    addr_gen.connect_tile_addr(tile_addr)
    addr_gen.connect_start(start)
    addr_gen.connect_write_valid(write_valid)

    # Create simulation
    sim = Simulation()

    # Test sequence - just handle one tile completely
    steps = [
        # Step 0: Initial state
        {"tile_addr": 0, "start": 0, "write_valid": 0},
        # Step 1: Start tile 0
        {"tile_addr": 0, "start": 1, "write_valid": 0},
        # Step 2: First write
        {"tile_addr": 0, "start": 0, "write_valid": 1},
        # Step 3: Second write
        {"tile_addr": 0, "start": 0, "write_valid": 1},
        # Step 4: Idle
        {"tile_addr": 0, "start": 0, "write_valid": 0},
        {"tile_addr": 0, "start": 0, "write_valid": 0},
    ]

    print("\nCycle by cycle analysis:")
    print("-" * 50)
    print("Cycle | Inputs      | State    | Row | Addr | WE |")
    print("-" * 50)

    for i, step in enumerate(steps):
        sim.step(step)

        # Inspect all relevant signals
        state = TiledAccumulatorFSM(sim.inspect(addr_gen.state.name))
        row = sim.inspect(addr_gen.current_row.name)
        addr = sim.inspect(addr_gen.internal_addr.name)
        we = sim.inspect(addr_gen.internal_write_enable.name)

        print(
            f"{i:5d} | "
            f"t={step['tile_addr']} "
            f"s={step['start']} "
            f"v={step['write_valid']} | "
            f"{state.name:8s} | "
            f"{row:3d} | "
            f"{addr:4d} | "
            f"{we:2d} | "
        )

    return sim


# Run simple test
test_address_generator_simple()


ROM Contents:
ROM data: [0, 2]
ROM width: 2
ROM addr width: 1

Cycle by cycle analysis:
--------------------------------------------------
Cycle | Inputs      | State    | Row | Addr | WE |
--------------------------------------------------
    0 | t=0 s=0 v=0 | IDLE     |   0 |    0 |  0 | 
    1 | t=0 s=1 v=0 | IDLE     |   0 |    0 |  0 | 
    2 | t=0 s=0 v=1 | WRITING  |   0 |    0 |  1 | 
    3 | t=0 s=0 v=1 | WRITING  |   1 |    1 |  1 | 
    4 | t=0 s=0 v=0 | IDLE     |   0 |    1 |  0 | 
    5 | t=0 s=0 v=0 | IDLE     |   0 |    1 |  0 | 


<pyrtl.simulation.Simulation at 0xffff71f5f6b0>

### A bigger test now

In [None]:
def test_address_generator_comprehensive():
    """Comprehensive test of address generator behavior"""
    from pyrtl import reset_working_block, Input, Output, Simulation

    reset_working_block()

    # Test parameters
    ARRAY_SIZE = 2
    TILE_ADDR_WIDTH = 1

    # Create inputs/outputs
    tile_addr = Input(TILE_ADDR_WIDTH, "tile_addr")
    start = Input(1, "start")
    write_valid = Input(1, "write_valid")

    # Create address generator
    addr_gen = TiledAddressGenerator(
        tile_addr_width=TILE_ADDR_WIDTH, array_size=ARRAY_SIZE
    )

    # Connect signals
    addr_gen.connect_tile_addr(tile_addr)
    addr_gen.connect_start(start)
    addr_gen.connect_write_valid(write_valid)

    # Create simulation
    sim = Simulation()

    # Test sequence
    steps = [
        # Initial state
        {"tile_addr": 0, "start": 0, "write_valid": 0},
        # Process tile 0
        {"tile_addr": 0, "start": 1, "write_valid": 0},  # Start tile 0
        {"tile_addr": 0, "start": 0, "write_valid": 1},  # Write row 0
        {"tile_addr": 0, "start": 0, "write_valid": 0},  # Gap in writes
        {"tile_addr": 0, "start": 0, "write_valid": 1},  # Write row 1
        # Immediate start of tile 1
        {"tile_addr": 1, "start": 1, "write_valid": 0},  # Start tile 1
        {"tile_addr": 1, "start": 0, "write_valid": 1},  # Write row 0
        {"tile_addr": 1, "start": 0, "write_valid": 1},  # Write row 1
        # Try to start new tile while busy (should be ignored)
        {"tile_addr": 0, "start": 1, "write_valid": 1},
        # Return to idle
        {"tile_addr": 0, "start": 0, "write_valid": 0},
    ]

    print("\nCycle by cycle analysis:")
    print("-" * 70)
    print("Cycle | Inputs      | State    | Row | Addr | WE | Notes")
    print("-" * 70)

    expected_states = []
    expected_addrs = []
    expected_rows = []

    for i, step in enumerate(steps):
        sim.step(step)

        # Inspect signals
        state = TiledAccumulatorFSM(sim.inspect(addr_gen.state.name))
        row = sim.inspect(addr_gen.current_row.name)
        addr = sim.inspect(addr_gen.internal_addr.name)
        we = sim.inspect(addr_gen.internal_write_enable.name)

        # Verify write enable is only high during WRITING state with valid input
        we = sim.inspect(addr_gen.internal_write_enable.name)
        assert we == (
            state == TiledAccumulatorFSM.WRITING and step["write_valid"]
        ), f"Write enable incorrect at cycle {i}. Expected {state == TiledAccumulatorFSM.WRITING and step['write_valid']}, got {we}"

        # Add notes about what should be happening
        if step["start"] and state == TiledAccumulatorFSM.IDLE:
            note = "Starting new tile"
        elif state == TiledAccumulatorFSM.WRITING and step["write_valid"]:
            note = f"Writing to tile {step['tile_addr']}"
        elif step["start"] and state == TiledAccumulatorFSM.WRITING:
            note = "Start ignored (busy)"
        else:
            note = ""

        print(
            f"{i:5d} | "
            f"t={step['tile_addr']} "
            f"s={step['start']} "
            f"v={step['write_valid']} | "
            f"{state.name:8s} | "
            f"{row:3d} | "
            f"{addr:4d} | "
            f"{we:2d} | "
            f"{note}"
        )

        # Store values for verification
        expected_states.append(state)
        expected_addrs.append(addr)
        expected_rows.append(row)

    # Verify behavior
    assert expected_addrs[2] == 0, "First write should be to address 0"
    assert expected_addrs[4] == 1, "Second write should be to address 1"
    assert expected_addrs[6] == 2, "First write of tile 1 should be to address 2"
    assert expected_addrs[7] == 3, "Second write of tile 1 should be to address 3"

    assert all(
        row < ARRAY_SIZE for row in expected_rows
    ), "Row counter should never exceed array size"

    return sim


# Run test
test_address_generator_comprehensive()


Cycle by cycle analysis:
----------------------------------------------------------------------
Cycle | Inputs      | State    | Row | Addr | WE | Notes
----------------------------------------------------------------------
    0 | t=0 s=0 v=0 | IDLE     |   0 |    0 |  0 | 
    1 | t=0 s=1 v=0 | IDLE     |   0 |    0 |  0 | Starting new tile
    2 | t=0 s=0 v=1 | WRITING  |   0 |    0 |  1 | Writing to tile 0
    3 | t=0 s=0 v=0 | WRITING  |   1 |    1 |  0 | 
    4 | t=0 s=0 v=1 | WRITING  |   1 |    1 |  1 | Writing to tile 0
    5 | t=1 s=1 v=0 | IDLE     |   0 |    1 |  0 | Starting new tile
    6 | t=1 s=0 v=1 | WRITING  |   0 |    2 |  1 | Writing to tile 1
    7 | t=1 s=0 v=1 | WRITING  |   1 |    3 |  1 | Writing to tile 1
    8 | t=0 s=1 v=1 | IDLE     |   0 |    3 |  0 | Starting new tile
    9 | t=0 s=0 v=0 | WRITING  |   0 |    0 |  0 | 


<pyrtl.simulation.Simulation at 0xffff71ee1f10>

### An even bigger test!

In [87]:
def test_address_generator_large():
    """Test address generator with 16x16 array and 8 tiles"""
    from pyrtl import reset_working_block, Input, Output, Simulation

    reset_working_block()

    # Test parameters
    ARRAY_SIZE = 16
    NUM_TILES = 8
    TILE_ADDR_WIDTH = (NUM_TILES - 1).bit_length()  # 3 bits for 8 tiles

    # Expected base addresses for each tile (ARRAY_SIZE * tile_number)
    TILE_BASE_ADDRS = [ARRAY_SIZE * i for i in range(NUM_TILES)]
    # Tile 0: 0
    # Tile 1: 16
    # Tile 2: 32
    # Tile 3: 48
    # Tile 4: 64
    # Tile 5: 80
    # Tile 6: 96
    # Tile 7: 112

    # Create inputs
    tile_addr = Input(TILE_ADDR_WIDTH, "tile_addr")
    start = Input(1, "start")
    write_valid = Input(1, "write_valid")

    # Create address generator
    addr_gen = TiledAddressGenerator(
        tile_addr_width=TILE_ADDR_WIDTH, array_size=ARRAY_SIZE
    )

    # Connect signals
    addr_gen.connect_tile_addr(tile_addr)
    addr_gen.connect_start(start)
    addr_gen.connect_write_valid(write_valid)

    # Create simulation
    sim = Simulation()

    # Test sequence - we'll test:
    # 1. Complete write of tile 0
    # 2. Partial write of tile 3 with gaps
    # 3. Complete write of tile 7 (last tile)
    steps = []

    # Tile 0 complete write
    steps.append({"tile_addr": 0, "start": 1, "write_valid": 0})  # Start
    for _ in range(ARRAY_SIZE):
        steps.append({"tile_addr": 0, "start": 0, "write_valid": 1})  # Write rows
    steps.append({"tile_addr": 0, "start": 0, "write_valid": 0})  # Idle cycle

    # Tile 3 partial write with gaps
    steps.append({"tile_addr": 3, "start": 1, "write_valid": 0})  # Start
    for i in range(8):  # Only write 8 rows with gaps
        steps.append(
            {"tile_addr": 3, "start": 0, "write_valid": i % 2}
        )  # Alternating valid
    steps.append({"tile_addr": 3, "start": 0, "write_valid": 0})  # Idle cycle

    # Tile 7 complete write
    steps.append({"tile_addr": 7, "start": 1, "write_valid": 0})  # Start
    for _ in range(ARRAY_SIZE):
        steps.append({"tile_addr": 7, "start": 0, "write_valid": 1})  # Write rows
    steps.append({"tile_addr": 7, "start": 0, "write_valid": 0})  # Idle cycle

    print("\nCycle by cycle analysis:")
    print("-" * 70)
    print("Cycle | Inputs      | State    | Row | Addr | WE | Expected Addr")
    print("-" * 70)

    for i, step in enumerate(steps):
        # Calculate expected values before simulation step
        curr_tile = step["tile_addr"]
        base_addr = TILE_BASE_ADDRS[curr_tile]

        # Model FSM and address generation in software
        if i == 0:
            expected_addr = 0
            expected_state = TiledAccumulatorFSM.IDLE
            expected_row = 0
        else:
            prev_step = steps[i - 1]

            # Model state transitions
            if expected_state == TiledAccumulatorFSM.IDLE:
                if prev_step["start"]:
                    expected_state = TiledAccumulatorFSM.WRITING
                    expected_addr = TILE_BASE_ADDRS[prev_step["tile_addr"]]
                    expected_row = 0
            elif expected_state == TiledAccumulatorFSM.WRITING:
                if prev_step["write_valid"]:
                    if expected_row == ARRAY_SIZE - 1:
                        expected_state = TiledAccumulatorFSM.IDLE
                        expected_row = 0
                    else:
                        expected_row += 1
                        expected_addr += 1

        # Simulate one step
        sim.step(step)

        # Get actual values
        state = TiledAccumulatorFSM(sim.inspect(addr_gen.state.name))
        row = sim.inspect(addr_gen.current_row.name)
        addr = sim.inspect(addr_gen.internal_addr.name)
        we = sim.inspect(addr_gen.internal_write_enable.name)

        assert (
            state == expected_state
        ), f"State mismatch at cycle {i}. Expected {expected_state}, got {state}"
        assert (
            addr == expected_addr
        ), f"Address mismatch at cycle {i}. Expected {expected_addr}, got {addr}"
        assert (
            row == expected_row
        ), f"Row mismatch at cycle {i}. Expected {expected_row}, got {row}"

        # Verify write enable
        assert we == (
            state == TiledAccumulatorFSM.WRITING and step["write_valid"]
        ), f"Write enable incorrect at cycle {i}"

        # Verify address
        # assert addr == expected_addr, \
        if addr != expected_addr:
            print(
                f"Address mismatch at cycle {i}. Expected {expected_addr}, got {addr}"
            )

        # Verify row counter
        assert row < ARRAY_SIZE, f"Row counter exceeded array size at cycle {i}"

        print(
            f"{i:5d} | "
            f"t={step['tile_addr']} "
            f"s={step['start']} "
            f"v={step['write_valid']} | "
            f"{state.name:8s} | "
            f"{row:3d} | "
            f"{addr:4d} | "
            f"{we:2d} | "
            f"{expected_addr:4d}"
        )

    return sim


# Run test
test_address_generator_large()


Cycle by cycle analysis:
----------------------------------------------------------------------
Cycle | Inputs      | State    | Row | Addr | WE | Expected Addr
----------------------------------------------------------------------
    0 | t=0 s=1 v=0 | IDLE     |   0 |    0 |  0 |    0
    1 | t=0 s=0 v=1 | WRITING  |   0 |    0 |  1 |    0
    2 | t=0 s=0 v=1 | WRITING  |   1 |    1 |  1 |    1
    3 | t=0 s=0 v=1 | WRITING  |   2 |    2 |  1 |    2
    4 | t=0 s=0 v=1 | WRITING  |   3 |    3 |  1 |    3
    5 | t=0 s=0 v=1 | WRITING  |   4 |    4 |  1 |    4
    6 | t=0 s=0 v=1 | WRITING  |   5 |    5 |  1 |    5
    7 | t=0 s=0 v=1 | WRITING  |   6 |    6 |  1 |    6
    8 | t=0 s=0 v=1 | WRITING  |   7 |    7 |  1 |    7
    9 | t=0 s=0 v=1 | WRITING  |   8 |    8 |  1 |    8
   10 | t=0 s=0 v=1 | WRITING  |   9 |    9 |  1 |    9
   11 | t=0 s=0 v=1 | WRITING  |  10 |   10 |  1 |   10
   12 | t=0 s=0 v=1 | WRITING  |  11 |   11 |  1 |   11
   13 | t=0 s=0 v=1 | WRITING  |  12 | 

<pyrtl.simulation.Simulation at 0xffff71d992e0>

### IDK what this test is but it passes

In [88]:
def test_address_generator():
    """Test address generator with ROM-based base addresses"""
    from pyrtl import reset_working_block, Input, Output, Simulation

    reset_working_block()

    # Test parameters
    TILE_ADDR_WIDTH = 2  # 4 tiles
    ARRAY_SIZE = 4  # 4x4 systolic array

    # Expected base addresses for verification
    EXPECTED_BASE_ADDRS = {
        0: 0,  # Tile 0 starts at 0
        1: 4,  # Tile 1 starts at 4
        2: 8,  # Tile 2 starts at 8
        3: 12,  # Tile 3 starts at 12
    }

    # Create inputs
    tile_addr = Input(TILE_ADDR_WIDTH, "tile_addr")
    start = Input(1, "start")
    write_valid = Input(1, "write_valid")

    # Create outputs to monitor
    write_addr = Output(4, "write_addr")  # 4 bits can address 16 locations
    write_enable = Output(1, "write_enable")
    current_row = Output(2, "current_row")  # 2 bits for 0-3

    # Create address generator
    addr_gen = TiledAddressGenerator(TILE_ADDR_WIDTH, ARRAY_SIZE)

    # Connect signals
    addr_gen.connect_tile_addr(tile_addr)
    addr_gen.connect_start(start)
    addr_gen.connect_write_valid(write_valid)

    write_addr <<= addr_gen.internal_write_addr
    write_enable <<= addr_gen.internal_write_enable
    current_row <<= addr_gen.current_row

    # Create simulation
    sim = Simulation()

    # Test sequence
    print("\nTesting Address Generation:")

    # Test each tile
    for tile in range(4):
        print(f"\nTile {tile}:")
        # Start new tile
        sim.step({"tile_addr": tile, "start": 1, "write_valid": 0})

        # Write all rows
        for row in range(4):
            sim.step({"tile_addr": tile, "start": 0, "write_valid": 1})
            addr = sim.inspect("write_addr")
            print(
                f"Row {row}: Address = {addr} "
                f"(Expected {EXPECTED_BASE_ADDRS[tile] + row})"
            )

    return sim


# Run test
if __name__ == "__main__":
    sim = test_address_generator()


Testing Address Generation:

Tile 0:
Row 0: Address = 0 (Expected 0)
Row 1: Address = 1 (Expected 1)
Row 2: Address = 2 (Expected 2)
Row 3: Address = 3 (Expected 3)

Tile 1:
Row 0: Address = 4 (Expected 4)
Row 1: Address = 5 (Expected 5)
Row 2: Address = 6 (Expected 6)
Row 3: Address = 7 (Expected 7)

Tile 2:
Row 0: Address = 8 (Expected 8)
Row 1: Address = 9 (Expected 9)
Row 2: Address = 10 (Expected 10)
Row 3: Address = 11 (Expected 11)

Tile 3:
Row 0: Address = 12 (Expected 12)
Row 1: Address = 13 (Expected 13)
Row 2: Address = 14 (Expected 14)
Row 3: Address = 15 (Expected 15)


# Next section ...