From 77d9edd2c94817be26bc73e9c125aad565deedbd Mon Sep 17 00:00:00 2001 From: Niklas Gustafsson Date: Mon, 3 Oct 2022 12:54:42 -0700 Subject: [PATCH 01/11] Update release notes. --- RELEASENOTES.md | 11 ++++++-- docfx/articles/modules.md | 56 +++++++++++++++++++++------------------ 2 files changed, 39 insertions(+), 28 deletions(-) diff --git a/RELEASENOTES.md b/RELEASENOTES.md index deb0f26e7..fe55eb04b 100644 --- a/RELEASENOTES.md +++ b/RELEASENOTES.md @@ -14,20 +14,27 @@ Doing so (rather than qualifying names with 'TorchSharp.') was already recommend __Loss functions are now aligned with the PyTorch APIs.__ This is a major change and the reason for incrementing the minor version number. The most direct consequence is that losses are modules rather than delegates, which means you need to call .forward() to actually compute the loss. Also, the factories are in torch.nn rather than torch.nn.functional and have the same Pascal-case names as the corresponding types. The members of the torch.nn.functional static class are now proper immediate loss functions, whereas the previous ones returned a loss delegate. +__Generic Module base class.__ The second major change is that Module is made type-safe with respect to the `forward()` function. Module is now an abstract base class, and interfaces `IModule`, `IModule`,... are introduced to define the signature of the `forward()` function. For most custom modules, this means that the base class has to be changed to `Module`, but some modules may require more significant changes. + +ScriptModule follows this pattern, but this version introduces `ScriptModule` base classes, with corresponding `torch.jit.load()` static factory methods. + __Fixed Bugs:__ +#323 forward() should take a variable-length list of arguments
#558 Fix deviation from the Pytorch loss function/module APIs
#742 Ease of use: Module.to method should be generic T -> T
#743 Ease of use: module factories should have dtype and device
+#745 Executing a TorchScript that returns multiple values, throws an exception
#744 Some of functions with inconsistent argument names
#749 functional.linear is wrong
-#761 Stateful optimizers should have support for save/load from disk. +#761 Stateful optimizers should have support for save/load from disk.
__API Changes__: Module.to(), cpu(), and cuda() were redone as extension methods. The virtual methods to override, if necessary, are now named '_to'. A need to do so should be extremely rare.
Support for saving and restoring hyperparameters and state of optimizers
- +Loss functions are now Modules rather than delegates.
+Custom modules should now use generic versions as base classes.
## NuGet Version 0.97.6 diff --git a/docfx/articles/modules.md b/docfx/articles/modules.md index cc194aa64..f5ee6fe6e 100644 --- a/docfx/articles/modules.md +++ b/docfx/articles/modules.md @@ -4,12 +4,26 @@ Unfortunatly, the word 'Module' is one of the most overloaded terms in software. In the context of TorchSharp, it means the same as in PyTorch: the fundamental building block of all models is the 'Module' class. All neural network layers are derived from Module and the way to create a model for training and inference in your code is to create a new Module. Without it, back-propagation will not work. +## Sequential + +For the most basic network architecture, which actually covers a surprising breadth, there is a simple way to create modules -- the 'Sequential' class. This class is created from a list of Modules (i.e. model components). When data is passed to the model in the form of a tensor, the Sequential instance will invoke each submodule in the order they were passed when the Sequential instance was created, passing the output from each layer to the next. The output of the final submodule will be the output of the Sequential instance. + +```C# +var seq = Sequential(("lin1", Linear(100, 10)), ("lin2", Linear(10, 5))); +... +seq.forward(some_data); +``` + +There is no real performance reason to use Sequential instead of rolling your own custom module, but there is much less code to write. That said, if using a custom module (up next) is your preference, for whatever reason, that's what you should do. It can be useful to create a custom module first, debug it, and then convert to using Sequential. + +Sequential can only string together modules that take a single tensor and return a single tensor -- anything else needs to be a custom module. + ## Custom Modules -A custom module is created by deriving a subclass from torch.nn.Module. One that is equivalent to the previous example looks like this: +A custom module is created by deriving a subclass from torch.nn.Module. The first generic parameters denote the input types of the module's forward() method: ```C# - private class TestModule1 : Module + private class TestModule1 : Module { public TestModule1() : base("TestModule1") @@ -36,9 +50,11 @@ Custom modules should always call `RegisterComponents()` once all submodules and The `forward()` method contains the computation of the module. It can contain a mix of TorchSharp primitives, layers, as well as any .NET code. Note, however, that only TorchSharp APIs are capable of operating on data residing in CUDA memory. Therefore, if performance is of the essence, expressing all computation in terms of TorchSharp APIs is essential. Non-TorchSharp APIs should be limited to things that aren't related to the tensor data, things like logging, for example. -In PyTorch, the `forward()` method takes an arbitrary number of arguments, of any type, and supports using default arguments. Currently, TorchSharp only defines two versions of `forward()` -- taking one or two tensors, and returning a single tensor. The TorchSharp implementation of Sequential assumes that it is passing only a single tensor along between its layers. Therefore, any model that needs to pass multiple arguments between layers will have to be custom. +In PyTorch, the `forward()` method takes an arbitrary number of arguments, of any type, and supports using default arguments. -Note that the local variable 'x' was declared in a using statement. This is important in order to deallocate the native memory associated with it as soon as it is no longer needed. For that reason, it is important to pull out temporaries like this into local variables, especially when the code is running on a GPU. (More on this at: [Dispose vs. GC in TorchSharp](memory.md)) +In TorchSharp, the signature of `forward()` is determined by the type parameter signature of the Module<> base class. In most cases, the input and output are both `torch.Tensor`. As noted above, the Sequential module collection class assumes that it is passing only a single tensor along between its layers. Therefore, any model that needs to pass multiple arguments between layers will have to be custom. + +Going back to the code inside the `forward()` method -- please note that the local variable 'x' was declared in a using statement. This is important in order to deallocate the native memory associated with it as soon as it is no longer needed. For that reason, it is important to pull out temporaries like this into local variables, especially when the code is running on a GPU. (More on this at: [Dispose vs. GC in TorchSharp](memory.md)) In other words, the following code is less memory efficient, because it delays reclaiming native memory until the next time GC is run: @@ -48,18 +64,6 @@ In other words, the following code is less memory efficient, because it delays r return lin2.forward(lin1.forward(input)); } ``` -## Sequential - -For the simplest network architecture, which actually covers a surprising breadth, there is a simplified way to create modules -- the 'Sequential' class. This class is created from a list of Modules (i.e. model components). When data is passed to the model, the Sequential instance will invoke each submodule in the order they were passed when the Sequential instance was created, passing the output from each layer to the next. The output of the final submodule will be the output of the Sequential instance. - -```C# -var seq = Sequential(("lin1", Linear(100, 10)), ("lin2", Linear(10, 5))); -... -seq.forward(some_data); -``` - -There is no real performance reason to use Sequential instead of rolling your own custom module, but there is much less code to write. That said, if using a custom module (up next) is your preference, for whatever reason, that's what you should do. It can be useful to create a custom module first, debug it, and then convert to using Sequential. - ## Using Sequential Inside A Custome Module @@ -131,12 +135,12 @@ To illustrate, this is the code for MobileNet from the TorchSharp examples: ## ModuleList -In some circumstances, it's useful to define a dynamic number of modules in a custom module. It could be because you want to parameterize the network architecture, or dynamically choose which layers to run, or just that its tedious to define so many fields. This may be addressed by using a ModuleList to contain the submodules. Unlike Sequential, ModuleList itself does not suffice -- its `forward()` method will throw an exception if invoked. +In some circumstances, it's useful to define a dynamic number of modules in a custom module. It could be because you want to parameterize the network architecture, or dynamically choose which layers to run, or just that its tedious to define so many fields. This may be addressed by using a ModuleList to contain the submodules. Unlike Sequential, ModuleList itself does not suffice -- its `forward()` method will throw an exception if invoked, and you must iterate through the modules of the list directly in your `forward()` implementation. The purpose is simply to provide a list implementation that automatically registers the submodules when components are registered. You have to iterate through the list in the `forward()` method: ```C# - private class TestModule1 : Module + private class TestModule1 : Module { public TestModule1() : base("TestModule1") @@ -157,12 +161,12 @@ The purpose is simply to provide a list implementation that automatically regist ## ModuleDict -In some circumstances, it's useful to define a dynamic number of modules in a custom module. It could be because you want to parameterize the network architecture, or dynamically choose which layers to run, or just that its tedious to define so many fields. This may be addressed by using a ModuleList to contain the submodules. Unlike Sequential, ModuleList itself does not suffice -- its `forward()` method will throw an exception if invoked. +In some circumstances, it's useful to define a dynamic number of modules in a custom module and use a dictionary to hold them. Like with `ModuleList`, it could be because you want to parameterize the network architecture, or dynamically choose which layers to run, or just that its tedious to define so many fields. This may be addressed by using a ModuleDict to contain the submodules. The purpose is simply to provide a list implementation that automatically registers the submodules when components are registered. You have to iterate through the list in the `forward()` method: ```C# - private class TestModule1 : Module + private class TestModule1 : Module { public TestModule1() : base("TestModule1") @@ -191,7 +195,7 @@ Many modules are just compositions of existing modules, but sometimes it will im For example, a re-implementation of 'Linear' would look something like: ```C# - private class MyLinear : Module + private class MyLinear : Module { public MyLinear(long input_size, long output_size) : base("MyLinear") @@ -212,9 +216,9 @@ For example, a re-implementation of 'Linear' would look something like: } ``` -In this case, we're not relying on 'using' in the `forward()` method, because the temporary is reused as the target by the `add_()` function. +In this case, we're not relying on 'using' in the `forward()` method, because the "temporary" is reused as the target by the `add_()` function. -Parameter's dirty little secret is that it will clean out the tensor that is given to its constructor. So, `Parameter()` is preferrably used with another tensor factory (such as in the example above), or a cloned tensor. +Parameter's dirty little secret is that it will clean out the tensor that is given to its constructor. So, `Parameter()` is preferrably used with another tensor factory (such as in the example above), or a cloned tensor. Once you have passes a tensor to the Parameter constructor, the original tensor is invalidated. ## ParameterList @@ -226,7 +230,7 @@ Much like ModuleDict, ParameterDict is a dictionary of Parameter instances, whic ## Buffers -Sometimes, a module needs to allocate tensor that are not trainable, i.e. their values are not modified during back-propagation. An example is a random dropout mask. These are referred to as 'buffers' as opposed to 'parameters' and are treated differently by `RegisterComponents()` -- even though they are not trainable, the native runtime still wants to know about them for other purposes, so it is important to declare them in the module. +Sometimes, a module needs to allocate tensor that are not trainable, i.e. their values are not modified during back-propagation. An example is a random dropout mask. These are referred to as 'buffers' as opposed to 'parameters' and are treated differently by `RegisterComponents()` -- even though they are not trainable, the native runtime still wants to know about them for other purposes, such as storing them to disk, so it is important to declare them in the module. Each buffer should be declared as a field of type 'Tensor' (not 'Parameter'). This will ensure that the buffer is registered properly when `RegisterComponents()` is called. @@ -238,7 +242,7 @@ It is sometimes necessary to create a new model from an existing one and discard So, for example: ```C# - private class TestModule1 : Module + private class TestModule1 : Module { public TestModule1() : base("TestModule1") @@ -259,7 +263,7 @@ So, for example: private Module lin2; } - private class TestModule2 : Module + private class TestModule2 : Module { public TestModule2() : base("TestModule2") From c20540aa57959b15fca9389999c13fab56fda2f4 Mon Sep 17 00:00:00 2001 From: Niklas Gustafsson Date: Mon, 3 Oct 2022 15:29:15 -0700 Subject: [PATCH 02/11] Adding support for ScriptModules: 1. Calling methods other than 'forward' 2. Compiling Python scripts containing functions. --- src/Native/LibTorchSharp/THSJIT.cpp | 63 +++++-- src/Native/LibTorchSharp/THSJIT.h | 5 + src/Native/LibTorchSharp/Utils.h | 3 +- src/TorchSharp/JIT/CompilationUnit.cs | 176 ++++++++++++++++++ src/TorchSharp/JIT/ScriptModule.cs | 91 +++++++++ .../TorchSharpTest.WithCudaBinaries.csproj | 3 + test/TorchSharpTest/TestJIT.cs | 107 +++++++++-- test/TorchSharpTest/TorchSharpTest.csproj | 3 + test/TorchSharpTest/exported.method.dat | Bin 0 -> 1846 bytes 9 files changed, 420 insertions(+), 31 deletions(-) create mode 100644 src/TorchSharp/JIT/CompilationUnit.cs create mode 100644 test/TorchSharpTest/exported.method.dat diff --git a/src/Native/LibTorchSharp/THSJIT.cpp b/src/Native/LibTorchSharp/THSJIT.cpp index e1c2d7e87..5dc33273b 100644 --- a/src/Native/LibTorchSharp/THSJIT.cpp +++ b/src/Native/LibTorchSharp/THSJIT.cpp @@ -5,13 +5,23 @@ JITModule THSJIT_load(const char* filename) { CATCH( auto res = torch::jit::load(filename); - auto copy = new torch::jit::Module(res); - return new std::shared_ptr(copy); + auto copy = new torch::jit::Module(res); + return new std::shared_ptr(copy); ); return nullptr; } +JITCompilationUnit THSJIT_compile(const char* script) +{ + //CATCH( + auto res = torch::jit::compile(script); + return new std::shared_ptr(res); + //); + + return nullptr; +} + void THSJIT_save(JITModule module, const char* filename) { CATCH( @@ -145,19 +155,14 @@ JITMethod THSJIT_Module_get_method(const JITModule module, const char* name) return new std::shared_ptr(copy); } -void THSJIT_Module_forward(const JITModule module, const Tensor* tensorPtrs, const int length, Tensor* (*allocator)(size_t length), int8_t* typeCode) +void ReturnHelper(c10::IValue result, Tensor* (*allocator)(size_t length), int8_t* typeCode) { - *typeCode = 0; - - CATCH( - auto result = (*module)->forward(toTensors((torch::Tensor**)tensorPtrs, length)); - // TypeCode: - // - // 0 -- Not supported - // 1 -- Single tensor - // 2 -- Tuple of tensors - // 3 -- List of tensors +// +// 0 -- Not supported +// 1 -- Single tensor +// 2 -- Tuple of tensors +// 3 -- List of tensors if (result.isTensor()) { Tensor* output = allocator(1); @@ -184,6 +189,38 @@ void THSJIT_Module_forward(const JITModule module, const Tensor* tensorPtrs, con output[i] = ResultTensor(list[i].toTensor()); return; } +} + +void THSJIT_Module_forward(const JITModule module, const Tensor* tensorPtrs, const int length, Tensor* (*allocator)(size_t length), int8_t* typeCode) +{ + *typeCode = 0; + + CATCH( + auto result = (*module)->forward(toTensors((torch::Tensor**)tensorPtrs, length)); + ReturnHelper(result, allocator, typeCode); + ) +} + +void THSJIT_Module_invoke(const JITModule module, const char* name, const Tensor* tensorPtrs, const int length, Tensor* (*allocator)(size_t length), int8_t* typeCode) +{ + *typeCode = 0; + + CATCH( + auto method = (*module)->get_method(name); + auto result = method(toTensors((torch::Tensor**)tensorPtrs, length)); + ReturnHelper(result, allocator, typeCode); + ) +} + +void THSJIT_CompilationUnit_Invoke(const JITCompilationUnit module, const char* method, const Tensor* tensorPtrs, const int length, Tensor* (*allocator)(size_t length), int8_t* typeCode) +{ + *typeCode = 0; + + CATCH( + auto args = toTensors((torch::Tensor**)tensorPtrs, length); + auto func = (*module)->find_function(method); + auto result = (*func)(args); + ReturnHelper(result, allocator, typeCode); ) } diff --git a/src/Native/LibTorchSharp/THSJIT.h b/src/Native/LibTorchSharp/THSJIT.h index 0b3d0c5b9..6e53dc291 100644 --- a/src/Native/LibTorchSharp/THSJIT.h +++ b/src/Native/LibTorchSharp/THSJIT.h @@ -19,13 +19,18 @@ enum TypeKind : int8_t { EXPORT_API(JITModule) THSJIT_load(const char* filename); EXPORT_API(void) THSJIT_save(JITModule module, const char* filename); +EXPORT_API(JITCompilationUnit) THSJIT_compile(const char* script); EXPORT_API(void) THSJIT_Module_dispose(const JITModule module); +EXPORT_API(void) THSJIT_CompilationUnit_dispose(const JITCompilationUnit module); EXPORT_API(int) THSJIT_Module_num_inputs(const JITModule method); EXPORT_API(int) THSJIT_Module_num_outputs(const JITModule method); EXPORT_API(void) THSJIT_Module_forward(const JITModule module, const Tensor* tensorPtrs, const int length, Tensor* (*allocator)(size_t length), int8_t* typeCode); +EXPORT_API(void) THSJIT_Module_invoke(const JITModule module, const char* name, const Tensor* tensorPtrs, const int length, Tensor* (*allocator)(size_t length), int8_t* typeCode); + +EXPORT_API(void) THSJIT_CompilationUnit_Invoke(const JITCompilationUnit module, const char* method, const Tensor* tensorPtrs, const int length, Tensor* (*allocator)(size_t length), int8_t* typeCode); EXPORT_API(int) THSJIT_Module_is_training(JITModule module); EXPORT_API(void) THSJIT_Module_train(JITModule module, bool on); diff --git a/src/Native/LibTorchSharp/Utils.h b/src/Native/LibTorchSharp/Utils.h index bbad85e41..adefa5aca 100644 --- a/src/Native/LibTorchSharp/Utils.h +++ b/src/Native/LibTorchSharp/Utils.h @@ -16,7 +16,8 @@ typedef torch::nn::utils::rnn::PackedSequence* PackedSequence; typedef std::shared_ptr * NNModule; typedef std::shared_ptr * NNAnyModule; typedef std::shared_ptr * Optimizer; -typedef std::shared_ptr * JITModule; +typedef std::shared_ptr * JITCompilationUnit; +typedef std::shared_ptr* JITModule; typedef std::shared_ptr* JITMethod; typedef std::shared_ptr * JITFunction; typedef std::shared_ptr * JITType; diff --git a/src/TorchSharp/JIT/CompilationUnit.cs b/src/TorchSharp/JIT/CompilationUnit.cs new file mode 100644 index 000000000..cf32f03cf --- /dev/null +++ b/src/TorchSharp/JIT/CompilationUnit.cs @@ -0,0 +1,176 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using System.Linq; +using System.Collections.Generic; +using System.Diagnostics; +using System.Reflection; +using System.Runtime.InteropServices; +using static TorchSharp.torch; +using System.Net; +using static TorchSharp.torch.nn; + +namespace TorchSharp +{ + public static partial class torch + { + public static partial class jit + { + /// + /// Represents a TorchScript compilation unit, i.e. a Python script file. + /// + /// + /// var cu = torch.jit.compile(@" + /// def relu_script(a, b): + /// return torch.relu(a + b) + /// "); + /// + /// var y = cu.invoke("relu_script", torch.randn(10)); + /// + /// + /// Currently, scripts are limited to defining functions. Classes will be ignored. + /// + public class CompilationUnit : IDisposable + { + internal CompilationUnit(IntPtr handle) + { + this.handle = handle; + } + + ~CompilationUnit() => Dispose(false); + + /// + /// Releases the storage. + /// + public void Dispose() + { + GC.SuppressFinalize(this); + Dispose(true); + } + + /// + /// Implements the .NET Dispose pattern. + /// + protected virtual void Dispose(bool disposing) + { + if (disposing && handle != IntPtr.Zero) { + handle = IntPtr.Zero; + } + } + + [DllImport("LibTorchSharp")] + private static extern void THSJIT_CompilationUnit_dispose(IntPtr handle); + + internal IntPtr handle; + + [DllImport("LibTorchSharp")] + private static extern void THSJIT_CompilationUnit_Invoke(IntPtr module, string name, IntPtr tensors, int length, AllocatePinnedArray allocator, out sbyte typeCode); + + /// + /// Invoke a function from the compilation unit. + /// + /// The name of the function. + /// Function arguments. + public object invoke(string name, params object[] objs) + { + if (String.IsNullOrEmpty(name)) throw new ArgumentNullException("method name"); + + if (!objs.All(o => typeof(Tensor).IsAssignableFrom(o.GetType()))) { + throw new NotImplementedException("ScriptModule.forward() taking non-tensors as input arguments"); + } + + IntPtr[] ptrArray = null; + sbyte typeCode = 0; + + using (var parray = new PinnedArray()) { + + var tensors = objs.Select(o => (Tensor)o).ToArray(); + var count = tensors.Length; + var tensorRefs = new IntPtr[count]; + for (var i = 0; i < tensors.Length; i++) tensorRefs[i] = tensors[i].Handle; + + THSJIT_CompilationUnit_Invoke(handle, name, parray.CreateArray(tensorRefs), count, parray.CreateArray, out typeCode); + torch.CheckForErrors(); + ptrArray = parray.Array; + } + + + switch (typeCode) { + default: + // Nothing. + throw new NotImplementedException("ScriptModule.forward() returning something else than a tensor, a tuple of tensors, or list of tensors."); + case 1: + // Tensor + return new Tensor(ptrArray[0]); + case 2: + // Tuple + switch (ptrArray.Length) { + case 1: + return new Tensor(ptrArray[0]); + case 2: + return (new Tensor(ptrArray[0]), new Tensor(ptrArray[1])); + case 3: + return (new Tensor(ptrArray[0]), new Tensor(ptrArray[1]), new Tensor(ptrArray[2])); + case 4: + return (new Tensor(ptrArray[0]), new Tensor(ptrArray[1]), new Tensor(ptrArray[2]), new Tensor(ptrArray[3])); + case 5: + return (new Tensor(ptrArray[0]), new Tensor(ptrArray[1]), new Tensor(ptrArray[2]), new Tensor(ptrArray[3]), new Tensor(ptrArray[4])); + default: { + // Too long a tuple, return as a list, instead. + var result = new Tensor[ptrArray.Length]; + for (var i = 0; i < ptrArray.Length; i++) { + result[i] = new Tensor(ptrArray[i]); + } + return result; + } + } + case 3: { + // List of tensors + var result = new Tensor[ptrArray.Length]; + for (var i = 0; i < ptrArray.Length; i++) { + result[i] = new Tensor(ptrArray[i]); + } + return result; + } + } + + } + + /// + /// Invoke a function from the compilation unit. + /// + /// The return type of the TorchScript function. + /// The name of the function. + /// Function arguments. + public TResult invoke(string name, params object[] inputs) => (TResult)invoke(name, inputs); + + /// + /// Invoke a function from the compilation unit. + /// + /// The type of all function arguments. + /// The return type of the TorchScript function. + /// The name of the function. + /// Function arguments. + public TResult invoke(string name, params T[] inputs) => (TResult)invoke(name, inputs); + } + + [DllImport("LibTorchSharp")] + private static extern IntPtr THSJIT_compile(string script); + + /// + /// Create a TorchScript compilation unit containing TorchScript-compliant Python from a string. + /// + /// A string with Python code expressing a set of TorchScript functions. + /// + public static CompilationUnit compile(string script) + { + if (String.IsNullOrEmpty(script)) + throw new ArgumentNullException("empty script"); + + var result = THSJIT_compile(script); + if (result == IntPtr.Zero) + CheckForErrors(); + return new CompilationUnit(result); + } + } + } +} diff --git a/src/TorchSharp/JIT/ScriptModule.cs b/src/TorchSharp/JIT/ScriptModule.cs index 7835e5258..9fb99705b 100644 --- a/src/TorchSharp/JIT/ScriptModule.cs +++ b/src/TorchSharp/JIT/ScriptModule.cs @@ -7,6 +7,7 @@ using System.Runtime.InteropServices; using static TorchSharp.torch; using System.Net; +using static TorchSharp.torch.nn; namespace TorchSharp { @@ -332,6 +333,96 @@ public object forward(params object[] objs) } } } + + [DllImport("LibTorchSharp")] + private static extern void THSJIT_Module_invoke(HType module, string name, IntPtr tensors, int length, AllocatePinnedArray allocator, out sbyte typeCode); + + /// + /// Invoke a function from the script module. + /// + /// The name of the function. + /// Function arguments. + public object invoke(string name, params object[] objs) + { + if (String.IsNullOrEmpty(name)) throw new ArgumentNullException("method name"); + + if (!objs.All(o => typeof(Tensor).IsAssignableFrom(o.GetType()))) { + throw new NotImplementedException("ScriptModule.forward() taking non-tensors as input arguments"); + } + + IntPtr[] ptrArray = null; + sbyte typeCode = 0; + + using (var parray = new PinnedArray()) { + + var tensors = objs.Select(o => (Tensor)o).ToArray(); + var count = tensors.Length; + var tensorRefs = new IntPtr[count]; + for (var i = 0; i < tensors.Length; i++) tensorRefs[i] = tensors[i].Handle; + + THSJIT_Module_invoke(handle, name, parray.CreateArray(tensorRefs), count, parray.CreateArray, out typeCode); + torch.CheckForErrors(); + ptrArray = parray.Array; + } + + + switch (typeCode) { + default: + // Nothing. + throw new NotImplementedException("ScriptModule.forward() returning something else than a tensor, a tuple of tensors, or list of tensors."); + case 1: + // Tensor + return new Tensor(ptrArray[0]); + case 2: + // Tuple + switch (ptrArray.Length) { + case 1: + return new Tensor(ptrArray[0]); + case 2: + return (new Tensor(ptrArray[0]), new Tensor(ptrArray[1])); + case 3: + return (new Tensor(ptrArray[0]), new Tensor(ptrArray[1]), new Tensor(ptrArray[2])); + case 4: + return (new Tensor(ptrArray[0]), new Tensor(ptrArray[1]), new Tensor(ptrArray[2]), new Tensor(ptrArray[3])); + case 5: + return (new Tensor(ptrArray[0]), new Tensor(ptrArray[1]), new Tensor(ptrArray[2]), new Tensor(ptrArray[3]), new Tensor(ptrArray[4])); + default: { + // Too long a tuple, return as a list, instead. + var result = new Tensor[ptrArray.Length]; + for (var i = 0; i < ptrArray.Length; i++) { + result[i] = new Tensor(ptrArray[i]); + } + return result; + } + } + case 3: { + // List of tensors + var result = new Tensor[ptrArray.Length]; + for (var i = 0; i < ptrArray.Length; i++) { + result[i] = new Tensor(ptrArray[i]); + } + return result; + } + } + + } + + /// + /// Invoke a function from the script module. + /// + /// The return type of the TorchScript function. + /// The name of the function. + /// Function arguments. + public TResult invoke(string name, params object[] inputs) => (TResult)invoke(name, inputs); + + /// + /// Invoke a function from the script module. + /// + /// The type of all function arguments. + /// The return type of the TorchScript function. + /// The name of the function. + /// Function arguments. + public TResult invoke(string name, params T[] inputs) => (TResult)invoke(name, inputs); } /// diff --git a/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj b/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj index 923178a4a..2f1d36bb7 100644 --- a/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj +++ b/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj @@ -58,6 +58,9 @@ + + PreserveNewest + PreserveNewest diff --git a/test/TorchSharpTest/TestJIT.cs b/test/TorchSharpTest/TestJIT.cs index a00d011fa..c5fa1890a 100644 --- a/test/TorchSharpTest/TestJIT.cs +++ b/test/TorchSharpTest/TestJIT.cs @@ -2,14 +2,9 @@ using System; using System.IO; using System.Linq; -using TorchSharp.Modules; +using static TorchSharp.torch; using static TorchSharp.torch.nn; using Xunit; -using Google.Protobuf; -using Tensorboard; -using static TorchSharp.torch.utils.tensorboard; -using ICSharpCode.SharpZipLib; -using System.Collections.Generic; #nullable enable @@ -26,7 +21,7 @@ public class TestJIT public void TestLoadJIT_Func() { // One linear layer followed by ReLU. - using var m = torch.jit.load(@"func.script.dat"); + using var m = torch.jit.load(@"func.script.dat"); var sms = m.named_modules().ToArray(); Assert.Empty(sms); @@ -45,7 +40,7 @@ public void TestLoadJIT_Func() public void TestLoadJIT_1() { // One linear layer followed by ReLU. - using var m = torch.jit.load(@"linrelu.script.dat"); + using var m = torch.jit.load(@"linrelu.script.dat"); var t = m.forward(torch.ones(10)); Assert.Equal(new long[] { 6 }, t.shape); @@ -62,10 +57,10 @@ public void TestSaveJIT() try { // One linear layer followed by ReLU. - using var m1 = torch.jit.load(@"linrelu.script.dat"); + using var m1 = torch.jit.load(@"linrelu.script.dat"); torch.jit.save(m1, location); - using var m2 = torch.jit.load(location); + using var m2 = torch.jit.load(location); var t = m2.forward(torch.ones(10)); @@ -82,7 +77,7 @@ public void TestSaveJIT() public void TestLoadJIT_2() { // One linear layer followed by ReLU. - using var m = torch.jit.load(@"scripted.script.dat"); + using var m = torch.jit.load(@"scripted.script.dat"); var t = m.forward(torch.ones(6)); Assert.Equal(new long[] { 6 }, t.shape); @@ -94,7 +89,7 @@ public void TestLoadJIT_2() public void TestLoadJIT_3() { // Two linear layers, nested Sequential, ReLU in between. - using var m = torch.jit.load(@"l1000_100_10.script.dat"); + using var m = torch.jit.load(@"l1000_100_10.script.dat"); var sms = m.named_modules().ToArray(); Assert.Equal(4, sms.Length); @@ -123,7 +118,7 @@ public void TestSaveLoadJITCUDA() { if (torch.cuda.is_available()) { - using var m = torch.jit.load(@"linrelu.script.dat"); + using var m = torch.jit.load(@"linrelu.script.dat"); m.to(DeviceType.CUDA); var params0 = m.parameters().ToArray(); @@ -144,7 +139,7 @@ public void TestJIT_TupleOut() // def a(x, y): // return x + y, x - y // - using var m = torch.jit.load<(torch.Tensor, torch.Tensor)>(@"tuple_out.dat"); + using var m = torch.jit.load<(Tensor, Tensor)>(@"tuple_out.dat"); var x = torch.rand(3, 4); var y = torch.rand(3, 4); @@ -164,7 +159,7 @@ public void TestJIT_TupleOutError() // def a(x, y): // return x + y, x - y // - using var m = torch.jit.load< (torch.Tensor, torch.Tensor)>(@"func.script.dat"); + using var m = torch.jit.load<(Tensor, Tensor)>(@"func.script.dat"); var x = torch.rand(3, 4); var y = torch.rand(3, 4); @@ -177,7 +172,7 @@ public void TestJIT_ListOut() // def a(x, y): // return [x + y, x - y] // - using var m = torch.jit.load(@"list_out.dat"); + using var m = torch.jit.load(@"list_out.dat"); var x = torch.rand(3, 4); var y = torch.rand(3, 4); @@ -197,11 +192,89 @@ public void TestJIT_ListOutError() // def a(x, y): // return x + y, x - y // - using var m = torch.jit.load(@"func.script.dat"); + using var m = torch.jit.load(@"func.script.dat"); var x = torch.rand(3, 4); var y = torch.rand(3, 4); Assert.Throws(() => m.forward(x, y)); } + + + + [Fact] + public void TestLoadJIT_Methods() + { + // class MyModule(nn.Module): + // def __init__(self): + // super().__init__() + // self.p = nn.Parameter(torch.rand(10)) + // def forward(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: + // return x + y, x - y + // + // @torch.jit.export + // def predict(self, x: Tensor) -> Tensor: + // return x + self.p + + using var m = new TestScriptModule(@"exported.method.dat"); + + var x = torch.rand(3, 4); + var y = torch.rand(3, 4); + var output = m.forward(x, y); + + Assert.Multiple( + () => Assert.Equal(x.shape, output.Item1.shape), + () => Assert.Equal(x.shape, output.Item2.shape), + () => Assert.Equal(x + y, output.Item1), + () => Assert.Equal(x - y, output.Item2) + ); + + var predict = m.predict(x); + + Assert.Multiple( + () => Assert.NotEqual(x, predict) + ); + } + + internal class TestScriptModule : Module + { + internal TestScriptModule(string filename) : base(nameof(TestScriptModule)) + { + m = torch.jit.load<(Tensor, Tensor)> (filename); + } + + public override (Tensor, Tensor) forward(Tensor input1, Tensor input2) + { + return m.forward(input1, input2); + } + + public Tensor predict(Tensor input) + { + return m.invoke("predict", input); + } + + private torch.jit.ScriptModule<(Tensor, Tensor)> m; + } + + [Fact] + public void TestJITCompile() + { + string script = @" + def relu_script(a, b): + return torch.relu(a + b) + def relu6_script(a, b): + return torch.relu6(a + b) +"; + + var cu = torch.jit.compile(script); + + Assert.NotNull(cu); + + var x = torch.randn(3, 4); + var y = torch.randn(3, 4); + var z = (Tensor)cu.invoke("relu_script", x, y); + Assert.Equal(torch.nn.functional.relu(x + y), z); + z = cu.invoke("relu6_script", x, y); + Assert.Equal(torch.nn.functional.relu6(x + y), z); + } } } diff --git a/test/TorchSharpTest/TorchSharpTest.csproj b/test/TorchSharpTest/TorchSharpTest.csproj index 516d0eb95..77e10469b 100644 --- a/test/TorchSharpTest/TorchSharpTest.csproj +++ b/test/TorchSharpTest/TorchSharpTest.csproj @@ -29,6 +29,9 @@ PreserveNewest + + PreserveNewest + PreserveNewest diff --git a/test/TorchSharpTest/exported.method.dat b/test/TorchSharpTest/exported.method.dat new file mode 100644 index 0000000000000000000000000000000000000000..d8686b69b546b93096d196de5c86384b62096d84 GIT binary patch literal 1846 zcmWIWW@cev;NW1u0AdV045<|b`9&qEDSEl7B^miC`YDMeiTVa^P8$NpB& z3OkSY`gW`HE$p>|jO}0iJ!yM{)5EUdjDp<(&tUtE+p%_U`i<<=;sWghyxBR<-;&fa z0NMw_0XXfKVBo@Pzg|Ihj++x3)CCPp$?@?e`9;YY@$p=~mA?5Yr8%iwg$$aFwHgtO zKnn^A8I!^Cdhw+tnK{K=@kOagrI|S?@g=Ew#rZ|?Wky_uOd1*y>_ElEK$VHeEVsYR(NE}6+CT!kE(K_x+joDrblE-6aP%*)J6FXZZo5Cih#GmGQX zN^^4JlM<7&%M*)I;xqE|vkSTXN|OqC^g!X(e(GfB3!q9628EjlFy1gDPLY8VBixen zQ&RQe;iOkk>E^@=3BcX}@4Ukb0(YN_EMC*|yt(8Ig z$HlRN`8mg}Hw%b!E?u#Fvfuap>+jd51&d6T=xGtvn4z+4!k^uX9QCs9>T`wz3nB&(#$6_eGnivoQkBxzi0i+(H768AC9J>Gj literal 0 HcmV?d00001 From 05e735fbc6d0ea80025eb193f09333b3a3b7f748 Mon Sep 17 00:00:00 2001 From: Niklas Gustafsson Date: Mon, 3 Oct 2022 15:42:53 -0700 Subject: [PATCH 03/11] Minor updates to the ScriptModule additions. --- src/TorchSharp/JIT/CompilationUnit.cs | 43 ++--------------------- src/TorchSharp/JIT/ScriptModule.cs | 49 ++++----------------------- 2 files changed, 9 insertions(+), 83 deletions(-) diff --git a/src/TorchSharp/JIT/CompilationUnit.cs b/src/TorchSharp/JIT/CompilationUnit.cs index cf32f03cf..e69d582fa 100644 --- a/src/TorchSharp/JIT/CompilationUnit.cs +++ b/src/TorchSharp/JIT/CompilationUnit.cs @@ -75,7 +75,7 @@ public object invoke(string name, params object[] objs) if (String.IsNullOrEmpty(name)) throw new ArgumentNullException("method name"); if (!objs.All(o => typeof(Tensor).IsAssignableFrom(o.GetType()))) { - throw new NotImplementedException("ScriptModule.forward() taking non-tensors as input arguments"); + throw new NotImplementedException($"CompilationUnit.{name}() is not yet taking non-tensors as input arguments"); } IntPtr[] ptrArray = null; @@ -93,46 +93,7 @@ public object invoke(string name, params object[] objs) ptrArray = parray.Array; } - - switch (typeCode) { - default: - // Nothing. - throw new NotImplementedException("ScriptModule.forward() returning something else than a tensor, a tuple of tensors, or list of tensors."); - case 1: - // Tensor - return new Tensor(ptrArray[0]); - case 2: - // Tuple - switch (ptrArray.Length) { - case 1: - return new Tensor(ptrArray[0]); - case 2: - return (new Tensor(ptrArray[0]), new Tensor(ptrArray[1])); - case 3: - return (new Tensor(ptrArray[0]), new Tensor(ptrArray[1]), new Tensor(ptrArray[2])); - case 4: - return (new Tensor(ptrArray[0]), new Tensor(ptrArray[1]), new Tensor(ptrArray[2]), new Tensor(ptrArray[3])); - case 5: - return (new Tensor(ptrArray[0]), new Tensor(ptrArray[1]), new Tensor(ptrArray[2]), new Tensor(ptrArray[3]), new Tensor(ptrArray[4])); - default: { - // Too long a tuple, return as a list, instead. - var result = new Tensor[ptrArray.Length]; - for (var i = 0; i < ptrArray.Length; i++) { - result[i] = new Tensor(ptrArray[i]); - } - return result; - } - } - case 3: { - // List of tensors - var result = new Tensor[ptrArray.Length]; - for (var i = 0; i < ptrArray.Length; i++) { - result[i] = new Tensor(ptrArray[i]); - } - return result; - } - } - + return torch.jit.ScriptModule.ProcessReturnValue(name, ptrArray, typeCode); } /// diff --git a/src/TorchSharp/JIT/ScriptModule.cs b/src/TorchSharp/JIT/ScriptModule.cs index 9fb99705b..4b0c8f5ef 100644 --- a/src/TorchSharp/JIT/ScriptModule.cs +++ b/src/TorchSharp/JIT/ScriptModule.cs @@ -293,11 +293,15 @@ public object forward(params object[] objs) ptrArray = parray.Array; } + return ProcessReturnValue("forward", ptrArray, typeCode); + } + internal static object ProcessReturnValue(string name, IntPtr[] ptrArray, sbyte typeCode) + { switch (typeCode) { default: // Nothing. - throw new NotImplementedException("ScriptModule.forward() returning something else than a tensor, a tuple of tensors, or list of tensors."); + throw new NotImplementedException($"ScriptModule.{name}() returning something else than a tensor, a tuple of tensors, or list of tensors."); case 1: // Tensor return new Tensor(ptrArray[0]); @@ -347,7 +351,7 @@ public object invoke(string name, params object[] objs) if (String.IsNullOrEmpty(name)) throw new ArgumentNullException("method name"); if (!objs.All(o => typeof(Tensor).IsAssignableFrom(o.GetType()))) { - throw new NotImplementedException("ScriptModule.forward() taking non-tensors as input arguments"); + throw new NotImplementedException($"ScriptModule.{name}() is not yet taking non-tensors as input arguments"); } IntPtr[] ptrArray = null; @@ -365,46 +369,7 @@ public object invoke(string name, params object[] objs) ptrArray = parray.Array; } - - switch (typeCode) { - default: - // Nothing. - throw new NotImplementedException("ScriptModule.forward() returning something else than a tensor, a tuple of tensors, or list of tensors."); - case 1: - // Tensor - return new Tensor(ptrArray[0]); - case 2: - // Tuple - switch (ptrArray.Length) { - case 1: - return new Tensor(ptrArray[0]); - case 2: - return (new Tensor(ptrArray[0]), new Tensor(ptrArray[1])); - case 3: - return (new Tensor(ptrArray[0]), new Tensor(ptrArray[1]), new Tensor(ptrArray[2])); - case 4: - return (new Tensor(ptrArray[0]), new Tensor(ptrArray[1]), new Tensor(ptrArray[2]), new Tensor(ptrArray[3])); - case 5: - return (new Tensor(ptrArray[0]), new Tensor(ptrArray[1]), new Tensor(ptrArray[2]), new Tensor(ptrArray[3]), new Tensor(ptrArray[4])); - default: { - // Too long a tuple, return as a list, instead. - var result = new Tensor[ptrArray.Length]; - for (var i = 0; i < ptrArray.Length; i++) { - result[i] = new Tensor(ptrArray[i]); - } - return result; - } - } - case 3: { - // List of tensors - var result = new Tensor[ptrArray.Length]; - for (var i = 0; i < ptrArray.Length; i++) { - result[i] = new Tensor(ptrArray[i]); - } - return result; - } - } - + return ProcessReturnValue(name, ptrArray, typeCode); } /// From b7f8351254a87a32dbae77e7db75f88e170af9ee Mon Sep 17 00:00:00 2001 From: Niklas Gustafsson Date: Mon, 3 Oct 2022 15:47:39 -0700 Subject: [PATCH 04/11] Removed commented code in C++. --- src/Native/LibTorchSharp/THSJIT.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Native/LibTorchSharp/THSJIT.cpp b/src/Native/LibTorchSharp/THSJIT.cpp index 5dc33273b..dac961f81 100644 --- a/src/Native/LibTorchSharp/THSJIT.cpp +++ b/src/Native/LibTorchSharp/THSJIT.cpp @@ -14,10 +14,10 @@ JITModule THSJIT_load(const char* filename) JITCompilationUnit THSJIT_compile(const char* script) { - //CATCH( + CATCH( auto res = torch::jit::compile(script); return new std::shared_ptr(res); - //); + ); return nullptr; } From 01e5319ef10859248d457a228cafc92e3b7a87b3 Mon Sep 17 00:00:00 2001 From: Niklas Gustafsson Date: Mon, 3 Oct 2022 16:55:06 -0700 Subject: [PATCH 05/11] Simplified code. --- src/TorchSharp/JIT/CompilationUnit.cs | 5 ++--- src/TorchSharp/JIT/ScriptModule.cs | 10 ++++------ 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/TorchSharp/JIT/CompilationUnit.cs b/src/TorchSharp/JIT/CompilationUnit.cs index e69d582fa..85171e4d0 100644 --- a/src/TorchSharp/JIT/CompilationUnit.cs +++ b/src/TorchSharp/JIT/CompilationUnit.cs @@ -83,10 +83,9 @@ public object invoke(string name, params object[] objs) using (var parray = new PinnedArray()) { - var tensors = objs.Select(o => (Tensor)o).ToArray(); - var count = tensors.Length; + var count = objs.Length; var tensorRefs = new IntPtr[count]; - for (var i = 0; i < tensors.Length; i++) tensorRefs[i] = tensors[i].Handle; + for (var i = 0; i < objs.Length; i++) tensorRefs[i] = ((Tensor)objs[i]).Handle; THSJIT_CompilationUnit_Invoke(handle, name, parray.CreateArray(tensorRefs), count, parray.CreateArray, out typeCode); torch.CheckForErrors(); diff --git a/src/TorchSharp/JIT/ScriptModule.cs b/src/TorchSharp/JIT/ScriptModule.cs index 4b0c8f5ef..251490fc9 100644 --- a/src/TorchSharp/JIT/ScriptModule.cs +++ b/src/TorchSharp/JIT/ScriptModule.cs @@ -283,10 +283,9 @@ public object forward(params object[] objs) using (var parray = new PinnedArray()) { - var tensors = objs.Select(o => (Tensor)o).ToArray(); - var count = tensors.Length; + var count = objs.Length; var tensorRefs = new IntPtr[count]; - for (var i = 0; i < tensors.Length; i++) tensorRefs[i] = tensors[i].Handle; + for (var i = 0; i < objs.Length; i++) tensorRefs[i] = ((Tensor)objs[i]).Handle; THSJIT_Module_forward(handle, parray.CreateArray(tensorRefs), count, parray.CreateArray, out typeCode); torch.CheckForErrors(); @@ -359,10 +358,9 @@ public object invoke(string name, params object[] objs) using (var parray = new PinnedArray()) { - var tensors = objs.Select(o => (Tensor)o).ToArray(); - var count = tensors.Length; + var count = objs.Length; var tensorRefs = new IntPtr[count]; - for (var i = 0; i < tensors.Length; i++) tensorRefs[i] = tensors[i].Handle; + for (var i = 0; i < objs.Length; i++) tensorRefs[i] = ((Tensor)objs[i]).Handle; THSJIT_Module_invoke(handle, name, parray.CreateArray(tensorRefs), count, parray.CreateArray, out typeCode); torch.CheckForErrors(); From b244ff0ee3ba95ad635dd367ebe621158594d4e2 Mon Sep 17 00:00:00 2001 From: Niklas Gustafsson Date: Tue, 4 Oct 2022 08:50:02 -0700 Subject: [PATCH 06/11] Support limited additional types in TorchScrip methods. --- src/Native/LibTorchSharp/THSJIT.cpp | 158 +++++++++--- src/Native/LibTorchSharp/THSJIT.h | 12 +- src/TorchSharp/JIT/CompilationUnit.cs | 17 +- src/TorchSharp/JIT/ScriptModule.cs | 306 +++++++++++++++++++----- test/TorchSharpTest/TestJIT.cs | 44 +++- test/TorchSharpTest/exported.method.dat | Bin 1846 -> 1974 bytes 6 files changed, 438 insertions(+), 99 deletions(-) diff --git a/src/Native/LibTorchSharp/THSJIT.cpp b/src/Native/LibTorchSharp/THSJIT.cpp index dac961f81..1d4dcf772 100644 --- a/src/Native/LibTorchSharp/THSJIT.cpp +++ b/src/Native/LibTorchSharp/THSJIT.cpp @@ -15,7 +15,7 @@ JITModule THSJIT_load(const char* filename) JITCompilationUnit THSJIT_compile(const char* script) { CATCH( - auto res = torch::jit::compile(script); + auto res = torch::jit::compile(script); return new std::shared_ptr(res); ); @@ -155,7 +155,7 @@ JITMethod THSJIT_Module_get_method(const JITModule module, const char* name) return new std::shared_ptr(copy); } -void ReturnHelper(c10::IValue result, Tensor* (*allocator)(size_t length), int8_t* typeCode) +void ReturnHelper(c10::IValue result, TensorOrScalar* (*allocator)(size_t length), int8_t* typeCode) { // TypeCode: // @@ -163,64 +163,166 @@ void ReturnHelper(c10::IValue result, Tensor* (*allocator)(size_t length), int8_ // 1 -- Single tensor // 2 -- Tuple of tensors // 3 -- List of tensors +// 4 -- Single scalar +// 5 -- Scalar tuple +// 6 -- List of scalars +// 7 -- List of scalars and tensors + + if (result.isScalar()) + { + TensorOrScalar* output = allocator(1); + output[0] = { 0, (ptrdiff_t)new torch::Scalar(result.toScalar()) }; + *typeCode = 4; + return; + } if (result.isTensor()) { - Tensor* output = allocator(1); - output[0] = ResultTensor(result.toTensor()); + TensorOrScalar* output = allocator(1); + output[0] = { 0, (ptrdiff_t)ResultTensor(result.toTensor()) }; *typeCode = 1; return; } + if (result.isTensorList()) { auto list = result.toTensorList(); *typeCode = 3; - Tensor* output = allocator(list.size()); + TensorOrScalar* output = allocator(list.size()); for (size_t i = 0; i < list.size(); i++) - output[i] = ResultTensor(list[i]); + output[i] = { 0, (ptrdiff_t)ResultTensor(list[i]) }; return; } + + if (result.isList()) + { + int foundTensor = 0; + int foundScalar = 0; + + auto& list = result.toList(); + TensorOrScalar* output = allocator(list.size()); + + for (int i = 0; i < list.size(); ++i) + { + output[i].Handle = -1; + c10::IValue value = list[i]; + + if (value.isTensor()) + { + output[i] = { 0, (ptrdiff_t)ResultTensor(value.toTensor()) }; + foundTensor += 1; + continue; + } + if (value.isScalar()) + { + output[i] = { 4, (ptrdiff_t)new torch::Scalar(value.toScalar()) }; + foundScalar += 1; + continue; + } + *typeCode = 0; + return; + } + + *typeCode = 7; + if (foundScalar == 0) + *typeCode = 3; + if (foundTensor == 0) + *typeCode = 6; + } + if (result.isTuple()) { - auto tuple = result.toTuple(); - auto list = tuple->elements(); - auto sz = list.size(); - *typeCode = 2; - Tensor* output = allocator(list.size()); - for (size_t i = 0; i < list.size(); i++) - // Assuming that all elements are tensors. - output[i] = ResultTensor(list[i].toTensor()); - return; + int foundTensor = 0; + int foundScalar = 0; + + auto& list = result.toTuple()->elements(); + TensorOrScalar* output = allocator(list.size()); + + for (int i = 0; i < list.size(); ++i) + { + output[i].Handle = -1; + c10::IValue value = list[i]; + + if (value.isTensor()) + { + output[i] = { 0, (ptrdiff_t)ResultTensor(value.toTensor()) }; + foundTensor += 1; + continue; + } + if (value.isScalar()) + { + output[i] = { 4, (ptrdiff_t)new torch::Scalar(value.toScalar()) }; + foundScalar += 1; + continue; + } + *typeCode = 0; + return; + } + + *typeCode = 7; + if (foundScalar == 0) + *typeCode = 2; + if (foundTensor == 0) + *typeCode = 5; } } -void THSJIT_Module_forward(const JITModule module, const Tensor* tensorPtrs, const int length, Tensor* (*allocator)(size_t length), int8_t* typeCode) +std::vector toIValue(const TensorOrScalar* tensorPtrs, const int length) +{ + std::vector tensors; + + if (tensorPtrs != nullptr) { + for (int i = 0; i < length; i++) + { + switch (tensorPtrs[i].TypeCode) { + case 0: + tensors.push_back(*(torch::Tensor*)(tensorPtrs[i].Handle)); + break; + case 1: + tensors.push_back(*(torch::Scalar*)(tensorPtrs[i].Handle)); + break; + case 2: + tensors.push_back(tensorPtrs[i].Handle != 0); + break; + case 3: + tensors.push_back((int)tensorPtrs[i].Handle); + break; + case 4: + tensors.push_back((long)tensorPtrs[i].Handle); + break; + } + } + } + return tensors; +} + +void THSJIT_Module_forward(const JITModule module, const TensorOrScalar* tensorPtrs, const int length, TensorOrScalar* (*allocator)(size_t length), int8_t* typeCode) { *typeCode = 0; CATCH( - auto result = (*module)->forward(toTensors((torch::Tensor**)tensorPtrs, length)); + auto result = (*module)->forward(toIValue(tensorPtrs, length)); ReturnHelper(result, allocator, typeCode); ) } -void THSJIT_Module_invoke(const JITModule module, const char* name, const Tensor* tensorPtrs, const int length, Tensor* (*allocator)(size_t length), int8_t* typeCode) +void THSJIT_Module_invoke(const JITModule module, const char* name, const TensorOrScalar* tensorPtrs, const int length, TensorOrScalar* (*allocator)(size_t length), int8_t* typeCode) { *typeCode = 0; - CATCH( - auto method = (*module)->get_method(name); - auto result = method(toTensors((torch::Tensor**)tensorPtrs, length)); - ReturnHelper(result, allocator, typeCode); - ) + //CATCH( + auto method = (*module)->get_method(name); + auto result = method(toIValue(tensorPtrs, length)); + ReturnHelper(result, allocator, typeCode); + //) } -void THSJIT_CompilationUnit_Invoke(const JITCompilationUnit module, const char* method, const Tensor* tensorPtrs, const int length, Tensor* (*allocator)(size_t length), int8_t* typeCode) +void THSJIT_CompilationUnit_Invoke(const JITCompilationUnit module, const char* method, const TensorOrScalar* tensorPtrs, const int length, TensorOrScalar* (*allocator)(size_t length), int8_t* typeCode) { *typeCode = 0; CATCH( - auto args = toTensors((torch::Tensor**)tensorPtrs, length); - auto func = (*module)->find_function(method); - auto result = (*func)(args); - ReturnHelper(result, allocator, typeCode); + auto args = toIValue(tensorPtrs, length); + auto func = (*module)->find_function(method); + auto result = (*func)(args); + ReturnHelper(result, allocator, typeCode); ) } diff --git a/src/Native/LibTorchSharp/THSJIT.h b/src/Native/LibTorchSharp/THSJIT.h index 6e53dc291..906ba6ad7 100644 --- a/src/Native/LibTorchSharp/THSJIT.h +++ b/src/Native/LibTorchSharp/THSJIT.h @@ -16,6 +16,12 @@ enum TypeKind : int8_t { // API. +struct TensorOrScalar +{ + int64_t TypeCode; + ptrdiff_t Handle; +}; + EXPORT_API(JITModule) THSJIT_load(const char* filename); EXPORT_API(void) THSJIT_save(JITModule module, const char* filename); @@ -27,10 +33,10 @@ EXPORT_API(void) THSJIT_CompilationUnit_dispose(const JITCompilationUnit module) EXPORT_API(int) THSJIT_Module_num_inputs(const JITModule method); EXPORT_API(int) THSJIT_Module_num_outputs(const JITModule method); -EXPORT_API(void) THSJIT_Module_forward(const JITModule module, const Tensor* tensorPtrs, const int length, Tensor* (*allocator)(size_t length), int8_t* typeCode); -EXPORT_API(void) THSJIT_Module_invoke(const JITModule module, const char* name, const Tensor* tensorPtrs, const int length, Tensor* (*allocator)(size_t length), int8_t* typeCode); +EXPORT_API(void) THSJIT_Module_forward(const JITModule module, const TensorOrScalar* tensorPtrs, const int length, TensorOrScalar* (*allocator)(size_t length), int8_t* typeCode); +EXPORT_API(void) THSJIT_Module_invoke(const JITModule module, const char* name, const TensorOrScalar* tensorPtrs, const int length, TensorOrScalar* (*allocator)(size_t length), int8_t* typeCode); -EXPORT_API(void) THSJIT_CompilationUnit_Invoke(const JITCompilationUnit module, const char* method, const Tensor* tensorPtrs, const int length, Tensor* (*allocator)(size_t length), int8_t* typeCode); +EXPORT_API(void) THSJIT_CompilationUnit_Invoke(const JITCompilationUnit module, const char* method, const TensorOrScalar* tensorPtrs, const int length, TensorOrScalar* (*allocator)(size_t length), int8_t* typeCode); EXPORT_API(int) THSJIT_Module_is_training(JITModule module); EXPORT_API(void) THSJIT_Module_train(JITModule module, bool on); diff --git a/src/TorchSharp/JIT/CompilationUnit.cs b/src/TorchSharp/JIT/CompilationUnit.cs index 85171e4d0..87111677a 100644 --- a/src/TorchSharp/JIT/CompilationUnit.cs +++ b/src/TorchSharp/JIT/CompilationUnit.cs @@ -8,6 +8,7 @@ using static TorchSharp.torch; using System.Net; using static TorchSharp.torch.nn; +using static TorchSharp.torch.jit.ScriptModule; namespace TorchSharp { @@ -74,25 +75,23 @@ public object invoke(string name, params object[] objs) { if (String.IsNullOrEmpty(name)) throw new ArgumentNullException("method name"); - if (!objs.All(o => typeof(Tensor).IsAssignableFrom(o.GetType()))) { - throw new NotImplementedException($"CompilationUnit.{name}() is not yet taking non-tensors as input arguments"); - } + //if (!objs.All(o => typeof(Tensor).IsAssignableFrom(o.GetType()))) { + // throw new NotImplementedException($"CompilationUnit.{name}() is not yet taking non-tensors as input arguments"); + //} - IntPtr[] ptrArray = null; + TensorOrScalar[] ptrArray = null; sbyte typeCode = 0; - using (var parray = new PinnedArray()) { + using (var parray = new PinnedArray()) { - var count = objs.Length; - var tensorRefs = new IntPtr[count]; - for (var i = 0; i < objs.Length; i++) tensorRefs[i] = ((Tensor)objs[i]).Handle; + ScriptModule.DetermineArgumentTypeRefs(objs, out int count, out TensorOrScalar[] tensorRefs); THSJIT_CompilationUnit_Invoke(handle, name, parray.CreateArray(tensorRefs), count, parray.CreateArray, out typeCode); torch.CheckForErrors(); ptrArray = parray.Array; } - return torch.jit.ScriptModule.ProcessReturnValue(name, ptrArray, typeCode); + return ScriptModule.ProcessReturnValue(name, ptrArray, typeCode); } /// diff --git a/src/TorchSharp/JIT/ScriptModule.cs b/src/TorchSharp/JIT/ScriptModule.cs index 251490fc9..a4fad5e84 100644 --- a/src/TorchSharp/JIT/ScriptModule.cs +++ b/src/TorchSharp/JIT/ScriptModule.cs @@ -15,6 +15,9 @@ public static partial class torch { public static partial class jit { + /// + /// This class represents a TorchScript module. + /// public class ScriptModule : torch.nn.Module { internal ScriptModule(IntPtr handle) : base(new HType(handle, true, THSJIT_Module_dispose), null) @@ -272,30 +275,140 @@ private Type GetType(Type type) [DllImport("LibTorchSharp")] private static extern void THSJIT_Module_forward(HType module, IntPtr tensors, int length, AllocatePinnedArray allocator, out sbyte typeCode); + /// + /// Invoke the 'forward' function of the script with any number of arguments. + /// + /// + /// + /// + /// Only certain types can currently be passed: + /// 1. Tensor + /// 2. Scalar + /// 3. int/long + /// 4. double/float + /// 5. bool + /// + /// Only certain types can currently be returned: + /// 1. Tensor / Scalar + /// 2. Tuple of Tensor / Scalar + /// 3. Array (Python list) of Tensor / Scalar + /// + /// For returned types, if the number of values returned in a tuple is greaterh than 5, it is returned as an array, instead. + /// If a tuple contains both tensors and scalars, it is returned as an object[]. + /// + /// public object forward(params object[] objs) { - if (!objs.All(o => typeof(Tensor).IsAssignableFrom(o.GetType()))) { - throw new NotImplementedException("ScriptModule.forward() taking non-tensors as input arguments"); + TensorOrScalar[] ptrArray = null; + sbyte typeCode = 0; + + using (var parray = new PinnedArray()) { + + DetermineArgumentTypeRefs(objs, out int count, out TensorOrScalar[] tensorRefs); + + THSJIT_Module_forward(handle, parray.CreateArray(tensorRefs), count, parray.CreateArray, out typeCode); + torch.CheckForErrors(); + ptrArray = parray.Array; } - IntPtr[] ptrArray = null; + return ProcessReturnValue(name, ptrArray, typeCode); + } + + [DllImport("LibTorchSharp")] + private static extern void THSJIT_Module_invoke(HType module, string name, IntPtr tensors, int length, AllocatePinnedArray allocator, out sbyte typeCode); + + [StructLayout(LayoutKind.Sequential)] + internal struct TensorOrScalar + { + public long TypeCode; + public IntPtr Handle; + } + + /// + /// Invoke a function from the script module. + /// + /// The name of the function. + /// Function arguments. + /// + /// Only certain types can currently be passed: + /// 1. Tensor + /// 2. Scalar + /// 3. int/long + /// 4. double/float + /// 5. bool + /// + /// Only certain types can currently be returned: + /// 1. Tensor / Scalar + /// 2. Tuple of Tensor / Scalar + /// 3. Array (Python list) of Tensor / Scalar + /// + /// For returned types, if the number of values returned in a tuple is greaterh than 5, it is returned as an array, instead. + /// If a tuple contains both tensors and scalars, it is returned as an object[]. + /// + public object invoke(string name, params object[] objs) + { + if (String.IsNullOrEmpty(name)) throw new ArgumentNullException("method name"); + + //if (!objs.All(o => typeof(Tensor).IsAssignableFrom(o.GetType()))) { + // throw new NotImplementedException($"ScriptModule.{name}() is not yet taking non-tensors as input arguments"); + //} + + TensorOrScalar[] ptrArray = null; sbyte typeCode = 0; - using (var parray = new PinnedArray()) { + using (var parray = new PinnedArray()) { - var count = objs.Length; - var tensorRefs = new IntPtr[count]; - for (var i = 0; i < objs.Length; i++) tensorRefs[i] = ((Tensor)objs[i]).Handle; + DetermineArgumentTypeRefs(objs, out int count, out TensorOrScalar[] tensorRefs); - THSJIT_Module_forward(handle, parray.CreateArray(tensorRefs), count, parray.CreateArray, out typeCode); + THSJIT_Module_invoke(handle, name, parray.CreateArray(tensorRefs), count, parray.CreateArray, out typeCode); torch.CheckForErrors(); ptrArray = parray.Array; } - return ProcessReturnValue("forward", ptrArray, typeCode); + return ProcessReturnValue(name, ptrArray, typeCode); + } + + internal static void DetermineArgumentTypeRefs(object[] objs, out int count, out TensorOrScalar[] tensorRefs) + { + count = objs.Length; + tensorRefs = new TensorOrScalar[count]; + for (var idx = 0; idx < objs.Length; idx++) { + switch (objs[idx]) { + case Tensor t: + tensorRefs[idx].Handle = t.Handle; + tensorRefs[idx].TypeCode = 0; + break; + case Scalar s: + tensorRefs[idx].Handle = s.Handle; + tensorRefs[idx].TypeCode = 1; + break; + case float f: + tensorRefs[idx].Handle = ((Scalar)f).Handle; + tensorRefs[idx].TypeCode = 1; + break; + case double d: + tensorRefs[idx].Handle = ((Scalar)d).Handle; + tensorRefs[idx].TypeCode = 1; + break; + case bool i: + tensorRefs[idx].Handle = (IntPtr)(i ? 1L : 0L); + tensorRefs[idx].TypeCode = 2; + break; + case int i: + tensorRefs[idx].Handle = (IntPtr)i; + tensorRefs[idx].TypeCode = 3; + break; + case long l: + tensorRefs[idx].Handle = (IntPtr)l; + tensorRefs[idx].TypeCode = 4; + break; + default: + throw new NotImplementedException($"Passing arguments of type {objs[idx].GetType().Name} to TorchScript."); + } + } } - internal static object ProcessReturnValue(string name, IntPtr[] ptrArray, sbyte typeCode) + internal static object ProcessReturnValue(string name, TensorOrScalar[] ptrArray, sbyte typeCode) { switch (typeCode) { default: @@ -303,25 +416,25 @@ internal static object ProcessReturnValue(string name, IntPtr[] ptrArray, sbyte throw new NotImplementedException($"ScriptModule.{name}() returning something else than a tensor, a tuple of tensors, or list of tensors."); case 1: // Tensor - return new Tensor(ptrArray[0]); + return new Tensor(ptrArray[0].Handle); case 2: // Tuple switch (ptrArray.Length) { case 1: - return new Tensor(ptrArray[0]); + return new Tensor(ptrArray[0].Handle); case 2: - return (new Tensor(ptrArray[0]), new Tensor(ptrArray[1])); + return (new Tensor(ptrArray[0].Handle), new Tensor(ptrArray[1].Handle)); case 3: - return (new Tensor(ptrArray[0]), new Tensor(ptrArray[1]), new Tensor(ptrArray[2])); + return (new Tensor(ptrArray[0].Handle), new Tensor(ptrArray[1].Handle), new Tensor(ptrArray[2].Handle)); case 4: - return (new Tensor(ptrArray[0]), new Tensor(ptrArray[1]), new Tensor(ptrArray[2]), new Tensor(ptrArray[3])); + return (new Tensor(ptrArray[0].Handle), new Tensor(ptrArray[1].Handle), new Tensor(ptrArray[2].Handle), new Tensor(ptrArray[3].Handle)); case 5: - return (new Tensor(ptrArray[0]), new Tensor(ptrArray[1]), new Tensor(ptrArray[2]), new Tensor(ptrArray[3]), new Tensor(ptrArray[4])); + return (new Tensor(ptrArray[0].Handle), new Tensor(ptrArray[1].Handle), new Tensor(ptrArray[2].Handle), new Tensor(ptrArray[3].Handle), new Tensor(ptrArray[4].Handle)); default: { // Too long a tuple, return as a list, instead. var result = new Tensor[ptrArray.Length]; for (var i = 0; i < ptrArray.Length; i++) { - result[i] = new Tensor(ptrArray[i]); + result[i] = new Tensor(ptrArray[i].Handle); } return result; } @@ -330,52 +443,76 @@ internal static object ProcessReturnValue(string name, IntPtr[] ptrArray, sbyte // List of tensors var result = new Tensor[ptrArray.Length]; for (var i = 0; i < ptrArray.Length; i++) { - result[i] = new Tensor(ptrArray[i]); + result[i] = new Tensor(ptrArray[i].Handle); + } + return result; + } + case 4: + // Scalar + return new Scalar(ptrArray[0].Handle); + case 5: + // Scalar tuple + switch (ptrArray.Length) { + case 1: + return new Scalar(ptrArray[0].Handle); + case 2: + return (new Scalar(ptrArray[0].Handle), new Scalar(ptrArray[1].Handle)); + case 3: + return (new Scalar(ptrArray[0].Handle), new Scalar(ptrArray[1].Handle), new Scalar(ptrArray[2].Handle)); + case 4: + return (new Scalar(ptrArray[0].Handle), new Scalar(ptrArray[1].Handle), new Scalar(ptrArray[2].Handle), new Scalar(ptrArray[3].Handle)); + case 5: + return (new Scalar(ptrArray[0].Handle), new Scalar(ptrArray[1].Handle), new Scalar(ptrArray[2].Handle), new Scalar(ptrArray[3].Handle), new Scalar(ptrArray[4].Handle)); + default: { + // Too long a tuple, return as a list, instead. + var result = new Scalar[ptrArray.Length]; + for (var i = 0; i < ptrArray.Length; i++) { + result[i] = new Scalar(ptrArray[i].Handle); + } + return result; + } + } + case 6: { + // List of scalars + var result = new Scalar[ptrArray.Length]; + for (var i = 0; i < ptrArray.Length; i++) { + result[i] = new Scalar(ptrArray[i].Handle); + } + return result; + } + case 7: { + // List of scalars and tensors + var result = new object[ptrArray.Length]; + for (var i = 0; i < ptrArray.Length; i++) { + result[i] = ptrArray[i].TypeCode == 0 ? new Tensor(ptrArray[i].Handle) : new Scalar(ptrArray[i].Handle); } return result; } } } - [DllImport("LibTorchSharp")] - private static extern void THSJIT_Module_invoke(HType module, string name, IntPtr tensors, int length, AllocatePinnedArray allocator, out sbyte typeCode); - - /// - /// Invoke a function from the script module. - /// - /// The name of the function. - /// Function arguments. - public object invoke(string name, params object[] objs) - { - if (String.IsNullOrEmpty(name)) throw new ArgumentNullException("method name"); - - if (!objs.All(o => typeof(Tensor).IsAssignableFrom(o.GetType()))) { - throw new NotImplementedException($"ScriptModule.{name}() is not yet taking non-tensors as input arguments"); - } - - IntPtr[] ptrArray = null; - sbyte typeCode = 0; - - using (var parray = new PinnedArray()) { - - var count = objs.Length; - var tensorRefs = new IntPtr[count]; - for (var i = 0; i < objs.Length; i++) tensorRefs[i] = ((Tensor)objs[i]).Handle; - - THSJIT_Module_invoke(handle, name, parray.CreateArray(tensorRefs), count, parray.CreateArray, out typeCode); - torch.CheckForErrors(); - ptrArray = parray.Array; - } - - return ProcessReturnValue(name, ptrArray, typeCode); - } - /// /// Invoke a function from the script module. /// /// The return type of the TorchScript function. /// The name of the function. /// Function arguments. + /// + /// Only certain types can currently be passed: + /// 1. Tensor + /// 2. Scalar + /// 3. int/long + /// 4. double/float + /// 5. bool + /// + /// Only certain types can currently be returned: + /// 1. Tensor / Scalar + /// 2. Tuple of Tensor / Scalar + /// 3. Array (Python list) of Tensor / Scalar + /// + /// For returned types, if the number of values returned in a tuple is greaterh than 5, it is returned as an array, instead. + /// If a tuple contains both tensors and scalars, it is returned as an object[]. + /// public TResult invoke(string name, params object[] inputs) => (TResult)invoke(name, inputs); /// @@ -385,6 +522,22 @@ public object invoke(string name, params object[] objs) /// The return type of the TorchScript function. /// The name of the function. /// Function arguments. + /// + /// Only certain types can currently be passed: + /// 1. Tensor + /// 2. Scalar + /// 3. int/long + /// 4. double/float + /// 5. bool + /// + /// Only certain types can currently be returned: + /// 1. Tensor / Scalar + /// 2. Tuple of Tensor / Scalar + /// 3. Array (Python list) of Tensor / Scalar + /// + /// For returned types, if the number of values returned in a tuple is greaterh than 5, it is returned as an array, instead. + /// If a tuple contains both tensors and scalars, it is returned as an object[]. + /// public TResult invoke(string name, params T[] inputs) => (TResult)invoke(name, inputs); } @@ -399,7 +552,22 @@ internal ScriptModule(IntPtr handle) : base(handle) { } /// /// Invoke the 'forward' function of the script with one tensor as its argument /// - /// + /// + /// Only certain types can currently be passed: + /// 1. Tensor + /// 2. Scalar + /// 3. int/long + /// 4. double/float + /// 5. bool + /// + /// Only certain types can currently be returned: + /// 1. Tensor / Scalar + /// 2. Tuple of Tensor / Scalar + /// 3. Array (Python list) of Tensor / Scalar + /// + /// For returned types, if the number of values returned in a tuple is greaterh than 5, it is returned as an array, instead. + /// If a tuple contains both tensors and scalars, it is returned as an object[]. + /// public TResult forward(params Tensor[] tensor) { return (TResult)base.forward(tensor); @@ -418,7 +586,22 @@ internal ScriptModule(IntPtr handle) : base(handle) { } /// /// Invoke the 'forward' function of the script with one tensor as its argument /// - /// + /// + /// Only certain types can currently be passed: + /// 1. Tensor + /// 2. Scalar + /// 3. int/long + /// 4. double/float + /// 5. bool + /// + /// Only certain types can currently be returned: + /// 1. Tensor / Scalar + /// 2. Tuple of Tensor / Scalar + /// 3. Array (Python list) of Tensor / Scalar + /// + /// For returned types, if the number of values returned in a tuple is greaterh than 5, it is returned as an array, instead. + /// If a tuple contains both tensors and scalars, it is returned as an object[]. + /// public TResult forward(T tensor) { return (TResult)base.forward(tensor); @@ -438,7 +621,22 @@ internal ScriptModule(IntPtr handle) : base(handle) { } /// /// Invoke the 'forward' function of the script with one tensor as its argument /// - /// + /// + /// Only certain types can currently be passed: + /// 1. Tensor + /// 2. Scalar + /// 3. int/long + /// 4. double/float + /// 5. bool + /// + /// Only certain types can currently be returned: + /// 1. Tensor / Scalar + /// 2. Tuple of Tensor / Scalar + /// 3. Array (Python list) of Tensor / Scalar + /// + /// For returned types, if the number of values returned in a tuple is greaterh than 5, it is returned as an array, instead. + /// If a tuple contains both tensors and scalars, it is returned as an object[]. + /// public TResult forward(T1 input1, T2 input2) { return (TResult)base.forward(input1, input2); diff --git a/test/TorchSharpTest/TestJIT.cs b/test/TorchSharpTest/TestJIT.cs index c5fa1890a..12671379a 100644 --- a/test/TorchSharpTest/TestJIT.cs +++ b/test/TorchSharpTest/TestJIT.cs @@ -5,6 +5,7 @@ using static TorchSharp.torch; using static TorchSharp.torch.nn; using Xunit; +using System.Security.Cryptography; #nullable enable @@ -205,15 +206,18 @@ public void TestJIT_ListOutError() public void TestLoadJIT_Methods() { // class MyModule(nn.Module): - // def __init__(self): + // def __init__(self): // super().__init__() // self.p = nn.Parameter(torch.rand(10)) // def forward(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: - // return x + y, x - y + // return x + y, x - y // // @torch.jit.export // def predict(self, x: Tensor) -> Tensor: - // return x + self.p + // return x + self.p + // @torch.jit.export + // def add_scalar(self, x: Tensor, i: int) -> Tensor: + // return x + i using var m = new TestScriptModule(@"exported.method.dat"); @@ -228,10 +232,15 @@ public void TestLoadJIT_Methods() () => Assert.Equal(x - y, output.Item2) ); - var predict = m.predict(x); + var ones = m.add_scalar(torch.zeros(10), 1); + + Assert.Equal(torch.ones(10), ones); + + var a = torch.rand(10); + var predict = m.predict(a); Assert.Multiple( - () => Assert.NotEqual(x, predict) + () => Assert.NotEqual(a, predict) ); } @@ -252,6 +261,11 @@ public Tensor predict(Tensor input) return m.invoke("predict", input); } + public Tensor add_scalar(Tensor input, int i) + { + return m.invoke("add_scalar", input, i); + } + private torch.jit.ScriptModule<(Tensor, Tensor)> m; } @@ -263,6 +277,12 @@ def relu_script(a, b): return torch.relu(a + b) def relu6_script(a, b): return torch.relu6(a + b) + def add_i(x: Tensor, i: int) -> Tensor: + return x + i + def add_d(x: Tensor, i: float) -> Tensor: + return x + i + def add_ii(x: int, i: int) -> Tuple[int,int]: + return (x + i,x-i) "; var cu = torch.jit.compile(script); @@ -271,10 +291,24 @@ def relu6_script(a, b): var x = torch.randn(3, 4); var y = torch.randn(3, 4); + + var zeros = torch.zeros(3, 4); + var ones = torch.ones(3, 4); + var z = (Tensor)cu.invoke("relu_script", x, y); Assert.Equal(torch.nn.functional.relu(x + y), z); z = cu.invoke("relu6_script", x, y); Assert.Equal(torch.nn.functional.relu6(x + y), z); + z = cu.invoke("add_i", zeros, 1); + Assert.Equal(ones, z); + z = cu.invoke("add_d", zeros, 1.0); + Assert.Equal(ones, z); + + var ss = cu.invoke<(Scalar,Scalar)>("add_ii", 3, 1); + Assert.Multiple( + () => Assert.Equal(4, ss.Item1.ToInt32()), + () => Assert.Equal(2, ss.Item2.ToInt32()) + ); } } } diff --git a/test/TorchSharpTest/exported.method.dat b/test/TorchSharpTest/exported.method.dat index d8686b69b546b93096d196de5c86384b62096d84..310139cb1d69b7ac2b7280180a920e88b2ea344a 100644 GIT binary patch delta 950 zcmdnSw~c>-gDFqy7P}k^E_<%CtL@4&xb08r2->riEwgJlG|T$V;+=M@CMnq;iQ8ls z;LXkq8S8y|lH)cR!dRz`V7-pPe*!jtc?^4GgLal$o5O`YhQb=W|l^}FTK2~&C3 z@hfEH3+y_1j>Yhrzw#olB$bOlzD?EHE#S2Lr-1{T`lwo?@arPA-XUrc+Sd)Hv%>m1$rfp_*z zs8`uvRj>W8+T)n@mSEm%&Ow5_xr}-)6#Q_ zH)P%Vb9ToDQxC1wq6=%o-KT8Y&A;Ibvx(xfot%HKEmNrw$(T4#J4HfbderRDx&!tX z&Yt61vPAIZ{sr;JoqzW4vTn1FTxegz3<^$R)|+Z`UI7E;D9w+!8O^6jf=(2$&z96MHcWAb15_f!^M4&c-+k(iN~HsiaUy73mL(=A(!xOFG-{J!)& ze}_?VgUgZ$VXu^fPOcGjwQ9fC>@dZ5>9Y&Dx3eSeIBi(Rz46YwwqLIb4HmYC`)ny)Y9jw+u}Wq3qZJ)j_SL_0jCI~xF-egIqvS%F7ao+VC3oG@BEs*Bpa4WI0J=I{1ONa4 delta 872 zcmdnSzm0E#gYAw+9s657E9^Yp>)Wl)x3Jd=GPZy5_oVF+P7k|+GYWPGJcI2sZpYfa z={K@diwm?5@Mh;Ye@jx!fPsO*V6q^qEVBdy*W~@Iq8xe!**R`bYzz~3wF&T)1QxP; zGiZ1-d$W{ic(Y9QVT|EyKXtP61yI|I$r~Bv8966kWK^kVWrzZS-T?2s!wLd-pNlMB z)AQsP!?at{E_YMXT{)yP0yK=J9D4KP4a%;*+tu`aW~Kk;3Q^^lYuvBi_+FPYoX4<4 z<>;%+Du)BTDq3IqJeSLxZIYf{|HDhLr@&NENqj=Y&IgCwAEvp)tH(?+tJY87YO_xy zyVAWz+LBKt32u=4R$*6aH zOK-llwr{S`x^wuC;57E_^`@&Y?<@SaDZuriuf-49FCQ2RZ*UiPbZvT7wPW@5{M-wA zL494BO{Uv=CV8vAl~1Zve#>(IqG4LhyE)HQM{SCg6Dc9P`rMoX_R0 zKJ~ZOugLVUb`>`}9=*)F=9p;mxu@4V6fR8t#|R3ZX*wU)-35k4958q$NHNPYNKF37 zCdcCDB*HK`kXd9B>mp7>a$C)6Ee%gaKur>mP=t~VlNs5xWs$>M5k-CnBLg_3v)Ke> zk*(E6QT2=oqG}hL38U`h_iSFwQ&<=#8?s7F4q-Pi5@LlY4)A7U2eGcOLHIz|17nvJ P%wuCkp9fT#rkvV1=! From 7cc7d9b45b057d4aa7a405b76fa219dd26cd7dd3 Mon Sep 17 00:00:00 2001 From: Niklas Gustafsson Date: Tue, 4 Oct 2022 09:02:30 -0700 Subject: [PATCH 07/11] Fixed a C++ problem for Clang --- src/Native/LibTorchSharp/THSJIT.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Native/LibTorchSharp/THSJIT.cpp b/src/Native/LibTorchSharp/THSJIT.cpp index 1d4dcf772..120ad67e8 100644 --- a/src/Native/LibTorchSharp/THSJIT.cpp +++ b/src/Native/LibTorchSharp/THSJIT.cpp @@ -197,7 +197,7 @@ void ReturnHelper(c10::IValue result, TensorOrScalar* (*allocator)(size_t length int foundTensor = 0; int foundScalar = 0; - auto& list = result.toList(); + auto list = result.toList(); TensorOrScalar* output = allocator(list.size()); for (int i = 0; i < list.size(); ++i) From c9c8e81a4343d7714cbfe8ef85e501cd3407d83d Mon Sep 17 00:00:00 2001 From: Niklas Gustafsson Date: Tue, 4 Oct 2022 09:19:42 -0700 Subject: [PATCH 08/11] Trying to make Clang on MacOS happy. --- src/Native/LibTorchSharp/THSJIT.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Native/LibTorchSharp/THSJIT.cpp b/src/Native/LibTorchSharp/THSJIT.cpp index 120ad67e8..b0b8e1c0c 100644 --- a/src/Native/LibTorchSharp/THSJIT.cpp +++ b/src/Native/LibTorchSharp/THSJIT.cpp @@ -285,7 +285,7 @@ std::vector toIValue(const TensorOrScalar* tensorPtrs, const int le tensors.push_back((int)tensorPtrs[i].Handle); break; case 4: - tensors.push_back((long)tensorPtrs[i].Handle); + tensors.push_back(c10::IValue(tensorPtrs[i].Handle)); break; } } From 19e3142eaa0f55af36ac7f534528d31c8e2181bd Mon Sep 17 00:00:00 2001 From: Niklas Gustafsson Date: Tue, 4 Oct 2022 09:38:13 -0700 Subject: [PATCH 09/11] More work to make Clang/MacOS happy. --- src/Native/LibTorchSharp/THSJIT.cpp | 6 +++--- src/TorchSharp/JIT/ScriptModule.cs | 7 +++++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/Native/LibTorchSharp/THSJIT.cpp b/src/Native/LibTorchSharp/THSJIT.cpp index b0b8e1c0c..3b65faf81 100644 --- a/src/Native/LibTorchSharp/THSJIT.cpp +++ b/src/Native/LibTorchSharp/THSJIT.cpp @@ -284,9 +284,9 @@ std::vector toIValue(const TensorOrScalar* tensorPtrs, const int le case 3: tensors.push_back((int)tensorPtrs[i].Handle); break; - case 4: - tensors.push_back(c10::IValue(tensorPtrs[i].Handle)); - break; + //case 4: + // tensors.push_back(c10::IValue(tensorPtrs[i].Handle)); // Clang on MacOS doesn't like. Pass as Scalar from .NET. + // break; } } } diff --git a/src/TorchSharp/JIT/ScriptModule.cs b/src/TorchSharp/JIT/ScriptModule.cs index a4fad5e84..d52ddc277 100644 --- a/src/TorchSharp/JIT/ScriptModule.cs +++ b/src/TorchSharp/JIT/ScriptModule.cs @@ -399,8 +399,11 @@ internal static void DetermineArgumentTypeRefs(object[] objs, out int count, out tensorRefs[idx].TypeCode = 3; break; case long l: - tensorRefs[idx].Handle = (IntPtr)l; - tensorRefs[idx].TypeCode = 4; + tensorRefs[idx].Handle = ((Scalar)l).Handle; + tensorRefs[idx].TypeCode = 1; + // The MacOS version of Clang doesn't like the use of int64_t, so pass as a Scalar instance, instead. + //tensorRefs[idx].Handle = (IntPtr)l; + //tensorRefs[idx].TypeCode = 4; break; default: throw new NotImplementedException($"Passing arguments of type {objs[idx].GetType().Name} to TorchScript."); From e8ad1e863a4b54d5f9028c441175316937a85e23 Mon Sep 17 00:00:00 2001 From: Niklas Gustafsson Date: Tue, 4 Oct 2022 10:26:50 -0700 Subject: [PATCH 10/11] PR feedback accounted for. Readying a new release. --- RELEASENOTES.md | 1 + build/BranchInfo.props | 4 +-- src/Native/LibTorchSharp/THSJIT.cpp | 36 ++++++--------------------- src/TorchSharp/JIT/CompilationUnit.cs | 4 --- src/TorchSharp/JIT/ScriptModule.cs | 4 --- test/TorchSharpTest/TestJIT.cs | 2 +- 6 files changed, 11 insertions(+), 40 deletions(-) diff --git a/RELEASENOTES.md b/RELEASENOTES.md index fe55eb04b..3b2d7e68c 100644 --- a/RELEASENOTES.md +++ b/RELEASENOTES.md @@ -28,6 +28,7 @@ __Fixed Bugs:__ #744 Some of functions with inconsistent argument names
#749 functional.linear is wrong
#761 Stateful optimizers should have support for save/load from disk.
+#771 Support more types for ScriptModule
__API Changes__: diff --git a/build/BranchInfo.props b/build/BranchInfo.props index bb95578ae..d598e1c84 100644 --- a/build/BranchInfo.props +++ b/build/BranchInfo.props @@ -1,8 +1,8 @@ 0 - 97 - 6 + 98 + 0 diff --git a/src/Native/LibTorchSharp/THSJIT.cpp b/src/Native/LibTorchSharp/THSJIT.cpp index 3b65faf81..5db131b9f 100644 --- a/src/Native/LibTorchSharp/THSJIT.cpp +++ b/src/Native/LibTorchSharp/THSJIT.cpp @@ -307,11 +307,11 @@ void THSJIT_Module_invoke(const JITModule module, const char* name, const Tensor { *typeCode = 0; - //CATCH( + CATCH( auto method = (*module)->get_method(name); auto result = method(toIValue(tensorPtrs, length)); ReturnHelper(result, allocator, typeCode); - //) + ) } void THSJIT_CompilationUnit_Invoke(const JITCompilationUnit module, const char* method, const TensorOrScalar* tensorPtrs, const int length, TensorOrScalar* (*allocator)(size_t length), int8_t* typeCode) @@ -387,14 +387,17 @@ void THSJIT_TensorType_dispose(const JITTensorType type) delete type; } +void THSJIT_CompilationUnit_dispose(const JITCompilationUnit module) +{ + delete module; +} + void* THSJIT_Type_cast(const JITType type) { switch ((*type)->kind()) { case c10::TypeKind::TensorType: return new std::shared_ptr((*type)->cast()); - //case c10::TypeKind::DimensionedTensorType: - // return new std::shared_ptr((*type)->cast()); default: return NULL; } @@ -433,8 +436,6 @@ int8_t THSJIT_Type_kind(const JITType type) { case c10::TypeKind::TensorType: return (int8_t)TypeKind::TensorType; - //case c10::TypeKind::DimensionedTensorType: - // return (int8_t)TypeKind::DimensionedTensorType; default: return -1; } @@ -448,29 +449,6 @@ JITType THSJIT_Module_getInputType(JITModule module, int8_t index) return new std::shared_ptr(schema.arguments()[1 + index].type()->cast()); } -//int8_t THSJIT_getScalarFromDimensionedTensorType(const JITDimensionedTensorType type) -//{ -// return (int8_t)(*type)->scalarType(); -//} -// -//int THSJIT_getDimensionedTensorTypeDimensions(const JITDimensionedTensorType type) -//{ -// return (*type)->dim(); -//} -// -//const char* THSJIT_getDimensionedTensorDevice(const JITDimensionedTensorType type) -//{ -// auto device = (*type)->device(); -// -// auto device_type = DeviceTypeName(device.type()); -// -// std::transform(device_type.begin(), device_type.end(), device_type.begin(), ::tolower); -// -// return make_sharable_string(device_type); -//} - - - void THSJIT_typeDispose(const JITType type) { delete type; diff --git a/src/TorchSharp/JIT/CompilationUnit.cs b/src/TorchSharp/JIT/CompilationUnit.cs index 87111677a..10aae0667 100644 --- a/src/TorchSharp/JIT/CompilationUnit.cs +++ b/src/TorchSharp/JIT/CompilationUnit.cs @@ -75,10 +75,6 @@ public object invoke(string name, params object[] objs) { if (String.IsNullOrEmpty(name)) throw new ArgumentNullException("method name"); - //if (!objs.All(o => typeof(Tensor).IsAssignableFrom(o.GetType()))) { - // throw new NotImplementedException($"CompilationUnit.{name}() is not yet taking non-tensors as input arguments"); - //} - TensorOrScalar[] ptrArray = null; sbyte typeCode = 0; diff --git a/src/TorchSharp/JIT/ScriptModule.cs b/src/TorchSharp/JIT/ScriptModule.cs index d52ddc277..0a683859d 100644 --- a/src/TorchSharp/JIT/ScriptModule.cs +++ b/src/TorchSharp/JIT/ScriptModule.cs @@ -349,10 +349,6 @@ public object invoke(string name, params object[] objs) { if (String.IsNullOrEmpty(name)) throw new ArgumentNullException("method name"); - //if (!objs.All(o => typeof(Tensor).IsAssignableFrom(o.GetType()))) { - // throw new NotImplementedException($"ScriptModule.{name}() is not yet taking non-tensors as input arguments"); - //} - TensorOrScalar[] ptrArray = null; sbyte typeCode = 0; diff --git a/test/TorchSharpTest/TestJIT.cs b/test/TorchSharpTest/TestJIT.cs index 12671379a..e138b068e 100644 --- a/test/TorchSharpTest/TestJIT.cs +++ b/test/TorchSharpTest/TestJIT.cs @@ -285,7 +285,7 @@ def add_ii(x: int, i: int) -> Tuple[int,int]: return (x + i,x-i) "; - var cu = torch.jit.compile(script); + using var cu = torch.jit.compile(script); Assert.NotNull(cu); From 3412c2e417197a600ff638459e99f624c96af5ea Mon Sep 17 00:00:00 2001 From: Niklas Gustafsson Date: Tue, 4 Oct 2022 10:29:36 -0700 Subject: [PATCH 11/11] Adding a couple of comments to release notes. --- RELEASENOTES.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/RELEASENOTES.md b/RELEASENOTES.md index 3b2d7e68c..3ecc76dce 100644 --- a/RELEASENOTES.md +++ b/RELEASENOTES.md @@ -36,6 +36,8 @@ Module.to(), cpu(), and cuda() were redone as extension methods. The virtual met Support for saving and restoring hyperparameters and state of optimizers
Loss functions are now Modules rather than delegates.
Custom modules should now use generic versions as base classes.
+ScriptModule supports calling methods other than forward()
+Added torch.jit.compile().
## NuGet Version 0.97.6