diff --git a/.scalafmt.conf b/.scalafmt.conf index 74d57207..b4b1ea9b 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -1,2 +1,3 @@ version = "3.8.0" align.tokens = [] +runner.dialect = scala3 diff --git a/2mm.fuse b/2mm.fuse new file mode 100644 index 00000000..4697a3f7 --- /dev/null +++ b/2mm.fuse @@ -0,0 +1,84 @@ +// BEGIN macro definition + + + + + + + + + + +// END macro definition + +decl alpha_int: ubit<32>[1]; +decl beta_int: ubit<32>[1]; +decl tmp_int: ubit<32>[8][8]; +decl A_int: ubit<32>[8][8]; +decl B_int: ubit<32>[8][8]; +decl C_int: ubit<32>[8][8]; +decl D_int: ubit<32>[8][8]; + +let tmp: ubit<32>[8][8]; +let A: ubit<32>[8][8]; +let B: ubit<32>[8][8]; +let C: ubit<32>[8][8]; +let D: ubit<32>[8][8]; + +view A_sh = A[_: bank 1][_: bank 1]; +view B_sh = B[_: bank 1][_: bank 1]; +view C_sh = C[_: bank 1][_: bank 1]; +view D_sh = D[_: bank 1][_: bank 1]; +view tmp_sh = tmp[_: bank 1][_: bank 1]; + +for (let i0: ubit<4> = 0..8) { + for (let j0: ubit<4> = 0..8) { + A_sh[i0][j0] := A_int[i0][j0]; + B_sh[i0][j0] := B_int[i0][j0]; + C_sh[i0][j0] := C_int[i0][j0]; + D_sh[i0][j0] := D_int[i0][j0]; + tmp_sh[i0][j0] := tmp_int[i0][j0]; + } +} + +--- + +for (let i: ubit<4> = 0..8) { + for (let j: ubit<4> = 0..8) { + tmp[i][j] := 0; + --- + for (let k: ubit<4> = 0..8) { + let v: ubit<32> = alpha_int[0] * A[i][k] * B[k][j]; + } combine { + tmp[i][j] += v; + } + } +} + +--- + +for (let i1: ubit<4> = 0..8) { + for (let j1: ubit<4> = 0..8) { + let d_tmp: ubit<32> = D[i1][j1]; + --- + D[i1][j1] := beta_int[0] * d_tmp; + --- + for (let k1: ubit<4> = 0..8) { + let v1: ubit<32> = tmp[i1][k1] * C[k1][j1]; + } combine { + D[i1][j1] += v1; + } + } +} + +--- + +for (let i0: ubit<4> = 0..8) { + for (let j0: ubit<4> = 0..8) { + A_int[i0][j0] := A_sh[i0][j0]; + B_int[i0][j0] := B_sh[i0][j0]; + C_int[i0][j0] := C_sh[i0][j0]; + D_int[i0][j0] := D_sh[i0][j0]; + tmp_int[i0][j0] := tmp_sh[i0][j0]; + } +} diff --git a/2mm.futil b/2mm.futil new file mode 100644 index 00000000..4f14f76b --- /dev/null +++ b/2mm.futil @@ -0,0 +1,687 @@ +// git.status = dirty, build.date = Sun Mar 31 20:54:20 EDT 2024, git.hash = 7041748 +import "primitives/core.futil"; +import "primitives/memories/seq.futil"; +import "primitives/binary_operators.futil"; +component main() -> () { + cells { + A0_0 = seq_mem_d2(32,8,8,4,4); + @external(1) A_int = seq_mem_d2(32,8,8,4,4); + A_int_read0_0 = std_reg(32); + A_read0_0 = std_reg(32); + A_sh_read0_0 = std_reg(32); + B0_0 = seq_mem_d2(32,8,8,4,4); + @external(1) B_int = seq_mem_d2(32,8,8,4,4); + B_int_read0_0 = std_reg(32); + B_read0_0 = std_reg(32); + B_sh_read0_0 = std_reg(32); + C0_0 = seq_mem_d2(32,8,8,4,4); + @external(1) C_int = seq_mem_d2(32,8,8,4,4); + C_int_read0_0 = std_reg(32); + C_read0_0 = std_reg(32); + C_sh_read0_0 = std_reg(32); + D0_0 = seq_mem_d2(32,8,8,4,4); + @external(1) D_int = seq_mem_d2(32,8,8,4,4); + D_int_read0_0 = std_reg(32); + D_sh_read0_0 = std_reg(32); + add0 = std_add(4); + add1 = std_add(4); + add10 = std_add(4); + add11 = std_add(4); + add2 = std_add(32); + add3 = std_add(4); + add4 = std_add(4); + add5 = std_add(4); + add6 = std_add(32); + add7 = std_add(4); + add8 = std_add(4); + add9 = std_add(4); + @external(1) alpha_int = seq_mem_d1(32,1,1); + alpha_int_read0_0 = std_reg(32); + @external(1) beta_int = seq_mem_d1(32,1,1); + beta_int_read0_0 = std_reg(32); + bin_read0_0 = std_reg(32); + bin_read1_0 = std_reg(32); + bin_read2_0 = std_reg(32); + bin_read3_0 = std_reg(32); + const0 = std_const(4,0); + const1 = std_const(4,0); + const10 = std_const(4,1); + const11 = std_const(4,1); + const12 = std_const(4,0); + const13 = std_const(4,0); + const14 = std_const(1,0); + const15 = std_const(4,0); + const16 = std_const(4,1); + const17 = std_const(4,1); + const18 = std_const(4,1); + const19 = std_const(4,0); + const2 = std_const(4,1); + const20 = std_const(4,0); + const21 = std_const(4,1); + const22 = std_const(4,1); + const3 = std_const(4,1); + const4 = std_const(4,0); + const5 = std_const(4,0); + const6 = std_const(32,0); + const7 = std_const(4,0); + const8 = std_const(1,0); + const9 = std_const(4,1); + d_tmp_0 = std_reg(32); + i0 = std_reg(4); + i00 = std_reg(4); + i01 = std_reg(4); + i10 = std_reg(4); + j0 = std_reg(4); + j00 = std_reg(4); + j01 = std_reg(4); + j10 = std_reg(4); + k0 = std_reg(4); + k10 = std_reg(4); + mult_pipe0 = std_mult_pipe(32); + mult_pipe1 = std_mult_pipe(32); + mult_pipe2 = std_mult_pipe(32); + mult_pipe3 = std_mult_pipe(32); + red_read00 = std_reg(32); + red_read10 = std_reg(32); + tmp0_0 = seq_mem_d2(32,8,8,4,4); + @external(1) tmp_int = seq_mem_d2(32,8,8,4,4); + tmp_int_read0_0 = std_reg(32); + tmp_read0_0 = std_reg(32); + tmp_sh_read0_0 = std_reg(32); + v1_0 = std_reg(32); + v_0 = std_reg(32); + } + wires { + group let0<"promotable"=1> { + i00.in = const0.out; + i00.write_en = 1'd1; + let0[done] = i00.done; + } + group let1<"promotable"=1> { + j00.in = const1.out; + j00.write_en = 1'd1; + let1[done] = j00.done; + } + group let10<"promotable"=2> { + alpha_int_read0_0.in = alpha_int.read_data; + alpha_int_read0_0.write_en = alpha_int.done; + let10[done] = alpha_int_read0_0.done; + alpha_int.content_en = 1'd1; + alpha_int.addr0 = const8.out; + } + group let11<"promotable"=4> { + bin_read0_0.in = mult_pipe0.out; + bin_read0_0.write_en = mult_pipe0.done; + let11[done] = bin_read0_0.done; + mult_pipe0.left = alpha_int_read0_0.out; + mult_pipe0.right = A_read0_0.out; + mult_pipe0.go = !mult_pipe0.done ? 1'd1; + } + group let12<"promotable"=4> { + bin_read1_0.in = mult_pipe1.out; + bin_read1_0.write_en = mult_pipe1.done; + let12[done] = bin_read1_0.done; + mult_pipe1.left = bin_read0_0.out; + mult_pipe1.right = B_read0_0.out; + mult_pipe1.go = !mult_pipe1.done ? 1'd1; + } + group let13<"promotable"=1> { + v_0.in = bin_read1_0.out; + v_0.write_en = 1'd1; + let13[done] = v_0.done; + } + group let14<"promotable"=2> { + red_read00.in = tmp0_0.read_data; + red_read00.write_en = tmp0_0.done; + let14[done] = red_read00.done; + tmp0_0.content_en = 1'd1; + tmp0_0.addr1 = j0.out; + tmp0_0.addr0 = i0.out; + } + group let15<"promotable"=1> { + i10.in = const12.out; + i10.write_en = 1'd1; + let15[done] = i10.done; + } + group let16<"promotable"=1> { + j10.in = const13.out; + j10.write_en = 1'd1; + let16[done] = j10.done; + } + group let17<"promotable"=2> { + beta_int_read0_0.in = beta_int.read_data; + beta_int_read0_0.write_en = beta_int.done; + let17[done] = beta_int_read0_0.done; + beta_int.content_en = 1'd1; + beta_int.addr0 = const14.out; + } + group let18<"promotable"=4> { + bin_read2_0.in = mult_pipe2.out; + bin_read2_0.write_en = mult_pipe2.done; + let18[done] = bin_read2_0.done; + mult_pipe2.left = beta_int_read0_0.out; + mult_pipe2.right = d_tmp_0.out; + mult_pipe2.go = !mult_pipe2.done ? 1'd1; + } + group let19<"promotable"=1> { + k10.in = const15.out; + k10.write_en = 1'd1; + let19[done] = k10.done; + } + group let2<"promotable"=2> { + A_int_read0_0.in = A_int.read_data; + A_int_read0_0.write_en = A_int.done; + let2[done] = A_int_read0_0.done; + A_int.content_en = 1'd1; + A_int.addr1 = j00.out; + A_int.addr0 = i00.out; + } + group let20<"promotable"=4> { + bin_read3_0.in = mult_pipe3.out; + bin_read3_0.write_en = mult_pipe3.done; + let20[done] = bin_read3_0.done; + mult_pipe3.left = tmp_read0_0.out; + mult_pipe3.right = C_read0_0.out; + mult_pipe3.go = !mult_pipe3.done ? 1'd1; + } + group let21<"promotable"=1> { + v1_0.in = bin_read3_0.out; + v1_0.write_en = 1'd1; + let21[done] = v1_0.done; + } + group let22<"promotable"=2> { + red_read10.in = D0_0.read_data; + red_read10.write_en = D0_0.done; + let22[done] = red_read10.done; + D0_0.content_en = 1'd1; + D0_0.addr1 = j10.out; + D0_0.addr0 = i10.out; + } + group let23<"promotable"=1> { + i01.in = const19.out; + i01.write_en = 1'd1; + let23[done] = i01.done; + } + group let24<"promotable"=1> { + j01.in = const20.out; + j01.write_en = 1'd1; + let24[done] = j01.done; + } + group let3<"promotable"=2> { + B_int_read0_0.in = B_int.read_data; + B_int_read0_0.write_en = B_int.done; + let3[done] = B_int_read0_0.done; + B_int.content_en = 1'd1; + B_int.addr1 = j00.out; + B_int.addr0 = i00.out; + } + group let4<"promotable"=2> { + C_int_read0_0.in = C_int.read_data; + C_int_read0_0.write_en = C_int.done; + let4[done] = C_int_read0_0.done; + C_int.content_en = 1'd1; + C_int.addr1 = j00.out; + C_int.addr0 = i00.out; + } + group let5<"promotable"=2> { + D_int_read0_0.in = D_int.read_data; + D_int_read0_0.write_en = D_int.done; + let5[done] = D_int_read0_0.done; + D_int.content_en = 1'd1; + D_int.addr1 = j00.out; + D_int.addr0 = i00.out; + } + group let6<"promotable"=2> { + tmp_int_read0_0.in = tmp_int.read_data; + tmp_int_read0_0.write_en = tmp_int.done; + let6[done] = tmp_int_read0_0.done; + tmp_int.content_en = 1'd1; + tmp_int.addr1 = j00.out; + tmp_int.addr0 = i00.out; + } + group let7<"promotable"=1> { + i0.in = const4.out; + i0.write_en = 1'd1; + let7[done] = i0.done; + } + group let8<"promotable"=1> { + j0.in = const5.out; + j0.write_en = 1'd1; + let8[done] = j0.done; + } + group let9<"promotable"=1> { + k0.in = const7.out; + k0.write_en = 1'd1; + let9[done] = k0.done; + } + group upd0<"promotable"=1> { + A0_0.content_en = 1'd1; + A0_0.addr1 = j00.out; + A0_0.addr0 = i00.out; + A0_0.write_en = 1'd1; + A0_0.write_data = A_int_read0_0.out; + upd0[done] = A0_0.done; + } + group upd1<"promotable"=1> { + B0_0.content_en = 1'd1; + B0_0.addr1 = j00.out; + B0_0.addr0 = i00.out; + B0_0.write_en = 1'd1; + B0_0.write_data = B_int_read0_0.out; + upd1[done] = B0_0.done; + } + group upd10<"promotable"=1> { + tmp0_0.content_en = 1'd1; + tmp0_0.addr1 = j0.out; + tmp0_0.addr0 = i0.out; + tmp0_0.write_en = 1'd1; + add2.left = red_read00.out; + add2.right = v_0.out; + tmp0_0.write_data = add2.out; + upd10[done] = tmp0_0.done; + } + group upd11<"promotable"=1> { + k0.write_en = 1'd1; + add3.left = k0.out; + add3.right = const9.out; + k0.in = add3.out; + upd11[done] = k0.done; + } + group upd12<"promotable"=1> { + j0.write_en = 1'd1; + add4.left = j0.out; + add4.right = const10.out; + j0.in = add4.out; + upd12[done] = j0.done; + } + group upd13<"promotable"=1> { + i0.write_en = 1'd1; + add5.left = i0.out; + add5.right = const11.out; + i0.in = add5.out; + upd13[done] = i0.done; + } + group upd14<"promotable"=2> { + d_tmp_0.write_en = D0_0.done; + D0_0.content_en = 1'd1; + D0_0.addr1 = j10.out; + D0_0.addr0 = i10.out; + d_tmp_0.in = D0_0.read_data; + upd14[done] = d_tmp_0.done; + } + group upd15<"promotable"=1> { + D0_0.content_en = 1'd1; + D0_0.addr1 = j10.out; + D0_0.addr0 = i10.out; + D0_0.write_en = 1'd1; + D0_0.write_data = bin_read2_0.out; + upd15[done] = D0_0.done; + } + group upd16<"promotable"=2> { + tmp_read0_0.write_en = tmp0_0.done; + tmp0_0.content_en = 1'd1; + tmp0_0.addr1 = k10.out; + tmp0_0.addr0 = i10.out; + tmp_read0_0.in = tmp0_0.read_data; + upd16[done] = tmp_read0_0.done; + } + group upd17<"promotable"=2> { + C_read0_0.write_en = C0_0.done; + C0_0.content_en = 1'd1; + C0_0.addr1 = j10.out; + C0_0.addr0 = k10.out; + C_read0_0.in = C0_0.read_data; + upd17[done] = C_read0_0.done; + } + group upd18<"promotable"=1> { + D0_0.content_en = 1'd1; + D0_0.addr1 = j10.out; + D0_0.addr0 = i10.out; + D0_0.write_en = 1'd1; + add6.left = red_read10.out; + add6.right = v1_0.out; + D0_0.write_data = add6.out; + upd18[done] = D0_0.done; + } + group upd19<"promotable"=1> { + k10.write_en = 1'd1; + add7.left = k10.out; + add7.right = const16.out; + k10.in = add7.out; + upd19[done] = k10.done; + } + group upd2<"promotable"=1> { + C0_0.content_en = 1'd1; + C0_0.addr1 = j00.out; + C0_0.addr0 = i00.out; + C0_0.write_en = 1'd1; + C0_0.write_data = C_int_read0_0.out; + upd2[done] = C0_0.done; + } + group upd20<"promotable"=1> { + j10.write_en = 1'd1; + add8.left = j10.out; + add8.right = const17.out; + j10.in = add8.out; + upd20[done] = j10.done; + } + group upd21<"promotable"=1> { + i10.write_en = 1'd1; + add9.left = i10.out; + add9.right = const18.out; + i10.in = add9.out; + upd21[done] = i10.done; + } + group upd22<"promotable"=2> { + A_sh_read0_0.write_en = A0_0.done; + A0_0.content_en = 1'd1; + A0_0.addr1 = j01.out; + A0_0.addr0 = i01.out; + A_sh_read0_0.in = A0_0.read_data; + upd22[done] = A_sh_read0_0.done; + } + group upd23<"promotable"=1> { + A_int.content_en = 1'd1; + A_int.addr1 = j01.out; + A_int.addr0 = i01.out; + A_int.write_en = 1'd1; + A_int.write_data = A_sh_read0_0.out; + upd23[done] = A_int.done; + } + group upd24<"promotable"=2> { + B_sh_read0_0.write_en = B0_0.done; + B0_0.content_en = 1'd1; + B0_0.addr1 = j01.out; + B0_0.addr0 = i01.out; + B_sh_read0_0.in = B0_0.read_data; + upd24[done] = B_sh_read0_0.done; + } + group upd25<"promotable"=1> { + B_int.content_en = 1'd1; + B_int.addr1 = j01.out; + B_int.addr0 = i01.out; + B_int.write_en = 1'd1; + B_int.write_data = B_sh_read0_0.out; + upd25[done] = B_int.done; + } + group upd26<"promotable"=2> { + C_sh_read0_0.write_en = C0_0.done; + C0_0.content_en = 1'd1; + C0_0.addr1 = j01.out; + C0_0.addr0 = i01.out; + C_sh_read0_0.in = C0_0.read_data; + upd26[done] = C_sh_read0_0.done; + } + group upd27<"promotable"=1> { + C_int.content_en = 1'd1; + C_int.addr1 = j01.out; + C_int.addr0 = i01.out; + C_int.write_en = 1'd1; + C_int.write_data = C_sh_read0_0.out; + upd27[done] = C_int.done; + } + group upd28<"promotable"=2> { + D_sh_read0_0.write_en = D0_0.done; + D0_0.content_en = 1'd1; + D0_0.addr1 = j01.out; + D0_0.addr0 = i01.out; + D_sh_read0_0.in = D0_0.read_data; + upd28[done] = D_sh_read0_0.done; + } + group upd29<"promotable"=1> { + D_int.content_en = 1'd1; + D_int.addr1 = j01.out; + D_int.addr0 = i01.out; + D_int.write_en = 1'd1; + D_int.write_data = D_sh_read0_0.out; + upd29[done] = D_int.done; + } + group upd3<"promotable"=1> { + D0_0.content_en = 1'd1; + D0_0.addr1 = j00.out; + D0_0.addr0 = i00.out; + D0_0.write_en = 1'd1; + D0_0.write_data = D_int_read0_0.out; + upd3[done] = D0_0.done; + } + group upd30<"promotable"=2> { + tmp_sh_read0_0.write_en = tmp0_0.done; + tmp0_0.content_en = 1'd1; + tmp0_0.addr1 = j01.out; + tmp0_0.addr0 = i01.out; + tmp_sh_read0_0.in = tmp0_0.read_data; + upd30[done] = tmp_sh_read0_0.done; + } + group upd31<"promotable"=1> { + tmp_int.content_en = 1'd1; + tmp_int.addr1 = j01.out; + tmp_int.addr0 = i01.out; + tmp_int.write_en = 1'd1; + tmp_int.write_data = tmp_sh_read0_0.out; + upd31[done] = tmp_int.done; + } + group upd32<"promotable"=1> { + j01.write_en = 1'd1; + add10.left = j01.out; + add10.right = const21.out; + j01.in = add10.out; + upd32[done] = j01.done; + } + group upd33<"promotable"=1> { + i01.write_en = 1'd1; + add11.left = i01.out; + add11.right = const22.out; + i01.in = add11.out; + upd33[done] = i01.done; + } + group upd4<"promotable"=1> { + tmp0_0.content_en = 1'd1; + tmp0_0.addr1 = j00.out; + tmp0_0.addr0 = i00.out; + tmp0_0.write_en = 1'd1; + tmp0_0.write_data = tmp_int_read0_0.out; + upd4[done] = tmp0_0.done; + } + group upd5<"promotable"=1> { + j00.write_en = 1'd1; + add0.left = j00.out; + add0.right = const2.out; + j00.in = add0.out; + upd5[done] = j00.done; + } + group upd6<"promotable"=1> { + i00.write_en = 1'd1; + add1.left = i00.out; + add1.right = const3.out; + i00.in = add1.out; + upd6[done] = i00.done; + } + group upd7<"promotable"=1> { + tmp0_0.content_en = 1'd1; + tmp0_0.addr1 = j0.out; + tmp0_0.addr0 = i0.out; + tmp0_0.write_en = 1'd1; + tmp0_0.write_data = const6.out; + upd7[done] = tmp0_0.done; + } + group upd8<"promotable"=2> { + A_read0_0.write_en = A0_0.done; + A0_0.content_en = 1'd1; + A0_0.addr1 = k0.out; + A0_0.addr0 = i0.out; + A_read0_0.in = A0_0.read_data; + upd8[done] = A_read0_0.done; + } + group upd9<"promotable"=2> { + B_read0_0.write_en = B0_0.done; + B0_0.content_en = 1'd1; + B0_0.addr1 = j0.out; + B0_0.addr0 = k0.out; + B_read0_0.in = B0_0.read_data; + upd9[done] = B_read0_0.done; + } + } + control { + seq { + @pos(0) let0; + repeat 8 { + seq { + @pos(1) let1; + repeat 8 { + seq { + @pos(2) let2; + par { + @pos(3) upd0; + @pos(4) let3; + } + par { + @pos(5) upd1; + @pos(6) let4; + } + par { + @pos(7) upd2; + @pos(8) let5; + } + par { + @pos(9) upd3; + @pos(10) let6; + } + @pos(11) upd4; + @pos(1) upd5; + } + } + @pos(0) upd6; + } + } + @pos(12) let7; + repeat 8 { + seq { + @pos(13) let8; + repeat 8 { + seq { + @pos(14) upd7; + @pos(15) let9; + repeat 8 { + seq { + par { + @pos(16) let10; + @pos(17) upd8; + @pos(18) upd9; + } + let11; + let12; + let13; + let14; + upd10; + @pos(15) upd11; + } + } + @pos(13) upd12; + } + } + @pos(12) upd13; + } + } + @pos(19) let15; + repeat 8 { + seq { + @pos(20) let16; + repeat 8 { + seq { + @pos(21) upd14; + @pos(22) let17; + let18; + upd15; + @pos(23) let19; + repeat 8 { + seq { + par { + @pos(24) upd16; + @pos(25) upd17; + } + let20; + let21; + let22; + upd18; + @pos(23) upd19; + } + } + @pos(20) upd20; + } + } + @pos(19) upd21; + } + } + @pos(26) let23; + repeat 8 { + seq { + @pos(27) let24; + repeat 8 { + seq { + @pos(28) upd22; + par { + @pos(29) upd23; + @pos(30) upd24; + } + par { + @pos(31) upd25; + @pos(32) upd26; + } + par { + @pos(33) upd27; + @pos(34) upd28; + } + par { + @pos(35) upd29; + @pos(36) upd30; + } + @pos(37) upd31; + @pos(27) upd32; + } + } + @pos(26) upd33; + } + } + } + } +} +metadata #{ + 0: for (let i0: ubit<4> = 0..8) { + 1: for (let j0: ubit<4> = 0..8) { + 2: A_sh[i0][j0] := A_int[i0][j0]; + 3: A_sh[i0][j0] := A_int[i0][j0]; + 4: B_sh[i0][j0] := B_int[i0][j0]; + 5: B_sh[i0][j0] := B_int[i0][j0]; + 6: C_sh[i0][j0] := C_int[i0][j0]; + 7: C_sh[i0][j0] := C_int[i0][j0]; + 8: D_sh[i0][j0] := D_int[i0][j0]; + 9: D_sh[i0][j0] := D_int[i0][j0]; + 10: tmp_sh[i0][j0] := tmp_int[i0][j0]; + 11: tmp_sh[i0][j0] := tmp_int[i0][j0]; + 12: for (let i: ubit<4> = 0..8) { + 13: for (let j: ubit<4> = 0..8) { + 14: tmp[i][j] := 0; + 15: for (let k: ubit<4> = 0..8) { + 16: let v: ubit<32> = alpha_int[0] * A[i][k] * B[k][j]; + 17: let v: ubit<32> = alpha_int[0] * A[i][k] * B[k][j]; + 18: let v: ubit<32> = alpha_int[0] * A[i][k] * B[k][j]; + 19: for (let i1: ubit<4> = 0..8) { + 20: for (let j1: ubit<4> = 0..8) { + 21: let d_tmp: ubit<32> = D[i1][j1]; + 22: D[i1][j1] := beta_int[0] * d_tmp; + 23: for (let k1: ubit<4> = 0..8) { + 24: let v1: ubit<32> = tmp[i1][k1] * C[k1][j1]; + 25: let v1: ubit<32> = tmp[i1][k1] * C[k1][j1]; + 26: for (let i0: ubit<4> = 0..8) { + 27: for (let j0: ubit<4> = 0..8) { + 28: A_int[i0][j0] := A_sh[i0][j0]; + 29: A_int[i0][j0] := A_sh[i0][j0]; + 30: B_int[i0][j0] := B_sh[i0][j0]; + 31: B_int[i0][j0] := B_sh[i0][j0]; + 32: C_int[i0][j0] := C_sh[i0][j0]; + 33: C_int[i0][j0] := C_sh[i0][j0]; + 34: D_int[i0][j0] := D_sh[i0][j0]; + 35: D_int[i0][j0] := D_sh[i0][j0]; + 36: tmp_int[i0][j0] := tmp_sh[i0][j0]; + 37: tmp_int[i0][j0] := tmp_sh[i0][j0]; +}# diff --git a/src/main/scala/backends/calyx/Ast.scala b/src/main/scala/backends/calyx/Ast.scala index da19bdc1..563ba715 100644 --- a/src/main/scala/backends/calyx/Ast.scala +++ b/src/main/scala/backends/calyx/Ast.scala @@ -27,26 +27,24 @@ object Calyx: vsep( this.map.toSeq .sortBy(_._2) - .map({ - case (pos, c) => - text(c.toString()) <> text(":") <+> text( - pos.longString.split("\n")(0) - ) + .map({ case (pos, c) => + text(c.toString()) <> text(":") <+> text( + pos.longString.split("\n")(0) + ) }) ), left = text("#") <> lbrace, right = rbrace <> text("#") ) - private def emitPos(pos: Position, @annotation.unused span: Int)( - implicit meta: Metadata + private def emitPos(pos: Position, @annotation.unused span: Int)(implicit + meta: Metadata ): Doc = // Add position information to the metadata. if pos.line != 0 && pos.column != 0 then val count = meta.addPos(pos) text("@pos") <> parens(text(count.toString)) <> space - else - emptyDoc + else emptyDoc /* (if (pos.line == 0 && pos.column == 0) { emptyDoc } else { @@ -74,7 +72,7 @@ object Calyx: def doc(): Doc def emit(): String = this.doc().pretty - /** A variable representing the name of a component. **/ + /** A variable representing the name of a component. * */ case class CompVar(name: String) extends Emitable with Ordered[CompVar]: override def doc(): Doc = text(name) def port(port: String): CompPort = CompPort(this, port) @@ -87,12 +85,13 @@ object Calyx: attrs: List[(String, Int)] = List() ) extends Emitable: override def doc(): Doc = - val attrDoc = hsep(attrs.map({ - case (attr, v) => text(s"@${attr}") <> parens(text(v.toString())) + val attrDoc = hsep(attrs.map({ case (attr, v) => + text(s"@${attr}") <> parens(text(v.toString())) })) <> (if attrs.isEmpty then emptyDoc else space) attrDoc <> id.doc() <> colon <+> value(width) - /**** definition statements *****/ + /** ** definition statements **** + */ case class Namespace(name: String, comps: List[NamespaceStatement]): def doc(implicit meta: Metadata): Doc = vsep(comps.map(_.doc)) @@ -123,7 +122,8 @@ object Calyx: control: Control ) extends NamespaceStatement - /***** structure *****/ + /** *** structure **** + */ sealed trait Port extends Emitable with Ordered[Port]: override def doc(): Doc = this match case CompPort(id, name) => @@ -166,9 +166,8 @@ object Calyx: val attrDoc = hsep( attrs - .map({ - case (attr, v) => - text("@") <> text(attr) <> parens(text(v.toString())) + .map({ case (attr, v) => + text("@") <> text(attr) <> parens(text(v.toString())) }) ) <> (if attrs.isEmpty then emptyDoc else space) @@ -183,7 +182,9 @@ object Calyx: (if comb then text("comb ") else emptyDoc) <> text("group") <+> id.doc() <> (if delay.isDefined then - angles(text("\"promotable\"") <> equal <> text(delay.get.toString())) + angles( + text("\"promotable\"") <> equal <> text(delay.get.toString()) + ) else emptyDoc) <+> scope(vsep(conns.map(_.doc()))) @@ -194,10 +195,8 @@ object Calyx: case (Group(thisId, _, _, _), Group(thatId, _, _, _)) => thisId.compare(thatId) case (Assign(thisSrc, thisDest, _), Assign(thatSrc, thatDest, _)) => { - if thisSrc.compare(thatSrc) == 0 then - thisDest.compare(thatDest) - else - thisSrc.compare(thatSrc) + if thisSrc.compare(thatSrc) == 0 then thisDest.compare(thatDest) + else thisSrc.compare(thatSrc) } case (_: Cell, _) => -1 case (_, _: Cell) => 1 @@ -263,7 +262,8 @@ object Calyx: case class Not(inner: GuardExpr) extends GuardExpr case object True extends GuardExpr - /***** control *****/ + /** *** control **** + */ sealed trait Control: var attributes = Map[String, Int]() @@ -284,12 +284,10 @@ object Calyx: case _ => ParComp(List(this, c)) def attributesDoc(): Doc = - if this.attributes.isEmpty then - emptyDoc + if this.attributes.isEmpty then emptyDoc else - hsep(attributes.map({ - case (attr, v) => - text(s"@$attr") <> parens(text(v.toString())) + hsep(attributes.map({ case (attr, v) => + text(s"@$attr") <> parens(text(v.toString())) })) <> space def doc(implicit meta: Metadata): Doc = @@ -302,18 +300,16 @@ object Calyx: text("if") <+> port.doc() <+> text("with") <+> cond.doc() <+> scope(trueBr.doc) <> ( - if falseBr == Empty then - emptyDoc - else - space <> text("else") <+> scope(falseBr.doc) - ) + if falseBr == Empty then emptyDoc + else space <> text("else") <+> scope(falseBr.doc) + ) case While(port, cond, body) => { text("while") <+> port.doc() <+> text("with") <+> cond.doc() <+> scope(body.doc(meta)) } - case While(count, body) => { - text("repeat") <+> count.toString <+> + case Repeat(count, body) => { + text("repeat") <+> text(count.toString) <+> scope(body.doc(meta)) } case e @ Enable(id) => { @@ -321,17 +317,16 @@ object Calyx: } case i @ Invoke(id, refCells, inConnects, outConnects) => { val cells = - if refCells.isEmpty then - emptyDoc + if refCells.isEmpty then emptyDoc else - brackets(commaSep(refCells.map({ - case (param, cell) => text(param) <> equal <> cell.doc() + brackets(commaSep(refCells.map({ case (param, cell) => + text(param) <> equal <> cell.doc() }))) - val inputDefs = inConnects.map({ - case (param, arg) => text(param) <> equal <> arg.doc() + val inputDefs = inConnects.map({ case (param, arg) => + text(param) <> equal <> arg.doc() }) - val outputDefs = outConnects.map({ - case (param, arg) => text(param) <> equal <> arg.doc() + val outputDefs = outConnects.map({ case (param, arg) => + text(param) <> equal <> arg.doc() }) emitPos(i.pos, i.span) <> text("invoke") <+> id.doc() <> cells <> diff --git a/src/main/scala/backends/calyx/Backend.scala b/src/main/scala/backends/calyx/Backend.scala index 3088bcff..b836f59c 100644 --- a/src/main/scala/backends/calyx/Backend.scala +++ b/src/main/scala/backends/calyx/Backend.scala @@ -12,21 +12,20 @@ import fuselang.common.{Configuration => C} import Helpers._ -/** - * Helper class that gives names to the fields of the output of `emitExpr` and - * `emitBinop`. - * - `port` holds either an input or output port that represents how data - * flows between expressions. - * - `done` holds the port that signals when the writing or reading from `port` - * is done. - * - `structure` represents additional structure involved in computing the - * expression. - * - `delay` is the static delay required to complete the structure within - * the emitted output. - * - `multiCycleInfo` is the variable and delay of the of the op that requires - * multiple cycles to complete. This is necessary for the case when a - * `write_en` signal should not be high until the op is `done`. If this is - * None, then the emitted output has no multi-cycle ops. +/** Helper class that gives names to the fields of the output of `emitExpr` and + * `emitBinop`. + * - `port` holds either an input or output port that represents how data + * flows between expressions. + * - `done` holds the port that signals when the writing or reading from + * `port` is done. + * - `structure` represents additional structure involved in computing the + * expression. + * - `delay` is the static delay required to complete the structure within + * the emitted output. + * - `multiCycleInfo` is the variable and delay of the of the op that + * requires multiple cycles to complete. This is necessary for the case + * when a `write_en` signal should not be high until the op is `done`. If + * this is None, then the emitted output has no multi-cycle ops. */ private case class EmitOutput( // The port that contains the output from the operation @@ -41,31 +40,29 @@ private case class EmitOutput( val multiCycleInfo: Option[(Port, Option[Int])] ) -/** - * CALLING CONVENTION: - * The backend supports functions using Calyx's component definitions. - * For a function: +/** CALLING CONVENTION: The backend supports functions using Calyx's component + * definitions. For a function: * ``` * def id(x: ubit<32>): ubit<32> = { * let out: ubit<32> = x; * return out; * } * ``` - * The function generates a port named unique (`out`). - * Values returned by the method are carried on this port. + * The function generates a port named unique (`out`). Values returned by the + * method are carried on this port. * * Calls are transformed into `invoke` statements in Calyx. Uses of the - * returned value from the function are assumed to available after the - * `invoke` statement. + * returned value from the function are assumed to available after the `invoke` + * statement. * - * The `out` port is marked using the "stable" attributed which is verified - * by the Calyx compiler to enable such uses: + * The `out` port is marked using the "stable" attributed which is verified by + * the Calyx compiler to enable such uses: * https://github.com/cucapra/Calyx/issues/304 */ private class CalyxBackendHelper { - /** A list of function IDs that require width arguments - * in their SystemVerilog module definition. + /** A list of function IDs that require width arguments in their SystemVerilog + * module definition. */ val requiresWidthArguments = List("fp_sqrt", "sqrt") @@ -80,15 +77,15 @@ private class CalyxBackendHelper { CompVar(s"$base${idx(base)}") } - /** A Calyx variable will either be a - * local variable (LocalVar) or - * a function parameter (ParameterVar). */ + /** A Calyx variable will either be a local variable (LocalVar) or a function + * parameter (ParameterVar). + */ sealed trait VType case object LocalVar extends VType case object ParameterVar extends VType - /** Store mappings from Dahlia variables to - * generated Calyx variables. */ + /** Store mappings from Dahlia variables to generated Calyx variables. + */ type Store = Map[CompVar, (CompVar, VType)] /** Mappings from Function Id to Function Definition. */ @@ -148,9 +145,9 @@ private class CalyxBackendHelper { } } - /** Returns a list of tuples (name, width) for each address port - in a memory. For example, a D1 Memory declared as (32, 1, 1) - would return List[("addr0", 1)]. + /** Returns a list of tuples (name, width) for each address port in a memory. + * For example, a D1 Memory declared as (32, 1, 1) would return + * List[("addr0", 1)]. */ def getAddrPortToWidths(typ: TArray, id: Id): List[(String, BigInt)] = { // Emit the array to determine the port widths. @@ -169,23 +166,24 @@ private class CalyxBackendHelper { } val addressIndices = (dims + 1 to dims << 1).toList - addressIndices.zipWithIndex.map({ - case (n: Int, i: Int) => (s"addr${i}", arrayArgs(n)) + addressIndices.zipWithIndex.map({ case (n: Int, i: Int) => + (s"addr${i}", arrayArgs(n)) }) } /** Returns the width argument(s) of a given function, based on the return * type of the function. This is necessary because some components may - * require certain module parameters in SystemVerilog. For example, - * `foo` with SystemVerilog module definition: + * require certain module parameters in SystemVerilog. For example, `foo` + * with SystemVerilog module definition: * ``` * module foo #( * parameter WIDTH * ) ( ... ); * ``` - * Requires that a `WIDTH` be provided. Currently, the functions that - * do require these parameters must be manually added to the list - * `requiresWidthArguments`. */ + * Requires that a `WIDTH` be provided. Currently, the functions that do + * require these parameters must be manually added to the list + * `requiresWidthArguments`. + */ def getCompInstArgs( funcId: Id )(implicit id2FuncDef: FunctionMapping): List[BigInt] = { @@ -203,9 +201,11 @@ private class CalyxBackendHelper { } } - /** `emitInvokeDecl` computes the necessary structure and control for Syntax.EApp. */ - def emitInvokeDecl(app: EApp)( - implicit store: Store, + /** `emitInvokeDecl` computes the necessary structure and control for + * Syntax.EApp. + */ + def emitInvokeDecl(app: EApp)(implicit + store: Store, id2FuncDef: FunctionMapping ): (Cell, Seq[Structure], Control) = { val functionName = app.func.toString() @@ -225,14 +225,13 @@ private class CalyxBackendHelper { val (refCells, inConnects) = id2FuncDef(app.func).args .zip(argPorts) - .partitionMap({ - case (param, v) => - param.typ match { - case _: TArray => { - Left((param.id.v, CompVar(getPortName(v)))) - } - case _ => Right((param.id.v, v)) + .partitionMap({ case (param, v) => + param.typ match { + case _: TArray => { + Left((param.id.v, CompVar(getPortName(v)))) } + case _ => Right((param.id.v, v)) + } }) ( @@ -242,8 +241,8 @@ private class CalyxBackendHelper { ) } - /** `emitDecl(d)` computes the structure that is needed to - * represent the declaration `d`. Simply returns a `List[Structure]`. + /** `emitDecl(d)` computes the structure that is needed to represent the + * declaration `d`. Simply returns a `List[Structure]`. */ def emitDecl(d: Decl): Structure = d.typ match { case tarr: TArray => emitArrayDecl(tarr, d.id, List("external" -> 1)) @@ -253,11 +252,11 @@ private class CalyxBackendHelper { case x => throw NotImplemented(s"Type $x not implemented for decls.", x.pos) } - /** `emitBinop` is a helper function to generate the structure - * for `e1 binop e2`. The return type is described in `emitExpr`. + /** `emitBinop` is a helper function to generate the structure for `e1 binop + * e2`. The return type is described in `emitExpr`. */ - def emitBinop(compName: String, e1: Expr, e2: Expr)( - implicit store: Store + def emitBinop(compName: String, e1: Expr, e2: Expr)(implicit + store: Store ): EmitOutput = { val e1Out = emitExpr(e1) val e2Out = emitExpr(e2) @@ -300,14 +299,14 @@ private class CalyxBackendHelper { None, struct ++ e1Out.structure ++ e2Out.structure, for d1 <- e1Out.delay; d2 <- e2Out.delay - yield d1 + d2, + yield d1 + d2, None ) } // if there is additional information about the integer bit, // use fixed point binary operation case (e1Bits, Some(intBit1)) => { - val (e2Bits, Some(intBit2)) = bitsForType(e2.typ, e2.pos) : @unchecked + val (e2Bits, Some(intBit2)) = bitsForType(e2.typ, e2.pos): @unchecked val fracBit1 = e1Bits - intBit1 val fracBit2 = e2Bits - intBit2 val isSigned = signed(e1.typ) @@ -337,7 +336,7 @@ private class CalyxBackendHelper { None, struct ++ e1Out.structure ++ e2Out.structure, for d1 <- e1Out.delay; d2 <- e2Out.delay - yield d1 + d2, + yield d1 + d2, None ) } @@ -350,8 +349,8 @@ private class CalyxBackendHelper { e2: Expr, outPort: String, delay: Option[Int] - )( - implicit store: Store + )(implicit + store: Store ): EmitOutput = { val e1Out = emitExpr(e1) val e2Out = emitExpr(e2) @@ -416,19 +415,19 @@ private class CalyxBackendHelper { Some(comp.name.port("done")), struct ++ e1Out.structure ++ e2Out.structure, for d1 <- e1Out.delay; d2 <- e2Out.delay; d3 <- delay - yield d1 + d2 + d3, + yield d1 + d2 + d3, Some((comp.name.port("done"), delay)) ) } - /** `emitExpr(expr, rhsInfo)(implicit store)` calculates the necessary structure - * to compute `expr`. - * - If rhsInfo is defined then this expression is an LHS. rhsInfo contains - * (done, delay) information for the RHS being written to this LHS. - * - Otherwise, this is an RHS expression. + /** `emitExpr(expr, rhsInfo)(implicit store)` calculates the necessary + * structure to compute `expr`. + * - If rhsInfo is defined then this expression is an LHS. rhsInfo contains + * (done, delay) information for the RHS being written to this LHS. + * - Otherwise, this is an RHS expression. */ - def emitExpr(expr: Expr, rhsInfo: Option[(Port, Option[Int])] = None)( - implicit store: Store + def emitExpr(expr: Expr, rhsInfo: Option[(Port, Option[Int])] = None)(implicit + store: Store ): EmitOutput = expr match { case _: EInt => { @@ -560,7 +559,8 @@ private class CalyxBackendHelper { // Cast ERational to Fixed Point. case ECast(ERational(value), typ) => { val _ = rhsInfo - val (width, Some(intWidth)) = bitsForType(Some(typ), expr.pos) : @unchecked + val (width, Some(intWidth)) = + bitsForType(Some(typ), expr.pos): @unchecked val fracWidth = width - intWidth // Interpret as an integer. val isNegative = value.startsWith("-") @@ -683,13 +683,13 @@ private class CalyxBackendHelper { }) // set ContentEn to 1'd1 - val contentEnStruct = List(Assign(ConstantPort(1,1), contentEnPort)) + val contentEnStruct = List(Assign(ConstantPort(1, 1), contentEnPort)) // Set write_en to 1'd0 for reads, to port for writes. val writeEnStruct = rhsInfo match { case Some((port, _)) => List(Assign(port, writeEnPort)) - case None => List(Assign(ConstantPort(1,0), writeEnPort)) + case None => List(Assign(ConstantPort(1, 0), writeEnPort)) } val delay = (rhsInfo) match { @@ -700,7 +700,9 @@ private class CalyxBackendHelper { EmitOutput( accessPort, Some(donePort), - contentEnStruct ++ (indexing ++ (if rhsInfo.isDefined then writeEnStruct else List())), + contentEnStruct ++ (indexing ++ (if rhsInfo.isDefined then + writeEnStruct + else List())), delay, Some((donePort, delay)) ) @@ -724,8 +726,8 @@ private class CalyxBackendHelper { def emitCmd( c: Command - )( - implicit store: Store, + )(implicit + store: Store, id2FuncDef: FunctionMapping ): (List[Structure], Control, Store) = { c match { @@ -769,7 +771,8 @@ private class CalyxBackendHelper { // into a register. val reg = Stdlib.register(genName(s"$id"), typ_b) val groupName = genName("let") - val doneHole = Assign(reg.name.port("done"), HolePort(groupName, "done")) + val doneHole = + Assign(reg.name.port("done"), HolePort(groupName, "done")) val struct = List( @@ -936,28 +939,31 @@ private class CalyxBackendHelper { } case CEmpty => (List(), Empty, store) case wh @ CWhile(cond, _, body) => { - if wh.attributes.contains("bound") then { - val (bodyStruct, bodyCon, st) = emitCmd(body) - val control = Repeat(10, bodyCon) - (bodyStruct, control, st) - } else { - val condOut = emitExpr(cond) - val groupName = genName("cond") - assertOrThrow( - !condOut.done.isDefined, - BackendError("Loop condition is non-combinational") - ) - val (condGroup, condDefs) = - Group.fromStructure( - groupName, - condOut.structure, - condOut.delay, - true + wh.attributes.get("bound") match { + case Some(count) => { + val (bodyStruct, bodyCon, st) = emitCmd(body) + val control = Repeat(count, bodyCon) + (bodyStruct, control, st) + } + case None => { + val condOut = emitExpr(cond) + val groupName = genName("cond") + assertOrThrow( + !condOut.done.isDefined, + BackendError("Loop condition is non-combinational") ) - val (bodyStruct, bodyCon, st) = emitCmd(body) - val control = While(condOut.port, condGroup.id, bodyCon) - control.attributes = wh.attributes - (condGroup :: bodyStruct ++ condDefs, control, st) + val (condGroup, condDefs) = + Group.fromStructure( + groupName, + condOut.structure, + condOut.delay, + true + ) + val (bodyStruct, bodyCon, st) = emitCmd(body) + val control = While(condOut.port, condGroup.id, bodyCon) + control.attributes = wh.attributes + (condGroup :: bodyStruct ++ condDefs, control, st) + } } } case _: CFor => @@ -998,7 +1004,7 @@ private class CalyxBackendHelper { ) } } - case CReturn(expr:EVar) => { + case CReturn(expr: EVar) => { // Hooks the output port of the emitted `expr` to PortDef `out` of the component. val condOut = emitExpr(expr) val outPort = ThisPort(CompVar("out")) @@ -1008,7 +1014,7 @@ private class CalyxBackendHelper { case CReturn(e) => { throw NotImplemented( s"Only allowed to return variables. Store the return expression in a variable to return it." - + s"e.g. `let _tmp = ${Pretty.emitExpr(e)(false).pretty}; return _tmp`", + + s"e.g. `let _tmp = ${Pretty.emitExpr(e)(false).pretty}; return _tmp`", e.pos ) } @@ -1047,7 +1053,9 @@ private class CalyxBackendHelper { ) val functionDefinitions: List[Component] = - for ( case (id, FuncDef(_, params, retType, Some(body))) <- id2FuncDef.toList ) + for ( + case (id, FuncDef(_, params, retType, Some(body))) <- id2FuncDef.toList + ) yield { val (refCells, inputPorts) = params.partitionMap(param => param.typ match { @@ -1084,7 +1092,8 @@ private class CalyxBackendHelper { else List( PortDef(CompVar("out"), outputBitWidth, List(("stable" -> 1))) - ), + ) + , refCells.toList ++ cmdStructure.sorted, controls ) @@ -1092,7 +1101,7 @@ private class CalyxBackendHelper { val imports = Import("primitives/core.futil") :: - Import("primitives/memories/seq.futil") :: + Import("primitives/memories/seq.futil") :: Import("primitives/binary_operators.futil") :: p.includes.flatMap(_.backends.get(C.Calyx)).map(i => Import(i)).toList