<a href="https://colab.research.google.com/github/bsbatusesli/CFU-Playground/blob/main/python/amaranth_cfu/SIMD_Implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [28]:
# Install Amaranth 
!pip install --upgrade 'amaranth[builtin-yosys]'

# CFU-Playground library
!git clone https://github.com/google/CFU-Playground.git
import sys
sys.path.append('CFU-Playground/python')

# Imports
from amaranth import *
from amaranth.back import verilog
from amaranth.sim import Delay, Simulator, Tick
from amaranth_cfu import TestBase, SimpleElaboratable, pack_vals, simple_cfu, InstructionBase, CfuTestBase
import re, unittest

# Utility to convert Amaranth to verilog 
def convert_elaboratable(elaboratable):
  v = verilog.convert(elaboratable, name='Top', ports=elaboratable.ports)
  v = re.sub(r'\(\*.*\*\)', '', v)
  return re.sub(r'^ *\n', '\n', v, flags=re.MULTILINE)

def runTests(klazz):
  loader = unittest.TestLoader()
  suite = unittest.TestSuite()
  suite.addTests(loader.loadTestsFromTestCase(klazz))
  runner = unittest.TextTestRunner()
  runner.run(suite)

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
fatal: destination path 'CFU-Playground' already exists and is not an empty directory.


In [10]:
#function for codebook storing. 
#Codebook has 4 clusters => 2bit codewords
from amaranth import C, Module, Signal, signed
from amaranth_cfu import all_words, InstructionBase, InstructionTestBase, pack_vals, simple_cfu
import unittest

class SIMD8_StoreCodebook(SimpleElaboratable):
  def __init__(self):
    # 2 x 32 bit input lines
    self.in0 = Signal(32)
    self.in1 = Signal(32)
    self.enable = Signal()
    self.clear = Signal()

    self.clusters = Signal(32)

  def elab(self,m):

    cluster = lambda s: all_words(s, 8)

    for clus, value in zip(cluster(self.clusters), cluster(self.in1)):
      with m.If(self.enable):
        m.d.comb += clus.eq(value.as_signed())
   
      with m.If(~self.enable):
        m.d.comb += clus.eq(clus)

      with m.If(self.clear):
        m.d.comb += clus.eq(0)

class SIMD8_StoreCodebookTest(TestBase):
  def create_dut(self):
    return SIMD8_StoreCodebook()
      
  def test(self):
    
    def pack(*values, bits=8):
      mask = (1 << bits) - 1
      result = 0
      for i, v in enumerate(values):
        if v < 0:
          v = mask + v + 1
        result += (v & mask) << (i * bits)
      return result
    
    
    TEST_CASE = [
        # (enable, clear, in1, (clus0, clus1, clus2, clus3))
        (1, 0, pack(64, -6, 10 ,-22),pack(64, -6, 10 ,-22)), # store the code book (0x03FA0AEA)
        (0, 0, pack(54, -44, 3, -12),pack(64, -6, 10 ,-22)), # output shouldnt change since enable == 0

        (1, 0, pack(127, -127, 0, -17), pack(127, -127, 0, -17)), # store new code book corner cases
        (0, 1, pack(54, -44, 3, -12),pack(0, 0, 0 , 0)), #clear the store book
    ]
    def process():
      for (en, rst, b, exp_clusters) in TEST_CASE:
        yield self.dut.enable.eq(en)
        yield self.dut.clear.eq(rst)
        yield self.dut.in1.eq(b)
        yield Delay(0.1)

        self.assertEqual(exp_clusters, (yield self.dut.clusters))
        yield
    self.run_sim(process)

runTests(SIMD8_StoreCodebookTest)

.
----------------------------------------------------------------------
Ran 1 test in 0.013s

OK


In [None]:
class SIMD8_StoreWeights(SimpleElaboratable):
  def __init__(self):

    #INPUTS
    self.in0 = Signal(32)
    self.in1 = Signal(32)
    self.store_en = Signal(1)
    self.control = Signal(2)


    #OUTPUTS
    self.weight_codes = Signal(16)

    #REG
    self.weights = Signal(64)

  def elab(self, m):
    with m.If(self.store_en):
      m.d.comb += self.weights.eq(Cat([self.in0, self.in1]))
    
    with m.Else():
      m.d.comb += self.weights.eq(self.weights)
    
    with m.Switch(self.control):
      with m.Case("00"):
        m.d.comb += self.weight_codes.eq(self.weights[0:16])
      with m.Case("01"):
        m.d.comb += self.weight_codes.eq(self.weights[16:32])
      with m.Case("10"):
        m.d.comb += self.weight_codes.eq(self.weights[32:48])
      with m.Case("11"):
        m.d.comb += self.weight_codes.eq(self.weights[48:64])

class SIMD8_StoreWeightsTests(TestBase):
  def create_dut(self):
    return SIMD8_StoreWeights()
      
  def test(self):
    
    
    TEST_CASE = [
        # (store_en, control, in0, in1, expected_weight_codes)
        (1,0b00, 0x01FFFA32, 0x12ED547A, 0xFA32), #load weights checks first slot
        (0,0b01, 0x01FFFA32, 0x12ED547A, 0x01FF), #checks second slot
        (0,0b10, 0x01FFFA32, 0x12ED547A, 0x547A), #checks third slot
        (0,0b11, 0x01FFFA32, 0x12ED547A, 0x12ED), # checks forth slot 
        (0,0b11, 0xFFFFFFFF, 0xFFFFFFFF, 0x12ED), # checking enable function
        
    ]
    def process():
      for (store_en, control, in0, in1, expected_weight_codes) in TEST_CASE:
        yield self.dut.store_en.eq(store_en)
        yield self.dut.control.eq(control)
        yield self.dut.in0.eq(in0)
        yield self.dut.in1.eq(in1)
        yield Delay(0.1)

        self.assertEqual(expected_weight_codes, (yield self.dut.weight_codes))
        yield

        
    self.run_sim(process)

runTests(SIMD8_StoreWeightsTests)
      


.
----------------------------------------------------------------------
Ran 1 test in 0.009s

OK


