Skip to content

Commit

Permalink
[ARC][Seq] Switch arc over to use seq.clock (#6054)
Browse files Browse the repository at this point in the history
  • Loading branch information
nandor authored Sep 6, 2023
1 parent 326e507 commit d345d1f
Show file tree
Hide file tree
Showing 23 changed files with 138 additions and 114 deletions.
1 change: 1 addition & 0 deletions include/circt/Dialect/Arc/ArcOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "circt/Dialect/Arc/ArcDialect.h"
#include "circt/Dialect/Arc/ArcTypes.h"
#include "circt/Dialect/Seq/SeqTypes.h"

#include "circt/Dialect/Arc/ArcInterfaces.h.inc"

Expand Down
13 changes: 7 additions & 6 deletions include/circt/Dialect/Arc/ArcOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

include "circt/Dialect/Arc/ArcDialect.td"
include "circt/Dialect/Arc/ArcTypes.td"
include "circt/Dialect/Seq/SeqTypes.td"
include "circt/Dialect/Arc/ArcInterfaces.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/FunctionInterfaces.td"
Expand Down Expand Up @@ -53,7 +54,7 @@ def DefineOp : ArcOp<"define", [
build($_builder, $_state, sym_name, function_type, mlir::ArrayAttr(),
mlir::ArrayAttr());
}]>,
OpBuilder<(ins "mlir::StringRef":$sym_name,
OpBuilder<(ins "mlir::StringRef":$sym_name,
"mlir::FunctionType":$function_type), [{
build($_builder, $_state, sym_name, function_type, mlir::ArrayAttr(),
mlir::ArrayAttr());
Expand Down Expand Up @@ -141,7 +142,7 @@ def StateOp : ArcOp<"state", [

let arguments = (ins
FlatSymbolRefAttr:$arc,
Optional<I1>:$clock,
Optional<ClockType>:$clock,
Optional<I1>:$enable,
Optional<I1>:$reset,
I32Attr:$latency,
Expand Down Expand Up @@ -265,8 +266,8 @@ def CallOp : ArcOp<"call", [

def ClockGateOp : ArcOp<"clock_gate", [Pure]> {
let summary = "Clock gate";
let arguments = (ins I1:$input, I1:$enable);
let results = (outs I1:$output);
let arguments = (ins ClockType:$input, I1:$enable);
let results = (outs ClockType:$output);
let assemblyFormat = [{
$input `,` $enable attr-dict
}];
Expand Down Expand Up @@ -322,7 +323,7 @@ def MemoryWritePortOp : ArcOp<"memory_write_port", [
MemoryType:$memory,
FlatSymbolRefAttr:$arc,
Variadic<AnyType>:$inputs,
Optional<I1>:$clock,
Optional<ClockType>:$clock,
UnitAttr:$enable,
UnitAttr:$mask,
DefaultValuedAttr<I32Attr, "1">:$latency
Expand Down Expand Up @@ -412,7 +413,7 @@ def ClockDomainOp : ArcOp<"clock_domain", [
]> {
let summary = "a clock domain";

let arguments = (ins Variadic<AnyType>:$inputs, I1:$clock);
let arguments = (ins Variadic<AnyType>:$inputs, ClockType:$clock);
let results = (outs Variadic<AnyType>:$outputs);
let regions = (region SizedRegion<1>:$body);

Expand Down
20 changes: 16 additions & 4 deletions lib/Conversion/SeqToSV/FirMemLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ FirMemLowering::createMemoryModule(FirMemConfig &mem,
SmallVector<hw::PortInfo> ports;

// Common types used for memory ports.
Type clkType = ClockType::get(context);
Type bitType = IntegerType::get(context, 1);
Type dataType = IntegerType::get(context, std::max((size_t)1, mem.dataWidth));
Type maskType = IntegerType::get(context, mem.maskBits);
Expand Down Expand Up @@ -226,7 +227,7 @@ FirMemLowering::createMemoryModule(FirMemConfig &mem,
auto addCommonPorts = [&](StringRef prefix, size_t idx) {
addInput(prefix, idx, "_addr", addrType);
addInput(prefix, idx, "_en", bitType);
addInput(prefix, idx, "_clk", bitType);
addInput(prefix, idx, "_clk", clkType);
};

// Add the read ports.
Expand Down Expand Up @@ -317,6 +318,17 @@ void FirMemLowering::lowerMemoriesInModule(
};
auto valueOrOne = [&](Value value) { return value ? value : constOne(); };

DenseMap<Value, Value> clocks;
auto mapClock = [&](Value clock) {
auto it = clocks.try_emplace(clock, Value{});
if (it.second) {
ImplicitLocOpBuilder builder(clock.getLoc(), clock.getContext());
builder.setInsertionPointAfterValue(clock);
it.first->second = builder.createOrFold<seq::ToClockOp>(clock);
}
return it.first->second;
};

for (auto [config, genOp, memOp] : mems) {
LLVM_DEBUG(llvm::dbgs() << "- Lowering " << memOp.getName() << "\n");
SmallVector<Value> inputs;
Expand All @@ -332,7 +344,7 @@ void FirMemLowering::lowerMemoriesInModule(
continue;
addInput(port.getAddress());
addInput(valueOrOne(port.getEnable()));
addInput(port.getClock());
addInput(mapClock(port.getClock()));
addOutput(port.getData());
}

Expand All @@ -343,7 +355,7 @@ void FirMemLowering::lowerMemoriesInModule(
continue;
addInput(port.getAddress());
addInput(valueOrOne(port.getEnable()));
addInput(port.getClock());
addInput(mapClock(port.getClock()));
addInput(port.getMode());
addInput(port.getWriteData());
addOutput(port.getReadData());
Expand All @@ -358,7 +370,7 @@ void FirMemLowering::lowerMemoriesInModule(
continue;
addInput(port.getAddress());
addInput(valueOrOne(port.getEnable()));
addInput(port.getClock());
addInput(mapClock(port.getClock()));
addInput(port.getData());
if (config->maskBits > 1)
addInput(valueOrOne(port.getMask()));
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Arc/ArcTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "circt/Dialect/Arc/ArcTypes.h"
#include "circt/Dialect/Arc/ArcDialect.h"
#include "circt/Dialect/Seq/SeqTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Arc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ add_circt_dialect_library(CIRCTArc

LINK_LIBS PUBLIC
CIRCTHW
CIRCTSeq
MLIRIR
MLIRInferTypeOpInterface
MLIRSideEffectInterfaces
Expand Down
13 changes: 10 additions & 3 deletions lib/Dialect/Arc/Transforms/LowerState.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,18 @@ LogicalResult ModuleLowering::lowerPrimaryInputs() {
if (blockArg == storageArg)
continue;
auto name = moduleOp.getArgName(blockArg.getArgNumber());
auto intType = blockArg.getType().dyn_cast<IntegerType>();
if (!intType)
auto argTy = blockArg.getType();
IntegerType innerTy;
if (argTy.isa<seq::ClockType>()) {
innerTy = IntegerType::get(context, 1);
} else if (auto intType = argTy.dyn_cast<IntegerType>()) {
innerTy = intType;
} else {
return mlir::emitError(blockArg.getLoc(), "input ")
<< name << " is of non-integer type " << blockArg.getType();
}
auto state = builder.create<RootInputOp>(
blockArg.getLoc(), StateType::get(intType), name, storageArg);
blockArg.getLoc(), StateType::get(innerTy), name, storageArg);
Value readOp = replaceValueWithStateRead(blockArg, state);
// Presently all clocks must be arguments, so they can be resolved here.
if (auto it = clockLowerings.find(blockArg); it != clockLowerings.end())
Expand Down Expand Up @@ -799,6 +805,7 @@ LogicalResult LowerStatePass::runOnModule(HWModuleOp moduleOp,
builder.create<ModelOp>(moduleOp.getLoc(), moduleOp.getModuleNameAttr());
modelOp.getBody().takeBody(moduleOp.getBody());
moduleOp->erase();

return success();
}

Expand Down
38 changes: 19 additions & 19 deletions test/Conversion/ConvertToArcs/convert-to-arcs.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,15 @@ hw.module @SplitAtConstants() -> (z: i4) {
// CHECK-NEXT: }

// CHECK-LABEL: hw.module @Pipeline
hw.module @Pipeline(%clock: i1, %i0: i4, %i1: i4) -> (z: i4) {
hw.module @Pipeline(%clock: !seq.clock, %i0: i4, %i1: i4) -> (z: i4) {
// CHECK-NEXT: [[S0:%.+]] = arc.state @Pipeline_arc(%i0, %i1) clock %clock lat 1
// CHECK-NEXT: [[S1:%.+]] = arc.state @Pipeline_arc_0([[S0]], %i0) clock %clock lat 1
// CHECK-NEXT: [[S2:%.+]] = arc.state @Pipeline_arc_1([[S1]], %i1) lat 0
// CHECK-NEXT: hw.output [[S2]]
%0 = comb.add %i0, %i1 : i4
%1 = seq.compreg %0, %clock : i4
%1 = seq.compreg %0, %clock : i4, !seq.clock
%2 = comb.xor %1, %i0 : i4
%3 = seq.compreg %2, %clock : i4
%3 = seq.compreg %2, %clock : i4, !seq.clock
%4 = comb.mul %3, %i1 : i4
hw.output %4 : i4
}
Expand All @@ -94,16 +94,16 @@ hw.module @Pipeline(%clock: i1, %i0: i4, %i1: i4) -> (z: i4) {
// CHECK-NEXT: }

// CHECK-LABEL: hw.module @Reshuffling
hw.module @Reshuffling(%clockA: i1, %clockB: i1) -> (z0: i4, z1: i4, z2: i4, z3: i4) {
hw.module @Reshuffling(%clockA: !seq.clock, %clockB: !seq.clock) -> (z0: i4, z1: i4, z2: i4, z3: i4) {
// CHECK-NEXT: hw.instance "x" @Reshuffling2()
// CHECK-NEXT: arc.state @Reshuffling_arc(%x.z0, %x.z1) clock %clockA lat 1
// CHECK-NEXT: arc.state @Reshuffling_arc_0(%x.z2, %x.z3) clock %clockB lat 1
// CHECK-NEXT: hw.output
%x.z0, %x.z1, %x.z2, %x.z3 = hw.instance "x" @Reshuffling2() -> (z0: i4, z1: i4, z2: i4, z3: i4)
%4 = seq.compreg %x.z0, %clockA : i4
%5 = seq.compreg %x.z1, %clockA : i4
%6 = seq.compreg %x.z2, %clockB : i4
%7 = seq.compreg %x.z3, %clockB : i4
%4 = seq.compreg %x.z0, %clockA : i4, !seq.clock
%5 = seq.compreg %x.z1, %clockA : i4, !seq.clock
%6 = seq.compreg %x.z2, %clockB : i4, !seq.clock
%7 = seq.compreg %x.z3, %clockB : i4, !seq.clock
hw.output %4, %5, %6, %7 : i4, i4, i4, i4
}
// CHECK-NEXT: }
Expand All @@ -127,15 +127,15 @@ hw.module.extern private @Reshuffling2() -> (z0: i4, z1: i4, z2: i4, z3: i4)
// CHECK-NEXT: }

// CHECK-LABEL: hw.module @FactorOutCommonOps
hw.module @FactorOutCommonOps(%clock: i1, %i0: i4, %i1: i4) -> (o0: i4, o1: i4) {
hw.module @FactorOutCommonOps(%clock: !seq.clock, %i0: i4, %i1: i4) -> (o0: i4, o1: i4) {
// CHECK-DAG: [[T0:%.+]] = arc.state @FactorOutCommonOps_arc_1(%i0, %i1) lat 0
%0 = comb.add %i0, %i1 : i4
// CHECK-DAG: [[T1:%.+]] = arc.state @FactorOutCommonOps_arc([[T0]], %i0) clock %clock lat 1
// CHECK-DAG: [[T2:%.+]] = arc.state @FactorOutCommonOps_arc_0([[T0]], %i1) clock %clock lat 1
%1 = comb.xor %0, %i0 : i4
%2 = comb.mul %0, %i1 : i4
%3 = seq.compreg %1, %clock : i4
%4 = seq.compreg %2, %clock : i4
%3 = seq.compreg %1, %clock : i4, !seq.clock
%4 = seq.compreg %2, %clock : i4, !seq.clock
// CHECK-NEXT: hw.output [[T1]], [[T2]]
hw.output %3, %4 : i4, i4
}
Expand Down Expand Up @@ -169,14 +169,14 @@ hw.module.extern private @SplitAtInstance2(%a: i4) -> (z: i4)


// CHECK-LABEL: hw.module @AbsorbNames
hw.module @AbsorbNames(%clock: i1) -> () {
hw.module @AbsorbNames(%clock: !seq.clock) -> () {
// CHECK-NEXT: %x.z0, %x.z1 = hw.instance "x" @AbsorbNames2()
// CHECK-NEXT: arc.state @AbsorbNames_arc(%x.z0, %x.z1) clock %clock lat 1
// CHECK-SAME: {names = ["myRegA", "myRegB"]}
// CHECK-NEXT: hw.output
%x.z0, %x.z1 = hw.instance "x" @AbsorbNames2() -> (z0: i4, z1: i4)
%myRegA = seq.compreg %x.z0, %clock : i4
%myRegB = seq.compreg %x.z1, %clock : i4
%myRegA = seq.compreg %x.z0, %clock : i4, !seq.clock
%myRegB = seq.compreg %x.z1, %clock : i4, !seq.clock
}
// CHECK-NEXT: }

Expand All @@ -187,11 +187,11 @@ hw.module.extern @AbsorbNames2() -> (z0: i4, z1: i4)
// CHECK-NEXT: }

// CHECK-LABEL: hw.module @Trivial(
hw.module @Trivial(%clock: i1, %i0: i4, %reset: i1) -> (out: i4) {
hw.module @Trivial(%clock: !seq.clock, %i0: i4, %reset: i1) -> (out: i4) {
// CHECK: [[RES0:%.+]] = arc.state @[[TRIVIAL_ARC]](%i0) clock %clock reset %reset lat 1 {names = ["foo"]
// CHECK-NEXT: hw.output [[RES0:%.+]]
%0 = hw.constant 0 : i4
%foo = seq.compreg %i0, %clock, %reset, %0 : i4
%foo = seq.compreg %i0, %clock, %reset, %0 : i4, !seq.clock
hw.output %foo : i4
}
// CHECK-NEXT: }
Expand All @@ -205,13 +205,13 @@ hw.module @Trivial(%clock: i1, %i0: i4, %reset: i1) -> (out: i4) {
// CHECK-NEXT: }

// CHECK-LABEL: hw.module @NonTrivial(
hw.module @NonTrivial(%clock: i1, %i0: i4, %reset1: i1, %reset2: i1) -> (out1: i4, out2: i4) {
hw.module @NonTrivial(%clock: !seq.clock, %i0: i4, %reset1: i1, %reset2: i1) -> (out1: i4, out2: i4) {
// CHECK: [[RES2:%.+]] = arc.state @[[NONTRIVIAL_ARC_0]](%i0) clock %clock reset %reset1 lat 1 {names = ["foo"]
// CHECK-NEXT: [[RES3:%.+]] = arc.state @[[NONTRIVIAL_ARC_1]](%i0) clock %clock reset %reset2 lat 1 {names = ["bar"]
// CHECK-NEXT: hw.output [[RES2]], [[RES3]]
%0 = hw.constant 0 : i4
%foo = seq.compreg %i0, %clock, %reset1, %0 : i4
%bar = seq.compreg %i0, %clock, %reset2, %0 : i4
%foo = seq.compreg %i0, %clock, %reset1, %0 : i4, !seq.clock
%bar = seq.compreg %i0, %clock, %reset2, %0 : i4, !seq.clock
hw.output %foo, %bar : i4, i4
}
// CHECK-NEXT: }
2 changes: 1 addition & 1 deletion test/Dialect/Arc/Reduction/state-elimination.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// RUN: circt-reduce %s --test /usr/bin/env --test-arg grep --test-arg -q --test-arg "DummyArc(%arg0)" --keep-best=0 --include arc-state-elimination | FileCheck %s

// CHECK-LABEL: hw.module @Foo
hw.module @Foo(%clk: i1, %en: i1, %rst: i1, %arg0: i32) -> (out: i32) {
hw.module @Foo(%clk: !seq.clock, %en: i1, %rst: i1, %arg0: i32) -> (out: i32) {
// CHECK-NEXT: [[V0:%.+]] = arc.call @DummyArc(%arg0) : (i32) -> i32
%0 = arc.state @DummyArc(%arg0) clock %clk enable %en reset %rst lat 1 {name="reg1"} : (i32) -> (i32)
// CHECK-NEXT: hw.output [[V0]]
Expand Down
10 changes: 5 additions & 5 deletions test/Dialect/Arc/arc-canonicalizer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@
//===----------------------------------------------------------------------===//

// CHECK-LABEL: hw.module @passthoughChecks
hw.module @passthoughChecks(%in0: i1, %in1: i1) -> (out0: i1, out1: i1, out2: i1, out3: i1, out4: i1, out5: i1, out6: i1, out7: i1, out8: i1, out9: i1) {
hw.module @passthoughChecks(%clock: !seq.clock, %in0: i1, %in1: i1) -> (out0: i1, out1: i1, out2: i1, out3: i1, out4: i1, out5: i1, out6: i1, out7: i1, out8: i1, out9: i1) {
%0:2 = arc.call @passthrough(%in0, %in1) : (i1, i1) -> (i1, i1)
%1:2 = arc.call @noPassthrough(%in0, %in1) : (i1, i1) -> (i1, i1)
%2:2 = arc.state @passthrough(%in0, %in1) lat 0 : (i1, i1) -> (i1, i1)
%3:2 = arc.state @noPassthrough(%in0, %in1) lat 0 : (i1, i1) -> (i1, i1)
%4:2 = arc.state @passthrough(%in0, %in1) clock %in0 lat 1 : (i1, i1) -> (i1, i1)
%4:2 = arc.state @passthrough(%in0, %in1) clock %clock lat 1 : (i1, i1) -> (i1, i1)
hw.output %0#0, %0#1, %1#0, %1#1, %2#0, %2#1, %3#0, %3#1, %4#0, %4#1 : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
// CHECK-NEXT: [[V0:%.+]]:2 = arc.call @noPassthrough(%in0, %in1) :
// CHECK-NEXT: [[V1:%.+]]:2 = arc.state @noPassthrough(%in0, %in1) lat 0 :
// CHECK-NEXT: [[V2:%.+]]:2 = arc.state @passthrough(%in0, %in1) clock %in0 lat 1 :
// CHECK-NEXT: [[V2:%.+]]:2 = arc.state @passthrough(%in0, %in1) clock %clock lat 1 :
// CHECK-NEXT: hw.output %in0, %in1, [[V0]]#0, [[V0]]#1, %in0, %in1, [[V1]]#0, [[V1]]#1, [[V2]]#0, [[V2]]#1 :
}
arc.define @passthrough(%arg0: i1, %arg1: i1) -> (i1, i1) {
Expand All @@ -38,7 +38,7 @@ arc.define @memArcTrue(%arg0: i1, %arg1: i32) -> (i1, i32, i1) {
}

// CHECK-LABEL: hw.module @memoryWritePortCanonicalizations
hw.module @memoryWritePortCanonicalizations(%clk: i1, %addr: i1, %data: i32) {
hw.module @memoryWritePortCanonicalizations(%clk: !seq.clock, %addr: i1, %data: i32) {
// CHECK-NEXT: [[MEM:%.+]] = arc.memory <2 x i32, i1>
%mem = arc.memory <2 x i32, i1>
arc.memory_write_port %mem, @memArcFalse(%addr, %data) clock %clk enable lat 1 : <2 x i32, i1>, i1, i32
Expand Down Expand Up @@ -150,7 +150,7 @@ arc.define @OneOfThreeUsed(%arg0: i1, %arg1: i1, %arg2: i1) -> i1 {
}

// CHECK: @test1
hw.module @test1 (%arg0: i1, %arg1: i1, %arg2: i1, %clock: i1) -> (out0: i1, out1: i1) {
hw.module @test1 (%arg0: i1, %arg1: i1, %arg2: i1, %clock: !seq.clock) -> (out0: i1, out1: i1) {
// CHECK-NEXT: arc.state @OneOfThreeUsed(%arg1) clock %clock lat 1 : (i1) -> i1
%0 = arc.state @OneOfThreeUsed(%arg0, %arg1, %arg2) clock %clock lat 1 : (i1, i1, i1) -> i1
// CHECK-NEXT: arc.state @NestedCall(%arg1)
Expand Down
Loading

0 comments on commit d345d1f

Please sign in to comment.