In [1]:
from tutorial_utils import magma_to_verilog_string, smt_to_smtlib_string

import hwtypes as hw
import magma as m

import peak
from peak import Peak  # the base class of Peak circuits
from peak import family_closure



In the previous section, we introduced hwtypes and showed how we can use hwtypes to meta-program higher-order functions in python. In this section, we will introduce PEak which extends the expression language of hwtypes.

Similar to hwtypes, the high-level goal of PEak is to create a single source of truth which functions as a functional model, a formal specification, and as RTL. PEak differs from hwtypes by giving designers a more natural language to describe circuits,  specifically as classes. PEak attempts to mimic the behavior of normal python. To this aim, PEak circuits declare subcomponents in their `__init__` method and define their behavior in their `__call__` method.  Unlike hardwarre types, PEak allows for the use of `if` statements to build `ites`.  Further, the PEak compiler reinterprets assignment to registers as register writes and references to registers as register reads.

It should be noted that the formal model generated by PEak cannot be used to verify the RTL generated by PEak.  Showing equivelence between the formal model and the RTL would simply show that the PEak compiler's back-ends for SMT and RTL are consistent (of course, such a check could be useful for finding bugs in the PEak compiler, but not for verifying the RTL).

In the following example, we demonstrate the type of program we aim to be able to write.  First we define an `ALU` class which performs either an add or a multiply on two data inputs (`in_0`, `in_1`) and is controlled by a single bit `op`.  Next, we define a `PE` class which contains 3 `ALU`s.  The `PE` has 4 data inputs (`in_0`, ..., `in_3`) and a 3-bit control singal (`ops`), which controls the `ALU`s. 

In [2]:
BV = hw.BitVector
DataT = BV[8]
Bit = hw.Bit

class ALU:
    def __call__(self, op: Bit, in_0: DataT, in_1: DataT) -> DataT:
        if op:
            return in_0 + in_1
        else:
            return in_0 * in_1
        
class PE:
    def __init__(self):
        self.alu_0 = ALU()
        self.alu_1 = ALU()
        self.alu_2 = ALU()
    
    def __call__(self, 
                 ops: BV[3],
                 in_0: DataT,
                 in_1: DataT,
                 in_2: DataT,
                 in_3: DataT,
                ) -> DataT:
        res_0 = self.alu_0(ops[0], in_0, in_1)
        res_1 = self.alu_1(ops[1], in_2, in_3)
        return self.alu_2(ops[2], res_0, res_1)

pe = PE()
s =  pe(hw.BitVector[3](0b101), DataT(1), DataT(2), DataT(3), DataT(4))
assert s == (1+2)+(3*4)
print(repr(s))

BitVector[8](15)


The above will work with python values (as shown); however, as is, this code cannot generate SMT or RTL.  This is because the it is still fundamentally a hwtypes program and is hence subject to the restrictions in the previous section, i.e., an `if` statement can only be evaluated on constant values, and RTL requires a Magma wrapper circuit.  PEak removes these restrictions by compiling `if` statement into `ite`s and automatically generating a wrapper circuit. To evoke the PEak compiler some boiler plate must be added.  Below, we show how to extend the above example using PEak (code points of interest are labeled with comments `# k`).

In [3]:
@family_closure(peak.family) # 1
def closure(family): # 2
    BV = family.BitVector #
    DataT = BV[8]         # 3
    Bit = family.Bit      #
    
    @family.compile(locals(), globals()) # 4
    class ALU(Peak): # 5
        def __call__(self, op: Bit, in_0: DataT, in_1: DataT) -> DataT: # 6
            if op:
                return in_0 + in_1
            else:
                return in_0 * in_1
            
    @family.compile(locals(), globals()) # 4
    class PE(Peak): # 5
        def __init__(self):
            self.alu_0 = ALU()
            self.alu_1 = ALU()
            self.alu_2 = ALU()

        def __call__(self, 
                     ops: BV[3],  #
                     in_0: DataT, #
                     in_1: DataT, #
                     in_2: DataT, # 6
                     in_3: DataT, #
                    ) -> DataT:   # 
            res_0 = self.alu_0(ops[0], in_0, in_1)
            res_1 = self.alu_1(ops[1], in_2, in_3)
            return self.alu_2(ops[2], res_0, res_1)
    
    return PE

For the moment, let us ignore the `family_closure` decorator (`# 1`).  We explain it last, as understanding its behavior and utility is difficult without first understanding the rest of the code. 

The first piece of boiler plate is the construction of a closure (`# 2`) over the various interpretations (e.g., python, SMT, Magma).  This closure takes a single argument, which is a *family* object.  A family object provides an implementation of the core `Bit` and `BitVector` types.  It must also provide an implementation of a register type and Abstract Data Types (ADTs).  We explain the latter two later.  Additionally, each family defines a specific compilation flow, as we describe below.

Recall that in Section 1, we defined `bounded_factorial` as follows: 
```Python
PyDataT = hw.BitVector[8]
SmtDataT = hw.SMTBitVector[8]
MagmaDataT = m.Bits[8]

def bounded_factorial(x):
    if not isinstance(x, hw.AbstractBitVector):
        raise TypeError()
    T = type(x)
    ...
```
We needed to dynamically get the type of `x` to allow us to construct constants with the proper type, e.g., `T(1)`.  In a PEak program we, access these type constructors through the family object and hence avoid the necessity of such dynamic inspection.  In PEak we would write:

