In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from tutorial_utils import magma_to_verilog_string, smt_to_smtlib_string

import hwtypes as hw
import magma as m

The core of PEak is the expression language of hwtypes. HWtypes defines abstract interfaces and type constructors for a number of types and kinds. This includes fixwidth bitvector, algebraic data types, and a bit (bool) type.

A `BitVector` of length `n` is constucted with `BitVector[n]`.  The construct of ADTs will be discuss later.
`BitVector` has two key subtypes `Unsigned` and `Signed`.

Bitvectors provide an interface equivelent to the smtlib standard.

Further as a convenience these methods are defined as operators where applicable.

For example:
`x.bvadd(y)` can be invoked with `x + y`

The semantics of sign dependent operators are defined by their type.
For example `x < y` invokes `x.bvslt(y)` for signed `x` and `x.bvult(y)` for unsigned `y`.
For legacy reasons the the base `BitVector` has the behavior of an `Unsigned`.
In general mixing of signed and unsigned operands has undefined behavior for legacy reasons.

Most operators and methods attempt to coerce their arguments.  This was done to allow for
the use of python integer constants e.g. `bv + 1`.

Hwtypes provides two main implementations of the bitvector and bit types.  The first implementation is provides a pure python functional model over constant values.  The second wraps pysmt to generate smt expressions.

Magma provides a third implementation which allows for the definition of circuits.

This uniform interface allows for the same hwtypes programs to be interpretted multiple ways.  First to the pure python can be use to simulate a circuit, the SMT implementation can be used to generate a formal model, and the last to generate actual rtl.  This single source of truth is powerful for designers as they need not implement the same thing multiple times. 

Hwtypes is an expression language only, all statements are executed in pure python following typical python semantics.

In the following we see how the same function can be invoked with constant python values, symbolic smt values, or to build a magma circuit:

In [3]:
def add(x, y):
    return x + y

In [4]:
PyDataT = hw.BitVector[8]
SmtDataT = hw.SMTBitVector[8]
MagmaDataT = m.Bits[8]

x = PyDataT(1)
y = PyDataT(2)
results = add(x, y) 
print(repr(results))
print('---')

x = SmtDataT(name='x')
y = SmtDataT(name='y')
results = add(x, y)
print(smt_to_smtlib_string(results))
print('---')
# del because jupyter seems to keep references alive longer than it should which breaks SMT variables
del x
del y

class Adder(m.Circuit):
    io = m.IO(x=m.In(MagmaDataT), y=m.In(MagmaDataT), results=m.Out(MagmaDataT))
    io.results @= add(io.x, io.y)
    
print(magma_to_verilog_string(Adder))
print('---')

BitVector[8](3)
---
(bvadd x y)
---
/tmp/tmpxg6qyudk/out.json
module coreir_add #(
    parameter width = 1
) (
    input [width-1:0] in0,
    input [width-1:0] in1,
    output [width-1:0] out
);
  assign out = in0 + in1;
endmodule

module Adder (
    input [7:0] x,
    input [7:0] y,
    output [7:0] results
);
wire [7:0] magma_Bits_8_add_inst0_out;
coreir_add #(
    .width(8)
) magma_Bits_8_add_inst0 (
    .in0(x),
    .in1(y),
    .out(magma_Bits_8_add_inst0_out)
);
assign results = magma_Bits_8_add_inst0_out;
endmodule


---


The real power of hwtypes comes from embedding python which facilates the generation of more complex formula. For example we can generalize add to builder an adder tree over any number of inputs with the use of the of a recursive funtion:

In [5]:
def add_n(*args):
    n = len(args)
    if n == 0:
        return 0
    elif n == 1:
        return args[0]
    else:
        return add_n(*args[:n//2]) + add_n(*args[n//2:])

In [6]:
x = PyDataT(1)
y = PyDataT(2)
z = PyDataT(3)
results = add_n(x, y, z) 
print(repr(results))
print('---')

x = SmtDataT(name='x')
y = SmtDataT(name='y')
z = SmtDataT(name='z')
results = add_n(x, y, z) 
print(smt_to_smtlib_string(results))
print('---')
del x
del y
del z

class Adder3(m.Circuit):
    io = m.IO(x=m.In(MagmaDataT), 
              y=m.In(MagmaDataT), 
              z=m.In(MagmaDataT), 
              results=m.Out(MagmaDataT))
    io.results @= add_n(io.x, io.y, io.z)
    
print(magma_to_verilog_string(Adder3))
print('---')

BitVector[8](6)
---
(bvadd x (bvadd y z))
---
/tmp/tmpjr0uutio/out.json
module coreir_add #(
    parameter width = 1
) (
    input [width-1:0] in0,
    input [width-1:0] in1,
    output [width-1:0] out
);
  assign out = in0 + in1;
endmodule

module Adder3 (
    input [7:0] x,
    input [7:0] y,
    input [7:0] z,
    output [7:0] results
);
wire [7:0] magma_Bits_8_add_inst0_out;
wire [7:0] magma_Bits_8_add_inst1_out;
coreir_add #(
    .width(8)
) magma_Bits_8_add_inst0 (
    .in0(y),
    .in1(z),
    .out(magma_Bits_8_add_inst0_out)
);
coreir_add #(
    .width(8)
) magma_Bits_8_add_inst1 (
    .in0(x),
    .in1(magma_Bits_8_add_inst0_out),
    .out(magma_Bits_8_add_inst1_out)
);
assign results = magma_Bits_8_add_inst1_out;
endmodule


---


We can easily further generalize this to build reduction trees over any function. 

In [7]:
_MISSING = object() # a sentinel object

def gen_tree_reducer(f, ident=_MISSING):
    '''
    f :: T -> T -> T
    ident :: Optional[T]
    '''
    def reducer(*args):
        '''
         *args :: List[T]
        '''
        n = len(args)
        if n == 0:
            if ident is _MISSING:
                raise ValueError('cannot reduce empty list')
            return ident
        elif n == 1:
            return args[0]
        else:
            return f(reducer(*args[:n//2]), reducer(*args[n//2:]))
    return reducer

Using builtin higher order functions (e.g. `map`) is also supported

In [8]:
add_n = gen_tree_reducer(add, 0)

def sum_of_sq(*args):
    return add_n(*map(lambda x : x*x, args))

In [9]:
x = PyDataT(1)
y = PyDataT(2)
z = PyDataT(3)
results = sum_of_sq(x, y, z) 
print(repr(results))
print('---')

x = SmtDataT(name='x')
y = SmtDataT(name='y')
z = SmtDataT(name='z')
results = sum_of_sq(x, y, z) 
print(smt_to_smtlib_string(results))
print('---')
del x
del y
del z

class SumOfSq3(m.Circuit):
    io = m.IO(x=m.In(MagmaDataT), 
              y=m.In(MagmaDataT), 
              z=m.In(MagmaDataT), 
              results=m.Out(MagmaDataT))
    io.results @= sum_of_sq(io.x, io.y, io.z)
    
print(magma_to_verilog_string(SumOfSq3))
print('---')

BitVector[8](14)
---
(bvadd (bvmul x x) (bvadd (bvmul y y) (bvmul z z)))
---
/tmp/tmp43tfp6pm/out.json
module coreir_mul #(
    parameter width = 1
) (
    input [width-1:0] in0,
    input [width-1:0] in1,
    output [width-1:0] out
);
  assign out = in0 * in1;
endmodule

module coreir_add #(
    parameter width = 1
) (
    input [width-1:0] in0,
    input [width-1:0] in1,
    output [width-1:0] out
);
  assign out = in0 + in1;
endmodule

module SumOfSq3 (
    input [7:0] x,
    input [7:0] y,
    input [7:0] z,
    output [7:0] results
);
wire [7:0] magma_Bits_8_add_inst0_out;
wire [7:0] magma_Bits_8_add_inst1_out;
wire [7:0] magma_Bits_8_mul_inst0_out;
wire [7:0] magma_Bits_8_mul_inst1_out;
wire [7:0] magma_Bits_8_mul_inst2_out;
coreir_add #(
    .width(8)
) magma_Bits_8_add_inst0 (
    .in0(magma_Bits_8_mul_inst1_out),
    .in1(magma_Bits_8_mul_inst2_out),
    .out(magma_Bits_8_add_inst0_out)
);
coreir_add #(
    .width(8)
) magma_Bits_8_add_inst1 (
    .in0(magma_Bits_8_mul_inst0_o

Note how all `if`'s are resolved without respect to the value's of the data. The hwtypes expression language allows for conditionals using the `ite` method on the bit type.  For example one might write an absolute value function as follows:  

In [10]:
def abs(x):
    return (x < 0).ite(-x, x)

In [11]:
PyDataT = hw.SIntVector[8]
SmtDataT = hw.SMTSIntVector[8]
MagmaDataT = m.SInt[8]

x = PyDataT(-1)
results = abs(x) 
print(repr(results))
print('---')

x = SmtDataT(name='x')
results = abs(x) 
print(smt_to_smtlib_string(results))
print('---')
del x
class Abs(m.Circuit):
    io = m.IO(x=m.In(MagmaDataT), 
              results=m.Out(MagmaDataT))
    io.results @= abs(io.x)
    
print(magma_to_verilog_string(Abs))
print('---')

SIntVector[8](1)
---
(ite (bvslt x #b00000000) (bvneg x) x)
---
/tmp/tmp87oycvza/out.json
module coreir_slt #(
    parameter width = 1
) (
    input [width-1:0] in0,
    input [width-1:0] in1,
    output out
);
  assign out = $signed(in0) < $signed(in1);
endmodule

module coreir_neg #(
    parameter width = 1
) (
    input [width-1:0] in,
    output [width-1:0] out
);
  assign out = -in;
endmodule

module coreir_mux #(
    parameter width = 1
) (
    input [width-1:0] in0,
    input [width-1:0] in1,
    input sel,
    output [width-1:0] out
);
  assign out = sel ? in1 : in0;
endmodule

module coreir_const #(
    parameter width = 1,
    parameter value = 1
) (
    output [width-1:0] out
);
  assign out = value;
