In [1]:
%load_ext autoreload
%autoreload 2
%env CIRCT_HOME=/home/donovick/src/circt/

env: CIRCT_HOME=/home/donovick/src/circt/


In [2]:
from tutorial_utils import magma_to_verilog_string, smt_to_smtlib_string

import hwtypes as hw
import magma as m



The core of PEak is the python-embedded expression language of hwtypes. hwtypes defines abstract interfaces and type constructors for a number of types and kinds. This includes fixed-width bitvector types, arbitrary precision floating point types, algebraic data types, and a bit (bool) type.  In this Section, we focus on the bit and bitvector types; the others are covered later.

A `BitVector` type of length `n` is constucted with `BitVector[n]`. `BitVector` has two key subtypes: `Unsigned` and `Signed`. The author of this tutorial would have preferred for the mixing of signed and unsigned bitvector types to universally raise an error. However, in order to support legacy code, such mixing instead has undefined behavior which may raise errors.  Bitvector types have numerous constructors (also for legacy reasons). While the set of types accepted is not specifically defined by the abstract interface, bitvector implementations minimally support construction from: `BitVector`, `Bit`, `int`, types which define `__int__`, and sequences of objects that can be used to construct `Bit`. 

Similar to `BitVector`, the full set of types that can be used to construct `Bit` types is implementation-dependent but minimally includes: `Bit`, `bool`, types which define `__bool__`, and the integer constants `0` and `1`. `Bit` types are required to implement the standard bitwise operators: and `&`, or `|`, xor `^`, and not `~`; as well as equals `==`, not equals `!=`, and an if-then-else (`ite`) method which we will describe later in this section.  

