# Accumulator Buffer

The accumulator buffer is needed to store partial results of large matrix multiplications from the outputs of the systolic array.

Each accumulator bank will connect to a column output from the systolic array. We can specify the address width which will determine the number of values we can store. This is equal to `2^len(mem_addr)`. 

In [23]:
from dataclasses import dataclass
import pyrtl
from IPython.display import SVG
from pyrtl import *
import numpy as np
from typing import List, Type, Callable
from hardware_accelerators.dtypes import BaseFloat
from hardware_accelerators.rtllib import float_adder
from hardware_accelerators.simulation import convert_array_dtype, render_waveform
from hardware_accelerators.simulation.repr_funcs import *

In [39]:
def accum(size, data_in, waddr, wen, wclear, raddr, lastvec):
    """A single 32-bit accumulator with 2^size 32-bit buffers.
    On wen, writes data_in to the specified address (waddr) if wclear is high;
    otherwise, it performs an accumulate at the specified address (buffer[waddr] += data_in).
    lastvec is a control signal indicating that the operation being stored now is the
    last vector of a matrix multiply instruction (at the final accumulator, this becomes
    a "done" signal).
    """

    mem = MemBlock(bitwidth=16, addrwidth=size, name="MEMORY")

    sum = float_adder(data_in, mem[waddr], BF16)[: mem.bitwidth]

    # Writes
    with conditional_assignment:
        with wen:
            with wclear:
                mem[waddr] |= data_in
            with otherwise:
                mem[waddr] |= sum
    # Read
    data_out = mem[raddr]

    # # Pipeline registers
    # waddrsave = Register(len(waddr))
    # waddrsave.next <<= waddr
    # wensave = Register(1)
    # wensave.next <<= wen
    # wclearsave = Register(1)
    # wclearsave.next <<= wclear
    # lastsave = Register(1)
    # lastsave.next <<= lastvec

    # return data_out, waddrsave, wensave, wclearsave, lastsave


reset_working_block()
# Output(32), Output(3), Output(1), Output(1), Output(1) =
accum(
    3,
    Input(16, "data_input"),
    Input(3, "write_address"),
    Input(1, "write_enable"),
    Input(1, "clear"),
    Input(3, "read_address"),
    Input(1, "last vector"),
)

# Single Accumulator Block

In [None]:
class AccumulatorBlock:
    def __init__(
        self,
        data_type: Type[BaseFloat],
        addr_width: int,
        adder: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector],
    ):
        """Single accumulator block with memory and floating point add capability
        
        Args:
            data_type: Number format for accumulation values
            addr_width: Number of address bits for memory
            adder: Floating point adder implementation
        """
        self.data_type = data_type
        self.data_width = data_type.bitwidth()
        
        # Input wires
        self.write_data = WireVector(self.data_width)
        self.write_addr = WireVector(addr_width)
        self.write_enable = WireVector(1)
        self.write_clear = WireVector(1)  # 1 for overwrite, 0 for accumulate
        self.read_addr = WireVector(addr_width)
        
        # Memory block
        self.memory = MemBlock(
            bitwidth=self.data_width,
            addrwidth=addr_width,
        )
        
        # Read data is direct from memory
        self.read_data = self.memory[self.read_addr]
        
        # Write logic with floating point add
        with conditional_assignment:
            with self.write_enable:
                with self.write_clear:
                    # Overwrite mode
                    self.memory[self.write_addr] |= self.write_data
                with ~self.write_clear:
                    # Accumulate mode - use floating point add
                    current = self.memory[self.write_addr]
                    sum_result = adder(self.write_data, current, self.data_type)e
                    self.memory[self.write_addr] |= sum_result

    def connect_write_data(self, source: WireVector):
        """Connect data input"""
        self.write_data <<= source
        
    def connect_write_addr(self, addr: WireVector):
        """Connect write address input"""
        self.write_addr <<= addr
        
    def connect_write_enable(self, enable: WireVector):
        """Connect write enable signal"""
        self.write_enable <<= enable
        
    def connect_write_clear(self, clear: WireVector):
        """Connect write clear signal (overwrite vs accumulate)"""
        self.write_clear <<= clear
        
    def connect_read_addr(self, addr: WireVector):
        """Connect read address input"""
        self.read_addr <<= addr

In [16]:
@dataclass
class AccumulatorPorts:
    """Container for accumulator buffer I/O ports"""

    write_addr: WireVector
    write_data: List[WireVector]  # List of data inputs, one per column
    write_enable: WireVector
    write_clear: WireVector  # 1 to overwrite, 0 to accumulate
    read_addr: WireVector
    read_data: List[WireVector]  # List of data outputs, one per column


