Rust APIs to handle PyTorch graph modules and graphs
This API can help writing a Python module in Rust using PyO3, in case the module needs to handle PyTorch graph modules or graphs.
#[repr(transparent)]
pub struct GraphModule(_);
A wrapper for PyTorch's GraphModule
class.
The constructor method of this returns a shared reference &GraphModule
instead of an owned value. The return value is GIL-bound owning reference into Python's heap.
-
pub fn new<'py>( py: Python<'py>, nn: &GraphModule, graph: &Graph ) -> PyResult<&'py Self>
Create new instance of
GraphModule
PyTorch class with PyTorch native constructor butclass_name
is not given (so that it remains as the default value'GraphModule'
).If new instance is created succesfully, returns
Ok
with a shared reference to the newly created instance in it. Otherwise, returnsErr
with aPyErr
in it. ThePyErr
will explain the cause of the failure. -
pub fn new_with_empty_gm<'py>( py: Python<'py>, graph: &Graph ) -> PyResult<&'py Self>
Create new instane of
GraphModule
PyTorch class with PyTorch native constructor butclass_name
is not given (so that it remains as the default value'GraphModule'
) androot
is a newly createdtorch.nn.Module
bytorch.nn.Module()
.If new instance is created succesfully, returns
Ok
with a shared reference to the newly created instance in it. Otherwise, returnsErr
with aPyErr
in it. ThePyErr
will explain the cause of the failure. -
pub fn extract_parameters(&self) -> PyResult<HashMap<String, &[u8]>>
Collect all parameters of this
GraphModule
.Make a
HashMap
which maps the parameter name to a slice representing the underlying storage of the parameter value.If this process is successful, returns
Ok
with theHashMap
in it. Otherwise, returnErr
with aPyErr
in it.PyErr
will explain the cause of the failure. -
pub fn extract_buffers(&self) -> PyResult<HashMap<String, &[u8]>>
Collect all buffers of this
GraphModule
.Make a
HashMap
which maps the buffer name to a slice representing the underlying storage of the buffer value.If this process is successful, returns
Ok
with theHashMap
in it. Otherwise, returnErr
with aPyErr
in it.PyErr
will explain the cause of the failure. -
pub fn graph(&self) -> PyResult<&Graph>
Retrieve the
graph
attribute of thisGraphModule
.If the retrieval is done successfully, returns
Ok
with a shared reference to thegraph
attribute (&Graph
) in it. Otherwise, returnsErr
with aPyErr
in it. ThePyErr
will explain the cause of the failure. -
pub fn get_parameter( &self, name: &str ) -> PyResult<Option<&[u8]>>
Get the underlying storage of the parameter value named as the value of
name
, for thisGraphModule
.If there is no parameter named as the value of
name
, returnsOk(None)
. If there exists such parameter, returnsOk(Some)
with a slice representing the underlying storage of the parameter value. If this process fails, returnsErr
with aPyErr
in it.PyErr
will explain the cause of the failure. -
pub fn count_parameters(&self) -> PyResult<usize>
Get the number of parameters of this
GraphModule
.If a Python error occurs during this procedure, returns
Err
with aPyErr
in it.PyErr
will explain the error. Otherwise, returnsOk
with the number of parameters of thisGraphModule
in it. -
pub fn get_buffer( &self, name: &str ) -> PyResult<Option<&[u8]>>
Get the underlying storage of the buffer value named as the value of
name
, for thisGraphModule
.If there is no buffer named as the value of
name
, returnsOk(None)
. If there exists such buffer, returnsOk(Some)
with a slice representing the underlying storage of the buffer value. If this process fails, returnsErr
with aPyErr
in it.PyErr
will explain the cause of the failure. -
pub fn count_buffers(&self) -> PyResult<usize>
Get the number of buffers of this
GraphModule
.If a Python error occurs during this procedure, returns
Err
with aPyErr
in it.PyErr
will explain the error. Otherwise, returnsOk
with the number of parameters of thisGraphModule
in it. -
pub fn print_readable(&self) -> PyResult<String>
Stringify this
GraphModule
.This does the same what
print_readable
instance method ofGraphModule
PyTorch class does, butprint_output
is given asTrue
.If stringifying is done successfully, returns
Ok
with the resulting string in it. Otherwise, returnsErr
with aPyErr
in it. ThePyErr
will explain the cause of the failure.
#[repr(transparent)]
pub struct Graph(_);
A wrapper for PyTorch's Graph
class.
The constructor method of this returns a shared reference &Graph
instead of an owned value. The return value is GIL-bound owning reference into Python's heap.
-
pub fn new(py: Python<'_>) -> PyResult<&Self>
Create new instance of
Graph
PyTorch class with PyTorch native constructor.If new instance is created successfully, returns
Ok
with a shared reference to the newly created instance in it. Otherwise, returnsErr
with aPyErr
in it. ThePyErr
will explain the cause of the failure. -
pub fn nodes_iterator(&self) -> PyResult<&PyIterator>
Retrieve all the
Node
s of thisGraph
as a Python iterator.If the retrieval is done successfully, returns
Ok
with a shared reference to a Python iterator for it in it. Otherwise, returnsErr
with aPyErr
in it. ThePyErr
will explain the cause of the failure. -
pub fn eliminate_dead_code(&self) -> PyResult<()>
An interface for
eliminate_dead_code
instance method ofGraph
PyTorch class.If the method call is done successfully, returns
Ok(())
. Otherwise, returnsErr
with aPyErr
in it. ThePyErr
will explain the cause of the failure. -
pub fn lint(&self) -> PyResult<()>
An interface for
lint
instance method ofGraph
PyTorch class.If the method call is done successfully, returns
Ok(())
. Otherwise, returnsErr
with aPyErr
in it. ThePyErr
will explain the cause of the failure. -
pub fn create_node<S: AsRef<str>>( &self, op: Op, target: Target, args: impl IntoIterator<IntoIter = impl ExactSizeIterator<Item = Argument>>, kwargs: impl IntoIterator<Item = (String, Argument)>, name: S, meta: Option<HashMap<String, PyObject>>, ) -> PyResult<&Node>
An interface for
create_node
instance method ofGraph
PyTorch class, buttype_expr
is not given (None
). Also, ifmeta
is given, the newly createdNode
will have an attributemeta
, whose value will be the given argumentmeta
.If the method call is done successfully, returns
Ok
with a shared reference to the newly createdNode
in it. Otherwise, returnsErr
with aPyErr
in it. ThePyErr
will explain the cause of the failure. -
pub fn placeholder<S: AsRef<str>>( &self, name: S ) -> PyResult<&Node>
Create and insert a placeholder
Node
into thisGraph
. A placeholder represents a function input.name
is the name for the input value.This does the same what
placeholder
instance method ofGraph
PyTorch class does, buttype_expr
isNone
anddefault_value
isinspect.Signature.empty
.If the creation and insertion of the
Node
is done successfully, returnsOk
with a shared reference to the newly createdNode
in it. Otherwise, returnsErr
with aPyErr
in it. ThePyErr
will explain the cause of the failure. -
pub fn output( &self, args: Argument ) -> PyResult<&Node>
Create and insert an output
Node
into thisGraph
.args
is the value that should be returned by this output node.args
has to beArgument::NodeTuple
.This does the same what
output
instance method ofGraph
PyTorch class does, buttype_expr
isNone
and the newly createdNode
has a name 'output'.If the creation and insertion of the
Node
is done successfully, returnsOk
with a shared reference to the newly createdNode
in it. Otherwise, returnsErr
with aPyErr
in it. ThePyErr
will explain the cause of the failure. -
pub fn call_custom_function<S: AsRef<str>>( &self, name: S, custom_fn: CustomFn, args: impl IntoIterator<IntoIter = impl ExactSizeIterator<Item = Argument>>, kwargs: impl IntoIterator<Item = (String, Argument)>, ) -> PyResult<&Node>
Create and insert a call_function
Node
into thisGraph
. call_functionNode
represents a call to a Python callable, specified bycustom_fn
.This does the same what
call_function
instance method ofGraph
PyTorch class does, but the name ofthe_function
parameter is changed intocustom_fn
,type_expr
is not given (None
), and thename
for the name of this node is given.custom_fn
must be aCustomFn
, a python callable which calls a Rust function actually.If the creation and insertion of the
Node
is done successfully, returnsOk
with a shared reference to the newly createdNode
in it. Otherwise, returnsErr
with aPyErr
in it. ThePyErr
will explain the cause of the failure. -
pub fn call_python_function<S: AsRef<str>>( &self, name: S, the_function: Py<PyAny>, args: impl IntoIterator<IntoIter = impl ExactSizeIterator<Item = Argument>>, kwargs: impl IntoIterator<Item = (String, Argument)>, ) -> PyResult<&Node>
Create and insert a call_function
Node
into thisGraph
. call_functionNode
represents a call to a Python callable, specified bythe_function
.This does the same what
call_function
instance method ofGraph
PyTorch class does, buttype_expr
is not given (None
) and thename
for the name of this node is given.If the creation and insertion of the
Node
is done successfully, returnsOk
with a shared reference to the newly createdNode
in it. Otherwise, returnsErr
with aPyErr
in it. ThePyErr
will explain the cause of the failure. -
pub fn node_copy( &self, node: &Node, mapper: Option<&HashMap<String, String>>, ) -> PyResult<&Node>
Copy a
Node
from anotherGraph
into thisGraph
(self
).node
is the node to copy intoself
.mapper
needs to transform arguments from the graph ofnode
to the graph of self.This does the same what
node_copy
instance method ofGraph
PyTorch class does.If the copying and insertion of the
Node
is done successfuly, returnsOk
with a shared reference to the newly createdNode
in it. Otherwise, returnsErr
with aPyErr
in it. ThePyErr
will explain the cause of the failure. -
pub fn flatten_node_args<S: AsRef<str>>( &self, node_name: S ) -> PyResult<Option<Vec<String>>>
Retrieve the names of argument
Node
s of theNode
named as the value ofnode_name
in thisGraph
.If this graph doesn't have a
Node
named as the value ofnode_name
, returnsOk(None)
. If this graph have aNode
named as the value ofnode_name
, returnsOk(Some)
with aVec
of names of argumentNode
s of theNode
, in theSome
. If something fails while looking into thisGraph
, returnsErr
with aPyErr
in it. ThePyErr
will explain the cause of the failure. -
pub fn users<S: AsRef<str>>( &self, node_name: S ) -> PyResult<Option<Vec<String>>>
Retrieve the names of user
Node
s of theNode
named as the value ofnode_name
in thisGraph
.If this graph doesn't have a
Node
named as the value ofnode_name
, returnsOk(None)
. If this graph have aNode
named as the value ofnode_name
, returnsOk(Some)
with aVec
of names of userNode
s of theNode
, in theSome
. If something fails while looking into thisGraph
, returnsErr
with aPyErr
in it. ThePyErr
will explain the cause of the failure. -
pub fn graph_to_string( &self, py: Python<'_> ) -> PyResult<String>
Stringify this
Graph
.This does the same what
__str__
instance method ofGraph
PyTorch class.If stringifying is done successfully, returns
Ok
with the resulting string in it. Otherwise, returnsErr
with aPyErr
in it. ThePyErr
will explain the cause of the failure. -
pub fn extract_named_nodes(&self) -> PyResult<IndexMap<String, &Node>>
Collect all named
Node
s of thisGraph
.Make an
IndexMap
which maps eachNode
's name to a shared reference of theNode
itself, for everyNode
inself
.If this process is successful, returns
Ok
with theIndexMap
in it. Otherwise, returnErr
with aPyErr
in it.PyErr
will explain the cause of the failure. -
pub fn lookup_node<S: AsRef<str>>( &self, name: S ) -> PyResult<Option<&Node>>
Lookup a
Node
by its name(name
) in thisGraph
.If there is no
Node
with a name named as the value ofname
,Ok(None)
is returned. If there exists suchNode
in thisGraph
,Ok(Some)
with a shared reference to theNode
is returned. If this process fails, returnsErr
with aPyErr
in it.PyErr
will explain the cause of the failure.
#[repr(transparent)]
pub struct Node(_);
A wrapper for PyTorch's Node
class.
This appears as a shared reference &Node
into Python's heap instead of an owned value.
-
pub fn flatten_node_args(&self) -> PyResult<Vec<String>>
Retrieve the names of argument
Node
s of thisNode
. Although aNode
can have multiple arguments and an argument can have one or moreNode
s, the result will contain all the argumentNode
s' names in a 1-dimensional vector. (This is why this method is namedflatten_node_args
.)If the retrieval is done successfully, returns
Ok
with aVec
of names of argument nodes. Otherwise, returnsErr
with aPyErr
in it. ThePyErr
will explain the cause of the failure. -
pub fn args(&self) -> PyResult<Vec<Argument>>
Retrieve the arguments of this
Node
.If the retrieval is done successfully, returns
Ok
with aVec<
Argument
>
containing the arguments. Otherwise, returnsErr
with aPyErr
in it. ThePyErr
will explain the cause of the failure. -
pub fn name(&self) -> PyResult<String>
Retrieve the name of this
Node
.If the retrieval is done successfully, returns
Ok
with the name in it. Otherwise, returnsErr
with aPyErr
in it. ThePyErr
will explain the cause of the failure. -
pub fn op(&self) -> PyResult<Op>
Retrieve the opcode of this
Node
.If the retrieval is done successfully, returns
Ok
with the opcode inOp
in it. Otherwise, returnsErr
with aPyErr
in it. ThePyErr
will explain the cause of the failure. -
pub fn target(&self) -> PyResult<Target>
Retrieve the target this
Node
should call.If the retrieval is done successfully, returns
Ok
with the target inTarget
in it. Otherwise, returnsErr
with aPyErr
in it. ThePyErr
will explain the cause of the failure. -
pub fn kwargs(&self) -> PyResult<HashMap<String, Argument>>
Retrieve the kwargs to be passed to the target of this
Node
.If the retrieval is done successfully, returns
Ok
with the kwargs inHashMap<String,
Argument
>
in it. Otherwise, returnsErr
with aPyErr
in it. ThePyErr
will explain the cause of the failure. -
pub fn meta(&self) -> PyResult<HashMap<String, PyObject>>
Retrieve the meta of this
Node
.If this
Node
has an attributemeta
, returnsOk
with the meta inHashMap<String, PyObject>
in it. Otherwise, returnsOk(Default::default())
. This never returnsErr
.
Wrapper for a Rust function. This wraps a function to execute it in Python. Therefore, the function needs to receive 2 arguments, args as &PyTuple
and kwargs as Option<&PyDict>
, and return PyResult<PyObject>
.
#[pyclass]
#[derive(Clone)]
pub struct CustomFn {
pub func_name: String,
/* private fields */
}
An interface for Python callable object which actually executes a Rust function.
-
pub fn new<S: AsRef<str>>( func_name: S, func: FunctionWrapper ) -> Self
Create a new Python callable object which is named as the value of
func_name
and actually executes a Rust function wrapped infunc
.
#[derive(Debug, Clone, FromPyObject)]
pub struct TensorMeta {
pub shape: Vec<usize>,
pub dtype: Dtype,
pub requires_grad: bool,
pub stride: Vec<usize>,
pub memory_format: Option<MemoryFormat>,
pub is_quantized: bool,
pub qparams: HashMap<String, PyObject>,
}
A structure containing pertinent information about a tensor within a PyTorch program.
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum Op {
Placeholder,
CallFunction,
CallMethod,
CallModule,
GetAttr,
Output,
}
A representation of opcodes for Node
s.
#[derive(Debug, Clone)]
pub enum Target {
Str(String),
TorchOp(String, PyObject),
BuiltinFn(String, PyObject),
Callable(PyObject),
CustomFn(CustomFn),
}
A representation of targets for Node
s.
#[derive(Debug, Clone)]
pub enum Argument {
Node(String),
NodeList(Vec<String>),
NodeTuple(Vec<String>),
OptionalNodeList(Vec<Option<String>>),
OptionalNodeTuple(Vec<Option<String>>),
NoneList(usize),
NoneTuple(usize),
Bool(bool),
Int(i64),
Float(f64),
VecBool(Vec<bool>),
VecInt(Vec<i64>),
VecFloat(Vec<f64>),
Dtype(Dtype),
Layout(Layout),
Device(Device),
MemoryFormat(MemoryFormat),
Value(PyObject),
EmptyList,
None,
}
A representation of arguments for Node
s.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum Dtype {
Float32,
Float64,
Complex64,
Complex128,
Float16,
Bfloat16,
Uint8,
Int8,
Int16,
Int32,
Int64,
Bool,
}
An enum
which represents the data type of a torch.Tensor
.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum MemoryFormat {
ContiguousFormat,
ChannelsLast,
ChannelsLast3d,
PreserveFormat,
}
An enum
which represents the memory format on which a torch.Tensor
is or will be allocated.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum Device {
Cpu(Option<usize>),
Cuda(Option<usize>),
Mps(Option<usize>),
}
An enum
which represents the device on which a torch.Tensor
is or will be allocated.
By executing following, the documentation, by cargo-docs
, for this crate will open.
cargo doc --open
More detailed documentation for torch.fx
may be needed.