Note that at a high level, objects which have an `__int__` method can be thought of objects which are "castable" to `int`.  Readers should refer to the relevent python documentation for more details (https://docs.python.org/3/library/functions.html#int). Similarly, the `__bool__` method is used to define how objects are "cast" to `bool`; however, *all* objects are "castable" to `bool` unless they raise an error in `__bool__` or `__len__`.  See:  https://docs.python.org/3/library/functions.html#bool and https://docs.python.org/3/library/stdtypes.html#truth.
    
The SMT-LIB standard (http://smtlib.cs.uiowa.edu/index.shtml) defines a large set of arithmetic and bitwise functions on bitvectors.  These functions are defined in both the base theory (`FixedSizeBitVectors`) and in its associated logics (`BV` and `QF_BV`). The hwtypes bitvector interface defines a method for each of these functions with the exception of `bv2nat`.  For instance, the equivalent of the smt term `(bvadd x y)` where `x` and `y` are terms of the sort `(_ BitVector n)` would be the hwtypes expression `x.bvadd(y)` where `x` and `y` are of the type `BitVector[n]`. More generally, if `f` is a function defined by the SMT-LIB standard on bitvectors than there is an equivalent method named `f` on the hwtypes `BitVector` type. 
   
As a convenience, these methods are defined as operators where applicable.  For example: `x.bvadd(y)` can be invoked with `x + y`. The semantics of sign-dependent operators are defined by their type. For example, `x < y` invokes `x.bvslt(y)` for signed `x` and `x.bvult(y)` for unsigned `x`.  Most operators and methods attempt to coerce their arguments. Any object that can be used to construct a bitvector will typically be coerced. This was done to allow the use of python integer constants, e.g., `bv + 1`.

hwtypes provides two main implementations of the bitvector and bit types.  The first implementation is a pure python functional model over constant values. The second wraps pySMT (described below) to generate SMT terms. Finaly, Magma provides a third implementation which allows for the definition of circuits. This uniform interface allows for the same hwtypes programs to be interpretted multiple ways.  The pure python implementation can be used to simulate a circuit, the SMT implementation can be used to generate a formal model, and the Magma implementation can be used to generate actual RTL.  This single source of truth paradigm is powerful for designers, as they need not implement the same thing multiple times, and the different implementations are guaranteed to be consistent with each other. Code written using the hwtypes bitvector type can be thought of as a formal and executable specification. When specifications are written in a non-formal, non-executable way, there is typically some room for interpretation, which can lead to difficulties when it comes time to test the design.  Often, the functional model, formal model, or RTL differ in some way, and we are left wondering where the error is.  Did all teams interpret the specification similarly and correctly? If not, whose implementation is wrong?  With a single source of truth, such mismatches are avoided by construction.

pySMT is a solver-independent python API for constructing SMT formulas. pySMT provides a uniform interface while allowing the user to use their SMT solver of choice under the hood. Importantly, pySMT can also construct SMT formulas without a solver.  This allows the constructors of the different BitVector implementations to be uniform. In contrast, SMT-Switch (a similar project) requires a solver object to to build terms. This would be quite inconvenient for hwtypes as either a reference to a solver would need to be passed to each bitvector (changing the constructor interface) or there would need to be an implicit global context which holds a reference to a solver (which would make working with multiple solvers difficult).
 
hwtypes is an expression language only; all statements are executed in pure python (as opposed to some sort of AST-rewriting-based approach) following typical python semantics. In Section 2, we will introduce the Language PEak which breaks away from the semantics of pure python and reinterprets the meaning of assignment statements and if statements.
    
    

Below, we show how the same function can be invoked to do calculations with constant python values, produce symbolic SMT values, or build a Magma circuit:

In [3]:
def add(x, y):
    return x + y

In [4]:
PyDataT = hw.BitVector[8]
SmtDataT = hw.SMTBitVector[8]
MagmaDataT = m.Bits[8]

x = PyDataT(1)
y = PyDataT(2)
results = add(x, y)
print(repr(results))
print('---')

x = SmtDataT(name='x')
y = SmtDataT(name='y')
results = add(x, y)
print(smt_to_smtlib_string(results))
print('---')
# del because jupyter seems to keep references alive longer than it should which breaks SMT variables
del x
del y


class Adder(m.Circuit):
    io = m.IO(
        x=m.In(MagmaDataT), y=m.In(MagmaDataT), results=m.Out(MagmaDataT)
    )
    io.results @= add(io.x, io.y)


print(magma_to_verilog_string(Adder))
print('---')

BitVector[8](3)
---
(bvadd x y)
---
module coreir_add #(
    parameter width = 1
) (
    input [width-1:0] in0,
    input [width-1:0] in1,
    output [width-1:0] out
);
  assign out = in0 + in1;
endmodule

module Adder (
    input [7:0] x,
    input [7:0] y,
    output [7:0] results
);
wire [7:0] magma_Bits_8_add_inst0_out;
coreir_add #(
    .width(8)
) magma_Bits_8_add_inst0 (
    .in0(x),
    .in1(y),
    .out(magma_Bits_8_add_inst0_out)
);
assign results = magma_Bits_8_add_inst0_out;
endmodule


---


The real power of hwtypes comes from its embedding in python which facilitates the generation of more complex formulas. For example, we can generalize add to build an adder tree over any number of inputs with the use of a recursive function:

In [5]:
def add_n(*args):
    n = len(args)
    if n == 0:
        return 0
    elif n == 1:
        return args[0]
    else:
        return add_n(*args[:n // 2]) + add_n(*args[n // 2:])

In [6]:
x = PyDataT(1)
y = PyDataT(2)
z = PyDataT(3)
results = add_n(x, y, z)
print(repr(results))
print('---')

x = SmtDataT(name='x')
y = SmtDataT(name='y')
z = SmtDataT(name='z')
results = add_n(x, y, z)
print(smt_to_smtlib_string(results))
print('---')
del x
del y
del z


class Adder3(m.Circuit):
    io = m.IO(
        x=m.In(MagmaDataT),
        y=m.In(MagmaDataT),
        z=m.In(MagmaDataT),
        results=m.Out(MagmaDataT)
    )
    io.results @= add_n(io.x, io.y, io.z)


print(magma_to_verilog_string(Adder3))
print('---')

BitVector[8](6)
---
(bvadd x (bvadd y z))
---
module coreir_add #(
    parameter width = 1
) (
    input [width-1:0] in0,
    input [width-1:0] in1,
    output [width-1:0] out
);
  assign out = in0 + in1;
endmodule

module Adder3 (
    input [7:0] x,
    input [7:0] y,
    input [7:0] z,
    output [7:0] results
);
wire [7:0] magma_Bits_8_add_inst0_out;
wire [7:0] magma_Bits_8_add_inst1_out;
coreir_add #(
    .width(8)
) magma_Bits_8_add_inst0 (
    .in0(y),
    .in1(z),
    .out(magma_Bits_8_add_inst0_out)
);
coreir_add #(
    .width(8)
) magma_Bits_8_add_inst1 (
    .in0(x),
    .in1(magma_Bits_8_add_inst0_out),
    .out(magma_Bits_8_add_inst1_out)
);
assign results = magma_Bits_8_add_inst1_out;
endmodule


