<a href="https://colab.research.google.com/github/alanvgreen/CFU-Playground/blob/fccm2/proj/fccm_tutorial/Amaranth_for_CFUs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Amaranth for CFUs

```
Copyright 2022 Google LLC.
SPDX-License-Identifier: Apache-2.0
```
This page shows 

1. Incremental building of an Amaranth CFU
2. Simple examples of Amaranth's language features.

Also see:

* https://github.com/amaranth-lang/amaranth
* Docs: https://amaranth-lang.org/docs/amaranth/latest/

avg@google.com / 2022-04-19



This next cell initialises the libraries and Python path. Execute it before any others.

In [None]:
# Install Amaranth 
!pip install --upgrade 'amaranth[builtin-yosys]'

# CFU-Playground library
!git clone https://github.com/google/CFU-Playground.git
import sys
sys.path.append('CFU-Playground/python')

# Imports
from amaranth import *
from amaranth.back import verilog
from amaranth.sim import Delay, Simulator, Tick
from amaranth_cfu import TestBase, SimpleElaboratable, pack_vals, simple_cfu, InstructionBase, CfuTestBase
import re, unittest

# Utility to convert Amaranth to verilog 
def convert_elaboratable(elaboratable):
  v = verilog.convert(elaboratable, name='Top', ports=elaboratable.ports)
  v = re.sub(r'\(\*.*\*\)', '', v)
  return re.sub(r'^ *\n', '\n', v, flags=re.MULTILINE)

def runTests(klazz):
  loader = unittest.TestLoader()
  suite = unittest.TestSuite()
  suite.addTests(loader.loadTestsFromTestCase(klazz))
  runner = unittest.TextTestRunner()
  runner.run(suite)

## Four-way Multiply-Accumulate

These cells demonstrate the evolution of a full four-way multiply-accumulate CFU instruction.

### SingleMultiply

Demonstrates a simple calculation: `(a+128)*b`

In [None]:
class SingleMultiply(SimpleElaboratable):
  def __init__(self):
    self.a = Signal(signed(8))
    self.b = Signal(signed(8))
    self.result = Signal(signed(32))
  def elab(self, m):
    m.d.comb += self.result.eq((self.a + 128) * self.b)

class SingleMultiplyTest(TestBase):
  def create_dut(self):
    return SingleMultiply()
  def test(self):
    TEST_CASE = [
      (1-128, 1, 1),
      (33-128, -25, 33*-25),
    ]
    def process():
      for (a, b, expected) in TEST_CASE:
        yield self.dut.a.eq(a)
        yield self.dut.b.eq(b)
        yield Delay(0.1)
        self.assertEqual(expected, (yield self.dut.result))
        yield
    self.run_sim(process)

runTests(SingleMultiplyTest)

.
----------------------------------------------------------------------
Ran 1 test in 0.012s

OK


### WordMultiplyAdd

Performs four `(a + 128) * b` operations in parallel, and adds the results.

In [None]:
class WordMultiplyAdd(SimpleElaboratable):
  def __init__(self):
    self.a_word = Signal(32)
    self.b_word = Signal(32)
    self.result = Signal(signed(32))
  def elab(self, m):
    a_bytes = [self.a_word[i:i+8].as_signed() for i in range(0, 32, 8)]
    b_bytes = [self.b_word[i:i+8].as_signed() for i in range(0, 32, 8)]
    m.d.comb += self.result.eq(
        sum((a + 128) * b for a, b in zip(a_bytes, b_bytes)))


class WordMultiplyAddTest(TestBase):
  def create_dut(self):
    return WordMultiplyAdd()
  
  def test(self):
    def a(a, b, c, d): return pack_vals(a, b, c, d, offset=-128)
    def b(a, b, c, d): return pack_vals(a, b, c, d, offset=0)
    TEST_CASE = [
        (a(99, 22, 2, 1), b(-2, 6, 7, 111), 59),
        (a(63, 161, 15, 0), b(29, 13, 62, -38), 4850),
    ]
    def process():
      for (a, b, expected) in TEST_CASE:
        yield self.dut.a_word.eq(a)
        yield self.dut.b_word.eq(b)
        yield Delay(0.1)
        self.assertEqual(expected, (yield self.dut.result))
        yield
    self.run_sim(process)

runTests(WordMultiplyAddTest)

