In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F


# 1) Implement tiling and commnucation



In [4]:
def tmul(tileA, tileB, t):    
    
    # Check if the input dimension <= tile_size    
    assert tileA.size(0) <= t
    assert tileA.size(1) <= t
    assert tileB.size(1) <= t
    
    return tileA @ tileB

# 2) Implement custom network

In [5]:
def conv2d(inputs, weights, padding, tile_size, bias=None, sim=None):
    o_chn, i_chn, kernel_size, _ = weights.size()
    bs, i_chn, res, _ = inputs.size()
    
    def weight_lowering():
        lowered_weights = weights.reshape(o_chn, i_chn*kernel_size*kernel_size)
        return lowered_weights

    def inputs_lowering():
        # padding
        pad, _ = padding
        inputs_padded = torch.zeros(bs, i_chn, res+ pad * 2, res + pad * 2).type(torch.int8)
        inputs_padded[..., pad:res+pad, pad:res+pad] = inputs

        lowered_inputs = torch.zeros(kernel_size*kernel_size, i_chn, bs, res*res).type(torch.int8)
        for a in range(kernel_size):
            for b in range(kernel_size):
                lowered_inputs[a*kernel_size+b] = inputs_padded[..., a:res+a, b:res+b].transpose(0, 1).reshape(i_chn, bs, -1)
                
        lowered_inputs = lowered_inputs.transpose(0, 1)
        lowered_inputs = lowered_inputs.reshape(i_chn*kernel_size*kernel_size, bs*res*res)        
        return lowered_inputs
    
    def outputs_lifting():
        outputs = lowered_outputs.reshape(o_chn, bs, res, res).transpose(0, 1)         
        return outputs

    # Lower Weights
    weights_transformed = weight_lowering()
    
    # Lower Inputs
    inputs_transformed = inputs_lowering()   

    # Compute Outputs    
    lowered_outputs = mmul_tiling(weights_transformed, inputs_transformed, tile_size, sim)
        
    # Lift Outputs
    outputs = outputs_lifting()     
    
    if bias is not None:
        outputs += bias.view(1, o_chn, 1, 1)
    
    return outputs

In [6]:
class MMConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size,sim=None, padding=1, stride=1, bias=False):
        super(MMConv2d, self).__init__(in_channels, out_channels, kernel_size, padding=padding, stride = stride)
        self.lowering = False
        self.tiling = False
        self.tile_size = -1
        self.sim = sim
        self.quant = torch.quantization.QuantStub()
        
    def forward(self, inputs):
        return conv2d(inputs, self.weight, padding=self.padding, tile_size=self.tile_size, bias=self.bias, sim= self.sim)
    
    def set_tilesize(self, tile_size=-1):
        self.tile_size = tile_size
        
    def simulation_test(self, inputs):
        
        # Compute Output
        q_weight = (self.weight.data*100).type(torch.int8)
        q_bias = (self.bias.data*100).type(torch.int8)
        print("Input size: \t", inputs.size())
        print("Weight size: \t", self.weight.size())
        pred_outputs = conv2d(inputs, q_weight, padding=self.padding, tile_size=self.tile_size, bias=q_bias,sim = self.sim)
        print("Output size: \t", pred_outputs.size())        
        print("=============================================")
        
        # Evaluation
        true_outputs = F.conv2d(inputs, q_weight, padding=self.padding, bias=q_bias)
        correct = (pred_outputs - true_outputs).abs().max() < 1
        print("Correctness: \t", correct.item(), '\n')      

# 3) Test tiling and network

In [7]:
# Configurations
BS = 1
RES_X, RES_Y = (32, 32)
I_CHN = 4
O_CHN = 8
KERNEL_SIZE = 3
PADDING = 1
BIAS = False

layer = MMConv2d(I_CHN, O_CHN, KERNEL_SIZE, padding = PADDING, bias=BIAS)

In [11]:
tile_size = 4

layer.set_tilesize(tile_size)

inputs = (torch.randn(BS, I_CHN, RES_Y, RES_X)*20).type(torch.int8)
layer.simulation_test(inputs)

Input size: 	 torch.Size([1, 4, 32, 32])
Weight size: 	 torch.Size([8, 4, 3, 3])
Output size: 	 torch.Size([1, 8, 32, 32])
Correctness: 	 True 



# 4) Bring Amaranth hardware designs

In [12]:
from amaranth import *
from enum import IntEnum
import math
from amaranth.lib.fifo import SyncFIFOBuffered

In [46]:
# Amaranth hardware designs from previous Lab