---


We can easily further generalize this to build reduction trees over any function. 

<mark>Are you assuming f is commutative and associative here?  What assumptions are on f?</mark>

<mark>Caleb - I am not making any assumptions beyond the function signature.  I am not claiming the generated reduction tree is equivelent to a lfold/rfold (if I was I would need to assume associativity).  </mark>

<mark>Ok got it - in that case, maybe rename ident to something else - base_case or empty or something?</mark>

<mark>I don't think a function needs to be associative to have an identity.  We typically talk about Identities with regards to operators and not functions, however, there really isn't a difference between the two. Let $M$ be a magma (a set which is closed under a binary operator which I will call $+$). $M$ has a left identity iff $\exists l \in M\ \forall x \in M:\ l + x = x$, similary $M$ has a right identity iff $\exists r \in M\ \forall x \in M:\ x + r = x$. We say $M$ has an identity (or more verbosely a two-sided identity) if it has an element $i$ which is both a left and a right identity i.e. $\exists i \in M\ \forall x \in M:\ i + x = x \land x + i = x$. Similarly we can say a function `f :: T -> T -> T` has an identity `i :: T` iff `f(x, i) == f(i, x) == x` forall `x :: T`.  Now nothing about the following code requires that `f` actually have an identity or be associative, however, it's unlikely to be useful without those properties.</mark>

<mark>Fun fact a magma can have many distinct left or right identies, however, if it has both left and right identies then it has a unique two-side identy.  The proof of this is trivial let $l_0, l_1, ... \in M$ be left identities and $r_0, r_1, ... \in M$ be right identities. Now $l_i = l_i + r_j = r_j$ for all $i$ and $j$ and hance by transitivity there is a unique identity element.</mark>

In [7]:
_MISSING = object()  # a sentinel object