.
----------------------------------------------------------------------
Ran 1 test in 0.007s

OK


### WordMultiplyAccumulate

Adds an accumulator to the four-way multiply and add operation.

Includes an `enable` signal to control when accumulation takes place and a `clear` signal to rest the accumulator.

In [None]:
class WordMultiplyAccumulate(SimpleElaboratable):
  def __init__(self):
    self.a_word = Signal(32)
    self.b_word = Signal(32)
    self.accumulator = Signal(signed(32))
    self.enable = Signal()
    self.clear = Signal()
  def elab(self, m):
    a_bytes = [self.a_word[i:i+8].as_signed() for i in range(0, 32, 8)]
    b_bytes = [self.b_word[i:i+8].as_signed() for i in range(0, 32, 8)]
    calculations = ((a + 128) * b for a, b in zip(a_bytes, b_bytes))
    summed = sum(calculations)
    with m.If(self.enable):
      m.d.sync += self.accumulator.eq(self.accumulator + summed)
    with m.If(self.clear):
      m.d.sync += self.accumulator.eq(0)


class WordMultiplyAccumulateTest(TestBase):
  def create_dut(self):
    return WordMultiplyAccumulate()
  
  def test(self):
    def a(a, b, c, d): return pack_vals(a, b, c, d, offset=-128)
    def b(a, b, c, d): return pack_vals(a, b, c, d, offset=0)
    DATA = [
        # (a_word, b_word, enable, clear), expected accumulator
        ((a(0, 0, 0, 0),  b(0, 0, 0, 0), 0, 0), 0),

        # Simple tests: with just first byte
        ((a(10, 0, 0, 0), b(3, 0, 0, 0),  1, 0),   0),
        ((a(11, 0, 0, 0), b(-4, 0, 0, 0), 1, 0),  30),
        ((a(11, 0, 0, 0), b(-4, 0, 0, 0), 0, 0), -14),
        # Since was not enabled last cycle, accumulator will not change
        ((a(11, 0, 0, 0), b(-4, 0, 0, 0), 1, 0), -14),
        # Since was enabled last cycle, will change accumlator
        ((a(11, 0, 0, 0), b(-4, 0, 0, 0), 0, 1), -58),
        # Accumulator cleared
        ((a(11, 0, 0, 0), b(-4, 0, 0, 0), 0, 0),  0),

        # Uses all bytes (calculated on a spreadsheet)
        ((a(99, 22, 2, 1),      b(-2, 6, 7, 111), 1, 0),             0),
        ((a(2, 45, 79, 22),     b(-33, 6, -97, -22), 1, 0),         59),
        ((a(23, 34, 45, 56),    b(-128, -121, 119, 117), 1, 0),  -7884),
        ((a(188, 34, 236, 246), b(-87, 56, 52, -117), 1, 0),     -3035),
        ((a(131, 92, 21, 83),   b(-114, -72, -31, -44), 1, 0),  -33997),
        ((a(74, 68, 170, 39),   b(102, 12, 53, -128), 1, 0),    -59858),
        ((a(16, 63, 1, 198),    b(29, 36, 106, 62), 1, 0),      -47476),
        ((a(0, 0, 0, 0),        b(0, 0, 0, 0), 0, 1),           -32362),

        # Interesting bug
        ((a(128, 0, 0, 0), b(-104, 0, 0, 0), 1, 0), 0),
        ((a(0, 51, 0, 0), b(0, 43, 0, 0), 1, 0), -13312),
        ((a(0, 0, 97, 0), b(0, 0, -82, 0), 1, 0), -11119),
        ((a(0, 0, 0, 156), b(0, 0, 0, -83), 1, 0), -19073),
        ((a(0, 0, 0, 0), b(0, 0, 0, 0), 1, 0), -32021),
    ]

    dut = self.dut

    def process():
        for (a_word, b_word, enable, clear), expected in DATA:
            yield dut.a_word.eq(a_word)
            yield dut.b_word.eq(b_word)
            yield dut.enable.eq(enable)
            yield dut.clear.eq(clear)
            yield Delay(0.1)  # Wait for input values to settle

            # Check on accumulator, as calcuated last cycle
            self.assertEqual(expected, (yield dut.accumulator))
            yield Tick()
    self.run_sim(process)

runTests(WordMultiplyAccumulateTest)  

