Skip to content

Commit 34e5533

Browse files
committed
[OM] Add evaluator support for StringConcatOp
Implement string concatenation evaluation in the OM evaluator. The evaluator extracts string attributes from both operands and concatenates them to produce a new string attribute. This enables runtime evaluation of string concatenation operations in the OM dialect. Signed-off-by: Schuyler Eldridge <schuyler.eldridge@sifive.com>
1 parent ae00359 commit 34e5533

File tree

3 files changed

+111
-1
lines changed

3 files changed

+111
-1
lines changed

include/circt/Dialect/OM/Evaluator/Evaluator.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,9 @@ class Evaluator {
440440
FailureOr<EvaluatorValuePtr> evaluateListConcat(ListConcatOp op,
441441
ActualParameters actualParams,
442442
Location loc);
443+
FailureOr<EvaluatorValuePtr>
444+
evaluateStringConcat(StringConcatOp op, ActualParameters actualParams,
445+
Location loc);
443446
FailureOr<evaluator::EvaluatorValuePtr>
444447
evaluateBasePathCreate(FrozenBasePathCreateOp op,
445448
ActualParameters actualParams, Location loc);

lib/Dialect/OM/Evaluator/Evaluator.cpp

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ circt::om::Evaluator::getPartiallyEvaluatedValue(Type type, Location loc) {
9090

9191
return success(result);
9292
})
93+
.Case([&](circt::om::StringType type) {
94+
evaluator::EvaluatorValuePtr result =
95+
evaluator::AttributeValue::get(type, loc);
96+
return success(result);
97+
})
9398
.Default([&](auto type) { return failure(); });
9499
}
95100

@@ -155,7 +160,8 @@ FailureOr<evaluator::EvaluatorValuePtr> circt::om::Evaluator::getOrCreateValue(
155160
evaluator::PathValue::getEmptyPath(loc));
156161
return success(result);
157162
})
158-
.Case<ListCreateOp, ListConcatOp, ObjectFieldOp>([&](auto op) {
163+
.Case<ListCreateOp, ListConcatOp, StringConcatOp,
164+
ObjectFieldOp>([&](auto op) {
159165
return getPartiallyEvaluatedValue(op.getType(), loc);
160166
})
161167
.Case<ObjectOp>([&](auto op) {
@@ -378,6 +384,9 @@ circt::om::Evaluator::evaluateValue(Value value, ActualParameters actualParams,
378384
.Case([&](ListConcatOp op) {
379385
return evaluateListConcat(op, actualParams, loc);
380386
})
387+
.Case([&](StringConcatOp op) {
388+
return evaluateStringConcat(op, actualParams, loc);
389+
})
381390
.Case([&](AnyCastOp op) {
382391
return evaluateValue(op.getInput(), actualParams, loc);
383392
})
@@ -680,6 +689,63 @@ circt::om::Evaluator::evaluateListConcat(ListConcatOp op,
680689
return list;
681690
}
682691

692+
/// Evaluator dispatch function for String concatenation.
693+
FailureOr<evaluator::EvaluatorValuePtr>
694+
circt::om::Evaluator::evaluateStringConcat(StringConcatOp op,
695+
ActualParameters actualParams,
696+
Location loc) {
697+
// Get the op's EvaluatorValue handle, in case it hasn't been evaluated yet.
698+
auto handle = getOrCreateValue(op.getResult(), actualParams, loc);
699+
if (failed(handle))
700+
return handle;
701+
702+
// If it's fully evaluated, we can return it.
703+
if (handle.value()->isFullyEvaluated())
704+
return handle;
705+
706+
// Extract the string attributes, handling both AttributeValue and
707+
// ReferenceValue cases.
708+
auto extractAttr = [](evaluator::EvaluatorValue *value) -> StringAttr {
709+
return llvm::TypeSwitch<evaluator::EvaluatorValue *, StringAttr>(value)
710+
.Case([](evaluator::AttributeValue *val) {
711+
return val->getAs<StringAttr>();
712+
})
713+
.Case([](evaluator::ReferenceValue *val) {
714+
return cast<evaluator::AttributeValue>(val->getStrippedValue()->get())
715+
->getAs<StringAttr>();
716+
});
717+
};
718+
719+
// Evaluate all operands and concatenate them.
720+
std::string result;
721+
for (auto operand : op.getOperands()) {
722+
auto operandResult = evaluateValue(operand, actualParams, loc);
723+
if (failed(operandResult))
724+
return operandResult;
725+
if (!operandResult.value()->isFullyEvaluated())
726+
return handle;
727+
728+
StringAttr str = extractAttr(operandResult.value().get());
729+
assert(str && "expected StringAttr for StringConcatOp operand");
730+
result += str.getValue().str();
731+
}
732+
733+
// Create the concatenated string attribute.
734+
auto resultStr = StringAttr::get(result, op.getResult().getType());
735+
736+
// Finalize the op result value.
737+
auto *handleValue = cast<evaluator::AttributeValue>(handle.value().get());
738+
auto resultStatus = handleValue->setAttr(resultStr);
739+
if (failed(resultStatus))
740+
return resultStatus;
741+
742+
auto finalizeStatus = handleValue->finalize();
743+
if (failed(finalizeStatus))
744+
return finalizeStatus;
745+
746+
return handle;
747+
}
748+
683749
FailureOr<evaluator::EvaluatorValuePtr>
684750
circt::om::Evaluator::evaluateBasePathCreate(FrozenBasePathCreateOp op,
685751
ActualParameters actualParams,

unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1661,4 +1661,45 @@ om.class @Foo(
16611661
ASSERT_TRUE(object->getField("b").value()->isUnknown());
16621662
}
16631663

1664+
TEST(EvaluatorTests, StringConcat) {
1665+
const char *mod = R"MLIR(
1666+
module {
1667+
om.class @Test() -> (result: !om.string) {
1668+
%0 = om.constant "Hello, " : !om.string
1669+
%1 = om.constant "World!" : !om.string
1670+
%2 = om.string.concat %0, %1 : !om.string
1671+
om.class.fields %2 : !om.string
1672+
}
1673+
}
1674+
)MLIR";
1675+
1676+
DialectRegistry registry;
1677+
registry.insert<OMDialect>();
1678+
1679+
MLIRContext context(registry);
1680+
context.getOrLoadDialect<OMDialect>();
1681+
1682+
OwningOpRef<ModuleOp> owning =
1683+
parseSourceString<ModuleOp>(mod, ParserConfig(&context));
1684+
1685+
Evaluator evaluator(owning.release());
1686+
1687+
context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
1688+
llvm::errs() << "Diagnostic: " << diag << "\n";
1689+
});
1690+
1691+
auto result = evaluator.instantiate(StringAttr::get(&context, "Test"), {});
1692+
1693+
ASSERT_TRUE(succeeded(result));
1694+
1695+
auto fieldValue = llvm::cast<evaluator::ObjectValue>(result.value().get())
1696+
->getField("result")
1697+
.value();
1698+
1699+
ASSERT_EQ("Hello, World!",
1700+
llvm::cast<evaluator::AttributeValue>(fieldValue.get())
1701+
->getAs<StringAttr>()
1702+
.getValue());
1703+
}
1704+
16641705
} // namespace

0 commit comments

Comments
 (0)