def gen_tree_reducer(f, ident=_MISSING):
    '''
    f :: T -> T -> T
    ident :: Optional[T]
    '''
    def reducer(*args):
        '''
         *args :: List[T]
        '''
        n = len(args)
        if n == 0:
            if ident is _MISSING:
                raise ValueError('cannot reduce empty list')
            return ident
        elif n == 1:
            return args[0]
        else:
            return f(reducer(*args[:n // 2]), reducer(*args[n // 2:]))

    return reducer

The use of built-in higher-order functions (e.g., `map`) is also supported

In [8]:
add_n = gen_tree_reducer(add, 0)


def sum_of_sq(*args):
    return add_n(*map(lambda x: x * x, args))

In [9]:
x = PyDataT(1)
y = PyDataT(2)
z = PyDataT(3)
results = sum_of_sq(x, y, z)
print(repr(results))
print('---')

x = SmtDataT(name='x')
y = SmtDataT(name='y')
z = SmtDataT(name='z')
results = sum_of_sq(x, y, z)
print(smt_to_smtlib_string(results))
print('---')
del x
del y
del z


class SumOfSq3(m.Circuit):
    io = m.IO(
        x=m.In(MagmaDataT),
        y=m.In(MagmaDataT),
        z=m.In(MagmaDataT),
        results=m.Out(MagmaDataT)
    )
    io.results @= sum_of_sq(io.x, io.y, io.z)


print(magma_to_verilog_string(SumOfSq3))
print('---')

BitVector[8](14)
---
(bvadd (bvmul x x) (bvadd (bvmul y y) (bvmul z z)))
---
module coreir_mul #(
    parameter width = 1
) (
    input [width-1:0] in0,
    input [width-1:0] in1,
    output [width-1:0] out
);
  assign out = in0 * in1;
endmodule

module coreir_add #(
    parameter width = 1
) (
    input [width-1:0] in0,
    input [width-1:0] in1,
    output [width-1:0] out
);
  assign out = in0 + in1;
endmodule

module SumOfSq3 (
    input [7:0] x,
    input [7:0] y,
    input [7:0] z,
    output [7:0] results
);
wire [7:0] magma_Bits_8_add_inst0_out;
wire [7:0] magma_Bits_8_add_inst1_out;
wire [7:0] magma_Bits_8_mul_inst0_out;
wire [7:0] magma_Bits_8_mul_inst1_out;
wire [7:0] magma_Bits_8_mul_inst2_out;
coreir_add #(
    .width(8)
) magma_Bits_8_add_inst0 (
    .in0(magma_Bits_8_mul_inst1_out),
    .in1(magma_Bits_8_mul_inst2_out),
    .out(magma_Bits_8_add_inst0_out)
);
coreir_add #(
    .width(8)
) magma_Bits_8_add_inst1 (
    .in0(magma_Bits_8_mul_inst0_out),
    .in1(magma_Bits_8

Note how all `if`'s in `gen_tree_reducer` are resolved without accessing the values of the data. The hwtypes expression language allows for conditionals using the `ite` method on the bit type.  For example one might write an absolute value function as follows:  

In [10]:
def abs(x):
    return (x < 0).ite(-x, x)

In [11]:
PyDataT = hw.SIntVector[8]
SmtDataT = hw.SMTSIntVector[8]
MagmaDataT = m.SInt[8]

x = PyDataT(-1)
results = abs(x)
print(repr(results))
print('---')

x = SmtDataT(name='x')
results = abs(x)
print(smt_to_smtlib_string(results))
print('---')
del x


class Abs(m.Circuit):
    io = m.IO(x=m.In(MagmaDataT), results=m.Out(MagmaDataT))
    io.results @= abs(io.x)


print(magma_to_verilog_string(Abs))
print('---')

SIntVector[8](1)
---
(ite (bvslt x #b00000000) (bvneg x) x)
---
module coreir_slt #(
    parameter width = 1
) (
    input [width-1:0] in0,
    input [width-1:0] in1,
    output out
);
  assign out = $signed(in0) < $signed(in1);
endmodule

module coreir_neg #(
    parameter width = 1
) (
    input [width-1:0] in,
    output [width-1:0] out
);
  assign out = -in;
endmodule

module coreir_mux #(
    parameter width = 1
) (
    input [width-1:0] in0,
    input [width-1:0] in1,
    input sel,
    output [width-1:0] out
);
  assign out = sel ? in1 : in0;
endmodule

module coreir_const #(
    parameter width = 1,
    parameter value = 1
) (
    output [width-1:0] out
);
  assign out = value;
endmodule

module commonlib_muxn__N2__width8 (
    input [7:0] in_data [1:0],
    input [0:0] in_sel,
    output [7:0] out
);
wire [7:0] _join_out;
coreir_mux #(
    .width(8)
) _join (
    .in0(in_data[0]),
    .in1(in_data[1]),
    .sel(in_sel[0]),
    .out(_join_out)
);
assign out = _join_out;
endmodu

It is important to note that only constantly bounded recursion is possible. This ensures all hwtypes programs may be compiled to some finite circuit (and correspondingly some finite formula in first order logic). This restriction is easy to enforce as the `ite` method behaves like any other python method call. In particular its arguments will be evaluated eagerly, meaning unbounded data-dependent recursion will recurse infinitely:


In [12]:
def factorial(x):
    return (x != 1).ite(
            x * factorial(x - 1),  # factorial(x - 1) will always be evaluated leading to infinite recursion
            type(x)(1),  # cast 1 to the type of x
        )


try:
    x = factorial(PyDataT(5))
except RecursionError as e:
    print(f'Error: {e}')
else:
    print(f'5! = {x}')

Error: maximum recursion depth exceeded in comparison


Now the above program could be unrolled explicitly, which results in a quite large but finite circuit.  

In [13]:
PyDataT = hw.UIntVector[8]