In [None]:
class SIMD8_Weights(SimpleElaboratable):
  def __init__(self):

    #INPUTS
    self.in0 = Signal(32)
    self.in1 = Signal(32)
    self.store_en = Signal(1)
    self.control = Signal(2)
    
    self.clusters = Signal(32)


    #OUTPUTS
    self.decoded_weights = Signal(64)

    #REG
    self.control_next = Signal(2)
    self.control_reg = Signal(2)

    self.store_en_next = Signal(1)
    self.store_en_reg = Signal(1)

    self.weight_codes = Signal(16)
    self.weights = Signal(64)
    

  def elab(self, m):

    m.d.comb += [self.control_next.eq(self.control),
                 self.store_en_next.eq(self.store_en)
                 ]

    m.d.sync += [self.control_reg.eq(self.control_next),
                 self.store_en_reg.eq(self.store_en_next)
                ]

    #If store_en = 1 , concenate in0(32b) and in1(32b)
    # each weight 2bit  
    # weights(64b) = MSB (w32,w31,w30,..........w2,w1) LSB
    #               (w32-w17=in1 ,    w16-w1=in0)

    with m.If(self.store_en_reg):
      m.d.comb += self.weights.eq(Cat([self.in0, self.in1]))
    
    #~self.store_en_reg
    with m.Else():
      m.d.comb += self.weights.eq(self.weights)
    

    #Propogate weight codes
    #              LSB.....MSB                 LSB.....MSB
    # control=0 :  w1 ..... w8    control=1 :  w9 ..... w15 
    # control=2 :  w16 .....w23  control=3 :  w24 ..... w32

    with m.Switch(self.control_reg):
      with m.Case("00"):
        m.d.comb += self.weight_codes.eq(self.weights[0:16])
      with m.Case("01"):
        m.d.comb += self.weight_codes.eq(self.weights[16:32])
      with m.Case("10"):
        m.d.comb += self.weight_codes.eq(self.weights[32:48])
      with m.Case("11"):
        m.d.comb += self.weight_codes.eq(self.weights[48:64])

   
    codes = lambda s: all_words(s, 2)
    weights = lambda s: all_words(s, 8)

    # for each weight code (2bit) checks value and assigns its weight
    for code, weight in zip(codes(self.weight_codes), weights(self.decoded_weights)):
      with m.Switch(code):
        with m.Case("00"):
          m.d.comb += weight.eq(self.clusters[0:8])
        with m.Case("01"):
          m.d.comb += weight.eq(self.clusters[8:16])
        with m.Case("10"):
          m.d.comb += weight.eq(self.clusters[16:24])
        with m.Case("11"):
          m.d.comb += weight.eq(self.clusters[24:32])


class SIMD8_WeightsTests(TestBase):
  def create_dut(self):
    return SIMD8_Weights()
  
      
  def test(self):
    
    
    TEST_CASE = [
        # (store_en, control, in0, in1, clusters ,expected_weights_codes, expected_weights)
        (1,0b00, 0x01FFFA32, 0x12ED547A, 0xAABBCCDD, 0xFA32, 0xAAAABBBBDDAADDBB), # 0xFA32 = 11_11_10_10_00_11_00_10
        (0,0b01, 0x01FFFA32, 0x12ED547A, 0xAABBCCDD, 0x01FF, 0xDDDDDDCCAAAAAAAA), #checks second slot # 0x01FF = 00_00_00_01_11_11_11_11
        (0,0b10, 0x01FFFA32, 0x12ED547A, 0xAABBCCDD, 0x547A, 0xCCCCCCDDCCAABBBB), #checks third slot 0x547A= 01_01_01_00_01_11_10_10
        (0,0b11, 0x01FFFA32, 0x12ED547A, 0xAABBCCDD, 0x12ED, 0xDDCCDDBBAABBAACC), # checks forth slot 0x12ED= 00_01_00_10_11_10_11_01
        (0,0b11, 0xFFFFFFFF, 0xFFFFFFFF, 0xAABBCCDD, 0x12ED, 0xDDCCDDBBAABBAACC), # checking enable function      
    ]

    def process():
      for (store_en, control, in0, in1, clusters, expected_weight_codes, expected_weights) in TEST_CASE:
        yield self.dut.store_en.eq(store_en)
        yield self.dut.control.eq(control)
        yield self.dut.in0.eq(in0)
        yield self.dut.in1.eq(in1)
        yield self.dut.clusters.eq(clusters)
        yield Delay(0.1)
        yield
        yield
        self.assertEqual(expected_weight_codes, (yield self.dut.weight_codes))
        self.assertEqual(expected_weights,(yield self.dut.decoded_weights))
        yield

    self.run_sim(process, write_trace= False)


  
    
runTests(SIMD8_WeightsTests)


.
----------------------------------------------------------------------
Ran 1 test in 0.021s

OK


In [None]:
from amaranth import C, Module, Signal, signed
from amaranth_cfu import all_words, InstructionBase, InstructionTestBase, pack_vals, simple_cfu
import unittest

class SIMD8_MAC(SimpleElaboratable):
  def __init__(self):
    #INPUTS
    self.in0 = Signal(32)
    self.in1 = Signal(32)
    self.accumulate_en = Signal(1)
    self.reset = Signal(1)
    self.decoded_weights = Signal(64)


    #OUTPUTS
    self.output = Signal(32)
    

  def elab(self, m):

    words = lambda s: all_words(s, 8)

    # SIMD multiply step:
    self.prods_0 = [Signal(signed(32)) for _ in range(4)] # products connected to input line 0
    self.prods_1 = [Signal(signed(32)) for _ in range(4)] # products connected to input line 0
    
    for prod0, prod1, w0, w1, f0, f1 in zip(self.prods_0, self.prods_1, words(self.decoded_weights[0:32]), words(self.decoded_weights[32:64]), words(self.in0), words(self.in1)):
        m.d.comb += [prod0.eq(w0.as_signed() * f0.as_signed()),
                     prod1.eq(w1.as_signed() * f1.as_signed())
                    ]
   
    with m.If(self.reset):
        m.d.sync += self.output.eq(0)
    with m.If(self.accumulate_en):
        # Accumulate step:
        m.d.sync += self.output.eq(self.output + sum(self.prods_0) + sum(self.prods_1))


class SIMD8_MAC_Test(TestBase):
  def create_dut(self):
    return SIMD8_MAC()

  def test(self):

    def signed_pack(*values, bits=8):
        mask = (1 << bits) - 1
        result = 0
        for i, v in enumerate(values):
          if v < 0:
            v = mask + v + 1
          result += (v & mask) << (i * bits)
        return result

    TEST_CASE = [
        # (reset, accumulate_en, in0, in1 ,decoded_weights, expected_sums)
        (1,0, 0x00000005, 0, 0x2, 0),          #reset accumulator
        (0,1, 0x00000005, 0, 0x2, 10),         #w1 = 5 in0 = 2 accumulator = 10
        (0,1, 0x00000000, 1, 0x300000000, 13), #w17 = 3 in1 =1 accumulator = 10+3
        (1,0,0,0,0,0),                          #reset accumulator
        (0,1, 0x04090502, 0x04020305, signed_pack(5,-2,10,8,5,-2,10,8), 193), # filter values in0 = 2 5 9 4   in1 = 5 3 2 4 calculated on excel
        (0,1, signed_pack(-5,3,2,56), signed_pack(12,2,4,-9), signed_pack(5,-2,10,8,5,-2,10,8), 193+461) # filter values in0= -5,3,2,56 in1= 12,2,4,-9 
    ]

    def process():

      

      for (reset, accumulate_en, in0, in1, decoded_weights, expected_sums) in TEST_CASE:
        yield self.dut.reset.eq(reset)
        yield self.dut.accumulate_en.eq(accumulate_en)
        yield self.dut.in0.eq(in0)
        yield self.dut.in1.eq(in1)
        yield self.dut.decoded_weights.eq(decoded_weights)
        yield Delay(0.001)
        yield
        yield self.dut.accumulate_en.eq(0)
        yield
        self.assertEqual(expected_sums, (yield self.dut.output))

    self.run_sim(process, write_trace=False)

runTests(SIMD8_MAC_Test)

.
----------------------------------------------------------------------
Ran 1 test in 0.023s

OK


In [27]:
#function for codebook storing. 
#Codebook has 4 clusters => 2bit codewords
from amaranth import C, Module, Signal, signed
from amaranth_cfu import all_words, InstructionBase, InstructionTestBase, pack_vals, simple_cfu
import unittest

