In [None]:
import os
import pyrtl
from pyrtl import *
from typing import Callable, Type
from hardware_accelerators.dtypes import *
from hardware_accelerators.rtllib import *
from hardware_accelerators.rtllib.utils.adder_utils import *
from hardware_accelerators.rtllib.utils.multiplier_utils import *
from hardware_accelerators.rtllib.utils.common import *

# Generating Hardware Blocks for Analysis


In [11]:
dtype_list = [Float8, BF16, Float32]
dtype_map = {8: Float8, 16: BF16, 32: Float32}
w_a_pairs = [(8, 8), (8, 16), (8, 32), (16, 16), (16, 32), (32, 32)]
w_a_dtypes = [(dtype_map[w], dtype_map[a]) for w, a in w_a_pairs]

In [12]:
def create_basic_hardware_block(
    fn: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector],
    dtype: Type[BaseFloat],
    **kwargs,
):
    bitwidth = dtype.bitwidth()
    a, b = pyrtl.Input(bitwidth, "a"), pyrtl.Input(bitwidth, "b")
    out = pyrtl.Output(bitwidth, "out")
    out <<= fn(a, b, dtype, **kwargs)

In [53]:
def analyze(block: Block | None = None):
    if block is not None:
        pyrtl.set_working_block(block)
    pyrtl.synthesize()
    pyrtl.optimize()
    timing = pyrtl.TimingAnalysis()
    delay = timing.max_length()
    print(f"\nest. max delay: {delay:.2f} ps")
    print(f"est. max freq: {timing.max_freq():.2f} MHz")
    print(f"est. area: {pyrtl.area_estimation()}\n\n")

## Adders


In [108]:
def create_inputs(*bitwidths):
    return (pyrtl.Input(bitwidth) for bitwidth in bitwidths)


def create_outputs(*args):
    for wire in args:
        out = pyrtl.Output(len(wire))
        out <<= wire


def create_pipelined_adder(dtype: Type[BaseFloat], fast: bool = False):
    a, b = pyrtl.Input(dtype.bitwidth(), "a"), pyrtl.Input(dtype.bitwidth(), "b")
    w_en = pyrtl.Input(1, "w_en")
    out = pyrtl.Output(dtype.bitwidth(), "out")
    adder = FloatAdderPipelined(a, b, w_en, dtype, fast)
    out <<= adder.result


def create_adder_blocks(dtype: Type[BaseFloat], fast: bool = False) -> dict[str, Block]:
    bits = dtype.bitwidth()
    e_bits, m_bits = dtype.exponent_bits(), dtype.mantissa_bits()

    combinational_block = pyrtl.Block()
    adder_block = pyrtl.Block()
    stage_2_block = pyrtl.Block()
    stage_3_block = pyrtl.Block()
    stage_4_block = pyrtl.Block()
    stage_5_block = pyrtl.Block()

    # Combinational design
    with set_working_block(combinational_block):
        create_outputs(*float_adder(*create_inputs(bits, bits), dtype=dtype, fast=fast))

    # Complete pipelined design
    with set_working_block(adder_block):
        create_pipelined_adder(dtype, fast)

    # Stages 1 & 2
    with set_working_block(stage_2_block):
        float_components = extract_float_components(
            *create_inputs(bits, bits),
            e_bits=e_bits,
            m_bits=m_bits,
        )
        stage_2_outputs = adder_stage_2(
            *float_components,
            e_bits,
            m_bits,
            fast,
        )
        create_outputs(*stage_2_outputs)

    # Stage 3
    with set_working_block(stage_3_block):
        # Perform alignment and generate SGR bits
        stage_3_outputs = adder_stage_3(
            *create_inputs(m_bits + 1, e_bits),
            e_bits=e_bits,
            m_bits=m_bits,
        )
        create_outputs(*stage_3_outputs)

    # Stage 4
    with set_working_block(stage_4_block):
        # Perform mantissa addition and leading zero detection
        stage_4_outputs = adder_stage_4(
            *create_inputs(m_bits + 1, m_bits + 1, 1), m_bits=m_bits, fast=fast
        )
        create_outputs(*stage_4_outputs)

    # Stage 5
    with set_working_block(stage_5_block):
        # Perform normalization, rounding, and final assembly
        stage_5_outputs = adder_stage_5(
            *create_inputs(
                m_bits + 2,  # abs_mantissa: m_bits + 2 wide
                1,  # sticky_bit: 1 bit
                1,  # guard_bit: 1 bit
                1,  # round_bit: 1 bit
                4,  # lzc: 4 bits wide
                e_bits,  # exp_larger: e_bits wide
                1,  # sign_a: 1 bit
                1,  # sign_b: 1 bit
                e_bits + 1,  # exp_diff: e_bits + 1 wide
                1,  # is_neg: 1 bit
            ),
            e_bits=e_bits,
            m_bits=m_bits,
        )
        create_outputs(*stage_5_outputs)

    # Return all the generated blocks for analysis
    return {
        "combinational": combinational_block,
        "adder": adder_block,
        "stage_2": stage_2_block,
        "stage_3": stage_3_block,
        "stage_4": stage_4_block,
        "stage_5": stage_5_block,
    }

In [109]:
temp_working_block()
m_bits = BF16.mantissa_bits()

