Skip to content

Commit

Permalink
Merge pull request #30 from alanvgreen/ft
Browse files Browse the repository at this point in the history
Ft
  • Loading branch information
tcal-x committed Mar 17, 2021
2 parents bb59d68 + d51b8bd commit c2555a9
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 24 deletions.
24 changes: 15 additions & 9 deletions proj/mnv2_first/gateware/macc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from nmigen import Signal, signed

from nmigen_cfu import SimpleElaboratable
from nmigen_cfu import Sequencer

from .registerfile import Xetter

Expand All @@ -35,18 +35,24 @@ def __init__(self):
self.input_offset = Signal(signed(9))

def elab(self, m):

muls = []
for n in range(4):
tmp = Signal(signed(9))
inval = self.in0.word_select(n, 8).as_signed()
fval = self.in1.word_select(n, 8).as_signed()
mul = Signal(signed(32))
mul = Signal(signed(17)) # 8bits * 9 bits = 17 bits
m.d.comb += tmp.eq(inval + self.input_offset)
m.d.sync += mul.eq(tmp * fval)
muls.append(mul)
sum_muls = Signal(signed(19)) # 4 lots of 17 bits = 19 bits
m.d.comb += sum_muls.eq(sum(muls))

# Use a sequencer to count one cycle between start and end
m.submodules['seq'] = seq = Sequencer(1)
m.d.comb += seq.inp.eq(self.start)
with m.If(seq.sequence[-1]):
m.d.comb += [
tmp.eq(inval + self.input_offset),
mul.eq(tmp * fval)
self.done.eq(1),
self.output.eq(sum_muls),
]
muls.append(mul)
m.d.comb += [
self.output.eq(sum(muls)),
self.done.eq(1),
]
23 changes: 14 additions & 9 deletions proj/mnv2_first/gateware/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,12 @@ def rounding_divide_by_pot(x, exponent):

class RoundingDividebyPOTInstruction(InstructionBase):
def elab(self, m):
m.d.comb += [
self.output.eq(rounding_divide_by_pot(self.in0s, self.in1[:5])),
self.done.eq(1),
]
m.d.sync += self.done.eq(0)
with m.If(self.start):
m.d.sync += [
self.output.eq(rounding_divide_by_pot(self.in0s, self.in1[:5])),
self.done.eq(1),
]


def clamped(value, min_bound, max_bound):
Expand All @@ -124,7 +126,7 @@ def clamped(value, min_bound, max_bound):
class PostProcessor(SimpleElaboratable):
"""Does post-processing of an accumulator value.
This is a pipeline: place values at inputs and outputs appear 3 cycles later.
This is a pipeline: place values at inputs and outputs appear 4 cycles later.
It is capable of producing one result per cycle.
The function being implemented is:
Expand Down Expand Up @@ -161,6 +163,9 @@ class PostProcessor(SimpleElaboratable):
The post processed result
"""

# TODO: see if we can make this 3 cycles by bringing SRDHM down to 2 cycles
PIPELINE_CYCLES = 4

def __init__(self):
self.accumulator = Signal(signed(32))
self.bias = Signal(signed(32))
Expand Down Expand Up @@ -200,11 +205,11 @@ def elab(self, m):
]

# Output from SRDHM appears several cycles later
# Logic is then combinational to output
right_shifted = Signal(signed(32))
m.d.comb += right_shifted.eq(
m.d.sync += right_shifted.eq(
rounding_divide_by_pot(srdhm.result, right_sr[-1]))

# This logic is combinational to output
# acc += reg_output_offset
# if (acc < reg_activation_min) {
# acc = reg_activation_min
Expand Down Expand Up @@ -279,7 +284,7 @@ def elab(self, m):
]

# Use a sequencer to count down to processing end
m.submodules['seq'] = seq = Sequencer(4)
m.submodules['seq'] = seq = Sequencer(PostProcessor.PIPELINE_CYCLES)
m.d.comb += seq.inp.eq(self.start)

# Other control signal outputs - set *_next to indicate values used
Expand All @@ -288,5 +293,5 @@ def elab(self, m):
self.bias_next.eq(self.start),
self.multiplier_next.eq(self.start),
self.shift_next.eq(self.start),
self.done.eq(seq.sequence[3]),
self.done.eq(seq.sequence[-1]),
]
2 changes: 1 addition & 1 deletion proj/mnv2_first/gateware/test_macc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def process():
yield self.dut.in0.eq(input_value)
yield self.dut.in1.eq(filter_value)
yield self.dut.start.eq(1)
yield Delay(0.25)
yield
yield self.dut.start.eq(0)
while not (yield self.dut.done):
yield
Expand Down
11 changes: 7 additions & 4 deletions proj/mnv2_first/gateware/test_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,30 +56,33 @@ def test(self):
((0, 0, 0, 0, -128, -128, 127), None),
((0, 0, 0, 0, -128, -128, 127), None),
((0, 0, 0, 0, -128, -128, 127), None),
((0, 0, 0, 0, -128, -128, 127), None),

((11714, -12447, 1795228372, -7, -5, -128, 127), -10),
((12571, -12447, 1795228372, -7, -5, -128, 127), -4),
((0, 0, 0, 0, -5, -128, 127), None),
((0, 0, 0, 0, -5, -128, 127), None),
((0, 0, 0, 0, -5, -128, 127), None),
((0, 0, 0, 0, -5, -128, 127), None),

((-42127, 22500, 1279992908, -10, -4, -128, 127), -15),
((0, 0, 0, 0, -4, -128, 127), None),
((0, 0, 0, 0, -4, -128, 127), None),
((0, 0, 0, 0, -4, -128, 127), None),

((0, 0, 0, 0, -4, -128, 127), None),

((-18706, 12439, 1493407068, -9, 10, -128, 127), 1),
((0, 0, 0, 0, 10, -128, 127), None),
((0, 0, 0, 0, 10, -128, 127), None),
((0, 0, 0, 0, 10, -128, 127), None),
((0, 0, 0, 0, 10, -128, 127), None),

]

def process():
# Outputs are delayed by 3 cycles, so put in this data structure
# with indices shifted by 3
expected_outputs = [None, None, None] + [o for (_, o) in DATA]
# Outputs are delayed by 4 cycles, so put in this data structure
# with indices shifted by 4
expected_outputs = [None, None, None, None] + [o for (_, o) in DATA]

# Iterate through inputs as usual
for n, (inputs, _) in enumerate(DATA):
Expand Down
2 changes: 1 addition & 1 deletion proj/mnv2_first/src/proj_menu.c
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ static void do_mbqm_tests() {
cpp_math_mul_by_quantized_mul_gateware1(x, quantized_multiplier, shift);
int32_t gw2 =
cpp_math_mul_by_quantized_mul_gateware2(x, quantized_multiplier, shift);
if (gw1 != sw || gw2 != sw || (i % 128 == 0)) {
if (gw1 != sw || gw2 != sw || (i % (1024 * 8) == 0)) {
printf("mbqm(0x%08lx, 0x%08lx, %3d) = ", x, quantized_multiplier, shift);
printf("0x%08lx", sw);
if (gw1 != sw || gw2 != sw) {
Expand Down

0 comments on commit c2555a9

Please sign in to comment.