```Python
@family_closure(peak.family)
def closure(family):
    T = family.Unsigned[8]
    MAX_UINT = 2**T.size - 1
    
    def inner(x, ctr):
        if ctr == 0:
            return T(1)
        else:
            return (x <= 1).ite(
                T(1),
                x * inner(x - 1, ctr - 1),
            )
        
    def bounded_factorial(x):
        return inner(x, MAX_UINT)
    
    return bounded_factorial
```

This is not a large win in terms of code size, but it is significantly more natural code to write.  Further, one may not have direct access to a value of the desired type, e.g., when casting between signed and unsigned. 

`family.compile` (`# 4`) invokes the PEak compiler, passing the current namespace to the compiler with `locals(), globals()`.  Each family defines its own compilation flow. By having specialized compilation flows, we allow the SMT family to rewrite `if` statements to `ite`s, while the Magma family rewrites the `if`s <mark>to multiplexers?</mark> and wraps the resulting hwtypes program in a circuit.  While the full details of the rewrites used by the SMT and Magma implementations are quite complex, they are fairly straightforward for simple examples.  For example, the body of the `ALU`'s `__call__` method would be rewritten to a hwtypes program semantically equivalent to the following:
```Python
_cond_0 = op
_return_0 = in_0 + in_1
_return_1 = in_0 + in_1
return _cond_0.ite(_return_0, return_1)
```

PEak circuits should inherit from the `Peak` class (`# 5`).  This enables the *PEak protocol*.  The PEak protocol allows types to define how they behave when being read or written.  We will demonstrate a use of this later when we introduce registers.

It is important to note that the type annotations on the `__call__` method (`# 6`) are *not* optional.  The peak compiler uses the annotations to generate ports in a Magma context.

<mark>In what sense is family overloaded</mark>
The notion of family is overloaded in PEak, with the base families for the python, Magma, and SMT implementations being defined in the module `peak.family`.  This module also defines a *family group*; a family group is an object (typically a module) with attributes `PyFamily`, `SMTFamily`, and `MagmaFamily`.  Each of these attributes defines the type of a family within the family group.  The purpose of a family group is to allow uniform access to types whose implementation may differ between interpretations.  While the default family group simply provides type constructors for the primitive types, a family group may provide implementations of complex modules such as memories or floating point units.
<mark>Didn't fully understand this</mark>

The `family_closure` decorator (`# 1`) takes a family group as a parameter.  This parameter associates the decorated closure with a specific family group. This association allows the family closure to provide a shortcut using convenient syntax for calling the closure:

```Python
closure.Py == closure(family_group.PyFamily())
closure.SMT == closure(family_group.SMTFamily())
closure.Magma == closure(family_group.MagmaFamily())
```

This may seem inconsequential, but it is very convenient for programmatic manipulation of peak programs, as one can invoke the desired interpretation without knowledge of the specific families used.
  
As a convenience when using the base family group (`peak.family`), one may omit the family group parameter, e.g., the above example could have been written as:

```Python
@family_closure # 1
def closure(family): # 2
    ...
```

Prior to the development of families, developers were forced to manually wrap their code in boilerplate to achieve the same behavior in a much less ergonomic way. Below is a small example of what that looked like.

```Python
COMPILE_TARGET = 'magma'

...

if COMPILE_TARGET == 'python':
    BV = hw.BitVector
elif COMPILE_TARGET == 'magma':
    BV = m.Bits
else:
    assert COMPILE_TARGET == 'smt'
    BV = hw.SMTBitVector

class PE:
    def __call__(self):
        ...  # code that uses BV

if COMPILE_TARGET == 'magma':
    PE = magma_compile(PE) # invoke magma compilation flow
elif COMPILE_TARGET == 'smt':
    PE = smt_compile(PE) # invoke smt compilation flow
```

In the above, the first `if` block performs the functional equivalent of `family.BitVector`.   The second `if` block performs `family.compile`.

We will return to the ALU example to demonstrate how encapsulation can be used to extend an existing module

In [4]:
@family_closure
def data_t_closure(family):
    BV = family.BitVector
    DataT = BV[8]        
    Bit = family.Bit
    return BV, DataT, Bit

@family_closure
def closure(family):
    BV, DataT, Bit = data_t_closure(family)
    
    @family.compile(locals(), globals())
    class ALU(Peak):
        def __call__(self, op: Bit, in_0: DataT, in_1: DataT) -> DataT:
            if op:
                return in_0 + in_1
            else:
                return in_0 * in_1
            
    return ALU

@family_closure
def ext_closure(family):
    BV, DataT, Bit = data_t_closure(family)
    ALU = closure(family)
    
    @family.compile(locals(), globals())
    class ExtALU(Peak):
        def __init__(self):
            self.alu = ALU()
            
        def __call__(self, op: BV[2], in_0: DataT, in_1: DataT) -> DataT:
            if op[0]:
                in_1 = -in_1
            return ALU(op[1], in_0, in_1)         
    return ExtALU



As we mentioned previously PEak supports ADTs. In the following example we show enums can be used to build instruction sets in the place of raw bitvectors. We believe enums provide a more robust mechanism for defining instruction sets

