In [1]:
from myhdl import block, delay, always_seq, instance, always, Signal, ResetSignal, traceSignals, now
from dataclasses import dataclass
from itertools import tee, product
from typing import Callable, Generator
from copy import deepcopy
from abc import ABC, abstractmethod, abstractproperty
import logging
from sys import version
logging.basicConfig(level=logging.INFO)
logging.info(version)

INFO:root:3.9.5 (default, Aug 29 2021, 19:01:31) 
[GCC 9.3.0]


In [2]:
@ block
def counter(clk, enable, reset, count):
    @always_seq(clk.posedge, reset=reset)
    def increment():
        if enable:
            count.next = count.val + 1
    return increment

@block
def clk_driver(clk, enable, period=20):
    lowTime = int(period / 2)
    highTime = period - lowTime

    @instance
    def drive_clk():
        while True:
            if not enable: 
                yield enable
            yield delay(lowTime)
            clk.next = 1
            yield delay(highTime)
            clk.next = 0

    return drive_clk

In [3]:
class AddressStreamDescriptor(ABC):
    def __init__(self):
        self._done = False
    
    @abstractmethod
    def __post_init__(self):
        pass 
    
    @abstractmethod
    def reset(self):
        pass
    
    @abstractmethod
    def __iter__(self):
        pass
    
    @abstractmethod
    def __next__(self):
        pass
    
    @property
    def done(self):
        return self._done
    
    @done.setter
    def done(self, val):
        if val:
            logging.debug("{} has concluded @T={}".format(self, now()))
        else:
            logging.debug("{} initialized @T={}".format(self, now()))
        self._done = val

@dataclass
class HighLevelAddressStreamDescriptor(AddressStreamDescriptor):
    index_generator_fn: Generator

    def __post_init__(self):
        super().__init__()
        self.done = False
        self.index_generator_fn, self.initial_index_generator_fn = tee(self.index_generator_fn)

    def reset(self):
        self.done = False
        self.index_generator_fn = self.initial_index_generator_fn
        self.index_generator_fn, self.initial_index_generator_fn = tee(self.index_generator_fn)

    def __iter__(self):
        return self

    def __next__(self):
        try:
            next_index = next(self.index_generator_fn)
        except StopIteration:
            next_index = 0
            self.done = True
        return next_index 

@dataclass
class LowLevelAddressStreamDescriptor(AddressStreamDescriptor):
    iteration_domain: Generator
    access_map: Callable
    condition: Callable

    def __post_init__(self):
        super().__init__()
        self.done = False
        self.initial_iteration_domain = deepcopy(self.iteration_domain)

    def reset(self):
        self.done = False
        self.iteration_domain = deepcopy(self.initial_iteration_domain)

    def __iter__(self):
        return self

    def __next__(self):
        try:
            next_iteration_vector = next(self.iteration_domain)
            print(next_iteration_vector)
            if self.condition(next_iteration_vector):
                next_index = self.access_map(next_iteration_vector)
            else:
                next_index = 0
        except StopIteration:
            next_index = 0
            self.done = True
        return next_index   


In [4]:
@ block
def stream_generator(clk, enable, reset, stream, global_counter, start_offset, stream_out):
    @always(clk.posedge, reset.posedge)
    def generate():
        if not reset and enable:
            if not stream.done:
                if global_counter >= start_offset:
                    next_idx = next(stream)
                    stream_out.next = next_idx
                    # print("offset: {}, stream_out.next: {}".format(start_offset, next_idx))
        elif reset:
            stream.reset()
            stream_out.next = 0
    return generate

In [5]:
def chain_arch_pe_access_pattern(c_ub, i_ub, j_ub, access_fn, c_lb = 0, c_step = 1, i_lb = 0, i_step = 1, j_lb = 0, j_step = 1, condition = lambda c,i,j : True):
    for c in range(c_lb, c_ub, c_step):
        for i in range(i_lb, i_ub, i_step):
            for j in range(j_lb, j_ub, j_step):
                if (condition(c, i, j)):
                    yield access_fn(c, i, j)

def baseline_access_fn(pe_channel, pe_group , pe, ifmap_dim):
    pe_start_index_offset = pe_channel*(ifmap_dim**2)
    pe_start_index_offset += pe_group*ifmap_dim+pe
    return lambda _, i, j: i*ifmap_dim+j+pe_start_index_offset+1

# Layer Config
ifmap_dim = 224
kernel = 1
ofmap_dim = ifmap_dim-kernel+1
channel_count = 27

# Arch. Config For Full Channel Parallelism
pe_count = (kernel**2)*channel_count
pes_per_group = kernel
pes_per_channel = kernel**2
groups_per_channel = int(pes_per_channel/pes_per_group)
channel_chain_length = int(pe_count/pes_per_channel)

@block
def top():
    clk = Signal(bool(0))
    enable = Signal(bool(0))
    global_counter = Signal(0)
    reset = ResetSignal(bool(0), active=1, isasync=True)
    counter_inst = counter(clk, enable, reset, global_counter)
    clk_driver_inst = clk_driver(clk, enable, period=10)

    stream_out_list = [Signal(0) for _ in range(pe_count)]
            
    stream_generator_list = []
    for pe_channel in range(channel_chain_length):
        for pe_group in range(groups_per_channel):
            for pe in range(pes_per_group):
                pe_idx = pe_channel*pes_per_channel + pe_group*pes_per_group + pe
                stream_descriptor = HighLevelAddressStreamDescriptor(
                    chain_arch_pe_access_pattern(1, ofmap_dim, ofmap_dim, baseline_access_fn(pe_channel, pe_group, pe, ifmap_dim)))
                stream_generator_list.append(stream_generator(
                    clk, enable, reset, stream_descriptor, global_counter, pe_idx, stream_out_list[pe_idx]))

    @instance
    def start_sim():
        # reset cycle
        enable.next = 0
        reset.next = 1
        yield delay(10)
        enable.next = 1
        reset.next = 0

    return clk_driver_inst, counter_inst, start_sim, stream_generator_list


In [6]:
traceSignals.filename = 'Top'
traceSignals.tracebackup = False

In [7]:
dut = top()
inst = traceSignals(dut)
inst.run_sim(1200)
inst.quit_sim()

<class 'myhdl._SuspendSimulation'>: Simulated 1200 timesteps


In [8]:
utilization = ((pe_count+1)/2+((ofmap_dim**2)-(pe_count-1))+(pe_count+1)/2)/((pe_count-1)+ofmap_dim**2)
print(utilization)

0.980058204261802