class AccumulatorBuffer:
    def __init__(
        self,
        data_type: Type[BaseFloat],
        addr_width: int,
        num_columns: int,
        adder: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector],
    ):
        """Initialize accumulator buffer hardware

        Args:
            data_type: Number format for accumulation
            addr_width: Number of address bits
            num_columns: Number of parallel accumulator columns
            adder: Floating point adder implementation
        """
        self.data_type = data_type
        self.data_width = data_type.bitwidth()
        self.addr_width = addr_width
        self.num_columns = num_columns
        self.adder = adder

        # Create memory blocks - one per column
        self.memories = [
            MemBlock(
                bitwidth=self.data_width,
                addrwidth=self.addr_width,
                name=f"accum_mem_{i}",
            )
            for i in range(num_columns)
        ]

        # Create I/O ports
        self.write_addr = WireVector(addr_width)
        self.write_data = [WireVector(self.data_width) for _ in range(num_columns)]
        self.write_enable = WireVector(1)
        self.write_clear = WireVector(1)
        self.read_addr = WireVector(addr_width)
        self.read_data = [WireVector(self.data_width) for _ in range(num_columns)]

        # Connect memory logic
        for i in range(num_columns):
            with conditional_assignment:
                with self.write_enable:
                    with self.write_clear:
                        # Overwrite mode
                        self.memories[i][self.write_addr] |= self.write_data[i]
                    with pyrtl.otherwise:
                        # Accumulate mode - use floating point add
                        self.memories[i][self.write_addr] |= self.adder(
                            self.write_data[i],
                            self.memories[i][self.write_addr],
                            self.data_type,
                        )

            # Connect read port
            self.read_data[i] <<= self.memories[i][self.read_addr]

        # Package ports for external access
        self.ports = AccumulatorPorts(
            write_addr=self.write_addr,
            write_data=self.write_data,
            write_enable=self.write_enable,
            write_clear=self.write_clear,
            read_addr=self.read_addr,
            read_data=self.read_data,
        )

In [None]:
def test_accumulator():
    """Test accumulator buffer with sample data"""
    reset_working_block()

    # Test parameters
    SIZE = 4
    ADDR_BITS = 3  # 8 addresses
    dtype = BF16  # Use BFloat16 format

    # Create test data matrix
    test_data = np.array(
        [
            [1.5, 2.0, -1.0, 0.5],
            [0.25, -2.0, 3.0, 1.0],
            [-1.5, 0.5, 2.0, -0.5],
            [3.0, -1.0, 0.25, 2.0],
        ]
    )

    # Convert to binary format
    binary_data = convert_array_dtype(test_data, dtype)

    # Create input/output wires
    w_addr = Input(ADDR_BITS, "write_addr")
    w_en = Input(1, "write_enable")
    w_clear = Input(1, "write_clear")
    r_addr = Input(ADDR_BITS, "read_addr")

    w_data = [Input(dtype.bitwidth(), f"write_data_{i}") for i in range(SIZE)]
    r_data = [Output(dtype.bitwidth(), f"read_data_{i}") for i in range(SIZE)]

    # Create accumulator
    accum = AccumulatorBuffer(
        data_type=dtype, addr_width=ADDR_BITS, num_columns=SIZE, adder=float_adder
    )

    # Connect ports
    accum.connect_write_addr(w_addr)
    accum.connect_write_enable(w_en)
    accum.connect_write_clear(w_clear)
    accum.connect_read_addr(r_addr)

    for i in range(SIZE):
        accum.connect_write_data(i, w_data[i])
        accum.connect_read_data(i, r_data[i])

    # Create simulation
    sim = Simulation()

    # Helper function to create input dictionary
    def make_inputs(addr, data_row=None, enable=0, clear=0, read_addr=0):
        inputs = {
            "write_addr": addr,
            "write_enable": enable,
            "write_clear": clear,
            "read_addr": read_addr,
        }
        if data_row is not None:
            for i in range(SIZE):
                inputs[f"write_data_{i}"] = int(data_row[i])
        return inputs

    # Simulation steps

    # Step 1: Write first row with clear
    sim.step(make_inputs(0, binary_data[0], enable=1, clear=1))

    # Step 2: Write second row with clear
    sim.step(make_inputs(1, binary_data[1], enable=1, clear=1))

    # Step 3: Read first row
    sim.step(make_inputs(0, read_addr=0))

    # Step 4: Read second row
    sim.step(make_inputs(0, read_addr=1))

    # Helper function to convert output to float
    def get_output_row(sim_trace):
        return [
            dtype(binint=sim_trace[f"read_data_{i}"]).decimal_approx
            for i in range(SIZE)
        ]

    # Print results
    print("First row read:")
    print(get_output_row(sim))
    print("\nSecond row read:")
    print(get_output_row(sim))

    return sim


# Run test
if __name__ == "__main__":
    sim = test_accumulator()

PyrtlError: no nesting of conditional assignments allowed