Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions RELEASENOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,30 @@ Doing so (rather than qualifying names with 'TorchSharp.') was already recommend

__Loss functions are now aligned with the PyTorch APIs.__ This is a major change and the reason for incrementing the minor version number. The most direct consequence is that losses are modules rather than delegates, which means you need to call .forward() to actually compute the loss. Also, the factories are in torch.nn rather than torch.nn.functional and have the same Pascal-case names as the corresponding types. The members of the torch.nn.functional static class are now proper immediate loss functions, whereas the previous ones returned a loss delegate.

__Generic Module base class.__ The second major change is that Module is made type-safe with respect to the `forward()` function. Module is now an abstract base class, and interfaces `IModule<T,TResult>`, `IModule<T1,T2,TResult>`,... are introduced to define the signature of the `forward()` function. For most custom modules, this means that the base class has to be changed to `Module<Tensor,Tensor>`, but some modules may require more significant changes.

ScriptModule follows this pattern, but this version introduces `ScriptModule<T...,TResult>` base classes, with corresponding `torch.jit.load<T...,TResult>()` static factory methods.

__Fixed Bugs:__

#323 forward() should take a variable-length list of arguments<br/>
#558 Fix deviation from the Pytorch loss function/module APIs<br/>
#742 Ease of use: Module.to method should be generic T -> T<br/>
#743 Ease of use: module factories should have dtype and device<br/>
#745 Executing a TorchScript that returns multiple values, throws an exception<br/>
#744 Some of functions with inconsistent argument names<br/>
#749 functional.linear is wrong<br/>
#761 Stateful optimizers should have support for save/load from disk.
#761 Stateful optimizers should have support for save/load from disk.<br/>
#771 Support more types for ScriptModule<br/>

__API Changes__:

Module.to(), cpu(), and cuda() were redone as extension methods. The virtual methods to override, if necessary, are now named '_to'. A need to do so should be extremely rare.<br/>
Support for saving and restoring hyperparameters and state of optimizers<br/>

Loss functions are now Modules rather than delegates.<br/>
Custom modules should now use generic versions as base classes.<br/>
ScriptModule supports calling methods other than forward()<br/>
Added torch.jit.compile().<br/>

## NuGet Version 0.97.6

Expand Down
4 changes: 2 additions & 2 deletions build/BranchInfo.props
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
<Project>
<PropertyGroup>
<MajorVersion>0</MajorVersion>
<MinorVersion>97</MinorVersion>
<PatchVersion>6</PatchVersion>
<MinorVersion>98</MinorVersion>
<PatchVersion>0</PatchVersion>
</PropertyGroup>

</Project>
56 changes: 30 additions & 26 deletions docfx/articles/modules.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,26 @@ Unfortunatly, the word 'Module' is one of the most overloaded terms in software.

In the context of TorchSharp, it means the same as in PyTorch: the fundamental building block of all models is the 'Module' class. All neural network layers are derived from Module and the way to create a model for training and inference in your code is to create a new Module. Without it, back-propagation will not work.

## Sequential

For the most basic network architecture, which actually covers a surprising breadth, there is a simple way to create modules -- the 'Sequential' class. This class is created from a list of Modules (i.e. model components). When data is passed to the model in the form of a tensor, the Sequential instance will invoke each submodule in the order they were passed when the Sequential instance was created, passing the output from each layer to the next. The output of the final submodule will be the output of the Sequential instance.

```C#
var seq = Sequential(("lin1", Linear(100, 10)), ("lin2", Linear(10, 5)));
...
seq.forward(some_data);
```

There is no real performance reason to use Sequential instead of rolling your own custom module, but there is much less code to write. That said, if using a custom module (up next) is your preference, for whatever reason, that's what you should do. It can be useful to create a custom module first, debug it, and then convert to using Sequential.

Sequential can only string together modules that take a single tensor and return a single tensor -- anything else needs to be a custom module.

## Custom Modules

A custom module is created by deriving a subclass from torch.nn.Module. One that is equivalent to the previous example looks like this:
A custom module is created by deriving a subclass from torch.nn.Module<T...,TResult>. The first generic parameters denote the input types of the module's forward() method:

```C#
private class TestModule1 : Module
private class TestModule1 : Module<Tensor,Tensor>
{
public TestModule1()
: base("TestModule1")
Expand All @@ -36,9 +50,11 @@ Custom modules should always call `RegisterComponents()` once all submodules and

The `forward()` method contains the computation of the module. It can contain a mix of TorchSharp primitives, layers, as well as any .NET code. Note, however, that only TorchSharp APIs are capable of operating on data residing in CUDA memory. Therefore, if performance is of the essence, expressing all computation in terms of TorchSharp APIs is essential. Non-TorchSharp APIs should be limited to things that aren't related to the tensor data, things like logging, for example.

In PyTorch, the `forward()` method takes an arbitrary number of arguments, of any type, and supports using default arguments. Currently, TorchSharp only defines two versions of `forward()` -- taking one or two tensors, and returning a single tensor. The TorchSharp implementation of Sequential assumes that it is passing only a single tensor along between its layers. Therefore, any model that needs to pass multiple arguments between layers will have to be custom.
In PyTorch, the `forward()` method takes an arbitrary number of arguments, of any type, and supports using default arguments.

Note that the local variable 'x' was declared in a using statement. This is important in order to deallocate the native memory associated with it as soon as it is no longer needed. For that reason, it is important to pull out temporaries like this into local variables, especially when the code is running on a GPU. (More on this at: [Dispose vs. GC in TorchSharp](memory.md))
In TorchSharp, the signature of `forward()` is determined by the type parameter signature of the Module<> base class. In most cases, the input and output are both `torch.Tensor`. As noted above, the Sequential module collection class assumes that it is passing only a single tensor along between its layers. Therefore, any model that needs to pass multiple arguments between layers will have to be custom.

Going back to the code inside the `forward()` method -- please note that the local variable 'x' was declared in a using statement. This is important in order to deallocate the native memory associated with it as soon as it is no longer needed. For that reason, it is important to pull out temporaries like this into local variables, especially when the code is running on a GPU. (More on this at: [Dispose vs. GC in TorchSharp](memory.md))

In other words, the following code is less memory efficient, because it delays reclaiming native memory until the next time GC is run:

Expand All @@ -48,18 +64,6 @@ In other words, the following code is less memory efficient, because it delays r
return lin2.forward(lin1.forward(input));
}
```
## Sequential

For the simplest network architecture, which actually covers a surprising breadth, there is a simplified way to create modules -- the 'Sequential' class. This class is created from a list of Modules (i.e. model components). When data is passed to the model, the Sequential instance will invoke each submodule in the order they were passed when the Sequential instance was created, passing the output from each layer to the next. The output of the final submodule will be the output of the Sequential instance.

```C#
var seq = Sequential(("lin1", Linear(100, 10)), ("lin2", Linear(10, 5)));
...
seq.forward(some_data);
```

There is no real performance reason to use Sequential instead of rolling your own custom module, but there is much less code to write. That said, if using a custom module (up next) is your preference, for whatever reason, that's what you should do. It can be useful to create a custom module first, debug it, and then convert to using Sequential.


## Using Sequential Inside A Custome Module

Expand Down Expand Up @@ -131,12 +135,12 @@ To illustrate, this is the code for MobileNet from the TorchSharp examples:

## ModuleList

In some circumstances, it's useful to define a dynamic number of modules in a custom module. It could be because you want to parameterize the network architecture, or dynamically choose which layers to run, or just that its tedious to define so many fields. This may be addressed by using a ModuleList to contain the submodules. Unlike Sequential, ModuleList itself does not suffice -- its `forward()` method will throw an exception if invoked.
In some circumstances, it's useful to define a dynamic number of modules in a custom module. It could be because you want to parameterize the network architecture, or dynamically choose which layers to run, or just that its tedious to define so many fields. This may be addressed by using a ModuleList to contain the submodules. Unlike Sequential, ModuleList itself does not suffice -- its `forward()` method will throw an exception if invoked, and you must iterate through the modules of the list directly in your `forward()` implementation.