def bounded_factorial(x):
    if not isinstance(x, hw.AbstractBitVector):
        raise TypeError()
    T = type(x)
    MAX_INT = 2**x.size - 1

    def inner(x, ctr):
        if ctr == 0:
            return T(1)
        else:
            return (x <= 1).ite(
                T(1),
                x * inner(x - 1, ctr - 1),
            )

    return inner(x, MAX_INT)


x = PyDataT(9)
try:
    y = bounded_factorial(x)
except RecursionError as e:
    print(f'Error: {e}')
else:
    print(f'{x}! = {y}')

9! = 128


Note that this circuit does not in fact perform the factorial function. Instead, it performs: 
$$
f(0) = 1 \\
f(x) = x*f(x-1) \mod 2^{\text{bitwidth}(x)}$$

In the next example we will show a significantly more involved metaprogramming example.  We will develop a verilog style caste statement where a value is matched against a pattern with x's describing dont care.  First we will show the desired syntax and an implementation using python control flow which works for python values.  Then we will show how the same functionality can be achieve in hwtypes for all implementations.

In [14]:
DataT = hw.BitVector[3]
Bit = hw.Bit

def match(
        value: DataT, 
        case_dict: dict[str, DataT],
        default: DataT) -> DataT:
    for pattern, case_value in case_dict.items():
        if matches_string(value, pattern):
            return case_value
    return default

def matches_string(value: DataT, pattern: str) -> Bit:
    if len(value) != len(pattern):
        raise ValueError('pattern and value must be the same length')

    for v, c in zip(value, reversed(pattern)):
        if c == 'x':
            continue
        elif c == '0' or c == '1':
            if v != Bit(int(c)):
                return Bit(False)
        else:
            raise ValueError('invalid pattern')
            
    return Bit(True)
    
x = DataT(0b011)


y = match(
    x, 
    {
        '1x0': DataT(0), 
        '11x': DataT(1),
    },
    default=DataT(2)
)

print(y)

2


The behavior of the `match` function should be fairly clear.  It iterates through the `case_dict` and tests whether the value matches the pattern with the `matches_string` function. If no matching pattern is found the `default` value is returned.

The `matches_string` co-iterates the bits of the value (`v`) and the characters of the pattern (`c`). The pattern is reversed as python strings are MSB 0 where hwtypes uses LSB 0. If a `v` is found that does not match the corresponding `c` the value does not match the pattern, otherwise they match.

The above code has two data dependent `if` statement which must be removed to allow the code to work on all hwtypes implementations.  Specifically, `if matches_string(value, pattern):` and `if v != Bit(int(c)):`.  We will now demonstrate how to eliminate these `if` statements to achieve the same functionality in pure hwtypes.  We will start by rewriting `matches_string` as it is the more simple case.

In [15]:
def matches_string(value: DataT, pattern: str) -> Bit:
    if len(value) != len(pattern):
        raise ValueError('pattern and value must be the same length')

    matches = Bit(True)
    for v, c in zip(value, reversed(pattern)):
        if c == 'x':
            continue
        elif c == '1':
            matches &= v
        elif c == '0':
            matches &= ~v
        else:
            raise ValueError('invalid pattern')

    return matches

This matches_string function operates similar to other except instead of directly comparing the bits to characters the bits (or their negation) are and-reduced into the `matches` value. At first glance it might seem like there is still data dependent if statements, however this subtly incorrect, they rely only python string values not hwtypes values.

Next we will show to build the `match` function by working bottom up building an ite-expression.

In [16]:
def match(
        value: DataT, 
        case_dict: dict[str, DataT],
        default: DataT) -> DataT:
    matched_value = default
    for pattern, case_value in reversed(case_dict.items()):
        matched_value = matches_string(value, pattern).ite(
                            case_value,
                            matched_value
                        )
    return matched_value

The above starts by assigning the `matched_value` to `default` then iterates the `case_dict` in reverse. This when unrolled has the generates the moral equivelent of:
```Python
def match(...):
    if matches_string(value, pattern_0):
        return case_value_0
    else:
        if matches_string(value, pattern_1):
            return case_value_1
        else:
            ...
                else:
                    return default
```

In [17]:
x = DataT(0b011)


y = match(
    x, 
    {
        '1x0': DataT(0), 
        '11x': DataT(1),
    },
    default=DataT(2)
)

print(y)

2