.
----------------------------------------------------------------------
Ran 1 test in 0.017s

OK


### CFU Wrapper

Wraps the preceding logic in a CFU. Uses funct7 to determine what function the WordMultiplyAccumulate unit should perform.

In [None]:
class Macc4Instruction(InstructionBase):
    """Simple instruction that provides access to a WordMultiplyAccumulate

    The supported functions are:
        * 0: Reset accumulator
        * 1: 4-way multiply accumulate.
        * 2: Read accumulator
    """

    def elab(self, m):
        # Build the submodule
        m.submodules.macc4 = macc4 = WordMultiplyAccumulate()

        # Inputs to the macc4
        m.d.comb += macc4.a_word.eq(self.in0)
        m.d.comb += macc4.b_word.eq(self.in1)

        # Only function 2 has a defined response, so we can
        # unconditionally set it.
        m.d.comb += self.output.eq(macc4.accumulator)

        with m.If(self.start):
            m.d.comb += [
                # We can always return control to the CPU on next cycle
                self.done.eq(1),

                # clear on function 0, enable on function 1
                macc4.clear.eq(self.funct7 == 0),
                macc4.enable.eq(self.funct7 == 1),
            ]


def make_cfu():
    return simple_cfu({0: Macc4Instruction()})


class CfuTest(CfuTestBase):
    def create_dut(self):
        return make_cfu()

    def test(self):
        "Tests CFU plumbs to Madd4 correctly"
        def a(a, b, c, d): return pack_vals(a, b, c, d, offset=-128)
        def b(a, b, c, d): return pack_vals(a, b, c, d, offset=0)
        # These values were calculated with a spreadsheet
        DATA = [
            # ((fn3, fn7, op1, op2), result)
            ((0, 0, 0, 0), None),  # reset
            ((0, 1, a(130, 7, 76, 47), b(104, -14, -24, 71)), None),  # calculate
            ((0, 1, a(84, 90, 36, 191), b(109, 57, -50, -1)), None),
            ((0, 1, a(203, 246, 89, 178), b(-87, 26, 77, 71)), None),
            ((0, 1, a(43, 27, 78, 167), b(-24, -8, 65, 124)), None),
            ((0, 2, 0, 0), 59986),  # read result

            ((0, 0, 0, 0), None),  # reset
            ((0, 1, a(67, 81, 184, 130), b(81, 38, -116, 65)), None),
            ((0, 1, a(208, 175, 180, 198), b(-120, -70, 8, 11)), None),
            ((0, 1, a(185, 81, 101, 108), b(90, 6, -92, 83)), None),
            ((0, 1, a(219, 216, 114, 236), b(-116, -9, -109, -16)), None),
            ((0, 2, 0, 0), -64723),  # read result

            ((0, 0, 0, 0), None),  # reset
            ((0, 1, a(128, 0, 0, 0), b(-104, 0, 0, 0)), None),
            ((0, 1, a(0, 51, 0, 0),  b(0, 43, 0, 0)), None),
            ((0, 1, a(0, 0, 97, 0),  b(0, 0, -82, 0)), None),
            ((0, 1, a(0, 0, 0, 156), b(0, 0, 0, -83)), None),
            ((0, 2, a(0, 0, 0, 0),   b(0, 0, 0, 0)), -32021),
        ]
        self.run_ops(DATA)

runTests(CfuTest)

.
----------------------------------------------------------------------
Ran 1 test in 0.063s

OK


## Amaranth to Verilog Examples

These examples show Amaranth and the Verilog it is translated into.

### SyncAndComb

Demonstrates synchronous and combinatorial logic with a simple component that outputs the high bit of a 12 bit counter.

In [None]:
class SyncAndComb(Elaboratable):
  def __init__(self):
    self.out = Signal(1)
    self.ports = [self.out]
  def elaborate(self, platform):
    m = Module()
    counter = Signal(12)
    m.d.sync += counter.eq(counter + 1)
    m.d.comb += self.out.eq(counter[-1])
    return m
print(convert_elaboratable(SyncAndComb()))

/* Generated by Amaranth Yosys 0.10.0 (PyPI ver 0.10.0.dev46, git sha1 dca8fb54a) */