create_outputs(
    *add_sub_mantissas(
        *create_inputs(m_bits + 1, m_bits + 1, 1), m_bits=m_bits, fast=True
    )
)
analyze()


est. max delay: 1944.98 ps
est. max freq: 429.56 MHz
est. area: (0.001515888, 0)




In [None]:
temp_working_block()
m_bits = BF16.mantissa_bits()

create_outputs(
    *adder_stage_4(*create_inputs(m_bits + 1, m_bits + 1, 1), m_bits=m_bits, fast=True)
)
analyze()


est. max delay: 3862.15 ps
est. max freq: 235.56 MHz
est. area: (0.002212848, 0)




In [110]:
adder_blocks = create_adder_blocks(BF16, fast=True)

for name, block in adder_blocks.items():
    print(f"Analyzing {name} block:")
    analyze(block)

Analyzing combinational block:

est. max delay: 10181.66 ps
est. max freq: 94.66 MHz
est. area: (0.007849512, 0)


Analyzing adder block:

est. max delay: 3862.15 ps
est. max freq: 235.56 MHz
est. area: (0.0127561104, 0)


Analyzing stage_2 block:

est. max delay: 1619.71 ps
est. max freq: 499.32 MHz
est. area: (0.001942776, 0)


Analyzing stage_3 block:

est. max delay: 3090.40 ps
est. max freq: 287.90 MHz
est. area: (0.001724976, 0)


Analyzing stage_4 block:

est. max delay: 3862.15 ps
est. max freq: 235.56 MHz
est. area: (0.002212848, 0)


Analyzing stage_5 block:

est. max delay: 2499.24 ps
est. max freq: 346.95 MHz
est. area: (0.001986336, 0)




In [107]:
reset_working_block()

create_pipelined_adder(BF16, True)
analyze()


est. max delay: 3862.15 ps
est. max freq: 235.56 MHz
est. area: (0.0127561104, 0)




# Multipliers


In [113]:
def create_multiplier_blocks(
    dtype: Type[BaseFloat], fast: bool = False
) -> dict[str, Block]:
    bits = dtype.bitwidth()
    e_bits, m_bits = dtype.exponent_bits(), dtype.mantissa_bits()

    combinational_block = pyrtl.Block()
    multiplier_block = pyrtl.Block()
    stage_2_block = pyrtl.Block()
    stage_3_block = pyrtl.Block()
    stage_4_block = pyrtl.Block()

    # Combinational design
    with set_working_block(combinational_block):
        create_outputs(
            float_multiplier(*create_inputs(bits, bits), dtype=dtype, fast=fast)
        )

    # Complete pipelined design
    with set_working_block(multiplier_block):
        multiplier = FloatMultiplierPipelined(
            *create_inputs(bits, bits), dtype=dtype, fast=fast
        )
        create_outputs(multiplier._result)

    # Stage 1 & 2: Extract components and calculate sign, exponent sum, mantissa product
    with set_working_block(stage_2_block):
        float_components = extract_float_components(
            *create_inputs(bits, bits),
            e_bits=e_bits,
            m_bits=m_bits,
        )
        stage_2_outputs = multiplier_stage_2(
            *float_components,
            m_bits,
            fast,
        )
        create_outputs(*stage_2_outputs)

    # Stage 3: Leading zero detection and exponent adjustment
    with set_working_block(stage_3_block):
        stage_3_outputs = multiplier_stage_3(
            *create_inputs(e_bits + 1, 2 * m_bits + 2),  # exp_sum, mantissa_product
            e_bits=e_bits,
            m_bits=m_bits,
            fast=fast,
        )
        create_outputs(*stage_3_outputs)

    # Stage 4: Normalization, rounding, and final assembly
    with set_working_block(stage_4_block):
        stage_4_outputs = multiplier_stage_4(
            *create_inputs(
                e_bits,  # unbiased_exp
                e_bits,  # leading_zeros
                2 * m_bits + 2,  # mantissa_product
            ),
            m_bits=m_bits,
            e_bits=e_bits,
            fast=fast,
        )
        create_outputs(*stage_4_outputs)

    # Return all the generated blocks for analysis
    return {
        "combinational": combinational_block,
        "multiplier": multiplier_block,
        "stage_2": stage_2_block,
        "stage_3": stage_3_block,
        "stage_4": stage_4_block,
    }

In [117]:
multiplier_blocks = create_multiplier_blocks(Float8, fast=True)

for name, block in multiplier_blocks.items():
    print(f"Analyzing {name} block:")
    analyze(block)

Analyzing combinational block:

est. max delay: 4906.51 ps
est. max freq: 189.05 MHz
est. area: (0.003023064, 0)


Analyzing multiplier block:

est. max delay: 1828.47 ps
est. max freq: 452.19 MHz
est. area: (0.0038001744, 0)


Analyzing stage_2 block:

est. max delay: 1394.44 ps
est. max freq: 562.61 MHz
est. area: (0.001062864, 0)


Analyzing stage_3 block:

est. max delay: 1585.10 ps
est. max freq: 508.10 MHz
est. area: (0.000670824, 0)


Analyzing stage_4 block:

est. max delay: 1828.47 ps
est. max freq: 452.19 MHz
est. area: (0.001210968, 0)