class SIMD8_StoreCodebook(SimpleElaboratable):
  def __init__(self):
    # 2 x 32 bit input lines
    self.in0 = Signal(32)
    self.in1 = Signal(32)
    self.enable = Signal()
    self.clear = Signal()

    self.clusters = Signal(32)

  def elab(self,m):

    cluster = lambda s: all_words(s, 8)

    for clus, value in zip(cluster(self.clusters), cluster(self.in1)):
      with m.If(self.enable):
        m.d.comb += clus.eq(value.as_signed())
   
      with m.If(~self.enable):
        m.d.comb += clus.eq(clus)

      with m.If(self.clear):
        m.d.comb += clus.eq(0)

class SIMD8_StoreCodebookTest(TestBase):
  def create_dut(self):
    return SIMD8_StoreCodebook()
      
  def test(self):
    
    def pack(*values, bits=8):
      mask = (1 << bits) - 1
      result = 0
      for i, v in enumerate(values):
        if v < 0:
          v = mask + v + 1
        result += (v & mask) << (i * bits)
      return result
    
    
    TEST_CASE = [
        # (enable, clear, in1, (clus0, clus1, clus2, clus3))
        (1, 0, pack(64, -6, 10 ,-22),pack(64, -6, 10 ,-22)), # store the code book (0x03FA0AEA)
        (0, 0, pack(54, -44, 3, -12),pack(64, -6, 10 ,-22)), # output shouldnt change since enable == 0

        (1, 0, pack(127, -127, 0, -17), pack(127, -127, 0, -17)), # store new code book corner cases
        (0, 1, pack(54, -44, 3, -12),pack(0, 0, 0 , 0)), #clear the store book
    ]
    def process():
      for (en, rst, b, exp_clusters) in TEST_CASE:
        yield self.dut.enable.eq(en)
        yield self.dut.clear.eq(rst)
        yield self.dut.in1.eq(b)
        yield Delay(0.1)

        self.assertEqual(exp_clusters, (yield self.dut.clusters))
        yield
    self.run_sim(process)

#runTests(SIMD8_StoreCodebookTest)

class SIMD8_StoreCodebook_Instruction(InstructionBase):
  # Instruction for Storing Codebook
  def __init__(self):
    super().__init__()

  def elab(self, m):
    #Build submodule
    m.submodules.storeCodebook = storeCodebook = SIMD8_StoreCodebook()

    #inputs to the storeCodebook
    m.d.comb += [ storeCodebook.in0.eq(self.in0),
                 storeCodebook.in1.eq(self.in1),
                 storeCodebook.enable.eq(0), #normally en = 0
                 self.output.eq(storeCodebook.clusters) #for testing purpuses
                 ]

    
    with m.If(self.start):
        m.d.comb += [
            # We can always return control to the CPU on next cycle
            self.done.eq(1),
            storeCodebook.enable.eq(1) #when start , enable storing
        ]



class SIMD8_StoreCodebook_Instruction_Test(InstructionTestBase):
  def create_dut(self):
    return SIMD8_StoreCodebook_Instruction()
  
  def test(self):
    DATA = [
        # ( in0, in1 , expected output )
        (0,0,None), #reset
        (0, 0x12FAD87E, 0x12FAD87E)
    ]
    self.verify(DATA, trace=True)

#runTests(SIMD8_StoreCodebook_Instruction_Test)


def make_cfu():
  return simple_cfu({0: SIMD8_StoreCodebook_Instruction()})


class CfuTest(CfuTestBase):
  def create_dut(self):
    return make_cfu()

  def test(self):
    #Tests CFU plumbs to storeCodebook correctly
    DATA = [
        # ( (function_id, in0, in1) , expected output )
        ((0, 0, 0), None), #reset
        ((0, 0, 0x12FAD87E), 0x12FAD87E)
    ]
    self.run_ops(DATA)


runTests(CfuTest)

.
----------------------------------------------------------------------
Ran 1 test in 0.075s

OK


In [None]:
class SIMD8_Weights(SimpleElaboratable):
  def __init__(self):

    #INPUTS
    self.in0 = Signal(32)
    self.in1 = Signal(32)
    self.store_en = Signal(1)
    self.control = Signal(2)
    
    self.clusters = Signal(32)


    #OUTPUTS
    self.decoded_weights = Signal(64)


    self.weight_codes = Signal(16)
    self.weights = Signal(64)
    

  def elab(self, m):


    #If store_en = 1 , concenate in0(32b) and in1(32b)
    # each weight 2bit  
    # weights(64b) = MSB (w32,w31,w30,..........w2,w1) LSB
    #               (w32-w17=in1 ,    w16-w1=in0)

    with m.If(self.store_en):
      m.d.comb += self.weights.eq(Cat([self.in0, self.in1]))
    
    #~self.store_en_reg
    with m.Else():
      m.d.comb += self.weights.eq(self.weights)
    

    #Propogate weight codes
    #              LSB.....MSB                 LSB.....MSB
    # control=0 :  w1 ..... w8    control=1 :  w9 ..... w15 
    # control=2 :  w16 .....w23  control=3 :  w24 ..... w32

    with m.Switch(self.control):
      with m.Case(0):
        m.d.comb += self.weight_codes.eq(self.weights[0:16])
      with m.Case(1):
        m.d.comb += self.weight_codes.eq(self.weights[16:32])
      with m.Case(2):
        m.d.comb += self.weight_codes.eq(self.weights[32:48])
      with m.Case(3):
        m.d.comb += self.weight_codes.eq(self.weights[48:64])

   
    codes = lambda s: all_words(s, 2)
    weights = lambda s: all_words(s, 8)

    # for each weight code (2bit) checks value and assigns its weight
    for code, weight in zip(codes(self.weight_codes), weights(self.decoded_weights)):
      with m.Switch(code):
        with m.Case("00"):
          m.d.comb += weight.eq(self.clusters[0:8])
        with m.Case("01"):
          m.d.comb += weight.eq(self.clusters[8:16])
        with m.Case("10"):
          m.d.comb += weight.eq(self.clusters[16:24])
        with m.Case("11"):
          m.d.comb += weight.eq(self.clusters[24:32])


class SIMD8_WeightsTests(TestBase):
  def create_dut(self):
    return SIMD8_Weights()
  
      
  def test(self):
    
    
    TEST_CASE = [
        # (store_en, control, in0, in1, clusters ,expected_weights_codes, expected_weights)
        (1,0b00, 0x01FFFA32, 0x12ED547A, 0xAABBCCDD, 0xFA32, 0xAAAABBBBDDAADDBB), # 0xFA32 = 11_11_10_10_00_11_00_10
        (0,0b01, 0x01FFFA32, 0x12ED547A, 0xAABBCCDD, 0x01FF, 0xDDDDDDCCAAAAAAAA), #checks second slot # 0x01FF = 00_00_00_01_11_11_11_11
        (0,0b10, 0x01FFFA32, 0x12ED547A, 0xAABBCCDD, 0x547A, 0xCCCCCCDDCCAABBBB), #checks third slot 0x547A= 01_01_01_00_01_11_10_10
        (0,0b11, 0x01FFFA32, 0x12ED547A, 0xAABBCCDD, 0x12ED, 0xDDCCDDBBAABBAACC), # checks forth slot 0x12ED= 00_01_00_10_11_10_11_01
        (0,0b11, 0xFFFFFFFF, 0xFFFFFFFF, 0xAABBCCDD, 0x12ED, 0xDDCCDDBBAABBAACC), # checking enable function      
    ]

    def process():
      for (store_en, control, in0, in1, clusters, expected_weight_codes, expected_weights) in TEST_CASE:
        yield self.dut.store_en.eq(store_en)
        yield self.dut.control.eq(control)
        yield self.dut.in0.eq(in0)
        yield self.dut.in1.eq(in1)
        yield self.dut.clusters.eq(clusters)
        yield Delay(0.1)
        yield
        yield
        self.assertEqual(expected_weight_codes, (yield self.dut.weight_codes))
        self.assertEqual(expected_weights,(yield self.dut.decoded_weights))
        yield

    self.run_sim(process, write_trace= False)


  
    
