241 changes: 0 additions & 241 deletions mlir/docs/DebugActions.md

This file was deleted.

93 changes: 93 additions & 0 deletions mlir/include/mlir/Debug/CLOptionsSetup.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
//===- CLOptionsSetup.h - Helpers to setup debug CL options -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DEBUG_CLOPTIONSSETUP_H
#define MLIR_DEBUG_CLOPTIONSSETUP_H

#include "mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/StringRef.h"

#include <memory>

namespace mlir {
class MLIRContext;
namespace tracing {
class BreakpointManager;

class DebugConfig {
public:
/// Register the options as global LLVM command line options.
static void registerCLOptions();

/// Create a new config with the default set from the CL options.
static DebugConfig createFromCLOptions();

///
/// Options.
///

/// Enable the Debugger action hook: it makes a debugger (like gdb or lldb)
/// able to intercept MLIR Actions.
void enableDebuggerActionHook(bool enabled = true) {
enableDebuggerActionHookFlag = enabled;
}

/// Return true if the debugger action hook is enabled.
bool isDebuggerActionHookEnabled() const {
return enableDebuggerActionHookFlag;
}

/// Set the filename to use for logging actions, use "-" for stdout.
DebugConfig &logActionsTo(StringRef filename) {
logActionsToFlag = filename;
return *this;
}
/// Get the filename to use for logging actions.
StringRef getLogActionsTo() const { return logActionsToFlag; }

/// Set a location breakpoint manager to filter out action logging based on
/// the attached IR location in the Action context. Ownership stays with the
/// caller.
void addLogActionLocFilter(tracing::BreakpointManager *breakpointManager) {
logActionLocationFilter.push_back(breakpointManager);
}

/// Get the location breakpoint managers to use to filter out action logging.
ArrayRef<tracing::BreakpointManager *> getLogActionsLocFilters() const {
return logActionLocationFilter;
}

protected:
/// Enable the Debugger action hook: a debugger (like gdb or lldb) can
/// intercept MLIR Actions.
bool enableDebuggerActionHookFlag = false;

/// Log action execution to the given file (or "-" for stdout)
std::string logActionsToFlag;

/// Location Breakpoints to filter the action logging.
std::vector<tracing::BreakpointManager *> logActionLocationFilter;
};

/// This is a RAII class that installs the debug handlers on the context
/// based on the provided configuration.
class InstallDebugHandler {
public:
InstallDebugHandler(MLIRContext &context, const DebugConfig &config);
~InstallDebugHandler();

private:
class Impl;
std::unique_ptr<Impl> impl;
};

} // namespace tracing
} // namespace mlir

#endif // MLIR_DEBUG_CLOPTIONSSETUP_H
96 changes: 96 additions & 0 deletions mlir/include/mlir/Debug/DebuggerExecutionContextHook.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
//===- DebuggerExecutionContextHook.h - Debugger Support --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains a set of C API functions that are used by the debugger to
// interact with the ExecutionContext.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_SUPPORT_DEBUGGEREXECUTIONCONTEXTHOOK_H
#define MLIR_SUPPORT_DEBUGGEREXECUTIONCONTEXTHOOK_H

#include "mlir-c/IR.h"
#include "mlir/Debug/ExecutionContext.h"
#include "llvm/Support/Compiler.h"

extern "C" {
struct MLIRBreakpoint;
struct MLIRIRunit;
typedef struct MLIRBreakpoint *BreakpointHandle;
typedef struct MLIRIRunit *irunitHandle;

/// This is used by the debugger to control what to do after a breakpoint is
/// hit. See tracing::ExecutionContext::Control for more information.
void mlirDebuggerSetControl(int controlOption);

/// Print the available context for the current Action.
void mlirDebuggerPrintContext();

/// Print the current action backtrace.
void mlirDebuggerPrintActionBacktrace(bool withContext);

//===----------------------------------------------------------------------===//
// Cursor Management: The cursor is used to select an IRUnit from the context
// and to navigate through the IRUnit hierarchy.
//===----------------------------------------------------------------------===//

/// Print the current IR unit cursor.
void mlirDebuggerCursorPrint(bool withRegion);

/// Select the IR unit from the current context by ID.
void mlirDebuggerCursorSelectIRUnitFromContext(int index);

/// Select the parent IR unit of the provided IR unit, or print an error if the
/// IR unit has no parent.
void mlirDebuggerCursorSelectParentIRUnit();

/// Select the child IR unit at the provided index, print an error if the index
/// is out of bound. For example if the irunit is an operation, the children IR
/// units will be the operation's regions.
void mlirDebuggerCursorSelectChildIRUnit(int index);

/// Return the next IR unit logically in the IR. For example if the irunit is a
/// Region the next IR unit will be the next region in the parent operation or
/// nullptr if there is no next region.
void mlirDebuggerCursorSelectPreviousIRUnit();

/// Return the previous IR unit logically in the IR. For example if the irunit
/// is a Region, the previous IR unit will be the previous region in the parent
/// operation or nullptr if there is no previous region.
void mlirDebuggerCursorSelectNextIRUnit();

//===----------------------------------------------------------------------===//
// Breakpoint Management
//===----------------------------------------------------------------------===//

/// Enable the provided breakpoint.
void mlirDebuggerEnableBreakpoint(BreakpointHandle breakpoint);

/// Disable the provided breakpoint.
void mlirDebuggerDisableBreakpoint(BreakpointHandle breakpoint);

/// Add a breakpoint matching exactly the provided tag.
BreakpointHandle mlirDebuggerAddTagBreakpoint(const char *tag);

/// Add a breakpoint matching a pattern by name.
void mlirDebuggerAddRewritePatternBreakpoint(const char *patternNameInfo);

/// Add a breakpoint matching a file, line and column.
void mlirDebuggerAddFileLineColLocBreakpoint(const char *file, int line,
int col);

} // extern "C"

