Skip to content

Commit

Permalink
Add quir.declare_parameter, quir.load_parameter and generate quir.cir…
Browse files Browse the repository at this point in the history
…cuits (#100)

This PR adds two new operations to the QCS dialect, `declare_parameter`
and `parameter_load`. These operations are used to support input
parameters, specifically those defined in QASM as
```
input type name = value;
````

The PR also updates the `QUIRGenQASM3Visitor` to handle input parameters
as well as put real time operations into `quir.circuits`.

A command line option `--enable-parameters` is used as a feature flag to
gate use of the new features.

New tests have been added to check the new features. All existing tests
should pass without the new feature flag. Enabling support for the new
feature throughout the stack and removing the feature flag is reserved
for a future PR.

This PR does not change the handling of gate functions. That is
`quir.circuits` are not inserted into the functions generated by QASM3
`gate h q { }` statements.

---------

Co-authored-by: Thomas Alexander <thomasalexander2718@gmail.com>
  • Loading branch information
bcdonovan and taalexander committed May 2, 2023
1 parent 90f688b commit 128c937
Show file tree
Hide file tree
Showing 13 changed files with 739 additions and 18 deletions.
2 changes: 2 additions & 0 deletions include/Dialect/QCS/IR/QCSOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

#include "Dialect/QCS/IR/QCSTypes.h"

#include "mlir/IR/SymbolTable.h"

#define GET_OP_CLASSES
#include "Dialect/QCS/IR/QCSOps.h.inc"

Expand Down
70 changes: 70 additions & 0 deletions include/Dialect/QCS/IR/QCSOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ include "Dialect/QUIR/IR/QUIRTypeConstraints.td"
include "Dialect/QCS/IR/QCSBase.td"

include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/IR/SymbolInterfaces.td"

// Define a side effect that identifies an operation as not dead while not
// interfering with memory operations (e.g., allows store-forwarding across
Expand Down Expand Up @@ -225,4 +226,73 @@ def QCS_ShotInitOp : QCS_Op<"shot_init", [IsolatedFromAbove]> {
}];
}

def QCS_DeclareParameterOp : QCS_Op<"declare_parameter", [Symbol]> {
let summary = "system input parameter subject to post compilation updates";
let description = [{
The `qcs.declare_parameter` operation adds a symbol defining an input parameter
which may be modified after compilation before/during program invocation.
The value of the input parameter
may be obtained using the qcs.use_input_parameter operation.

Example:

```
// quir.angle input parameter
qcs.declare_parameter "theta" : !quir.angle<64> = 3.14159
```
}];

let arguments = (ins
SymbolNameAttr:$sym_name,
TypeAttr:$type,
OptionalAttr<AnyAttr>:$initial_value
);

let results = (outs);

let assemblyFormat = [{
attr-dict $sym_name `:` $type (`=` $initial_value^)?
}];

let builders = [
OpBuilder<(ins "::llvm::StringRef":$sym_name, "::mlir::TypeAttr":$type), [{
$_state.addAttribute("sym_name", $_builder.getStringAttr(sym_name));
$_state.addAttribute("type", type);
}]>,
OpBuilder<(ins "::llvm::StringRef":$sym_name, "::mlir::TypeAttr":$type, "Attribute":$value), [{
$_state.addAttribute("sym_name", $_builder.getStringAttr(sym_name));
$_state.addAttribute("type", type);
$_state.addAttribute("initial_value", value);
}]>,
];
}

def QCS_ParameterLoadOp : QCS_Op<"parameter_load",
[DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Use the current value of a parameter";
let description = [{
The operation `qcs.parameter_load` returns the current value of the
classical parameter with the given name.

Example:

```mlir
%2 = qcs.parameter_load "a" : !quir.angle<64>
```
}];

let arguments = (ins
FlatSymbolRefAttr:$parameter_name
);

let results = (outs AnyClassical:$res);

let assemblyFormat = [{
$parameter_name `:` type($res) attr-dict
}];
// op is verified by its traits
let verifier = ?;
}


#endif // QCS_OPS
7 changes: 7 additions & 0 deletions include/Dialect/QUIR/IR/QUIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,13 @@ def QUIR_CircuitOp : QUIR_Op<"circuit", [
CArg<"ArrayRef<DictionaryAttr>", "{}">:$argAttrs
)>];
let extraClassDeclaration = [{
static CircuitOp create(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs = {});
static CircuitOp create(Location location, StringRef name, FunctionType type,
Operation::dialect_attr_range attrs);
static CircuitOp create(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs,
ArrayRef<DictionaryAttr> argAttrs);

/// Create a deep copy of this circuit and all of its blocks, remapping any
/// operands that use values outside of the circuit using the map that is
Expand Down
19 changes: 15 additions & 4 deletions include/Frontend/OpenQASM3/QUIRGenQASM3Visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,16 @@ class QUIRGenQASM3Visitor : public BaseQASM3Visitor {
// References to MLIR single static assignment Values
// (TODO needs to be refactored)
std::unordered_map<std::string, mlir::Value> ssaValues;
std::vector<mlir::Value> ssaOtherValues;
mlir::OpBuilder builder;
mlir::OpBuilder topLevelBuilder;
mlir::OpBuilder circuitParentBuilder;
mlir::ModuleOp &newModule;
mlir::quir::CircuitOp currentCircuitOp;
std::string filename;
bool hasFailed = false;
bool hasFailed{false};
bool buildingInCircuit{false};
uint circuitCount{0};

mlir::Location getLocation(const QASM::ASTBase *);
bool assign(mlir::Value &, const std::string &);
Expand Down Expand Up @@ -85,15 +90,21 @@ class QUIRGenQASM3Visitor : public BaseQASM3Visitor {
QUIRGenQASM3Visitor(QASM::ASTStatementList *sList, mlir::OpBuilder b,
mlir::ModuleOp &newModule, std::string f)
: BaseQASM3Visitor(sList), builder(b), topLevelBuilder(b),
newModule(newModule), filename(std::move(f)), varHandler(builder) {}
circuitParentBuilder(b), newModule(newModule), filename(std::move(f)),
varHandler(builder) {}

QUIRGenQASM3Visitor(mlir::OpBuilder b, mlir::ModuleOp &newModule,
std::string filename)
: builder(b), topLevelBuilder(b), newModule(newModule),
filename(std::move(filename)), varHandler(builder) {}
: builder(b), topLevelBuilder(b), circuitParentBuilder(b),
newModule(newModule), filename(std::move(filename)),
varHandler(builder) {}

void initialize(uint numShots, const std::string &shotDelay);

void startCircuit(mlir::Location location);
void finishCircuit();
void switchCircuit(bool buildInCircuit, mlir::Location location);

void setInputFile(std::string);

mlir::LogicalResult walkAST();
Expand Down
41 changes: 40 additions & 1 deletion include/Frontend/OpenQASM3/QUIRVariableBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ namespace qssc::frontend::openqasm3 {

class QUIRVariableBuilder {
public:
QUIRVariableBuilder(mlir::OpBuilder &builder) : builder(builder) {}
QUIRVariableBuilder(mlir::OpBuilder &builder)
: builder(builder), classicalBuilder(builder) {}

/// Generate code for declaring a variable (at the builder's current insertion
/// point).
Expand All @@ -57,6 +58,19 @@ class QUIRVariableBuilder {
bool isInputVariable = false,
bool isOutputVariable = false);

/// Generate code for declaring a input parameter (at the builder's current
/// insertion point).
///
/// @param location source location related to the generated code.
/// @param variableName name of the variable. (_parameter will be added)
/// @param type type of the variable.
void generateParameterDeclaration(mlir::Location location,
llvm::StringRef variableName,
mlir::Type type, mlir::Value assignedValue);

mlir::Value generateParameterLoad(mlir::Location location,
llvm::StringRef variableName);

/// Generate code for declaring an array (at the builder's current insertion
/// point).
///
Expand Down Expand Up @@ -201,9 +215,34 @@ class QUIRVariableBuilder {

mlir::Type resolveQUIRVariableType(const QASM::ASTResultNode *node);

void setClassicalBuilder(mlir::OpBuilder b) {
classicalBuilder = b;
useClassicalBuilder = true;
}
void disableClassicalBuilder() { useClassicalBuilder = false; }

private:
// default builder - reference from QUIRGenQASM3Vistor class
mlir::OpBuilder &builder;

// classical builder - used by QUIRGenQASM3Vistor class when
// building inside a quir.circuit operation.
//
// the classical builder is used to insert classical operations such as a
// oq3.variable_load operation which are added to support a real time
// operation such as a quir.call_gate. The classical builder maintains the
// insertion point for the supporting classical operations which should be
// inserted at the same scope as the `quir.call_circuit` corresponding to the
// currently being inserted `quir.circuit`
//
// see also: QUIRGenQASM3Visitor::switchCircuit
mlir::OpBuilder classicalBuilder;
bool useClassicalBuilder{false};

mlir::OpBuilder getClassicalBuilder() {
return (useClassicalBuilder) ? classicalBuilder : builder;
}

std::unordered_map<std::string, mlir::Type> variables;

std::unordered_map<mlir::Operation *, mlir::Operation *> lastDeclaration;
Expand Down
45 changes: 45 additions & 0 deletions lib/Dialect/QCS/IR/QCSOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,54 @@
#include "Dialect/QCS/IR/QCSTypes.h"

#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/SymbolTable.h>

using namespace mlir;
using namespace mlir::qcs;

#define GET_OP_CLASSES
#include "Dialect/QCS/IR/QCSOps.cpp.inc"

static LogicalResult
verifyQCSParameterOpSymbolUses(SymbolTableCollection &symbolTable,
mlir::Operation *op,
bool operandMustMatchSymbolType = false) {
assert(op);

// Check that op has attribute variable_name
auto paramRefAttr = op->getAttrOfType<FlatSymbolRefAttr>("parameter_name");
if (!paramRefAttr)
return op->emitOpError(
"requires a symbol reference attribute 'parameter_name'");

// Check that symbol reference resolves to a parameter declaration
auto declOp =
symbolTable.lookupNearestSymbolFrom<DeclareParameterOp>(op, paramRefAttr);
if (!declOp)
return op->emitOpError() << "no valid reference to a parameter '"
<< paramRefAttr.getValue() << "'";

assert(op->getNumResults() <= 1 && "assume none or single result");

// Check that type of variables matches result type of this Op
if (op->getNumResults() == 1) {
if (op->getResult(0).getType() != declOp.type())
return op->emitOpError(
"type mismatch between variable declaration and variable use");
}

if (op->getNumOperands() > 0 && operandMustMatchSymbolType) {
assert(op->getNumOperands() == 1 &&
"type check only supported for a single operand");
if (op->getOperand(0).getType() != declOp.type())
return op->emitOpError(
"type mismatch between variable declaration and variable assignment");
}
return success();
}

LogicalResult
ParameterLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return verifyQCSParameterOpSymbolUses(symbolTable, getOperation(), true);
}
37 changes: 37 additions & 0 deletions lib/Dialect/QUIR/IR/QUIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,43 @@ static LogicalResult verify(CircuitOp op) {
return success();
}

CircuitOp CircuitOp::create(Location location, StringRef name,
FunctionType type, ArrayRef<NamedAttribute> attrs) {
OpBuilder builder(location->getContext());
OperationState state(location, getOperationName());
CircuitOp::build(builder, state, name, type, attrs);
return cast<CircuitOp>(Operation::create(state));
}
CircuitOp CircuitOp::create(Location location, StringRef name,
FunctionType type,
Operation::dialect_attr_range attrs) {
SmallVector<NamedAttribute, 8> attrRef(attrs);
return create(location, name, type, llvm::makeArrayRef(attrRef));
}
CircuitOp CircuitOp::create(Location location, StringRef name,
FunctionType type, ArrayRef<NamedAttribute> attrs,
ArrayRef<DictionaryAttr> argAttrs) {
CircuitOp circ = create(location, name, type, attrs);
circ.setAllArgAttrs(argAttrs);
return circ;
}

void CircuitOp::build(OpBuilder &builder, OperationState &state, StringRef name,
FunctionType type, ArrayRef<NamedAttribute> attrs,
ArrayRef<DictionaryAttr> argAttrs) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
state.attributes.append(attrs.begin(), attrs.end());
state.addRegion();

if (argAttrs.empty())
return;
assert(type.getNumInputs() == argAttrs.size());
function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs,
/*resultAttrs=*/llvm::None);
}

/// Clone the internal blocks and attributes from this circuit to the
/// destination circuit.
void CircuitOp::cloneInto(CircuitOp dest, BlockAndValueMapping &mapper) {
Expand Down
2 changes: 2 additions & 0 deletions lib/Frontend/OpenQASM3/OpenQASM3Frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ llvm::Error qssc::frontend::openqasm3::parse(
if (failed(visitor.walkAST()))
return llvm::createStringError(llvm::inconvertibleErrorCode(),
"Failed to emit QUIR");
// make sure to finish the in progress quir.circuit
visitor.finishCircuit();
if (mlir::failed(mlir::verify(newModule))) {
newModule.dump();

Expand Down
Loading

0 comments on commit 128c937

Please sign in to comment.