#runTests(SIMD8_WeightsTests)

class SIMD8_Weights_Instruction(InstructionBase):
  # Instruction for Managing Weights
  # It should first stores codes of clusters. It has memory for storing 32 weight codes
  # then it decodes the codes by looking at storebook. Each instruction can decode 8 weight
  # Those 8 weights controlled by funct 7 value (1,2,3,4)  
  

  def __init__(self):
    super().__init__()

  def elab(self, m):
    #Build submodule
    m.submodules.weights = weights = SIMD8_Weights()
    m.submodules.storeCodebook = storeCodebook = SIMD8_StoreCodebook()

    #connections 
    m.d.comb += [ weights.in0.eq(self.in0),
                  weights.in1.eq(self.in1),
                  weights.clusters.eq(storeCodebook.clusters),
                  weights.store_en.eq(0),
                  
                  self.output.eq(weights.decoded_weights[0:32]), # for testing purposes
                 ]

    
    with m.If(self.start):
        m.d.comb += self.done.eq(1)  # return control to the CPU on next cycle

        # Func7 defines functionality of instruction
          # func7 = 0 ====> stores weight codes
          # func7 = 1 ====> feeds first set of weights to pipeline   (w1 -> w8)
          # func7 = 2 ====> feeds second set of weights to pipeline  (w9 -> w16)
          # func7 = 3 ====> feeds third set of weights to pipeline   (w17 -> w24)
          # func7 = 4 ====> feeds forth set of weights to pipeline   (w25 -> w32)
        with m.Switch(self.funct7):
          with m.Case(0):
            m.d.comb += weights.store_en.eq(1) # enabling store
          with m.Case(1):
            m.d.comb += [weights.store_en.eq(0),
             weights.control.eq(0) ]
          with m.Case(2):
            m.d.comb += weights.control.eq(1)
          with m.Case(3):
            m.d.comb += weights.control.eq(2)
          with m.Case(4):
            m.d.comb += weights.control.eq(3)

class SIMD8_Weights_Instruction_Test(InstructionTestBase):
  def create_dut(self):
    return SIMD8_Weights_Instruction()
  
  def test(self):
    DATA = [
        # ( func7, in0, in1 , expected output )

      # store weight codes
      # First slot : 0xFA32 = 11_11_10_10_00_11_00_10 |||| Second slot:   0x01FF = 00_00_00_01_11_11_11_11
      # Third Slot:  0x547A= 01_01_01_00_01_11_10_10  |||| Fourth slot :  0x12ED= 00_01_00_10_11_10_11_01
      (0, 0x01FFFA32, 0x12ED547A, None)  ,
      (1, 0x01FFFA32, 0x12ED547A, None),
      (2, 0x01FFFA32, 0x12ED547A, None),
      (3, 0x01FFFA32, 0x12ED547A, None),
    ]
    self.verify(DATA, trace=True)

#runTests(SIMD8_Weights_Instruction_Test)



class SIMD8_CFU (Cfu):
  def elab_instructions(self, m):
    m.submodules.weights = weights = SIMD8_Weights()
    m.submodules.storeCodebook = storeCodebook = SIMD8_StoreCodebook()
    m.d.comb += [
        weights.clusters.eq(storeCodebook.clusters)
    ]
    return {
        0: SIMD8_StoreCodebook_Instruction(),
        1: SIMD8_Weights_Instruction()
        }


class CfuTest(CfuTestBase):
  def create_dut(self):
    return SIMD8_CFU()

  def test(self):

      # ( (function_id, func7 in0, in1) , expected output )
    DATA = [
   
      ((0, 0, 0, 0xAABBCCDD), None), # load Storebook


      ((1,0, 0x01FFFA32, 0x12ED547A), None)  , # store weight codes     
      # First slot : 0xFA32 = 11_11_10_10_00_11_00_10 |||| Second slot:   0x01FF = 00_00_00_01_11_11_11_11
      # Third Slot:  0x547A= 01_01_01_00_01_11_10_10  |||| Fourth slot :  0x12ED= 00_01_00_10_11_10_11_01
      ((1,1, 0x01FFFA32, 0x12ED547A), None),
      ((1,2, 0x01FFFA32, 0x12ED547A), None),
      ((1,3, 0x01FFFA32, 0x12ED547A), None),
      ((1,4, 0x01FFFA32, 0x12ED547A), None)
    ]
    self.run_ops(DATA, write_trace= True)


runTests(CfuTest)

In [89]:
from amaranth import C, Module, Signal, signed
from amaranth_cfu import all_words, InstructionBase, InstructionTestBase, pack_vals, simple_cfu, Cfu, CfuTestBase
import unittest

class SIMD8_StoreCodebook_Instruction(InstructionBase):
  # Instruction for Storing Codebook
  def __init__(self):
    super().__init__()
    self.clusters = Signal(32)
    self.reset_acc = Signal(1)

  def elab(self, m):
    cluster = lambda s: all_words(s, 8)

    m.d.sync += self.clusters.eq(self.clusters)

    with m.If(self.start):
      for clus, value in zip(cluster(self.clusters), cluster(self.in1)):
          m.d.sync += clus.eq(value)

      m.d.comb += [
                self.done.eq(1),
                self.output.eq(1),
            ]



    
    
    return m
      


class SIMD8_StoreCodebook_Instruction_Test(InstructionTestBase):
  def create_dut(self):
    return SIMD8_StoreCodebook_Instruction()
  
  def test(self):
    DATA = [
        # (( start, in1) , (done, output, clusters) )

        #set weights 0 
        ((1,0), (1, 1, 0)), 
        #set cluster values
        ((1,0x12FAD87E), (1,1,0x12FAD87E)),
        #Check that when start is 0, clusters remains same
        ((0,0xFFFFFFFF), (0,0,0x12FAD87E)),
        ((0,0xFFFADBCD), (0,0,0x12FAD87E)),
    ]
    def process():
            for n, (inputs, outputs) in enumerate(DATA):
                start, in1 = inputs
                done, output, clusters = outputs
                yield self.dut.start.eq(start)
                yield self.dut.in1.eq(in1)
                yield
                yield
                self.assertEqual((yield self.dut.done), done)
                self.assertEqual((yield self.dut.output), output)
                self.assertEqual((yield self.dut.clusters), clusters)

    self.run_sim(process, False)
    

#runTests(SIMD8_StoreCodebook_Instruction_Test)