endmodule

module commonlib_muxn__N2__width8 (
    input [7:0] in_data [1:0],
    input [0:0] in_sel,
    output [7:0] out
);
wire [7:0] _join_out;
coreir_mux #(
    .width(8)
) _join (
    .in0(in_data[0]),
    .in1(in_data[1]),
    .sel(in_sel[0]),
    .out(_join_out)
);
assig

It is important to note that only constantly bounded recursion is possible. This ensures all hwtypes programs may be compiled to some finite circuit (and correspondingly some finite formula in first order logic). This restriction is easy to enforce as the `ite` method behaves like any other python method call. In particular its arguments will be evaluated eagerly, meaning unbounded data dependent recursion will recurse infinitely:

In [12]:
def factorial(x):
    return (x == 1).ite(
        x*factorial(x - 1), # factorial(x - 1) will always be evaluated leading to infinite recursion
        type(x)(1),  # cast 1 to the type of x
    )

try:
    x = factorial(PyDataT(5))
except RecursionError as e:
    print(f'Error: {e}')
else:
    print(f'5! = {x}')

Error: maximum recursion depth exceeded in comparison


Now the above program could be unrolled explicitly, which results in a quite large but finite circuit.  

In [13]:
PyDataT = hw.UIntVector[8]

def bounded_factorial(x):
    if not isinstance(x, hw.AbstractBitVector):
        raise TypeError()
    T = type(x)
    MAX_INT = 2**x.size - 1
    def inner(x, ctr):
        if ctr == 0:
            return T(1)
        else:
            return (x == 1).ite(
                T(1),
                x*inner(x - 1, ctr - 1),
            )
    return inner(x, MAX_INT)

x = PyDataT(9)
try:
    y = bounded_factorial(x)
except RecursionError as e:
    print(f'Error: {e}')
else:
    print(f'{x}! = {y}')

9! = 128


Note this circuit in fact does not perform the factorial function. Instead it performs: 
$$
f(0) = 1 \\
f(x) = x*f(x-1) \mod 2^{\text{bitwidth}(x)}$$

In [14]:
@family_closure
def closure(family):
    DataT = family.BitVector[8]
    Bit = family.Bit
    Reg = family.gen_register(DataT, 0)
    @family.assemble(locals(), globals())
    class Ctr(Peak):
        def __init__(self):
            self.reg = Reg()
            
        def __call__(self, i: DataT) -> DataT:
            prev = self.reg
            self.reg = prev + 1
            return prev
    return Ctr

NameError: name 'family_closure' is not defined

In [None]:
# class T(enum): ...

# class Ext(enum, extends=T): ... # copy elements of T into Ext
    
# T < Ext

#from hwtypes import Sum
#T = Sum[hw.Bit, hw.BitVector[4]]
#S = Sum[hw.Bit, hw.BitVector[4], hw.BitVector[6]]
#issubclass(S, T)