<a href="https://colab.research.google.com/github/bsbatusesli/CFU-Playground/blob/main/python/amaranth_cfu/SIMD_Implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
# Install Amaranth 
!pip install --upgrade 'amaranth[builtin-yosys]'

# CFU-Playground library
!git clone https://github.com/google/CFU-Playground.git
import sys
sys.path.append('CFU-Playground/python')

# Imports
from amaranth import *
from amaranth.back import verilog
from amaranth.sim import Delay, Simulator, Tick
from amaranth_cfu import TestBase, SimpleElaboratable, pack_vals, simple_cfu, InstructionBase, CfuTestBase
import re, unittest

# Utility to convert Amaranth to verilog 
def convert_elaboratable(elaboratable):
  v = verilog.convert(elaboratable, name='Top', ports=elaboratable.ports)
  v = re.sub(r'\(\*.*\*\)', '', v)
  return re.sub(r'^ *\n', '\n', v, flags=re.MULTILINE)

def runTests(klazz):
  loader = unittest.TestLoader()
  suite = unittest.TestSuite()
  suite.addTests(loader.loadTestsFromTestCase(klazz))
  runner = unittest.TextTestRunner()
  runner.run(suite)

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
fatal: destination path 'CFU-Playground' already exists and is not an empty directory.


In [8]:
#function for codebook storing. 
#Codebook has 4 clusters => 2bit codewords


class SIMD8_StoreCodebook(SimpleElaboratable):
  def __init__(self):
    # 2 x 32 bit input lines
    self.in0 = Signal(32)
    self.in1 = Signal(32)
    self.enable = Signal()
    self.clear = Signal()
    self.clus0 = Signal(signed(8))
    self.clus1 = Signal(signed(8))
    self.clus2 = Signal(signed(8))
    self.clus3 = Signal(signed(8))

  def elab(self,m):
    with m.If(self.enable):
        m.d.comb +=[self.clus0.eq(self.in1[0:8].as_signed()),
                    self.clus1.eq(self.in1[8:16].as_signed()),
                    self.clus2.eq(self.in1[16:24].as_signed()),
                    self.clus3.eq(self.in1[24:31].as_signed()),
                 ]
    with m.If(~self.enable):
      m.d.comb += [ self.clus0.eq(self.clus0),
                    self.clus1.eq(self.clus1),
                    self.clus2.eq(self.clus2),
                    self.clus3.eq(self.clus3),
                  ]

    with m.If(self.clear):
      m.d.comb += [ self.clus0.eq(0),
                    self.clus1.eq(0),
                    self.clus2.eq(0),
                    self.clus3.eq(0),
                  ]    

class SIMD8_StoreCodebookTest(TestBase):
  def create_dut(self):
    return SIMD8_StoreCodebook()
      
  def test(self):
    
    def pack(*values, bits=8):
      mask = (1 << bits) - 1
      result = 0
      for i, v in enumerate(values):
        if v < 0:
          v = mask + v + 1
        result += (v & mask) << (i * bits)
      return result
    
    
    TEST_CASE = [
        # (enable, clear, in1, (clus0, clus1, clus2, clus3))
        (1, 0, pack(64, -6, 10 ,-22),(64, -6, 10 ,-22)), # store the code book (0x03FA0AEA)
        (0, 0, pack(54, -44, 3, -12),(64, -6, 10 ,-22)), # output shouldnt change since enable == 0

        (1, 0, pack(127, -127, 0, -17), (127, -127, 0, -17)), # store new code book corner cases
        (0, 1, pack(54, -44, 3, -12),(0, 0, 0 , 0)), #clear the store book
    ]
    def process():
      for (en, rst, b, exp_clusters) in TEST_CASE:
        yield self.dut.enable.eq(en)
        yield self.dut.clear.eq(rst)
        yield self.dut.in1.eq(b)
        yield Delay(0.1)

        self.assertEqual(exp_clusters, ((yield self.dut.clus0), (yield self.dut.clus1), (yield self.dut.clus2), (yield self.dut.clus3)))
        yield
    self.run_sim(process)

runTests(SIMD8_StoreCodebookTest)

.
----------------------------------------------------------------------
Ran 1 test in 0.012s

OK


In [46]:
class SIMD8_StoreWeights(SimpleElaboratable):
  def __init__(self):

    #INPUTS
    self.in0 = Signal(32)
    self.in1 = Signal(32)
    self.store_en = Signal(1)
    self.control = Signal(2)


    #OUTPUTS
    self.weight_codes = Signal(16)

    #REG
    self.weights = Signal(64)

  def elab(self, m):
    

    with m.If(self.store_en):
      m.d.comb += self.weights.eq(Cat([self.in0, self.in1]))
    
    with m.Else():
      m.d.comb += self.weights.eq(self.weights)
    
    with m.Switch(self.control):
      with m.Case("00"):
        m.d.comb += self.weight_codes.eq(self.weights[0:16])
      with m.Case("01"):
        m.d.comb += self.weight_codes.eq(self.weights[16:32])
      with m.Case("10"):
        m.d.comb += self.weight_codes.eq(self.weights[32:48])
      with m.Case("11"):
        m.d.comb += self.weight_codes.eq(self.weights[48:64])

class SIMD8_StoreWeightsTests(TestBase):
  def create_dut(self):
    return SIMD8_StoreWeights()
      
  def test(self):
    
    
    TEST_CASE = [
        # (store_en, control, in0, in1, expected_weight_codes)
        (1,0b00, 0x01FFFA32, 0x12ED547A, 0xFA32), #load weights checks first slot
        (0,0b01, 0x01FFFA32, 0x12ED547A, 0x01FF), #checks second slot
        (0,0b10, 0x01FFFA32, 0x12ED547A, 0x547A), #checks third slot
        (0,0b11, 0x01FFFA32, 0x12ED547A, 0x12ED), # checks forth slot 
        (0,0b11, 0xFFFFFFFF, 0xFFFFFFFF, 0x12ED), # checking enable function
        
    ]
    def process():
      for (store_en, control, in0, in1, expected_weight_codes) in TEST_CASE:
        yield self.dut.store_en.eq(store_en)
        yield self.dut.control.eq(control)
        yield self.dut.in0.eq(in0)
        yield self.dut.in1.eq(in1)
        yield Delay(0.1)

        self.assertEqual(expected_weight_codes, (yield self.dut.weight_codes))
        yield
    self.run_sim(process)

runTests(SIMD8_StoreWeightsTests)
      


.
----------------------------------------------------------------------
Ran 1 test in 0.008s

OK


'0x9c'