Skip to content

Commit 5b29d26

Browse files
committed
Refactor PyPrintAccumulatorm, PyFileAccumulator, and PySinglePartStringAccumulator from IRModules.cpp to PybindUtils.h
These are reusable utilities across bindings. Differential Revision: https://reviews.llvm.org/D90737
1 parent bf5c862 commit 5b29d26

File tree

2 files changed

+90
-88
lines changed

2 files changed

+90
-88
lines changed

mlir/lib/Bindings/Python/IRModules.cpp

Lines changed: 0 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -123,92 +123,6 @@ position in the argument list. If the value is an operation result, this is
123123
equivalent to printing the operation that produced it.
124124
)";
125125

126-
//------------------------------------------------------------------------------
127-
// Conversion utilities.
128-
//------------------------------------------------------------------------------
129-
130-
namespace {
131-
132-
/// Accumulates into a python string from a method that accepts an
133-
/// MlirStringCallback.
134-
struct PyPrintAccumulator {
135-
py::list parts;
136-
137-
void *getUserData() { return this; }
138-
139-
MlirStringCallback getCallback() {
140-
return [](const char *part, intptr_t size, void *userData) {
141-
PyPrintAccumulator *printAccum =
142-
static_cast<PyPrintAccumulator *>(userData);
143-
py::str pyPart(part, size); // Decodes as UTF-8 by default.
144-
printAccum->parts.append(std::move(pyPart));
145-
};
146-
}
147-
148-
py::str join() {
149-
py::str delim("", 0);
150-
return delim.attr("join")(parts);
151-
}
152-
};
153-
154-
/// Accumulates int a python file-like object, either writing text (default)
155-
/// or binary.
156-
class PyFileAccumulator {
157-
public:
158-
PyFileAccumulator(py::object fileObject, bool binary)
159-
: pyWriteFunction(fileObject.attr("write")), binary(binary) {}
160-
161-
void *getUserData() { return this; }
162-
163-
MlirStringCallback getCallback() {
164-
return [](const char *part, intptr_t size, void *userData) {
165-
py::gil_scoped_acquire();
166-
PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
167-
if (accum->binary) {
168-
// Note: Still has to copy and not avoidable with this API.
169-
py::bytes pyBytes(part, size);
170-
accum->pyWriteFunction(pyBytes);
171-
} else {
172-
py::str pyStr(part, size); // Decodes as UTF-8 by default.
173-
accum->pyWriteFunction(pyStr);
174-
}
175-
};
176-
}
177-
178-
private:
179-
py::object pyWriteFunction;
180-
bool binary;
181-
};
182-
183-
/// Accumulates into a python string from a method that is expected to make
184-
/// one (no more, no less) call to the callback (asserts internally on
185-
/// violation).
186-
struct PySinglePartStringAccumulator {
187-
void *getUserData() { return this; }
188-
189-
MlirStringCallback getCallback() {
190-
return [](const char *part, intptr_t size, void *userData) {
191-
PySinglePartStringAccumulator *accum =
192-
static_cast<PySinglePartStringAccumulator *>(userData);
193-
assert(!accum->invoked &&
194-
"PySinglePartStringAccumulator called back multiple times");
195-
accum->invoked = true;
196-
accum->value = py::str(part, size);
197-
};
198-
}
199-
200-
py::str takeValue() {
201-
assert(invoked && "PySinglePartStringAccumulator not called back");
202-
return std::move(value);
203-
}
204-
205-
private:
206-
py::str value;
207-
bool invoked = false;
208-
};
209-
210-
} // namespace
211-
212126
//------------------------------------------------------------------------------
213127
// Utilities.
214128
//------------------------------------------------------------------------------

mlir/lib/Bindings/Python/PybindUtils.h

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
#ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
1010
#define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
1111

12+
#include "mlir-c/Support.h"
13+
#include "llvm/ADT/Optional.h"
14+
#include "llvm/ADT/Twine.h"
15+
1216
#include <pybind11/pybind11.h>
1317
#include <pybind11/stl.h>
1418

15-
#include "llvm/ADT/Optional.h"
16-
#include "llvm/ADT/Twine.h"
1719

1820
namespace mlir {
1921
namespace python {
@@ -99,4 +101,90 @@ struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {};
99101
} // namespace detail
100102
} // namespace pybind11
101103

104+
//------------------------------------------------------------------------------
105+
// Conversion utilities.
106+
//------------------------------------------------------------------------------
107+
108+
namespace mlir {
109+
110+
/// Accumulates into a python string from a method that accepts an
111+
/// MlirStringCallback.
112+
struct PyPrintAccumulator {
113+
pybind11::list parts;
114+
115+
void *getUserData() { return this; }
116+
117+
MlirStringCallback getCallback() {
118+
return [](const char *part, intptr_t size, void *userData) {
119+
PyPrintAccumulator *printAccum =
120+
static_cast<PyPrintAccumulator *>(userData);
121+
pybind11::str pyPart(part, size); // Decodes as UTF-8 by default.
122+
printAccum->parts.append(std::move(pyPart));
123+
};
124+
}
125+
126+
pybind11::str join() {
127+
pybind11::str delim("", 0);
128+
return delim.attr("join")(parts);
129+
}
130+
};
131+
132+
/// Accumulates int a python file-like object, either writing text (default)
133+
/// or binary.
134+
class PyFileAccumulator {
135+
public:
136+
PyFileAccumulator(pybind11::object fileObject, bool binary)
137+
: pyWriteFunction(fileObject.attr("write")), binary(binary) {}
138+
139+
void *getUserData() { return this; }
140+
141+
MlirStringCallback getCallback() {
142+
return [](const char *part, intptr_t size, void *userData) {
143+
pybind11::gil_scoped_acquire();
144+
PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
145+
if (accum->binary) {
146+
// Note: Still has to copy and not avoidable with this API.
147+
pybind11::bytes pyBytes(part, size);
148+
accum->pyWriteFunction(pyBytes);
149+
} else {
150+
pybind11::str pyStr(part, size); // Decodes as UTF-8 by default.
151+
accum->pyWriteFunction(pyStr);
152+
}
153+
};
154+
}
155+
156+
private:
157+
pybind11::object pyWriteFunction;
158+
bool binary;
159+
};
160+
161+
/// Accumulates into a python string from a method that is expected to make
162+
/// one (no more, no less) call to the callback (asserts internally on
163+
/// violation).
164+
struct PySinglePartStringAccumulator {
165+
void *getUserData() { return this; }
166+
167+
MlirStringCallback getCallback() {
168+
return [](const char *part, intptr_t size, void *userData) {
169+
PySinglePartStringAccumulator *accum =
170+
static_cast<PySinglePartStringAccumulator *>(userData);
171+
assert(!accum->invoked &&
172+
"PySinglePartStringAccumulator called back multiple times");
173+
accum->invoked = true;
174+
accum->value = pybind11::str(part, size);
175+
};
176+
}
177+
178+
pybind11::str takeValue() {
179+
assert(invoked && "PySinglePartStringAccumulator not called back");
180+
return std::move(value);
181+
}
182+
183+
private:
184+
pybind11::str value;
185+
bool invoked = false;
186+
};
187+
188+
} // namespace mlir
189+
102190
#endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H

0 commit comments

Comments
 (0)