class MAC(Elaboratable):
    def __init__(self, num_bits, acc_bits, signed=True):
        self.num_bits = num_bits
        self.acc_bits = acc_bits
        self.signed = signed

        self.in_a = Signal(Shape(num_bits, signed=signed))
        self.in_a_valid = Signal(1)
        self.in_b = Signal(Shape(num_bits, signed=signed))
        self.in_b_valid = Signal(1)

        self.in_rst = Signal(1, reset_less=True)

        self.out_d = Signal(Shape(acc_bits, signed=signed))
        self.out_d_valid = Signal(1)
        self.out_ovf = Signal(1)

        self.tmp_prod = Signal(Shape(acc_bits, signed=signed))
        self.tmp_d = Signal(Shape(acc_bits, signed=signed))
        self.tmp_ovf = Signal(1)

    def elaborate(self, platform):
        m = Module()

        m.d.comb += [
            self.tmp_prod.eq(self.in_a * self.in_b),
            Cat(self.tmp_d, self.tmp_ovf).eq(
                self.out_d + self.tmp_prod),
        ]

        # no need to write reset code
        with m.If(self.in_a_valid & self.in_b_valid):
            if self.signed:
                m.d.sync += [
                    self.out_ovf.eq(
                        self.out_ovf |
                        self.tmp_prod[-1] & self.out_d[-1] & ~self.tmp_d[-1] |
                        ~self.tmp_prod[-1] & ~self.out_d[-1] & self.tmp_d[-1])
                ]
            else:
                m.d.sync += [
                    self.out_ovf.eq(self.out_ovf | self.tmp_ovf),
                ]
            m.d.sync += [
                self.out_d.eq(self.tmp_d),
                self.out_d_valid.eq(1),
            ]
        return m


class PE(Elaboratable):
    def __init__(self, num_bits, acc_bits, cnt_bits, signed=True):
        self.num_bits = num_bits
        self.acc_bits = acc_bits
        self.cnt_bits = cnt_bits
        self.signed = signed

        self.in_init = Signal(cnt_bits)
        self.in_rst = Signal(1, reset_less=True)

        self.in_a = Signal(Shape(num_bits, signed=signed))
        self.in_b = Signal(Shape(num_bits, signed=signed))

        self.out_d = Signal(Shape(acc_bits, signed=signed))
        self.out_d_valid = Signal(1)
        self.out_ovf = Signal(1)

        self.cnt_target = Signal(cnt_bits)
        self.cnt = Signal(cnt_bits)
        self.cnt_ovf = Signal(1)
        self.cnt_next = Signal(cnt_bits + 1)

        self.mac = MAC(num_bits=num_bits, acc_bits=acc_bits, signed=signed)
        self.is_exec = Signal(1)

    def elaborate(self, platform):
        m = Module()

        m.submodules.mac = mac = ResetInserter(self.mac.in_rst)(self.mac)

        m.d.comb += [
            mac.in_a.eq(self.in_a),
            mac.in_a_valid.eq(self.is_exec),
            mac.in_b.eq(self.in_b),
            mac.in_b_valid.eq(self.is_exec),
            mac.in_rst.eq(~self.is_exec & self.in_init.any() | self.in_rst),
            self.out_d.eq(mac.out_d),
            self.out_d_valid.eq(mac.out_d_valid & ~self.is_exec),
            self.out_ovf.eq(mac.out_ovf),
            self.cnt_next.eq(self.cnt + 1),
        ]

        with m.FSM(reset='INIT'):
            with m.State('INIT'):
                with m.If(self.in_init):
                    m.d.sync += [
                        self.cnt_target.eq(self.in_init),
                        self.cnt.eq(0),
                        self.cnt_ovf.eq(0),
                        self.is_exec.eq(1),
                    ]
                    m.next = 'EXEC'
            with m.State('EXEC'):
                m.d.sync += [
                    Cat(self.cnt, self.cnt_ovf).eq(self.cnt_next)
                ]
                with m.If(self.cnt_next[:-1] == self.cnt_target):
                    m.next = 'INIT'
                    m.d.sync += [
                        self.is_exec.eq(0),
                    ]

        return m

        