In [5]:
class ISA(hw.Enum):
    Add = hw.new_instruction()
    Sub = hw.new_instruction()
    And = hw.new_instruction()
    Or = hw.new_instruction()


@family_closure
def closure(family):
    BV = family.BitVector
    DataT = BV[8]
    Bit = family.Bit

    @family.assemble(locals(), globals())
    class ALU(Peak):
        def __call__(self, op: ISA, in_0: DataT, in_1: DataT) -> DataT:
            if op == ISA.Add:
                return in_0 + in_1
            elif op == ISA.Sub:
                return in_0 - in_1
            elif op == ISA.And:
                return in_0 & in_1
            else:
                return in_0 | in_1

    return ALU

We can compose enums sets by using a `Sum` type. `Sum` types will be explained in more detail in the next section

In [6]:
class Arith(hw.Enum):
    Add = hw.new_instruction()
    Sub = hw.new_instruction()


class Bitwise(hw.Enum):
    And = hw.new_instruction()
    Or = hw.new_instruction()


ISA = hw.Sum[Arith, Bitwise]


@family_closure
def lu_fc(family):
    BV = family.BitVector
    DataT = BV[8]

    @family.assemble(locals(), globals())
    class LU(Peak):
        def __call__(self, op: Bitwise, in_0: DataT, in_1: DataT) -> DataT:
            if op == Bitwise.And:
                return in_0 & in_1
            else:
                return in_0 | in_1

    return LU


@family_closure
def au_fc(family):
    BV = family.BitVector
    DataT = BV[8]

    @family.assemble(locals(), globals())
    class AU(Peak):
        def __call__(self, op: Arith, in_0: DataT, in_1: DataT) -> DataT:
            if op == Arith.Add:
                return in_0 + in_1
            else:
                return in_0 - in_1

    return AU


@family_closure
def alu_fc(family):
    BV = family.BitVector
    DataT = BV[8]
    LU_t = lu_fc(family)
    AU_t = au_fc(family)

    @family.assemble(locals(), globals())
    class ALU(Peak):
        def __init__(self):
            self.au = AU_t()
            self.lu = LU_t()

        def __call__(self, op: ISA, in_0: DataT, in_1: DataT) -> DataT:
            if op[Arith].match:
                return self.au(op[Arith].value, in_0, in_1)
            else:
                return self.lu(op[Bitwise].value, in_0, in_1)

    return ALU

So far all the examples demonstated so far do not include state.  The base families provides two distinct register primitives.  The first uses the same call sementatics as other peak circuits. 



In [7]:
@family_closure
def closure(family):
    DataT = family.BitVector[4]
    Bit = family.Bit
    DataRegister = family.gen_register(DataT, 0)
    
    @family.assemble(locals(), globals())
    class PipeLinedIncrementor(Peak):
        def __init__(self):
            self.data_reg = DataRegister()
            
        def __call__(self, stall: Bit, i: DataT) -> DataT:
            o = self.data_reg(i+1, ~stall) # enable the register if it is not stalled
            return o
            
    return PipeLinedIncrementor

pipe = closure.SMT()

SmtDataT = closure._family_.SMTFamily().BitVector[4]

data_0 = SmtDataT(name='data_0')
stall_0 = hw.SMTBit(name='stall_0')
data_out = pipe(stall_0, data_0)

print('cycle 0 output:')
print(smt_to_smtlib_string(data_out))

data_1 = SmtDataT(name='data_1')
stall_1 = hw.SMTBit(name='stall_1')
data_out = pipe(stall_1, data_1)

print('\ncycle 1 output:')
print(smt_to_smtlib_string(data_out))

data_2 = SmtDataT(name='data_2')
stall_2 = hw.SMTBit(name='stall_2')
data_out = pipe(stall_2, data_2)

print('\ncycle 2 output:')
print(smt_to_smtlib_string(data_out))
del data_0
del stall_0
del data_1
del stall_1
del data_2
del stall_2

cycle 0 output:
#b0000