class SIMD8_Weights_Instruction(InstructionBase):
  # Instruction for Managing Weights
  # It should first stores codes of clusters. It has memory for storing 32 weight codes
  # then it decodes the codes by looking at storebook. Each instruction can decode 8 weight
  # Those 8 weights controlled by funct 7 value (1,2,3,4)  
  

  def __init__(self):
    super().__init__()
    self.clusters = Signal(32)
    self.decoded_weights = Signal(64)
    self.weight_codes = Signal(16)
    self.weights = Signal(64)
    self.control = Signal(3)

  def elab(self, m):
    codes = lambda s: all_words(s, 2)
    weights = lambda s: all_words(s, 8)

    m.d.comb += [ self.weights.eq(self.weights),
                  self.weight_codes.eq(self.weight_codes),
                  self.decoded_weights.eq(self.decoded_weights),
                  self.control.eq(self.control),
                  self.output.eq(self.decoded_weights[0:32]) # for testing purposes !!! DELETE after
    ]
    
    with m.If(self.start):
        m.d.comb += self.done.eq(1)  # return control to the CPU on next cycle
        self.output.eq(self.decoded_weights[0:32])
        # Func7 defines functionality of instruction
          # func7 = 0 ====> stores weight codes
              # each weight 2bit  
              # weights(64b) = MSB (w32,w31,w30,..........w2,w1) LSB
              #(w32-w17=in1 ,    w16-w1=in0)
          # func7 = 1 ====> feeds first set of weights to pipeline   LSB (w1 -> w8)   MSB
          # func7 = 2 ====> feeds second set of weights to pipeline  LSB (w9 -> w16)  MSB
          # func7 = 3 ====> feeds third set of weights to pipeline   LSB (w17 -> w24) MSB
          # func7 = 4 ====> feeds forth set of weights to pipeline   LSB (w25 -> w32) MSB

        #with m.If(self.funct7):
        m.d.comb += self.weights.eq(Cat([self.in0, self.in1]))
        with m.Switch(self.control):
          with m.Case(0):
            m.d.comb += self.weight_codes.eq(self.weights[0:16])
          with m.Case(1):
            m.d.comb += self.weight_codes.eq(self.weights[16:32])
          with m.Case(2):
            m.d.comb += self.weight_codes.eq(self.weights[32:48])
          with m.Case(3):
            m.d.comb += self.weight_codes.eq(self.weights[48:64])

        # for each weight code (2bit) checks value and assigns its weight
        for code, weight in zip(codes(self.weight_codes), weights(self.decoded_weights)):
          with m.Switch(code):
            with m.Case("00"):
              m.d.comb += weight.eq(self.clusters[0:8])
            with m.Case("01"):
              m.d.comb += weight.eq(self.clusters[8:16])
            with m.Case("10"):
              m.d.comb += weight.eq(self.clusters[16:24])
            with m.Case("11"):
              m.d.comb += weight.eq(self.clusters[24:32])

class SIMD8_Weights_Instruction_Test(InstructionTestBase):
  def create_dut(self):
    return SIMD8_Weights_Instruction()
  
  def test(self):
    DATA = [
        # ( (start, control, in0, in1, clusters) , (done, weights, weight_codes, decoded_weights) )

      # store weight codes
      # First slot : 0xFA32 = 11_11_10_10_00_11_00_10 |||| Second slot:   0x01FF = 00_00_00_01_11_11_11_11
      # Third Slot:  0x547A= 01_01_01_00_01_11_10_10  |||| Fourth slot :  0x12ED= 00_01_00_10_11_10_11_01

      ((1, 0, 0x01FFFA32, 0x12ED547A, 0xAABBCCDD),(1, 0x12ED547A01FFFA32, None, None)),
      #check outputs remain same when start is 0
      ((0, 0, 0xFFFFFFFF, 0xFFFFFFFF, 0xAABBCCDD),(0, 0x12ED547A01FFFA32, None, None)),
      #check slots and decoding value
      ((1, 0, 0x01FFFA32, 0x12ED547A, 0xAABBCCDD),(1, 0x12ED547A01FFFA32, 0xFA32, 0xAAAABBBBDDAADDBB)),
      ((1, 1, 0x01FFFA32, 0x12ED547A, 0xAABBCCDD),(1, 0x12ED547A01FFFA32, 0x01FF, 0xDDDDDDCCAAAAAAAA)),
      ((1, 2, 0x01FFFA32, 0x12ED547A, 0xAABBCCDD),(1, 0x12ED547A01FFFA32, 0x547A, 0xCCCCCCDDCCAABBBB)),
      ((1, 3, 0x01FFFA32, 0x12ED547A, 0xAABBCCDD),(1, 0x12ED547A01FFFA32, 0x12ED, 0xDDCCDDBBAABBAACC)),
      


    ]
    
    def process():
      for n, (inputs, outputs) in enumerate(DATA):
        start, control, in0, in1, clusters = inputs
        done, weights, weight_codes, decoded_weights = outputs
        yield self.dut.start.eq(start)
        yield self.dut.in0.eq(in0)
        yield self.dut.in1.eq(in1)
        yield self.dut.clusters.eq(clusters)
        yield self.dut.control.eq(control)
        yield
        self.assertEqual((yield self.dut.done), done)
        self.assertEqual((yield self.dut.weights), weights)
        if weight_codes is not None :   self.assertEqual((yield self.dut.weight_codes), weight_codes)
        if decoded_weights is not None: self.assertEqual((yield self.dut.decoded_weights), decoded_weights)

    self.run_sim(process, False)

#runTests(SIMD8_Weights_Instruction_Test)


class SIMD8_MAC_Instruction(InstructionBase):
  #Instruction for accumulating stored and decoded weight with incoming inputs.
  #Value of Funct7 determines the functionality
  # Funct7 == 0 ====> Reset accumulator
  # Funct7 == 1 ====> Accumulate
  # Funct7 == 2 ====> Read accumulated sum

  def __init__(self):
    #INPUTS
    super().__init__()
    self.decoded_weights = Signal(64)
    self.acc = Signal(signed(32))
    self.acc_en = Signal(1)
    self.control = Signal(3)

    

  def elab(self, m):

    words = lambda s: all_words(s, 8)

    # SIMD multiply step:
    self.prods_0 = [Signal(signed(32)) for _ in range(4)] # products connected to input line 0
    self.prods_1 = [Signal(signed(32)) for _ in range(4)] # products connected to input line 0

    m.d.comb += [self.output.eq(self.acc),
                 self.control.eq(self.control)
                ]

    for prod0, prod1, w0, w1, f0, f1 in zip(self.prods_0, self.prods_1, words(self.decoded_weights[0:32]), words(self.decoded_weights[32:64]), words(self.in0), words(self.in1)):
          m.d.comb += [prod0.eq(w0.as_signed() * f0.as_signed()),
                      prod1.eq(w1.as_signed() * f1.as_signed())
                      ]
          

    with m.If(self.start):
      m.d.comb += self.done.eq(1)  # return control to the CPU on next cycle

      with m.Switch(self.funct7):
          with m.Case(0):
            m.d.sync += self.acc.eq(0) #reset accumulator
          with m.Case(1):
            m.d.comb += self.control.eq(0)
            m.d.sync += [self.acc.eq(self.acc + sum(self.prods_0) + sum(self.prods_1))
                  ]
          with m.Case(2):
            m.d.comb += self.control.eq(1)
            m.d.sync += [self.acc.eq(self.acc + sum(self.prods_0) + sum(self.prods_1))
                  ]
          with m.Case(3):
            m.d.comb += self.control.eq(2)
            m.d.sync += [self.acc.eq(self.acc + sum(self.prods_0) + sum(self.prods_1))
                  ]
          with m.Case(4):
            m.d.comb += self.control.eq(3)
            m.d.sync += [self.acc.eq(self.acc + sum(self.prods_0) + sum(self.prods_1))
                  ]
          with m.Case(5):
            m.d.sync += self.acc.eq(self.acc) #read accumulator value
                  

    with m.Else():
      m.d.sync += self.acc.eq(self.acc) 