The purpose is simply to provide a list implementation that automatically registers the submodules when components are registered. You have to iterate through the list in the `forward()` method:

```C#
private class TestModule1 : Module
private class TestModule1 : Module<Tensor,Tensor>
{
public TestModule1()
: base("TestModule1")
Expand All @@ -157,12 +161,12 @@ The purpose is simply to provide a list implementation that automatically regist

## ModuleDict

In some circumstances, it's useful to define a dynamic number of modules in a custom module. It could be because you want to parameterize the network architecture, or dynamically choose which layers to run, or just that its tedious to define so many fields. This may be addressed by using a ModuleList to contain the submodules. Unlike Sequential, ModuleList itself does not suffice -- its `forward()` method will throw an exception if invoked.
In some circumstances, it's useful to define a dynamic number of modules in a custom module and use a dictionary to hold them. Like with `ModuleList`, it could be because you want to parameterize the network architecture, or dynamically choose which layers to run, or just that its tedious to define so many fields. This may be addressed by using a ModuleDict to contain the submodules.

The purpose is simply to provide a list implementation that automatically registers the submodules when components are registered. You have to iterate through the list in the `forward()` method:

```C#
private class TestModule1 : Module
private class TestModule1 : Module<Tensor,Tensor>
{
public TestModule1()
: base("TestModule1")
Expand Down Expand Up @@ -191,7 +195,7 @@ Many modules are just compositions of existing modules, but sometimes it will im
For example, a re-implementation of 'Linear' would look something like:

```C#
private class MyLinear : Module
private class MyLinear : Module<Tensor,Tensor>
{
public MyLinear(long input_size, long output_size)
: base("MyLinear")
Expand All @@ -212,9 +216,9 @@ For example, a re-implementation of 'Linear' would look something like:
}
```

In this case, we're not relying on 'using' in the `forward()` method, because the temporary is reused as the target by the `add_()` function.
In this case, we're not relying on 'using' in the `forward()` method, because the "temporary" is reused as the target by the `add_()` function.

Parameter's dirty little secret is that it will clean out the tensor that is given to its constructor. So, `Parameter()` is preferrably used with another tensor factory (such as in the example above), or a cloned tensor.
Parameter's dirty little secret is that it will clean out the tensor that is given to its constructor. So, `Parameter()` is preferrably used with another tensor factory (such as in the example above), or a cloned tensor. Once you have passes a tensor to the Parameter constructor, the original tensor is invalidated.

## ParameterList

Expand All @@ -226,7 +230,7 @@ Much like ModuleDict, ParameterDict is a dictionary of Parameter instances, whic

## Buffers

Sometimes, a module needs to allocate tensor that are not trainable, i.e. their values are not modified during back-propagation. An example is a random dropout mask. These are referred to as 'buffers' as opposed to 'parameters' and are treated differently by `RegisterComponents()` -- even though they are not trainable, the native runtime still wants to know about them for other purposes, so it is important to declare them in the module.
Sometimes, a module needs to allocate tensor that are not trainable, i.e. their values are not modified during back-propagation. An example is a random dropout mask. These are referred to as 'buffers' as opposed to 'parameters' and are treated differently by `RegisterComponents()` -- even though they are not trainable, the native runtime still wants to know about them for other purposes, such as storing them to disk, so it is important to declare them in the module.

Each buffer should be declared as a field of type 'Tensor' (not 'Parameter'). This will ensure that the buffer is registered properly when `RegisterComponents()` is called.

Expand All @@ -238,7 +242,7 @@ It is sometimes necessary to create a new model from an existing one and discard
So, for example:

```C#
private class TestModule1 : Module
private class TestModule1 : Module<Tensor,Tensor>
{
public TestModule1()
: base("TestModule1")
Expand All @@ -259,7 +263,7 @@ So, for example:
private Module lin2;
}

private class TestModule2 : Module
private class TestModule2 : Module<Tensor,Tensor>
{
public TestModule2()
: base("TestModule2")
Expand Down
Loading