class AdderTree(Elaboratable):
    def __init__(self, acc_bits, fan_in, signed=True):
        self.acc_bits = acc_bits
        self.fan_in = fan_in
        self.signed = signed
        assert is_power_of_two(fan_in)
        assert fan_in >= 2

        self.in_data = Array([
            Signal(Shape(acc_bits, signed=signed), name=f'in_data_{fan_in}_{i}')
            for i in range(fan_in)])
        self.in_ovf = Array([Signal(1, name=f'in_ovf_{fan_in}_{i}')
                             for i in range(fan_in)])
        self.in_valid = Array([Signal(1, name=f'in_valid_{fan_in}_{i}')
                               for i in range(fan_in)])
        self.out_d = Signal(Shape(acc_bits, signed=signed))
        self.out_ovf = Signal(1)
        self.out_valid = Signal(1)

        self.tmp_ovf = Signal(1)

        self.tree_l = None
        self.tree_r = None
        if fan_in > 2:
            self.tree_l = AdderTree(acc_bits, fan_in // 2, signed=signed)
            self.tree_r = AdderTree(acc_bits, fan_in // 2, signed=signed)

    def elaborate(self, platform):
        m = Module()

        if self.fan_in > 2:
            m.submodules.tree_l = tree_l = self.tree_l
            m.submodules.tree_r = tree_r = self.tree_r

            for i in range(self.fan_in):
                half = self.fan_in // 2
                if i < half:
                    m.d.comb += [
                        tree_l.in_data[i].eq(self.in_data[i]),
                        tree_l.in_ovf[i].eq(self.in_ovf[i]),
                        tree_l.in_valid[i].eq(self.in_valid[i]),
                    ]
                else:
                    m.d.comb += [
                        tree_r.in_data[i - half].eq(self.in_data[i]),
                        tree_r.in_ovf[i - half].eq(self.in_ovf[i]),
                        tree_r.in_valid[i - half].eq(self.in_valid[i]),
                    ]
            m.d.comb += [
                Cat(self.out_d, self.tmp_ovf).eq(
                    tree_l.out_d + tree_r.out_d),
                self.out_valid.eq(tree_l.out_valid & tree_r.out_valid),
            ]
            if self.signed:
                m.d.comb += [
                    self.out_ovf.eq(
                        tree_l.out_ovf | tree_r.out_ovf |
                        (~self.tree_l.out_d[-1] & ~self.tree_r.out_d[-1] & self.out_d[-1]) |
                        (self.tree_l.out_d[-1] & self.tree_r.out_d[-1] & ~self.out_d[-1])
                    )
                ]
            else:
                m.d.comb += [
                    self.out_ovf.eq(
                        self.tmp_ovf | tree_l.out_ovf | tree_r.out_ovf),
                ]
        else:
            m.d.comb += [
                Cat(self.out_d, self.tmp_ovf).eq(
                    self.in_data[0] + self.in_data[1]),
                self.out_valid.eq(self.in_valid[0] & self.in_valid[1])
            ]
            if self.signed:
                m.d.comb += [
                    self.out_ovf.eq(
                        self.in_ovf[0] | self.in_ovf[1] |
                        (~self.in_data[0][-1] & ~self.in_data[1][-1] & self.out_d[-1]) |
                        (self.in_data[0][-1] & self.in_data[1][-1] & ~self.out_d[-1])
                    )
                ]
            else:
                m.d.comb += [
                    self.out_ovf.eq(
                        self.tmp_ovf | self.in_ovf[0] | self.in_ovf[1]),
                ]

        return m

def is_power_of_two(x):
    return (x & (x - 1)) == 0


class ACTCODE(IntEnum):
    NONE = 0
    RELU = 1


class PEStack(Elaboratable):
    def __init__(self, num_bits, width, cnt_bits, signed=True):
        self.width = width  # input bitwidth
        self.acc_bits = num_bits
        self.num_stack = width // num_bits
        self.num_bits = num_bits
        self.cnt_bits = cnt_bits
        self.signed = signed

        assert width in [32, 64, 128]
        assert width % num_bits == 0
        assert is_power_of_two(self.num_stack)

        self.adder_tree = AdderTree(
            acc_bits=self.acc_bits, fan_in=self.num_stack, signed=signed)

        self.pe_arr = [
            PE(num_bits=num_bits, acc_bits=self.acc_bits,
               cnt_bits=cnt_bits, signed=signed)
            for _ in range(self.num_stack)]

        self.in_rst = Signal(1, reset_less=True)
        self.in_init = Signal(cnt_bits)
        self.in_a = Signal(width)
        self.in_b = Signal(width)
        self.in_act = Signal(1)

        self.out_d = Signal(Shape(self.acc_bits, signed=True))
        self.out_ready = Signal(1)
        self.out_ovf = Signal(1)

    def elaborate(self, platform):
        m = Module()

        m.submodules.adder_tree = adder_tree = self.adder_tree

        with m.If(self.signed & (self.in_act == ACTCODE.RELU)):
            m.d.comb += [
                self.out_d.eq(Mux(adder_tree.out_d >= 0, adder_tree.out_d, 0)),
            ]
        with m.Else():  # NONE
            m.d.comb += [
                self.out_d.eq(adder_tree.out_d),
            ]

        m.d.comb += [
            self.out_d.eq(adder_tree.out_d),
            self.out_ready.eq(adder_tree.out_valid),
            self.out_ovf.eq(adder_tree.out_ovf),
        ]

        for i, pe in enumerate(self.pe_arr):
            m.submodules += pe

            m.d.comb += [
                pe.in_a.eq(
                    self.in_a[i*self.num_bits: (i+1)*self.num_bits]),
                pe.in_b.eq(
                    self.in_b[i*self.num_bits: (i+1)*self.num_bits]),
                pe.in_init.eq(self.in_init),
                pe.in_rst.eq(self.in_rst),
                adder_tree.in_data[i].eq(pe.out_d),
                adder_tree.in_valid[i].eq(pe.out_d_valid),
                adder_tree.in_ovf[i].eq(pe.out_ovf),
            ]
        return m

In [47]:
def mmul_tiling(matA, matB, t, simulator):
    a, c = matA.size()
    _, b = matB.size()
    matC = torch.zeros(a, b).type(torch.int8)
    
    if simulator is not None:
      for j in range((b + t - 1)//t):
        for i in range((a + t - 1)//t):
          #######  TODO  #######
          # Hint: use simulator.set_input
          
          simulator.set_input(matA[i*t:(i+1)*t, :].reshape(-1), matB[:, j*t:(j+1)*t].reshape(-1))

          #######################
          simulator.sim.add_sync_process(simulator.bench)
          simulator.sim.run()
          tileC = matC[i*t:(i+1)*t, j*t:(j+1)*t]
          tileC += simulator.output

    else:
      for i in range((a + t - 1)//t):
          for j in range((b + t - 1)//t):
            tileC = matC[i*t:(i+1)*t, j*t:(j+1)*t]
            for k in range((c + t - 1)//t):
              tileA = matA[i*t:(i+1)*t, k*t:(k+1)*t]
              tileB = matB[k*t:(k+1)*t, j*t:(j+1)*t]
              tileC += tmul(tileA, tileB, t)
    
    return matC

# 5) Make simulator class for communication

In [48]:
from amaranth.sim import Simulator
import numpy as np
from collections import deque
from pathlib import Path

In [49]:
class ComunicationSimulator():
  def __init__(self, width=32, num_bits = 8):
      self.output = 0

      self.width = width
      self.num_bits = num_bits
      signed = True
      cnt_bits = 5

      self.dut = PEStack(self.num_bits, self.width,
                      cnt_bits=cnt_bits, signed=signed)
      self.dut = ResetInserter(self.dut.in_rst)(self.dut)

      # make amaranth simulator as attribute of our simulator
      self.sim = Simulator(self.dut)
      self.sim.add_clock(1e-6)

      self.i_stack = []
      self.j_stack = []
      self.count = 0
      


  def set_input(self, input_a, input_b):
      self.i_stack = []
      self.j_stack = []
      self.count = len(input_a)//4

      #input_a and input_b are lists of tile size tensor
      for i in range(self.count):
          tmp = 0
          for l in range(self.width // self.num_bits):
            if int(input_a[i*4 + l].item())>=0:
              tmp = (tmp << self.num_bits) +\
                    int(input_a[i*4 + l].item())
            else:
              tmp = (tmp << self.num_bits) +\
                    int(2**self.num_bits +input_a[i*4 + l].item())
          self.i_stack.append(tmp)

          tmp = 0
          for l in range(self.width // self.num_bits):
            if int(input_b[i*4 + l].item())>=0:
              tmp = (tmp << self.num_bits) +\
                    int(input_b[i*4 + l].item())
            else:
              tmp = (tmp << self.num_bits) +\
                    int(2**self.num_bits +input_b[i*4 + l].item())
          self.j_stack.append(tmp)

  # run single clock cycle
  def test_case(self, dut, in_a, in_b, in_init):
      yield dut.in_a.eq(in_a)
      yield dut.in_b.eq(in_b)
      yield dut.in_init.eq(in_init)
      yield
      out_data = yield dut.out_d
      return out_data


  def bench(self):
        # initialize
        yield from self.test_case(self.dut, 0, 0, self.count)
        # feed
        for i in range(self.count):
            yield from self.test_case(self.dut, self.i_stack[i], self.j_stack[i],0)
        # get output
        self.output = yield from self.test_case(self.dut, 0, 0, 0)


# 6) Test Pytorch to Amaranth communication

In [50]:
np.random.seed(42)
simul = ComunicationSimulator()
# initialize layer with simulator
layer = MMConv2d(I_CHN, O_CHN, KERNEL_SIZE, sim = simul, padding=PADDING, bias=BIAS)

# fixed to 1 in this practice
tile_size = 1

layer.set_tilesize(tile_size)

inputs = (torch.randn(BS, I_CHN, RES_Y, RES_X)*20).type(torch.int8)
layer.simulation_test(inputs)

Input size: 	 torch.Size([1, 4, 32, 32])
Weight size: 	 torch.Size([8, 4, 3, 3])
Output size: 	 torch.Size([1, 8, 32, 32])
Correctness: 	 True 