namespace mlir {
// Setup the debugger hooks as a callback on the provided ExecutionContext.
void setupDebuggerExecutionContextHook(
tracing::ExecutionContext &executionContext);

} // namespace mlir

#endif // MLIR_SUPPORT_DEBUGGEREXECUTIONCONTEXTHOOK_H
12 changes: 11 additions & 1 deletion mlir/include/mlir/Debug/ExecutionContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,16 @@ struct ActionActiveStack {
const ActionActiveStack *getParent() const { return parent; }
const Action &getAction() const { return action; }
int getDepth() const { return depth; }
void print(raw_ostream &os, bool withContext) const;
void dump() const {
print(llvm::errs(), /*withContext=*/true);
llvm::errs() << "\n";
}
Breakpoint *getBreakpoint() const { return breakpoint; }
void setBreakpoint(Breakpoint *breakpoint) { this->breakpoint = breakpoint; }

private:
Breakpoint *breakpoint = nullptr;
const ActionActiveStack *parent;
const Action &action;
int depth;
Expand Down Expand Up @@ -69,7 +77,9 @@ class ExecutionContext {
ExecutionContext() = default;

/// Set the callback that is used to control the execution.
void setCallback(CallbackTy callback);
void setCallback(CallbackTy callback) {
onBreakpointControlExecutionCallback = callback;
}

/// This abstract class defines the interface used to observe an Action
/// execution. It allows to be notified before and after the callback is
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/IR/OperationSupport.h
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ class OpPrintingFlags {
OpPrintingFlags &printGenericOpForm();

/// Skip printing regions.
OpPrintingFlags &skipRegions();
OpPrintingFlags &skipRegions(bool skip = true);

/// Do not verify the operation when using custom operation printers.
OpPrintingFlags &assumeVerified();
Expand Down
3 changes: 1 addition & 2 deletions mlir/include/mlir/Rewrite/PatternApplicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ class ApplyPatternAction : public tracing::ActionImpl<ApplyPatternAction> {
"Encapsulate the application of rewrite patterns";

void print(raw_ostream &os) const override {
os << "`" << tag << "`\n"
<< " pattern: " << pattern.getDebugName() << '\n';
os << "`" << tag << " pattern: " << pattern.getDebugName();
}

private:
Expand Down
42 changes: 15 additions & 27 deletions mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#ifndef MLIR_TOOLS_MLIROPT_MLIROPTMAIN_H
#define MLIR_TOOLS_MLIROPT_MLIROPTMAIN_H

#include "mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h"
#include "mlir/Debug/CLOptionsSetup.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/StringRef.h"

Expand All @@ -30,9 +30,6 @@ namespace mlir {
class DialectRegistry;
class PassPipelineCLParser;
class PassManager;
namespace tracing {
class FileLineColLocBreakpointManager;
}

/// Configuration options for the mlir-opt tool.
/// This is intended to help building tools like mlir-opt by collecting the
Expand Down Expand Up @@ -64,6 +61,14 @@ class MlirOptMainConfig {
return allowUnregisteredDialectsFlag;
}

/// Set the debug configuration to use.
MlirOptMainConfig &setDebugConfig(tracing::DebugConfig config) {
debugConfig = std::move(config);
return *this;
}
tracing::DebugConfig &getDebugConfig() { return debugConfig; }
const tracing::DebugConfig &getDebugConfig() const { return debugConfig; }

/// Print the pass-pipeline as text before executing.
MlirOptMainConfig &dumpPassPipeline(bool dump) {
dumpPassPipelineFlag = dump;
Expand All @@ -85,26 +90,6 @@ class MlirOptMainConfig {
}
StringRef getIrdlFile() const { return irdlFileFlag; }

/// Set the filename to use for logging actions, use "-" for stdout.
MlirOptMainConfig &logActionsTo(StringRef filename) {
logActionsToFlag = filename;
return *this;
}
/// Get the filename to use for logging actions.
StringRef getLogActionsTo() const { return logActionsToFlag; }

/// Set a location breakpoint manager to filter out action logging based on
/// the attached IR location in the Action context. Ownership stays with the
/// caller.
void addLogActionLocFilter(tracing::BreakpointManager *breakpointManager) {
logActionLocationFilter.push_back(breakpointManager);
}

/// Get the location breakpoint managers to use to filter out action logging.
ArrayRef<tracing::BreakpointManager *> getLogActionsLocFilters() const {
return logActionLocationFilter;
}

/// Set the callback to populate the pass manager.
MlirOptMainConfig &
setPassPipelineSetupFn(std::function<LogicalResult(PassManager &)> callback) {
Expand Down Expand Up @@ -174,18 +159,21 @@ class MlirOptMainConfig {
/// general.
bool allowUnregisteredDialectsFlag = false;

/// Configuration for the debugging hooks.
tracing::DebugConfig debugConfig;

/// Print the pipeline that will be run.
bool dumpPassPipelineFlag = false;

/// Emit bytecode instead of textual assembly when generating output.
bool emitBytecodeFlag = false;

/// Enable the Debugger action hook: Debugger can intercept MLIR Actions.
bool enableDebuggerActionHookFlag = false;

/// IRDL file to register before processing the input.
std::string irdlFileFlag = "";

/// Log action execution to the given file (or "-" for stdout)
std::string logActionsToFlag;

/// Location Breakpoints to filter the action logging.
std::vector<tracing::BreakpointManager *> logActionLocationFilter;

Expand Down
120 changes: 120 additions & 0 deletions mlir/lib/Debug/CLOptionsSetup.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
//===- CLOptionsSetup.cpp - Helpers to setup debug CL options ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Debug/CLOptionsSetup.h"

#include "mlir/Debug/Counter.h"
#include "mlir/Debug/DebuggerDebugExecutionContextHook.h"
#include "mlir/Debug/ExecutionContext.h"
#include "mlir/Debug/Observers/ActionLogging.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Support/FileUtilities.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ToolOutputFile.h"

using namespace mlir;
using namespace mlir::tracing;
using namespace llvm;

namespace {
struct DebugConfigCLOptions : public DebugConfig {
DebugConfigCLOptions() {
static cl::opt<std::string, /*ExternalStorage=*/true> logActionsTo{
"log-actions-to",
cl::desc("Log action execution to a file, or stderr if "
" '-' is passed"),
cl::location(logActionsToFlag)};

static cl::list<std::string> logActionLocationFilter(
"log-mlir-actions-filter",
cl::desc(
"Comma separated list of locations to filter actions from logging"),
cl::CommaSeparated,
cl::cb<void, std::string>([&](const std::string &location) {
static bool register_once = [&] {
addLogActionLocFilter(&locBreakpointManager);
return true;
}();
(void)register_once;
static std::vector<std::string> locations;
locations.push_back(location);
StringRef locStr = locations.back();

// Parse the individual location filters and set the breakpoints.
auto diag = [](Twine msg) { llvm::errs() << msg << "\n"; };
auto locBreakpoint =
tracing::FileLineColLocBreakpoint::parseFromString(locStr, diag);
if (failed(locBreakpoint)) {
llvm::errs() << "Invalid location filter: " << locStr << "\n";
exit(1);
}
auto [file, line, col] = *locBreakpoint;
locBreakpointManager.addBreakpoint(file, line, col);
}));
}
tracing::FileLineColLocBreakpointManager locBreakpointManager;
};

} // namespace

static ManagedStatic<DebugConfigCLOptions> clOptionsConfig;
void DebugConfig::registerCLOptions() { *clOptionsConfig; }

DebugConfig DebugConfig::createFromCLOptions() { return *clOptionsConfig; }

class InstallDebugHandler::Impl {
public:
Impl(MLIRContext &context, const DebugConfig &config) {
if (config.getLogActionsTo().empty() &&
!config.isDebuggerActionHookEnabled()) {
if (tracing::DebugCounter::isActivated())
context.registerActionHandler(tracing::DebugCounter());
return;
}
errs() << "ExecutionContext registered on the context";
if (tracing::DebugCounter::isActivated())
emitError(UnknownLoc::get(&context),
"Debug counters are incompatible with --log-actions-to and "
"--mlir-enable-debugger-hook options and are disabled");
if (!config.getLogActionsTo().empty()) {
std::string errorMessage;
logActionsFile = openOutputFile(config.getLogActionsTo(), &errorMessage);
if (!logActionsFile) {
emitError(UnknownLoc::get(&context),
"Opening file for --log-actions-to failed: ")
<< errorMessage << "\n";
return;
}
logActionsFile->keep();
raw_fd_ostream &logActionsStream = logActionsFile->os();
actionLogger = std::make_unique<tracing::ActionLogger>(logActionsStream);
for (const auto *locationBreakpoint : config.getLogActionsLocFilters())
actionLogger->addBreakpointManager(locationBreakpoint);
executionContext.registerObserver(actionLogger.get());
}
if (config.isDebuggerActionHookEnabled()) {
errs() << " (with Debugger hook)";
setupDebuggerDebugExecutionContextHook(executionContext);
}
errs() << "\n";
context.registerActionHandler(executionContext);
}

private:
std::unique_ptr<ToolOutputFile> logActionsFile;
tracing::ExecutionContext executionContext;
std::unique_ptr<tracing::ActionLogger> actionLogger;
std::vector<std::unique_ptr<tracing::FileLineColLocBreakpoint>>
locationBreakpoints;
};

InstallDebugHandler::InstallDebugHandler(MLIRContext &context,
const DebugConfig &config)
: impl(std::make_unique<Impl>(context, config)) {}

InstallDebugHandler::~InstallDebugHandler() = default;
2 changes: 2 additions & 0 deletions mlir/lib/Debug/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
add_subdirectory(Observers)

add_mlir_library(MLIRDebug
CLOptionsSetup.cpp
DebugCounter.cpp
ExecutionContext.cpp
BreakpointManagers/FileLineColLocBreakpointManager.cpp
DebuggerExecutionContextHook.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Debug
Expand Down
369 changes: 369 additions & 0 deletions mlir/lib/Debug/DebuggerExecutionContextHook.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,369 @@
//===- DebuggerExecutionContextHook.cpp - Debugger Support ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Debug/DebuggerExecutionContextHook.h"

#include "mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h"
#include "mlir/Debug/BreakpointManagers/TagBreakpointManager.h"

using namespace mlir;
using namespace mlir::tracing;

namespace {
/// This structure tracks the state of the interactive debugger.
struct DebuggerState {
/// This variable keeps track of the current control option. This is set by
/// the debugger when control is handed over to it.
ExecutionContext::Control debuggerControl = ExecutionContext::Apply;

/// The breakpoint manager that allows the debugger to set breakpoints on
/// action tags.
TagBreakpointManager tagBreakpointManager;

/// The breakpoint manager that allows the debugger to set breakpoints on
/// FileLineColLoc locations.
FileLineColLocBreakpointManager fileLineColLocBreakpointManager;

/// Map of breakpoint IDs to breakpoint objects.
DenseMap<unsigned, Breakpoint *> breakpointIdsMap;

/// The current stack of actiive actions.
const tracing::ActionActiveStack *actionActiveStack;

/// This is a "cursor" in the IR, it is used for the debugger to navigate the
/// IR associated to the actions.
IRUnit cursor;
};
} // namespace

static DebuggerState &getGlobalDebuggerState() {
static LLVM_THREAD_LOCAL DebuggerState debuggerState;
return debuggerState;
}

extern "C" {
void mlirDebuggerSetControl(int controlOption) {
getGlobalDebuggerState().debuggerControl =
static_cast<ExecutionContext::Control>(controlOption);
}

void mlirDebuggerPrintContext() {
DebuggerState &state = getGlobalDebuggerState();
if (!state.actionActiveStack) {
llvm::outs() << "No active action.\n";
return;
}
const ArrayRef<IRUnit> &units =
state.actionActiveStack->getAction().getContextIRUnits();
llvm::outs() << units.size() << " available IRUnits:\n";
for (const IRUnit &unit : units) {
llvm::outs() << " - ";
unit.print(
llvm::outs(),
OpPrintingFlags().useLocalScope().skipRegions().enableDebugInfo());
llvm::outs() << "\n";
}
}

void mlirDebuggerPrintActionBacktrace(bool withContext) {
DebuggerState &state = getGlobalDebuggerState();
if (!state.actionActiveStack) {
llvm::outs() << "No active action.\n";
return;
}
state.actionActiveStack->print(llvm::outs(), withContext);
}

//===----------------------------------------------------------------------===//
// Cursor Management
//===----------------------------------------------------------------------===//

void mlirDebuggerCursorPrint(bool withRegion) {
auto &state = getGlobalDebuggerState();
if (!state.cursor) {
llvm::outs() << "No active MLIR cursor, select from the context first\n";
return;
}
state.cursor.print(llvm::outs(), OpPrintingFlags()
.skipRegions(!withRegion)
.useLocalScope()
.enableDebugInfo());
llvm::outs() << "\n";
}

void mlirDebuggerCursorSelectIRUnitFromContext(int index) {
auto &state = getGlobalDebuggerState();
if (!state.actionActiveStack) {
llvm::outs() << "No active MLIR Action stack\n";
return;
}
ArrayRef<IRUnit> units =
state.actionActiveStack->getAction().getContextIRUnits();
if (index < 0 || index >= static_cast<int>(units.size())) {
llvm::outs() << "Index invalid, bounds: [0, " << units.size()
<< "] but got " << index << "\n";
return;
}
state.cursor = units[index];
state.cursor.print(llvm::outs());
llvm::outs() << "\n";
}

void mlirDebuggerCursorSelectParentIRUnit() {
auto &state = getGlobalDebuggerState();
if (!state.cursor) {
llvm::outs() << "No active MLIR cursor, select from the context first\n";
return;
}
IRUnit *unit = &state.cursor;
if (auto *op = unit->dyn_cast<Operation *>()) {
state.cursor = op->getBlock();
} else if (auto *region = unit->dyn_cast<Region *>()) {
state.cursor = region->getParentOp();
} else if (auto *block = unit->dyn_cast<Block *>()) {
state.cursor = block->getParent();
} else {
llvm::outs() << "Current cursor is not a valid IRUnit";
return;
}
state.cursor.print(llvm::outs());
llvm::outs() << "\n";
}

void mlirDebuggerCursorSelectChildIRUnit(int index) {
auto &state = getGlobalDebuggerState();
if (!state.cursor) {
llvm::outs() << "No active MLIR cursor, select from the context first\n";
return;
}
IRUnit *unit = &state.cursor;
if (auto *op = unit->dyn_cast<Operation *>()) {
if (index < 0 || index >= static_cast<int>(op->getNumRegions())) {
llvm::outs() << "Index invalid, op has " << op->getNumRegions()
<< " but got " << index << "\n";
return;
}
state.cursor = &op->getRegion(index);
} else if (auto *region = unit->dyn_cast<Region *>()) {
auto block = region->begin();
int count = 0;
while (block != region->end() && count != index) {
++block;
++count;
}

if (block == region->end()) {
llvm::outs() << "Index invalid, region has " << count << " block but got "
<< index << "\n";
return;
}
state.cursor = &*block;
} else if (auto *block = unit->dyn_cast<Block *>()) {
auto op = block->begin();
int count = 0;
while (op != block->end() && count != index) {
++op;
++count;
}

if (op == block->end()) {
llvm::outs() << "Index invalid, block has " << count
<< "operations but got " << index << "\n";
return;
}
state.cursor = &*op;
} else {
llvm::outs() << "Current cursor is not a valid IRUnit";
return;
}
state.cursor.print(llvm::outs());
llvm::outs() << "\n";
}

void mlirDebuggerCursorSelectPreviousIRUnit() {
auto &state = getGlobalDebuggerState();
if (!state.cursor) {
llvm::outs() << "No active MLIR cursor, select from the context first\n";
return;
}
IRUnit *unit = &state.cursor;
if (auto *op = unit->dyn_cast<Operation *>()) {
Operation *previous = op->getPrevNode();
if (!previous) {
llvm::outs() << "No previous operation in the current block\n";
return;
}
state.cursor = previous;
} else if (auto *region = unit->dyn_cast<Region *>()) {
llvm::outs() << "Has region\n";
Operation *parent = region->getParentOp();
if (!parent) {
llvm::outs() << "No parent operation for the current region\n";
return;
}
if (region->getRegionNumber() == 0) {
llvm::outs() << "No previous region in the current operation\n";
return;
}
state.cursor =
&region->getParentOp()->getRegion(region->getRegionNumber() - 1);
} else if (auto *block = unit->dyn_cast<Block *>()) {
Block *previous = block->getPrevNode();
if (!previous) {
llvm::outs() << "No previous block in the current region\n";
return;
}
state.cursor = previous;
} else {
llvm::outs() << "Current cursor is not a valid IRUnit";
return;
}
state.cursor.print(llvm::outs());
llvm::outs() << "\n";
}

void mlirDebuggerCursorSelectNextIRUnit() {
auto &state = getGlobalDebuggerState();
if (!state.cursor) {
llvm::outs() << "No active MLIR cursor, select from the context first\n";
return;
}
IRUnit *unit = &state.cursor;
if (auto *op = unit->dyn_cast<Operation *>()) {
Operation *next = op->getNextNode();
if (!next) {
llvm::outs() << "No next operation in the current block\n";
return;
}
state.cursor = next;
} else if (auto *region = unit->dyn_cast<Region *>()) {
Operation *parent = region->getParentOp();
if (!parent) {
llvm::outs() << "No parent operation for the current region\n";
return;
}
if (region->getRegionNumber() == parent->getNumRegions() - 1) {
llvm::outs() << "No next region in the current operation\n";
return;
}
state.cursor =
&region->getParentOp()->getRegion(region->getRegionNumber() + 1);
} else if (auto *block = unit->dyn_cast<Block *>()) {
Block *next = block->getNextNode();
if (!next) {
llvm::outs() << "No next block in the current region\n";
return;
}
state.cursor = next;
} else {
llvm::outs() << "Current cursor is not a valid IRUnit";
return;
}
state.cursor.print(llvm::outs());
llvm::outs() << "\n";
}

//===----------------------------------------------------------------------===//
// Breakpoint Management
//===----------------------------------------------------------------------===//

void mlirDebuggerEnableBreakpoint(BreakpointHandle breakpoint) {
reinterpret_cast<Breakpoint *>(breakpoint)->enable();
}

void mlirDebuggerDisableBreakpoint(BreakpointHandle breakpoint) {
reinterpret_cast<Breakpoint *>(breakpoint)->disable();
}

BreakpointHandle mlirDebuggerAddTagBreakpoint(const char *tag) {
DebuggerState &state = getGlobalDebuggerState();
Breakpoint *breakpoint =
state.tagBreakpointManager.addBreakpoint(StringRef(tag, strlen(tag)));
int breakpointId = state.breakpointIdsMap.size() + 1;
state.breakpointIdsMap[breakpointId] = breakpoint;
return reinterpret_cast<BreakpointHandle>(breakpoint);
}

void mlirDebuggerAddRewritePatternBreakpoint(const char *patternNameInfo) {}

void mlirDebuggerAddFileLineColLocBreakpoint(const char *file, int line,
int col) {
getGlobalDebuggerState().fileLineColLocBreakpointManager.addBreakpoint(
StringRef(file, strlen(file)), line, col);
}

} // extern "C"

LLVM_ATTRIBUTE_NOINLINE void mlirDebuggerBreakpointHook() {
static LLVM_THREAD_LOCAL void *volatile sink;
sink = (void *)&sink;
}

static void preventLinkerDeadCodeElim() {
static void *volatile sink;
static bool initialized = [&]() {
sink = (void *)mlirDebuggerSetControl;
sink = (void *)mlirDebuggerEnableBreakpoint;
sink = (void *)mlirDebuggerDisableBreakpoint;
sink = (void *)mlirDebuggerPrintContext;
sink = (void *)mlirDebuggerPrintActionBacktrace;
sink = (void *)mlirDebuggerCursorPrint;
sink = (void *)mlirDebuggerCursorSelectIRUnitFromContext;
sink = (void *)mlirDebuggerCursorSelectParentIRUnit;
sink = (void *)mlirDebuggerCursorSelectChildIRUnit;
sink = (void *)mlirDebuggerCursorSelectPreviousIRUnit;
sink = (void *)mlirDebuggerCursorSelectNextIRUnit;
sink = (void *)mlirDebuggerAddTagBreakpoint;
sink = (void *)mlirDebuggerAddRewritePatternBreakpoint;
sink = (void *)mlirDebuggerAddFileLineColLocBreakpoint;
sink = (void *)&sink;
return true;
}();
(void)initialized;
}

static tracing::ExecutionContext::Control
debuggerCallBackFunction(const tracing::ActionActiveStack *actionStack) {
preventLinkerDeadCodeElim();
// Invoke the breakpoint hook, the debugger is supposed to trap this.
// The debugger controls the execution from there by invoking
// `mlirDebuggerSetControl()`.
auto &state = getGlobalDebuggerState();
state.actionActiveStack = actionStack;
getGlobalDebuggerState().debuggerControl = ExecutionContext::Apply;
actionStack->getAction().print(llvm::outs());
llvm::outs() << "\n";
mlirDebuggerBreakpointHook();
return getGlobalDebuggerState().debuggerControl;
}

namespace {
/// Manage the stack of actions that are currently active.
class DebuggerObserver : public ExecutionContext::Observer {
void beforeExecute(const ActionActiveStack *action, Breakpoint *breakpoint,
bool willExecute) override {
auto &state = getGlobalDebuggerState();
state.actionActiveStack = action;
}
void afterExecute(const ActionActiveStack *action) override {
auto &state = getGlobalDebuggerState();
state.actionActiveStack = action->getParent();
state.cursor = nullptr;
}
};
} // namespace

void mlir::setupDebuggerExecutionContextHook(
tracing::ExecutionContext &executionContext) {
executionContext.setCallback(debuggerCallBackFunction);
DebuggerState &state = getGlobalDebuggerState();
static DebuggerObserver observer;
executionContext.registerObserver(&observer);
executionContext.addBreakpointManager(&state.fileLineColLocBreakpointManager);
executionContext.addBreakpointManager(&state.tagBreakpointManager);
}
36 changes: 31 additions & 5 deletions mlir/lib/Debug/ExecutionContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,47 @@
#include "mlir/Debug/ExecutionContext.h"

#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/FormatVariadic.h"

#include <cstddef>

using namespace mlir;
using namespace mlir::tracing;

//===----------------------------------------------------------------------===//
// ExecutionContext
// ActionActiveStack
//===----------------------------------------------------------------------===//

static const thread_local ActionActiveStack *actionStack = nullptr;

void ExecutionContext::setCallback(CallbackTy callback) {
onBreakpointControlExecutionCallback = callback;
void ActionActiveStack::print(raw_ostream &os, bool withContext) const {
os << "ActionActiveStack depth " << getDepth() << "\n";
const ActionActiveStack *current = this;
int count = 0;
while (current) {
llvm::errs() << llvm::formatv("#{0,3}: ", count++);
current->action.print(llvm::errs());
llvm::errs() << "\n";
ArrayRef<IRUnit> context = current->action.getContextIRUnits();
if (withContext && !context.empty()) {
llvm::errs() << "Context:\n";
llvm::interleave(
current->action.getContextIRUnits(),
[&](const IRUnit &unit) {
llvm::errs() << " - ";
unit.print(llvm::errs());
},
[&]() { llvm::errs() << "\n"; });
llvm::errs() << "\n";
}
current = current->parent;
}
}

//===----------------------------------------------------------------------===//
// ExecutionContext
//===----------------------------------------------------------------------===//

static const LLVM_THREAD_LOCAL ActionActiveStack *actionStack = nullptr;

void ExecutionContext::registerObserver(Observer *observer) {
observers.push_back(observer);
}
Expand Down Expand Up @@ -72,6 +97,7 @@ void ExecutionContext::operator()(llvm::function_ref<void()> transform,
if (breakpoint)
break;
}
info.setBreakpoint(breakpoint);

bool shouldExecuteAction = true;
// If we have a breakpoint, or if `depthToBreak` was previously set and the
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,8 @@ OpPrintingFlags &OpPrintingFlags::printGenericOpForm() {
}

/// Always skip Regions.
OpPrintingFlags &OpPrintingFlags::skipRegions() {
skipRegionsFlag = true;
OpPrintingFlags &OpPrintingFlags::skipRegions(bool skip) {
skipRegionsFlag = skip;
return *this;
}

Expand Down
87 changes: 11 additions & 76 deletions mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@

#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/Debug/CLOptionsSetup.h"
#include "mlir/Debug/Counter.h"
#include "mlir/Debug/DebuggerExecutionContextHook.h"
#include "mlir/Debug/ExecutionContext.h"
#include "mlir/Debug/Observers/ActionLogging.h"
#include "mlir/Dialect/IRDL/IR/IRDL.h"
Expand Down Expand Up @@ -77,45 +79,17 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
cl::desc("IRDL file to register before processing the input"),
cl::location(irdlFileFlag), cl::init(""), cl::value_desc("filename"));

static cl::opt<bool, /*ExternalStorage=*/true> enableDebuggerHook(
"mlir-enable-debugger-hook",
cl::desc("Enable Debugger hook for debugging MLIR Actions"),
cl::location(enableDebuggerActionHookFlag), cl::init(false));

static cl::opt<bool, /*ExternalStorage=*/true> explicitModule(
"no-implicit-module",
cl::desc("Disable implicit addition of a top-level module op during "
"parsing"),
cl::location(useExplicitModuleFlag), cl::init(false));

static cl::opt<std::string, /*ExternalStorage=*/true> logActionsTo{
"log-actions-to",
cl::desc("Log action execution to a file, or stderr if "
" '-' is passed"),
cl::location(logActionsToFlag)};

static cl::list<std::string> logActionLocationFilter(
"log-mlir-actions-filter",
cl::desc(
"Comma separated list of locations to filter actions from logging"),
cl::CommaSeparated,
cl::cb<void, std::string>([&](const std::string &location) {
static bool register_once = [&] {
addLogActionLocFilter(&locBreakpointManager);
return true;
}();
(void)register_once;
static std::vector<std::string> locations;
locations.push_back(location);
StringRef locStr = locations.back();

// Parse the individual location filters and set the breakpoints.
auto diag = [](Twine msg) { llvm::errs() << msg << "\n"; };
auto locBreakpoint =
tracing::FileLineColLocBreakpoint::parseFromString(locStr, diag);
if (failed(locBreakpoint)) {
llvm::errs() << "Invalid location filter: " << locStr << "\n";
exit(1);
}
auto [file, line, col] = *locBreakpoint;
locBreakpointManager.addBreakpoint(file, line, col);
}));

static cl::opt<bool, /*ExternalStorage=*/true> showDialects(
"show-dialects",
cl::desc("Print the list of registered dialects and exit"),
Expand Down Expand Up @@ -165,19 +139,18 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
/// Pointer to static dialectPlugins variable in constructor, needed by
/// setDialectPluginsCallback(DialectRegistry&).
cl::list<std::string> *dialectPlugins = nullptr;

/// The breakpoint manager for the log action location filter.
tracing::FileLineColLocBreakpointManager locBreakpointManager;
};
} // namespace

ManagedStatic<MlirOptMainConfigCLOptions> clOptionsConfig;

void MlirOptMainConfig::registerCLOptions(DialectRegistry &registry) {
clOptionsConfig->setDialectPluginsCallback(registry);
tracing::DebugConfig::registerCLOptions();
}

MlirOptMainConfig MlirOptMainConfig::createFromCLOptions() {
clOptionsConfig->setDebugConfig(tracing::DebugConfig::createFromCLOptions());
return *clOptionsConfig;
}

Expand Down Expand Up @@ -213,45 +186,6 @@ void MlirOptMainConfigCLOptions::setDialectPluginsCallback(
});
}

/// Set the ExecutionContext on the context and handle the observers.
class InstallDebugHandler {
public:
InstallDebugHandler(MLIRContext &context, const MlirOptMainConfig &config) {
if (config.getLogActionsTo().empty()) {
if (tracing::DebugCounter::isActivated())
context.registerActionHandler(tracing::DebugCounter());
return;
}
if (tracing::DebugCounter::isActivated())
emitError(UnknownLoc::get(&context),
"Debug counters are incompatible with --log-actions-to option "
"and are disabled");
std::string errorMessage;
logActionsFile = openOutputFile(config.getLogActionsTo(), &errorMessage);
if (!logActionsFile) {
emitError(UnknownLoc::get(&context),
"Opening file for --log-actions-to failed: ")
<< errorMessage << "\n";
return;
}
logActionsFile->keep();
raw_fd_ostream &logActionsStream = logActionsFile->os();
actionLogger = std::make_unique<tracing::ActionLogger>(logActionsStream);
for (const auto *locationBreakpoint : config.getLogActionsLocFilters())
actionLogger->addBreakpointManager(locationBreakpoint);

executionContext.registerObserver(actionLogger.get());
context.registerActionHandler(executionContext);
}

private:
std::unique_ptr<llvm::ToolOutputFile> logActionsFile;
std::unique_ptr<tracing::ActionLogger> actionLogger;
std::vector<std::unique_ptr<tracing::FileLineColLocBreakpoint>>
locationBreakpoints;
tracing::ExecutionContext executionContext;
};

/// Perform the actions on the input file indicated by the command line flags
/// within the specified context.
///
Expand Down Expand Up @@ -372,7 +306,8 @@ static LogicalResult processBuffer(raw_ostream &os,
if (config.shouldVerifyDiagnostics())
context.printOpOnDiagnostic(false);

InstallDebugHandler installDebugHandler(context, config);
tracing::InstallDebugHandler installDebugHandler(context,
config.getDebugConfig());

// If we are in verify diagnostics mode then we have a lot of work to do,
// otherwise just perform the actions without worrying about it.
Expand Down
15 changes: 15 additions & 0 deletions mlir/test/mlir-opt/debugcounter.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// This test exercise the example in docs/ActionTracing.md ; changes here should
// probably be reflected there.

// RUN: mlir-opt %s -mlir-debug-counter=unique-tag-for-my-action-skip=-1 -mlir-print-debug-counter --pass-pipeline="builtin.module(func.func(canonicalize))" --mlir-disable-threading 2>&1 | FileCheck %s --check-prefix=CHECK-UKNOWN-TAG
// RUN: mlir-opt %s -mlir-debug-counter=pass-execution-skip=1 -mlir-print-debug-counter --pass-pipeline="builtin.module(func.func(canonicalize))" --mlir-disable-threading 2>&1 | FileCheck %s --check-prefix=CHECK-PASS

func.func @foo() {
return
}

// CHECK-UKNOWN-TAG: DebugCounter counters:
// CHECK-UKNOWN-TAG: unique-tag-for-my-action : {0,-1,-1}

// CHECK-PASS: DebugCounter counters:
// CHECK-PASS: pass-execution : {1,1,-1}
568 changes: 568 additions & 0 deletions mlir/utils/lldb-scripts/action_debugging.py

Large diffs are not rendered by default.