cycle 1 output:
(ite (not stall_0) (bvadd data_0 #b0001) #b0000)

cycle 2 output:
(ite (not stall_1) (bvadd data_1 #b0001) (ite (not stall_0) (bvadd data_0 #b0001) #b0000))


TODO: fix up

The observant reader may notice that this syntax does not allow for a registers next state to be dependent on its current state as their is no way to probe the registers outputs without settings its inputs. Peak provides a second syntax for registers to address this.  In this syntax register reads and writes are performed implicitly ..somthing something.. setting or getting the attribute. In this syntax registers do not have an enable and must be set on all paths (Is this still true?)

In [8]:
@family_closure
def closure(family):
    DataT = family.BitVector[4]
    Bit = family.Bit
    DataRegister = family.gen_attr_register(DataT, 0)
    max_count = 4
    
    @family.assemble(locals(), globals())
    class Counter(Peak):
        def __init__(self):
            self.data_reg = DataRegister()
            
        def __call__(self, en: Bit) -> DataT:
            prev = self.data_reg
            if en:
                val = prev + 1
                if val < max_count:
                    self.data_reg = val
                else:
                    self.data_reg = DataT(0)
            else:
                self.data_reg = prev
            
            return prev + 1
            
    return Counter


In [9]:
Ctr = closure.SMT
ctr = Ctr()
init_value = hw.SMTBitVector[4](name='initial')
ctr.data_reg = init_value

for i in range(1):
    enable = hw.SMTBit(name=f'enable_{i}')
    out = ctr(enable)
    reg_state = ctr.data_reg
    
    print(f'\ncycle {i} output:')
    print(f'reg_state = {smt_to_smtlib_string(reg_state)}')
    print(f'out = {smt_to_smtlib_string(out)}')
    del enable
    
del init_value


cycle 0 output:
reg_state = (ite enable_0 (ite (bvult (bvadd initial #b0001) #b0100) (bvadd initial #b0001) #b0000) initial)
out = (bvadd initial #b0001)


In [10]:
@family_closure
def closure(family):
    DataT = family.BitVector[4]
    Bit = family.Bit
    DataRegister = family.gen_attr_register(DataT, 0)
    max_count = 4
    
    @family.assemble(locals(), globals())
    class Counter(Peak):
        def __init__(self):
            self.data_reg = DataRegister()
            
        def __call__(self, en: Bit) -> DataT:
            prev = self.data_reg
            if en:
                val = prev + 1
                if val < max_count:
                    self.data_reg = val
                else:
                    self.data_reg = DataT(0)
            else:
                self.data_reg = prev
            
            return prev
            
    return Counter

Again an observant reader may note that the first syntax can be constructed from the second.  The first is provided for legacy reasons and to allow for better synthesis in magma.  The better synthesis results stem from the use of vendor provided registers with enables instead of a mux and a register.   

So far all the demonstrated example have been static.  PEak provides meta-programming in two forms.  First, through the use of raw hwtypes, and second through loop unrolling and if inlining.  Which we will demonstrate over a series of examples. Our goal in these examples will be to generate an ALU and its ISA from a dicitonary which names each operation, its behavior, and its operands.  We will start with a hard coded implementation which has the desired behavior then we will show to generate that behavior using loop unrolling. 

In [11]:
import typing as tp
from hwtypes.adt import Enum, Tuple, Product

import ast_tools
from ast_tools.passes import apply_passes, loop_unroll, if_inline

In [12]:
def validate_arguments(
    op_map: dict[
            str,
            tuple[
                tp.Callable,
                tuple[int, ...]
            ]
        ], 
        num_ports: int, 
        datawidth: int) -> None:
    if not op_map:
        raise ValueError('op_map must not be empty')
    if datawidth < 1:
        raise ValueError('datawidth must be positive')
    if num_ports < 1:
        raise ValueError('num_ports must be positive')
        
    for k, (_, op_indices) in op_map.items():
        if type(k) != str:
            raise ValueError('op_map must have str keys')
        for idx in op_indices:
            if idx < 0:
                raise ValueError('op_indices must not be negative')
            elif idx >= num_ports:
                raise ValueError('op_indices must less than num_ports')

op_map_t = dict[str, tuple[tp.Callable, tuple[int, ...]]]

def alugen(
        op_map: op_map_t, 
        num_ports: int, 
        datawidth: int) -> tuple[tp.Type[Enum], family_closure]:
    validate_arguments(op_map, num_ports, datawidth)
    class Opcode(Enum):
        Add = Enum.Auto()
        Sub = Enum.Auto()
        Neg = Enum.Auto()

    @family_closure
    def ALU_fc(family):
        Word = family.BitVector[16]
        Bit = family.Bit
        T = Tuple[Word, Word, Word]

        @family.assemble(locals(), globals())
        class ALU(Peak):
            def __call__(self, inst: Opcode, inputs: T) -> Word:
                if inst == Opcode.Add:
                    return inputs[0] + inputs[1] + inputs[2]
                elif inst == Opcode.Sub:
                    return inputs[0] - inputs[2]
                else:
                    return -inputs[1]

        return ALU

    return Opcode, ALU_fc


def add3(a, b, c):
    return a + b + c


def sub(a, b):
    return a - b


def neg(a):
    return -a


NUM_INPUTS = 3
DATA_WIDTH = 16

OP_MAP = {
    'Add': (add3, (0, 1, 2)),
    'Sub': (sub, (0, 2)),
    'Neg': (neg, (1, )),
}

Opcode, ALU_fc = alugen(OP_MAP, NUM_INPUTS, DATA_WIDTH)

assert OP_MAP.keys() == Opcode.field_dict.keys()

DataT = Tuple[hw.BitVector[16], hw.BitVector[16], hw.BitVector[16]]

ALU = ALU_fc.Py # ALU_fc(peak.family.PyFamily())

assert issubclass(ALU.input_t, Tuple[Opcode, DataT])

alu = ALU()

for _ in range(256):
    a = hw.BitVector.random(16)
    b = hw.BitVector.random(16)
    c = hw.BitVector.random(16)
    inputs = DataT(a, b, c)

    assert alu(Opcode.Add, inputs) == a + b + c
    assert alu(Opcode.Sub, inputs) == a - c
    assert alu(Opcode.Neg, inputs) == -b

The above example the ALU has 3 inputs and supports 3 operations `Add`, `Sub`, and `Neg`.  The `Add` performs the sum of all inputs.  `Sub` returns the difference of the first and 3rd inputs.  `Neg` negates the second input. However, both our implementation and tests are hard coded to a specific generator parameters. We will first generalize the tests.   

In [13]:
import traceback
import sys

import warnings
def test_alugen(alugen_f, op_map, num_ports, datawidth):
    Opcode, ALU_fc = alugen_f(op_map, num_ports, datawidth)
    
    errors = 0
    # warn instead of assert so all errors are displayed
    def assert_equal(a, b, opcode=None):
        nonlocal errors
        if a != b:
            warnings.warn(f'\n\tTest Failed{"" if not opcode else f" for {opcode}"}:\n\t\t{repr(a)} == {repr(b)}', stacklevel=2)
            print(file=sys.stderr)
            errors += 1
        
    assert_equal(op_map.keys(), Opcode.field_dict.keys())
    
    Word = hw.BitVector[datawidth]
    DataT = Tuple[(Word for _ in range(num_ports))]
    
    ALU = ALU_fc.Py
    
    # PEak presents the input type as a Product of its inputs 
    ExpectedInputT = Product.from_fields('Input', {'inst': Opcode, 'inputs': DataT})
    assert_equal(ALU.input_t, ExpectedInputT)
    # PEak always wraps the return type in a Tuple.  
    # We insure that there is a single output of type Word)
    assert_equal(ALU.output_t, Tuple[Word])
    
    alu = ALU()
    
    for _ in range(256):
        inputs = DataT(*(hw.BitVector.random(datawidth) for _ in range(num_ports)))
        
        for opname, opcode in Opcode.field_dict.items():
            impl, op_indices = op_map[opname]
            assert_equal(alu(opcode, inputs), impl(*(inputs[idx] for idx in op_indices)), opname)
            
    return errors

In the above test the first assert insure the generated ISA, `Opcode`, has the specified instructions.  Next we assert that the generated ALU has the expected type. Finally, we test the functionality of the generated ALU by showing that each instruction is equivalent to it its specified implementation.  Below we see that the our hardcoded `alu_gen` does pass the generalized test , however, when we change the bitwidth to 8 or add a `Zero` instruction to the `op_map` it fails as expected. 

In [14]:
if errors := test_alugen(alugen, OP_MAP, 3, 16):
    print(f'Test had {errors} errors, see below for details')
else:
    print('Test passed!')

new_op_map = {k : v for k,v in OP_MAP.items()}
new_op_map['Zero'] = (lambda : 0), ()

if errors := test_alugen(alugen, new_op_map, 3, 8):
    print(f'Test had {errors} errors, see below for details')
else:
    print('Test passed!')


Test passed!
Test had 3 errors, see below for details


	Test Failed:
		dict_keys(['Add', 'Sub', 'Neg', 'Zero']) == dict_keys(['Add', 'Sub', 'Neg'])
  assert_equal(op_map.keys(), Opcode.field_dict.keys())

	Test Failed:
		Input == Input
  assert_equal(ALU.input_t, ExpectedInputT)

	Test Failed:
		Tuple[BitVector[16]] == Tuple[BitVector[8]]
  assert_equal(ALU.output_t, Tuple[Word])



We will now demonstrate how to use metaprogramming to generate the ALU.  We will start with metaprogramming the generations of the IO and the ISA as it is fairly trivial.

In [15]:
def alugen(
        op_map: op_map_t, 
        num_ports: int, 
        datawidth: int) -> tuple[tp.Type[Enum], family_closure]:
    
    validate_arguments(op_map, num_ports, datawidth)
    
    Opcode = Enum.from_fields('Opcode', {k: Enum.Auto() for k in op_map})

    @family_closure
    def ALU_fc(family):
        Word = family.BitVector[datawidth]
        Bit = family.Bit
        T = Tuple[(Word for _ in range(num_ports))]

        @family.assemble(locals(), globals())
        class ALU(Peak):
            def __call__(self, inst: Opcode, inputs: T) -> Word:
                if inst == Opcode.Add:
                    return inputs[0] + inputs[1] + inputs[2]
                elif inst == Opcode.Sub:
                    return inputs[0] - inputs[2]
                else:
                    return -inputs[1]

        return ALU

    return Opcode, ALU_fc

The above differs from the original in 3 ways:
1. `Opcode` is generated programmatically with `from_fields`
2. `Word` is defined to be `BitVector[datawidth]`
3. `T` is defined to be `Tuple[(Word for _ in range(num_ports))]` i.e. a `Tuple` of `num_ports` `Word`s
    

In [16]:
if errors := test_alugen(alugen, OP_MAP, 3, 16):
    print(f'Test had {errors} errors, see below for details')
else:
    print('Test passed!')

if errors := test_alugen(alugen, new_op_map, 3, 8):
    print(f'Test had {errors} errors, see below for details')
else:
    print('Test passed!')

Test passed!
Test had 254 errors, see below for details


	Test Failed for Zero:
		BitVector[8](215) == 0
  assert_equal(alu(opcode, inputs), impl(*(inputs[idx] for idx in op_indices)), opname)

	Test Failed for Zero:
		BitVector[8](136) == 0
  assert_equal(alu(opcode, inputs), impl(*(inputs[idx] for idx in op_indices)), opname)

	Test Failed for Zero:
		BitVector[8](84) == 0
  assert_equal(alu(opcode, inputs), impl(*(inputs[idx] for idx in op_indices)), opname)

	Test Failed for Zero:
		BitVector[8](159) == 0
  assert_equal(alu(opcode, inputs), impl(*(inputs[idx] for idx in op_indices)), opname)

	Test Failed for Zero:
		BitVector[8](92) == 0
  assert_equal(alu(opcode, inputs), impl(*(inputs[idx] for idx in op_indices)), opname)

	Test Failed for Zero:
		BitVector[8](53) == 0
  assert_equal(alu(opcode, inputs), impl(*(inputs[idx] for idx in op_indices)), opname)

	Test Failed for Zero:
		BitVector[8](89) == 0
  assert_equal(alu(opcode, inputs), impl(*(inputs[idx] for idx in op_indices)), opname)

	Test Failed for Zero:
		BitVector[8](135) ==

We see that now that are new implementation actuallly raises more errors.  This is because it is being tested on the `Zero` instruction which it does not implement.  The `Zero` instruction is being handled by the `else` case meant to handle the `Neg` instruction. Where what we want is the following:

In [17]:
def alugen(
        op_map: op_map_t, 
        num_ports: int, 
        datawidth: int) -> tuple[tp.Type[Enum], family_closure]:
    
    validate_arguments(op_map, num_ports, datawidth)
    
    Opcode = Enum.from_fields('Opcode', {k: Enum.Auto() for k in op_map})

    @family_closure
    def ALU_fc(family):
        Word = family.BitVector[datawidth]
        Bit = family.Bit
        T = Tuple[(Word for _ in range(num_ports))]

        @family.assemble(locals(), globals())
        class ALU(Peak):
            def __call__(self, inst: Opcode, inputs: T) -> Word:
                if inst == Opcode.Add:
                    return inputs[0] + inputs[1] + inputs[2]
                elif inst == Opcode.Sub:
                    return inputs[0] - inputs[2]
                elif inst == Opcode.Neg:
                    return -inputs[1]
                else:
                    return Word(0)


        return ALU

    return Opcode, ALU_fc

In [18]:
if errors := test_alugen(alugen, OP_MAP, 3, 16):
    print(f'Test had {errors} errors, see below for details')
else:
    print('Test passed!')

if errors := test_alugen(alugen, new_op_map, 3, 8):
    print(f'Test had {errors} errors, see below for details')
else:
    print('Test passed!')

Test passed!
Test passed!


Now of course the the above have not actually gotten us any closer to having a working generator. We will start by showing how to genrate the return values of the if / else structure then how to generate the if / else structure itself.

The code below should be fairly strait forward.  For each instruction we get its `impl` and `op_indices` from the `op_map`.  This is then passed to `_build_term` which selects the operands from `inputs` as specified by `op_indices` and applys `impl`.  The resulting term is then cast to `Word` if it is an `int` (as is the case for the `Zero` instruction).  The cast is necesarry because `ite` does not coerce its arguements. Specifically the `0` returned by `Zero`must be manually converted.

In [19]:
def alugen(
        op_map: op_map_t, 
        num_ports: int, 
        datawidth: int) -> tuple[tp.Type[Enum], family_closure]:
    
    validate_arguments(op_map, num_ports, datawidth)
    
    Opcode = Enum.from_fields('Opcode', {k: Enum.Auto() for k in op_map})

    @family_closure
    def ALU_fc(family):
        Word = family.BitVector[datawidth]
        Bit = family.Bit
        T = Tuple[(Word for _ in range(num_ports))]
        
        def _build_term(impl: tp.Callable[..., Word | int], inputs: T, op_indices: tuple[int, ...]) -> Word:
            term = impl(*(inputs[idx] for idx in op_indices))
            if isinstance(term, int):
                return Word(term)
            else:
                return term
            
        @family.assemble(locals(), globals())
        class ALU(Peak):
            def __call__(self, inst: Opcode, inputs: T) -> Word:
                if inst == Opcode.Add:
                    impl, op_indices = op_map['Add']
                    return _build_term(impl, inputs, op_indices)
                elif inst == Opcode.Sub:
                    impl, op_indices = op_map['Sub']
                    return _build_term(impl, inputs, op_indices)
                elif inst == Opcode.Neg:
                    impl, op_indices = op_map['Neg']
                    return _build_term(impl, inputs, op_indices)
                else:
                    impl, op_indices = op_map['Zero']
                    return _build_term(impl, inputs, op_indices)

        return ALU

    return Opcode, ALU_fc

In [20]:
if errors := test_alugen(alugen, OP_MAP, 3, 16):
    print(f'Test had {errors} errors, see below for details')
else:
    print('Test passed!')

if errors := test_alugen(alugen, new_op_map, 3, 8):
    print(f'Test had {errors} errors, see below for details')
else:
    print('Test passed!')

Test passed!
Test passed!


Now there are multiple aproaches to generalize. It can be directly implemented in hwtypes using a similar pattern to the match example.

In [21]:
def alugen_hwtypes(
        op_map: op_map_t, 
        num_ports: int, 
        datawidth: int) -> tuple[tp.Type[Enum], family_closure]:
    
    validate_arguments(op_map, num_ports, datawidth)
    
    Opcode = Enum.from_fields('Opcode', {k: Enum.Auto() for k in op_map})

    @family_closure
    def ALU_fc(family):
        Word = family.BitVector[datawidth]
        Bit = family.Bit
        T = Tuple[(Word for _ in range(num_ports))]
        
        def _build_term(impl: tp.Callable[..., Word | int], inputs: T, op_indices: tuple[int, ...]) -> Word:
            term = impl(*(inputs[idx] for idx in op_indices))
            if isinstance(term, int):
                return Word(term)
            else:
                return term
            
        def _build_ite(op_map: op_map_t, inst: Opcode, inputs: T) -> Word:
            iterator = reversed(op_map.items())
            op_name, (f, op_indices) = next(iterator)
            term = _build_term(f, inputs, op_indices)
            
            for op_name, (f, op_indices) in iterator:     
                op_enum = getattr(Opcode, op_name)
                term = Bit(
                    op_enum == inst
                ).ite(_build_term(f, inputs, op_indices), term)
                
            return term

        @family.assemble(locals(), globals())
        class ALU(Peak):
            def __call__(self, inst: Opcode, inputs: T) -> Word:
                return _build_ite(op_map, inst, inputs)

        return ALU

    return Opcode, ALU_fc

In [22]:
if errors := test_alugen(alugen_hwtypes, OP_MAP, 3, 16):
    print(f'Test had {errors} errors, see below for details')
else:
    print('Test passed!')

if errors := test_alugen(alugen_hwtypes, new_op_map, 3, 8):
    print(f'Test had {errors} errors, see below for details')
else:
    print('Test passed!')

Test passed!
Test passed!


However such a solution leaves much to be desired. The goal of PEak is to allow the use of If statements.  To this aim we allow the use of loop unrolling and if inlining to build a similar structure.

In [23]:
import ast_tools
from ast_tools.passes import apply_passes, loop_unroll, if_inline

def alugen_asttools(
        op_map: op_map_t, 
        num_ports: int, 
        datawidth: int) -> tuple[tp.Type[Enum], family_closure]:
    
    validate_arguments(op_map, num_ports, datawidth)
    
    op_names = list(op_map)
    Opcode = Enum.from_fields('Opcode', {k: Enum.Auto() for k in op_names})
    
    @family_closure
    def ALU_fc(family):
        Word = family.BitVector[datawidth]
        Bit = family.Bit
        T = Tuple[(Word for _ in range(num_ports))]
        
        def _build_term(impl: tp.Callable[..., Word | int], inputs: T, op_indices: tuple[int, ...]) -> Word:
            term = impl(*(inputs[idx] for idx in op_indices))
            if isinstance(term, int):
                return Word(term)
            else:
                return term

        @family.assemble(locals(), globals())
        class ALU(Peak):
            @apply_passes([loop_unroll()])
            def __call__(self, inst: Opcode, inputs: T) -> Word:
                for i in ast_tools.macros.unroll(range(len(op_names)-1)):
                    op_name = op_names[i]
                    op_enum = getattr(Opcode, op_name)
                    if op_enum == inst:
                        f, op_indices = op_map[op_name]
                        return _build_term(f, inputs, op_indices)
                    
                op_name = op_names[-1]
                f, op_indices = op_map[op_name]
                op_enum = getattr(Opcode, op_name)
                return _build_term(f, inputs, op_indices)
        return ALU

    return Opcode, ALU_fc

In [24]:
if errors := test_alugen(alugen_asttools, OP_MAP, 3, 16):
    print(f'Test had {errors} errors, see below for details')
else:
    print('Test passed!')

if errors := test_alugen(alugen_asttools, new_op_map, 3, 8):
    print(f'Test had {errors} errors, see below for details')
else:
    print('Test passed!')

Test passed!
Test passed!


In the previous example uses `ast_tools` loop unrolling macros.  It generates the following code:
```Python
    op_name = op_names[0] # Add
    op_enum = getattr(Opcode, op_name)
    if op_enum == inst:
        f, op_indices = op_map[op_name]
        return _build_term(f, inputs, op_indices)
    op_name = op_names[1] # Sub
    op_enum = getattr(Opcode, op_name)
    if op_enum == inst:
        f, op_indices = op_map[op_name]
        return _build_term(f, inputs, op_indices)
    op_name = op_names[2] # Neg
    op_enum = getattr(Opcode, op_name)
    if op_enum == inst:
        f, op_indices = op_map[op_name]
        return _build_term(f, inputs, op_indices)
    op_name = op_names[-1] # Zero
    f, op_indices = op_map[op_name]
    op_enum = getattr(Opcode, op_name)
    return _build_term(f, inputs, op_indices)
```
Currently ast tools unrolling is limited to integers only.  However, other datatypes can be used by using the pattern used above (using ints to index into a list). Further it is important to note that the body unrolled must be complete syntax.  Hence it is not possible to generate else ifs.

So far we have used random testing to ensure correct behavior of of the generated ALU. However, we can do better.  We can formally verify its behavior.

In [25]:
import pysmt
from pysmt import shortcuts as sc

def verify_alugen(alugen_f, op_map, num_ports, datawidth):
    sc.reset_env()
    Opcode, ALU_fc = alugen_f(op_map, num_ports, datawidth)
    
    sf = ALU_fc._family_.SMTFamily()
    
    Word = sf.BitVector[datawidth]
    DataT = Tuple[(Word for _ in range(num_ports))]
    
    ALU = ALU_fc.SMT

    
    alu = ALU()
    input_bvs = [Word(name=f'input_{i}') for i in range(num_ports)]
    inputs = DataT(*input_bvs)

    for opname, opcode in Opcode.field_dict.items():
        asm_op = sf.get_adt_t(Opcode)(opcode)
        impl, op_indices = op_map[opname]
        res = alu(asm_op, inputs)
        gold = impl(*(inputs[idx] for idx in op_indices))
        with sc.Solver('z3') as s:
            s.add_assertion((res != gold).value)
            if s.solve():
                print(f'Counter example found for {opname}')
                for i, bv in enumerate(input_bvs):
                    print(f'input_{i} = {s.get_value(bv.value)}')
                print(f'res = {s.get_value(res.value)}')
                print(f'gold = {s.get_value(gold.value)}')
            else:
                print(f'verified {opname}')

In [26]:
verify_alugen(alugen_asttools, OP_MAP, 3, 16)

verified Add
verified Sub
verified Neg


In [27]:
verify_alugen(alugen_asttools, new_op_map, 3, 8)

verified Add
verified Sub
verified Neg
verified Zero


Invalid instruction -> else clause behavior. 

In [28]:
import ast_tools
from ast_tools.passes import apply_passes, loop_unroll, if_inline

def alugen_asttools_with_valid_out(
        op_map: op_map_t, 
        num_ports: int, 
        datawidth: int) -> tuple[tp.Type[Enum], family_closure]:
    
    validate_arguments(op_map, num_ports, datawidth)
    
    op_names = list(op_map)
    Opcode = Enum.from_fields('Opcode', {k: Enum.Auto() for k in op_names})
    
    @family_closure
    def ALU_fc(family):
        Word = family.BitVector[datawidth]
        Bit = family.Bit
        T = Tuple[(Word for _ in range(num_ports))]
        
        def _build_term(impl: tp.Callable[..., Word | int], inputs: T, op_indices: tuple[int, ...]) -> Word:
            term = impl(*(inputs[idx] for idx in op_indices))
            if isinstance(term, int):
                return Word(term)
            else:
                return term

        @family.assemble(locals(), globals())
        class ALU(Peak):
            @apply_passes([loop_unroll()])
            def __call__(self, inst: Opcode, inputs: T) -> (Word, Bit):
                for i in ast_tools.macros.unroll(range(len(op_names))):
                    op_name = op_names[i]
                    op_enum = getattr(Opcode, op_name)
                    if op_enum == inst:
                        f, op_indices = op_map[op_name]
                        return _build_term(f, inputs, op_indices), Bit(1)
                
                return Word(0), Bit(0)
        return ALU

    return Opcode, ALU_fc

In [29]:
def verify_alugen_with_valid(alugen_f, op_map, num_ports, datawidth):
    sc.reset_env()
    Opcode, ALU_fc = alugen_f(op_map, num_ports, datawidth)
    
    sf = ALU_fc._family_.SMTFamily()
    
    Word = sf.BitVector[datawidth]
    DataT = Tuple[(Word for _ in range(num_ports))]
    
    ALU = ALU_fc.SMT

    
    alu = ALU()
    input_bvs = [Word(name=f'input_{i}') for i in range(num_ports)]
    inputs = DataT(*input_bvs)

    for opname, opcode in Opcode.field_dict.items():
        asm_op = sf.get_adt_t(Opcode)(opcode)
        impl, op_indices = op_map[opname]
        res, valid = alu(asm_op, inputs)
        gold = impl(*(inputs[idx] for idx in op_indices))
        with sc.Solver('z3') as s:
            s.add_assertion((res != gold).value)
            s.add_assertion(valid.value)
            if s.solve():
                print(f'Counter example found for {opname}')
                for i, bv in enumerate(input_bvs):
                    print(f'input_{i} = {s.get_value(bv.value)}')
                print(f'res = {s.get_value(res.value)}')
                print(f'res = {s.get_value(valid.value)}')
                print(f'gold = {s.get_value(gold.value)}')
            else:
                print(f'verified {opname}')
    
    AsmOpcode = sf.get_adt_t(Opcode)
    AsmBV = AsmOpcode._bitvector_t_()
    free_op_bv = AsmBV()
    free_op = AsmOpcode(free_op_bv)
    is_valid_op = AsmOpcode._is_valid_(free_op_bv)
    res, valid = alu(free_op, inputs)
    with sc.Solver('z3') as s:
        s.add_assertion(sc.Not((~is_valid_op | valid).value))
        if s.solve():
            print("oops")
        else:
            print("verified valid op -> valid out")
    
    with sc.Solver('z3') as s:
        s.add_assertion(sc.Not((is_valid_op | ~valid).value))
        if s.solve():
            print("oops")
        else:
            print("verified ~valid op -> ~valid out")

        


In [30]:
verify_alugen_with_valid(alugen_asttools_with_valid_out, OP_MAP, 3, 16)

verified Add
verified Sub
verified Neg
verified valid op -> valid out
verified ~valid op -> ~valid out


TODO Weave it all togeather with if inline optionally add the valid signal 