module Top(clk, rst, out);
  reg \initial  = 0;

  wire [12:0] \$1 ;

  wire [12:0] \$2 ;

  input clk;

  reg [11:0] counter = 12'h000;

  reg [11:0] \counter$next ;

  output out;

  input rst;
  assign \$2  = counter +  1'h1;
  always @(posedge clk)
    counter <= \counter$next ;
  always @* begin
    if (\initial ) begin end
    \counter$next  = \$2 [11:0];

    casez (rst)
      1'h1:
          \counter$next  = 12'h000;
    endcase
  end
  assign \$1  = \$2 ;
  assign out = counter[11];
endmodule



### Conditional Enable

Demonstrates Amaranth's equivalent to Verilog's `if` statement. A five bit counter is incremented when input signal `up` is high or decremented when `down` is high.

In [None]:
class ConditionalEnable(Elaboratable):
  def __init__(self):
    self.up = Signal()
    self.down = Signal()
    self.value = Signal(5)
    self.ports = [self.value, self.up, self.down]

  def elaborate(self, platform):
    m = Module()
    with m.If(self.up):
      m.d.sync += self.value.eq(self.value + 1)
    with m.Elif(self.down):
      m.d.sync += self.value.eq(self.value - 1)
    return m

print(convert_elaboratable(ConditionalEnable()))
    

/* Generated by Amaranth Yosys 0.10.0 (PyPI ver 0.10.0.dev46, git sha1 dca8fb54a) */




module Top(up, down, clk, rst, value);
  reg \initial  = 0;

  wire [5:0] \$1 ;

  wire [5:0] \$2 ;

  wire [5:0] \$4 ;

  wire [5:0] \$5 ;

  input clk;

  input down;

  input rst;

  input up;

  output [4:0] value;
  reg [4:0] value = 5'h00;

  reg [4:0] \value$next ;
  assign \$2  = value +  1'h1;
  assign \$5  = value -  1'h1;
  always @(posedge clk)
    value <= \value$next ;
  always @* begin
    if (\initial ) begin end
    \value$next  = value;

    casez ({ down, up })
      /* src = "<ipython-input-5-5a5e0372ea90>:10" */
      2'b?1:
          \value$next  = \$2 [4:0];
      /* src = "<ipython-input-5-5a5e0372ea90>:12" */
      2'b1?:
          \value$next  = \$5 [4:0];
    endcase

    casez (rst)
      1'h1:
          \value$next  = 5'h00;
    endcase
  end
  assign \$1  = \$2 ;
  assign \$4  = \$5 ;
endmodule



### EdgeDetector

Simple edge detector, along with a test case.

In [None]:
class EdgeDetector(SimpleElaboratable):
  """Detects low-high transitions in a signal"""
  def __init__(self):
    self.input = Signal()
    self.detected = Signal()
    self.ports = [self.input, self.detected]
  def elab(self, m):
    last = Signal()
    m.d.sync += last.eq(self.input)
    m.d.comb += self.detected.eq(self.input & ~last)
    
class EdgeDetectorTestCase(TestBase):
  def create_dut(self):
    return EdgeDetector()

  def test_with_table(self):
    TEST_CASE = [
      (0, 0),
      (1, 1),
      (0, 0),
      (0, 0),
      (1, 1),
      (1, 0),
      (0, 0),
    ]
    def process():
      for (input, expected) in TEST_CASE:
        # Set input
        yield self.dut.input.eq(input)
        # Allow some time for signals to propagate
        yield Delay(0.1)
        self.assertEqual(expected, (yield self.dut.detected))
        yield
    self.run_sim(process)

runTests(EdgeDetectorTestCase)
print(convert_elaboratable(EdgeDetector()))

.
----------------------------------------------------------------------
Ran 1 test in 0.006s

OK


/* Generated by Amaranth Yosys 0.10.0 (PyPI ver 0.10.0.dev46, git sha1 dca8fb54a) */




module Top(detected, clk, rst, \input );
  reg \initial  = 0;

  wire \$1 ;

  wire \$3 ;

  input clk;

  output detected;

  input \input ;

  reg last = 1'h0;

  reg \last$next ;

  input rst;
  assign \$1  = ~  last;
  assign \$3  = \input  &  \$1 ;
  always @(posedge clk)
    last <= \last$next ;
  always @* begin
    if (\initial ) begin end
    \last$next  = \input ;

    casez (rst)
      1'h1:
          \last$next  = 1'h0;
    endcase
  end
  assign detected = \$3 ;
endmodule