class SIMD8_MAC_Instruction_Test(InstructionTestBase):
  def create_dut(self):
    return SIMD8_MAC_Instruction()

  def test(self):

    def signed_pack(*values, bits=8):
        mask = (1 << bits) - 1
        result = 0
        for i, v in enumerate(values):
          if v < 0:
            v = mask + v + 1
          result += (v & mask) << (i * bits)
        return result


    DATA = [
        # ( (start, func7, in0, in1, decoded_weights) , (done, control, acc, output) )


      ((1, 0, 0x5, 0, 0x2),(1, None, 0, 0)),
      ((0, 0, 0x5, 0, 0x2),(1, None, 0, 0)),
      ((1, 1, 0x5, 0, 0x2),(1, 0, 10, 10)), #feed values for accumulation 5x2 
      ((1, 3, 0x3, 0, 0x1),(1, 2, 13, 13)), #second values 3x1
      ((1, 0, 0x3, 0, 0x1),(1, None, 0, 0)), # reset 
      ((1, 1, 0x4, 0, 0x3),(1, 0, 12, 12)),
      ((1, 5, 0x45, 0, 0x3),(1, None,12, 12)), #read
      ((1, 2, 0x4, 0, 0x3),(1, 1, 24, 24)),
      ((1, 1, signed_pack(-5,3,2,56), signed_pack(12,2,4,-9), signed_pack(5,-2,10,8,5,-2,10,8)),(1, 0, 24+461, 24+461)),
      ((1, 5, 0, 0, signed_pack(5,-2,10,8,5,-2,10,8)),(1, None, 24+461, 24+461)),


    ]
    
    def process():
      for n, (inputs, outputs) in enumerate(DATA):
        start, funct7, in0, in1, decoded_weights = inputs
        done, control, acc, output = outputs
        yield self.dut.start.eq(start)
        yield self.dut.in0.eq(in0)
        yield self.dut.in1.eq(in1)
        yield self.dut.decoded_weights.eq(decoded_weights)
        yield self.dut.funct7.eq(funct7)
        yield
        yield self.dut.start.eq(0)
        yield
        if acc is not None: self.assertEqual((yield self.dut.acc), acc)
        if control is not None: self.assertEqual((yield self.dut.control), control)
        if output is not None :   self.assertEqual((yield self.dut.output), output)
        
        

    self.run_sim(process, True)

#runTests(SIMD8_MAC_Instruction_Test)
    



class SIMD8_CFU(Cfu):
  def elab_instructions(self,m):
    m.submodules["store_codebook"] = store_codebook = SIMD8_StoreCodebook_Instruction()
    m.submodules["weights"] = weights = SIMD8_Weights_Instruction()
    m.submodules["macc"] = macc = SIMD8_MAC_Instruction()

    m.d.comb += [ weights.clusters.eq(store_codebook.clusters),
                 macc.decoded_weights.eq(weights.decoded_weights),
                 weights.control.eq(macc.control)
                 ]

    return {
        0: store_codebook,
        1: weights,
        2: macc
    }

class CfuTest(CfuTestBase):
  def create_dut(self):
    return SIMD8_CFU()

  def test(self):
    def signed_pack(*values, bits=8):
        mask = (1 << bits) - 1
        result = 0
        for i, v in enumerate(values):
          if v < 0:
            v = mask + v + 1
          result += (v & mask) << (i * bits)
        return result


      # ( (function_id, func7 in0, in1) , expected output )
    DATA = [
   
      ((0, 0, 0, 0xAABBCCDD), None), # load Storebook


      ((1,0, 0x01FFFA32, 0x12ED547A), None)  , # store weight codes     
      # First slot : 0xFA32 = 11_11_10_10_00_11_00_10 |||| Second slot:   0x01FF = 00_00_00_01_11_11_11_11
      # Third Slot:  0x547A= 01_01_01_00_01_11_10_10  |||| Fourth slot :  0x12ED= 00_01_00_10_11_10_11_01
      # since cfu has only 32 bit output, first 32 bit of weights can be tested, which means first half of each slot
      #reset accumulator
      ((2,0, 0, 0), None),
      # load simple Storebook
      ((0, 0, 0, 0x04030201), None), 
      #check accumulator is zero
      ((2,5, 0, 0), 0),
      #accumulate
      ((2,1, 0x00010001, 0x00020002), None),
      ((2,5, 0x00010001, 0x00020002), 3+4+6+8),

    ]
    self.run_ops(DATA, write_trace= True)


runTests(CfuTest)

<ipython-input-89-cda941bd54e3>:81: DriverConflict: Signal '(sig control)' is driven from multiple fragments: top.dut, top.dut.weights; hierarchy will be flattened
  self.control = Signal(3)
F
FAIL: test (__main__.CfuTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "<ipython-input-89-cda941bd54e3>", line 355, in test
    self.run_ops(DATA, write_trace= True)
  File "/content/CFU-Playground/python/amaranth_cfu/cfu.py", line 354, in run_ops
    self.run_sim(process, write_trace)
  File "/content/CFU-Playground/python/amaranth_cfu/util.py", line 142, in run_sim
    self.sim.run()
  File "/usr/local/lib/python3.10/dist-packages/amaranth/sim/core.py", line 175, in run
    while self.advance():
  File "/usr/local/lib/python3.10/dist-packages/amaranth/sim/core.py", line 166, in advance
    return self._engine.advance()
  File "/usr/local/lib/python3.10/dist-packages/amaranth/sim/pysim.py", line 319, in advance
    self._ste

In [24]:
(0b0000011 & 0b0001000) >> 3

if (0b0001010 & 0b0000011) :
  print('success')
else:
  print('fail')

x = bin(0b0001010)
x[:]

success


'0b101'

In [91]:
from amaranth import C, Module, Signal, signed
from amaranth_cfu import all_words, InstructionBase, InstructionTestBase, pack_vals, simple_cfu, Cfu, CfuTestBase
import unittest

class SIMD8_StoreCodebook_Instruction(InstructionBase):
  # Instruction for Storing Codebook
  # This instruction is independent from funct7 value
  # It splits the value on in1 line to 4 x 8 bit signed clusters
  # clus0 in[0:8] 
  def __init__(self):
    super().__init__()
    self.clusters = Signal(32)

  def elab(self, m):
    cluster = lambda s: all_words(s, 8)

    m.d.sync += self.clusters.eq(self.clusters)

    with m.If(self.start):
      for clus, value in zip(cluster(self.clusters), cluster(self.in1)):
          m.d.sync += clus.eq(value.as_signed())

      m.d.comb += [
                self.done.eq(1),
                self.output.eq(1),
            ]
    
    return m
      


class SIMD8_StoreCodebook_Instruction_Test(InstructionTestBase):
  def create_dut(self):
    return SIMD8_StoreCodebook_Instruction()
  
  def test(self):
    DATA = [
        # (( start, in1) , (done, output, clusters) )

        #set weights 0 
        ((1,0), (1, 1, 0)), 
        #set cluster values
        ((1,0x12FAD87E), (1,1,0x12FAD87E)),
        #Check that when start is 0, clusters remains same
        ((0,0xFFFFFFFF), (0,0,0x12FAD87E)),
        ((0,0xFFFADBCD), (0,0,0x12FAD87E)),
    ]
    def process():
            for n, (inputs, outputs) in enumerate(DATA):
                start, in1 = inputs
                done, output, clusters = outputs
                yield self.dut.start.eq(start)
                yield self.dut.in1.eq(in1)
                yield
                yield
                self.assertEqual((yield self.dut.done), done)
                self.assertEqual((yield self.dut.output), output)
                self.assertEqual((yield self.dut.clusters), clusters)

    self.run_sim(process, False)
    




class SIMD8_Weights_Instruction(InstructionBase):
  # Instruction for Managing Weights
  # It should first stores codes. It has memory for storing 32 weight codes
  # then it decodes the codes by looking at Codebook. In each cycle, only 8 weight can be pushed to the pipelin
  # Those 8 weights controlled by funct 7 last 2 bits funct[0:2] values:(0,1,2,3)  

  # 3. Bit enable/disable storing funct[3] (Note: that can be also deleted since it is used only when instruction called)
  # funct7 = |_|_|enableAccumulator|resetAccumulator|storeWeights|control1|control0|
    #mask for funct7
    #storeWeight Enable = 0x100
  

  def __init__(self):
    super().__init__()
    self.clusters = Signal(32)
    self.decoded_weights = Signal(64)
    self.weight_codes = Signal(16)
    self.weights = Signal(64)

  def elab(self, m):
    codes = lambda s: all_words(s, 2)
    weights = lambda s: all_words(s, 8)

    m.d.comb += [ self.weights.eq(self.weights),
                  self.weight_codes.eq(self.weight_codes),
                  self.decoded_weights.eq(self.decoded_weights),
                  #self.output.eq(self.decoded_weights[0:32]) # for testing purposes !!! DELETE after
    ]
    #checks funct7 last 2 bits and assigns weight codes
    with m.Switch(self.funct7[:2]):
      with m.Case(0b00):
        m.d.comb += self.weight_codes.eq(self.weights[0:16])
      with m.Case(0b01):
        m.d.comb += self.weight_codes.eq(self.weights[16:32])
      with m.Case(0b10):
        m.d.comb += self.weight_codes.eq(self.weights[32:48])
      with m.Case(0b11):
        m.d.comb += self.weight_codes.eq(self.weights[48:64])

        # for each weight code (2bit) checks value and assigns its weight
    for code, weight in zip(codes(self.weight_codes), weights(self.decoded_weights)):
      with m.Switch(code):
        with m.Case("00"):
          m.d.comb += weight.eq(self.clusters[0:8])
        with m.Case("01"):
          m.d.comb += weight.eq(self.clusters[8:16])
        with m.Case("10"):
          m.d.comb += weight.eq(self.clusters[16:24])
        with m.Case("11"):
          m.d.comb += weight.eq(self.clusters[24:32])
    
    with m.If(self.start):
        m.d.comb += self.done.eq(1)  # return control to the CPU on next cycle
        #self.output.eq(self.decoded_weights[0:32])

        #enable store new weights MSB In1 In0 LSB
        with m.If(self.funct7 & 0b100):
            m.d.comb += self.weights.eq(Cat([self.in0, self.in1]))
          



class SIMD8_Weights_Instruction_Test(InstructionTestBase):
  def create_dut(self):
    return SIMD8_Weights_Instruction()
  
  def test(self):
    DATA = [
        # ( (start, func7, in0, in1, clusters) , (done, weights, weight_codes, decoded_weights) )

      # store weight codes
      # First slot : 0xFA32 = 11_11_10_10_00_11_00_10 |||| Second slot:   0x01FF = 00_00_00_01_11_11_11_11
      # Third Slot:  0x547A= 01_01_01_00_01_11_10_10  |||| Fourth slot :  0x12ED= 00_01_00_10_11_10_11_01

      ((1, 0b0000100, 0x01FFFA32, 0x12ED547A, 0xAABBCCDD),(1, 0x12ED547A01FFFA32, None, None)),
      #check outputs remain same when start is 0, and other func7 values
      ((0, 0b0000100, 0xFFFFFFFF, 0xFFFFFFFF, 0xAABBCCDD),(0, 0x12ED547A01FFFA32, None, None)),
      ((1, 0b0011011, 0xFFFFFFFF, 0xFFFFFFFF, 0xAABBCCDD),(1, 0x12ED547A01FFFA32, None, None)),
      ((1, 0b1111000, 0xFFFFFFFF, 0xFFFFFFFF, 0xAABBCCDD),(1, 0x12ED547A01FFFA32, None, None)),
      #check slots and decoding value
      ((1, 0b0000000, 0x01FFFA32, 0x12ED547A, 0xAABBCCDD),(1, 0x12ED547A01FFFA32, 0xFA32, 0xAAAABBBBDDAADDBB)),
      ((1, 0b0000001, 0x01FFFA32, 0x12ED547A, 0xAABBCCDD),(1, 0x12ED547A01FFFA32, 0x01FF, 0xDDDDDDCCAAAAAAAA)),
      ((1, 0b0000010, 0x01FFFA32, 0x12ED547A, 0xAABBCCDD),(1, 0x12ED547A01FFFA32, 0x547A, 0xCCCCCCDDCCAABBBB)),
      ((1, 0b0000011, 0x01FFFA32, 0x12ED547A, 0xAABBCCDD),(1, 0x12ED547A01FFFA32, 0x12ED, 0xDDCCDDBBAABBAACC)),
      


    ]
    
    def process():
      for n, (inputs, outputs) in enumerate(DATA):
        start, funct7, in0, in1, clusters = inputs
        done, weights, weight_codes, decoded_weights = outputs
        yield self.dut.start.eq(start)
        yield self.dut.in0.eq(in0)
        yield self.dut.in1.eq(in1)
        yield self.dut.clusters.eq(clusters)
        yield self.dut.funct7.eq(funct7)
        yield
        self.assertEqual((yield self.dut.done), done)
        self.assertEqual((yield self.dut.weights), weights)
        if weight_codes is not None :   self.assertEqual((yield self.dut.weight_codes), weight_codes)
        if decoded_weights is not None: self.assertEqual((yield self.dut.decoded_weights), decoded_weights)

    self.run_sim(process, False)




class SIMD8_MAC_Instruction(InstructionBase):
  #Instruction for accumulating stored and decoded weight with incoming inputs.
  # Inputs splits 8bit words, in total 8x8bit filter values, each cycle only 8 MAC can be computed.
  # Output occurs in next clock cycle

  #Value of Funct7 determines the functionality
  # funct7 = |_|_|enableAccumulator|resetAccumulator|storeWeights|control1|control0|
  # masks for funct7 used in MAC instruction 
  # enableAccumulator = 0b10000
  # resetAccumulator = 0b1000

  # Control bits should be given properly while calling the instruction. 
  # 00 => Weight 1-8 pushed to the pipeline
  # 01 => Weight 9-15 pushed to the pipeline
  # 10 => Weight 16-23 pushed to the pipeline
  # 11 => Weight 24-32 pushed to the pipeline

  def __init__(self):
    #INPUTS
    super().__init__()
    self.decoded_weights = Signal(64)
    self.acc = Signal(signed(32))

    

  def elab(self, m):

    words = lambda s: all_words(s, 8)

    # SIMD multiply step:
    self.prods_0 = [Signal(signed(32)) for _ in range(4)] # products connected to input line 0
    self.prods_1 = [Signal(signed(32)) for _ in range(4)] # products connected to input line 1

    m.d.comb += self.output.eq(self.acc)
    m.d.sync += self.acc.eq(self.acc) 

    for prod0, prod1, w0, w1, f0, f1 in zip(self.prods_0, self.prods_1, words(self.decoded_weights[0:32]), words(self.decoded_weights[32:64]), words(self.in0), words(self.in1)):
          m.d.comb += [prod0.eq(w0.as_signed() * f0.as_signed()),
                      prod1.eq(w1.as_signed() * f1.as_signed())
                      ]

    #reset accumulator
    with m.If(self.funct7 & 0b1000):
          m.d.sync += self.acc.eq(0) 

    with m.If(self.start):
      m.d.comb += self.done.eq(1)  # return control to the CPU on next cycle

      # Accumulate step:
      with m.If(self.funct7 & 0b10000):
        m.d.sync += [self.acc.eq(self.acc + sum(self.prods_0) + sum(self.prods_1)) ]





class SIMD8_MAC_Instruction_Test(InstructionTestBase):
  def create_dut(self):
    return SIMD8_MAC_Instruction()

  def test(self):

    def signed_pack(*values, bits=8):
        mask = (1 << bits) - 1
        result = 0
        for i, v in enumerate(values):
          if v < 0:
            v = mask + v + 1
          result += (v & mask) << (i * bits)
        return result


    DATA = [
        # ( (start, func7, in0, in1, decoded_weights) , (done, acc, output) )

      #reset accumulator
      ((1, 0b1000, 0x5, 0, 0x2),(1, 0, 0)), 
      # check start
      ((0, 0b10000, 0x5, 0, 0x2),(1, 0, 0)),
      #feed values for accumulation 5x2 
      ((1, 0b10000, 0x5, 0, 0x2),(1, 10, 10)), 
      #second values 3x1
      ((1, 0b10000, 0x3, 0, 0x1),(1, 13, 13)), 
      # reset accumulator
      ((1, 0b1000, 0x3, 0, 0x1),(1, 0, 0)), 
      #enable accumulator again
      ((1, 0b10000, 0x4, 0, 0x3),(1, 12, 12)),
      #read values 
      ((1, 0b0001, 0x45, 0, 0x3),(1, 12, 12)), 
      #enable accumulator
      ((1, 0b10000, 0x4, 0, 0x3),(1, 24, 24)),
      ((1, 0b10000, signed_pack(-5,3,2,56), signed_pack(12,2,4,-9), signed_pack(5,-2,10,8,5,-2,10,8)),(1, 24+461, 24+461)),
      #read values
      ((1, 0b0001, 0, 0, signed_pack(5,-2,10,8,5,-2,10,8)),(1, 24+461, 24+461)),


    ]
    
    def process():
      for n, (inputs, outputs) in enumerate(DATA):
        start, funct7, in0, in1, decoded_weights = inputs
        done, acc, output = outputs
        yield self.dut.start.eq(start)
        yield self.dut.in0.eq(in0)
        yield self.dut.in1.eq(in1)
        yield self.dut.decoded_weights.eq(decoded_weights)
        yield self.dut.funct7.eq(funct7)
        yield
        yield self.dut.start.eq(0)
        yield
        if acc is not None: self.assertEqual((yield self.dut.acc), acc)
        if output is not None :   self.assertEqual((yield self.dut.output), output)
        
        

    self.run_sim(process, False)


    



class SIMD8_CFU(Cfu):
  def elab_instructions(self,m):
    m.submodules["store_codebook"] = store_codebook = SIMD8_StoreCodebook_Instruction()
    m.submodules["weights"] = weights = SIMD8_Weights_Instruction()
    m.submodules["macc"] = macc = SIMD8_MAC_Instruction()

    m.d.comb += [ weights.clusters.eq(store_codebook.clusters),
                 macc.decoded_weights.eq(weights.decoded_weights)
                 ]

    return {
        0: store_codebook,
        1: weights,
        2: macc
    }

class CfuTest(CfuTestBase):
  def create_dut(self):
    return SIMD8_CFU()

  def test(self):
    def signed_pack(*values, bits=8):
        mask = (1 << bits) - 1
        result = 0
        for i, v in enumerate(values):
          if v < 0:
            v = mask + v + 1
          result += (v & mask) << (i * bits)
        return result
    
    functionList = {
      "storeCodebook" : 0,
      "storeWeights" : 0b100,
      "resetAccumulator" : 0b1000,
      "enableAccumulator_weightSet1" : 0b10000,
      "enableAccumulator_weightSet2" : 0b10001,
      "enableAccumulator_weightSet3" : 0b10010,
      "enableAccumulator_weightSet4" : 0b10011,
      "readAccumulator" : 0,
    }



      # ( (function_id, func7 in0, in1) , expected output )
    DATA = [
      # load Storebook
      ((0, functionList["storeCodebook"], 0, 0x04030201), None), 

      # store weight codes 
      ((1,functionList["storeWeights"], 0x01FFFA32, 0x12ED547A), None)  ,     
      # First slot : 0xFA32 = 11_11_10_10_00_11_00_10 |||| Second slot:   0x01FF = 00_00_00_01_11_11_11_11
      # Third Slot:  0x547A= 01_01_01_00_01_11_10_10  |||| Fourth slot :  0x12ED= 00_01_00_10_11_10_11_01


      #reset accumulator
      ((2,functionList["resetAccumulator"], 0, 0), None),

      #check accumulator is zero
      ((2,0, 0, 0), 0),

      #accumulate and read
      ((2,functionList["enableAccumulator_weightSet1"], 0x01010101, 0x01010101), None),
      ((2,functionList["readAccumulator"], 0x01010101, 0x01010101), 3+1+4+1+3+3+4+4),

      #accumulate and read
      ((2,functionList["enableAccumulator_weightSet2"], 0x01010101, 0x01010101), None),
      ((2,functionList["readAccumulator"], 0x01010101, 0x01010101), 3+1+4+1+3+3+4+4+ 1+1+1+2+4+4+4+4),
    ]
    self.run_ops(DATA, write_trace= False)



runTests(SIMD8_Weights_Instruction_Test)
runTests(SIMD8_StoreCodebook_Instruction_Test)
runTests(SIMD8_MAC_Instruction_Test)
runTests(CfuTest)


.
----------------------------------------------------------------------
Ran 1 test in 0.021s

OK
.
----------------------------------------------------------------------
Ran 1 test in 0.007s

OK
.
----------------------------------------------------------------------
Ran 1 test in 0.019s

OK
.
----------------------------------------------------------------------
Ran 1 test in 0.056s

OK
