# Using TorchScript

This tutorial is an introduction to TorchScript, an intermediate representation of a PyTorch model (subclass of nn.Module) that can then be run in a high-performance environment such as C++.

In [3]:
import torch  # This is all you need to use both PyTorch and TorchScript!
import torch.nn as nn
print(torch.__version__)

1.5.0+cu101



    Specific methods for converting PyTorch modules to TorchScript, our high-performance deployment runtime

    Tracing an existing module
    Using scripting to directly compile a module
    How to compose both approaches
    Saving and loading TorchScript modules


In [0]:
class MyCell(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x, h):
        new_h = torch.tanh(x + h)
        return new_h, new_h

In [0]:
my_cell = MyCell()
x = torch.rand(3, 4)
h = torch.rand(3, 4)

In [6]:
print(my_cell(x, h))

(tensor([[0.5610, 0.0762, 0.8869, 0.6510],
        [0.9313, 0.8051, 0.6170, 0.1723],
        [0.7224, 0.7644, 0.7500, 0.6112]]), tensor([[0.5610, 0.0762, 0.8869, 0.6510],
        [0.9313, 0.8051, 0.6170, 0.1723],
        [0.7224, 0.7644, 0.7500, 0.6112]]))


# Simple Net

In [0]:
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h

In [0]:
my_cell = MyCell()
print(my_cell)
print(my_cell(x, h))

# Basics of TorchScript


Torch Script provides us 2 ways of converting the Python code to low level code.

1. Tracing
2. Scripting



Now let’s take our running example and see how we can apply TorchScript.

In short, TorchScript provides tools to capture the definition of your model, even in light of the flexible and dynamic nature of PyTorch. Let’s begin by examining what we call tracing.

# Tracing

In [0]:
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h

my_cell = MyCell()

In [0]:
x, h = torch.rand(3, 4), torch.rand(3, 4)

In [11]:
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell)

MyCell(
  original_name=MyCell
  (linear): Linear(original_name=Linear)
)


In [12]:
traced_cell(x, h)

(tensor([[ 0.9313, -0.0346,  0.8589,  0.7222],
         [ 0.4903, -0.0333,  0.8580,  0.7414],
         [ 0.8313,  0.1273,  0.8266,  0.1831]], grad_fn=<TanhBackward>),
 tensor([[ 0.9313, -0.0346,  0.8589,  0.7222],
         [ 0.4903, -0.0333,  0.8580,  0.7414],
         [ 0.8313,  0.1273,  0.8266,  0.1831]], grad_fn=<TanhBackward>))

We’ve rewinded a bit and taken the second version of our MyCell class. As before, we’ve instantiated it, but this time, we’ve called torch.jit.trace, passed in the Module, and passed in example inputs the network might see.

What exactly has this done? It has invoked the Module, recorded the operations that occured when the Module was run, and created an instance of torch.jit.ScriptModule (of which TracedModule is an instance)

TorchScript records its definitions in an Intermediate Representation (or IR), commonly referred to in Deep learning as a graph. We can examine the graph with the .graph property:

In [13]:
print(traced_cell.graph)

graph(%self.1 : __torch__.MyCell,
      %input : Float(3, 4),
      %h : Float(3, 4)):
  %19 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)
  %21 : Tensor = prim::CallMethod[name="forward"](%19, %input)
  %12 : int = prim::Constant[value=1]() # <ipython-input-9-c84ad9de827c>:7:0
  %13 : Float(3, 4) = aten::add(%21, %h, %12) # <ipython-input-9-c84ad9de827c>:7:0
  %14 : Float(3, 4) = aten::tanh(%13) # <ipython-input-9-c84ad9de827c>:7:0
  %15 : (Float(3, 4), Float(3, 4)) = prim::TupleConstruct(%14, %14)
  return (%15)



However, this is a very low-level representation and most of the information contained in the graph is not useful for end users. Instead, we can use the .code property to give a Python-syntax interpretation of the code:

In [14]:
print(traced_cell.code)

def forward(self,
    input: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = torch.add((self.linear).forward(input, ), h, alpha=1)
  _1 = torch.tanh(_0)
  return (_1, _1)



# But Why TorchScript ??

So why did we do all this? There are several reasons:

TorchScript code can be invoked in its own interpreter, which is basically a restricted Python interpreter. This interpreter does not acquire the Global Interpreter Lock, and so many requests can be processed on the same instance simultaneously.

This format allows us to save the whole model to disk and load it into another environment, such as in a server written in a language other than Python

TorchScript gives us a representation in which we can do compiler optimizations on the code to provide more efficient execution

TorchScript allows us to interface with many backend/device runtimes that require a broader view of the program than individual operators.

We can see that invoking traced_cell produces the same results as the Python module:

In [15]:
print(my_cell(x, h))
print(traced_cell(x, h))

(tensor([[ 0.9313, -0.0346,  0.8589,  0.7222],
        [ 0.4903, -0.0333,  0.8580,  0.7414],
        [ 0.8313,  0.1273,  0.8266,  0.1831]], grad_fn=<TanhBackward>), tensor([[ 0.9313, -0.0346,  0.8589,  0.7222],
        [ 0.4903, -0.0333,  0.8580,  0.7414],
        [ 0.8313,  0.1273,  0.8266,  0.1831]], grad_fn=<TanhBackward>))
(tensor([[ 0.9313, -0.0346,  0.8589,  0.7222],
        [ 0.4903, -0.0333,  0.8580,  0.7414],
        [ 0.8313,  0.1273,  0.8266,  0.1831]], grad_fn=<TanhBackward>), tensor([[ 0.9313, -0.0346,  0.8589,  0.7222],
        [ 0.4903, -0.0333,  0.8580,  0.7414],
        [ 0.8313,  0.1273,  0.8266,  0.1831]], grad_fn=<TanhBackward>))


# Scripting

Using Scripting to Convert Modules

There’s a reason we used version two of our module, and not the one with the control-flow-laden submodule. Let’s examine that now

In [0]:
class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self, dg):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

my_cell = MyCell(MyDecisionGate())

In [19]:
traced_cell = torch.jit.trace(my_cell, (x, h))
# print(traced_cell.code)

  This is separate from the ipykernel package so we can avoid doing imports until


In [20]:
print(traced_cell.code)

def forward(self,
    input: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = self.dg
  _1 = (self.linear).forward(input, )
  _2 = (_0).forward(_1, )
  _3 = torch.tanh(torch.add(_1, h, alpha=1))
  return (_3, _3)



Looking at the .code output, we can see that the if-else branch is nowhere to be found! 

Why? Tracing does exactly what we said it would: run the code, record the operations that happen and construct a ScriptModule that does exactly that. Unfortunately, things like control flow are erased.

How can we faithfully represent this module in TorchScript? We provide a script compiler, which does direct analysis of your Python source code to transform it into TorchScript. Let’s convert MyDecisionGate using the script compiler:

In [0]:
scripted_gate = torch.jit.script(MyDecisionGate())

In [0]:
my_cell = MyCell(scripted_gate)

In [0]:
traced_cell = torch.jit.script(my_cell)

In [24]:
print(traced_cell.code)

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = (self.dg).forward((self.linear).forward(x, ), )
  new_h = torch.tanh(torch.add(_0, h, alpha=1))
  return (new_h, new_h)



In [25]:
# New inputs
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell(x, h)

(tensor([[0.1735, 0.3896, 0.1230, 0.9110],
         [0.4432, 0.7596, 0.7057, 0.8592],
         [0.1503, 0.7608, 0.0539, 0.8040]], grad_fn=<TanhBackward>),
 tensor([[0.1735, 0.3896, 0.1230, 0.9110],
         [0.4432, 0.7596, 0.7057, 0.8592],
         [0.1503, 0.7608, 0.0539, 0.8040]], grad_fn=<TanhBackward>))

# Mixing Scripting and Tracing

Some situations call for using tracing rather than scripting (e.g. a module has many architectural decisions that are made based on constant Python values that we would like to not appear in TorchScript). 

In this case, scripting can be composed with tracing: torch.jit.script will inline the code for a traced module, and tracing will inline the code for a scripted module.

In [0]:
class MyRNNLoop(torch.nn.Module):
    def __init__(self):
        super(MyRNNLoop, self).__init__()
        self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))

    def forward(self, xs):
        h, y = torch.zeros(3, 4), torch.zeros(3, 4)
        for i in range(xs.size(0)):
            y, h = self.cell(xs[i], h)
        return y, h

rnn_loop = torch.jit.script(MyRNNLoop())

In [27]:
print(rnn_loop.code)

def forward(self,
    xs: Tensor) -> Tuple[Tensor, Tensor]:
  h = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
  y = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
  y0 = y
  h0 = h
  for i in range(torch.size(xs, 0)):
    _0 = (self.cell).forward(torch.select(xs, 0, i), h0, )
    y1, h1, = _0
    y0, h0 = y1, h1
  return (y0, h0)



In [28]:
class WrapRNN(torch.nn.Module):
    def __init__(self):
        super(WrapRNN, self).__init__()
        self.loop = torch.jit.script(MyRNNLoop())

    def forward(self, xs):
        y, h = self.loop(xs)
        return torch.relu(y)

traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
print(traced.code)

def forward(self,
    argument_1: Tensor) -> Tensor:
  _0, h, = (self.loop).forward(argument_1, )
  return torch.relu(h)



This way, scripting and tracing can be used when the situation calls for each of them and used together.

# Saving and Loading Models

We provide APIs to save and load TorchScript modules to/from disk in an archive format. 

This format includes code, parameters, attributes, and debug information, meaning that the archive is a freestanding representation of the model that can be loaded in an entirely separate process. Let’s save and load our wrapped RNN module:

In [0]:
traced.save('wrapped_rnn.zip')

loaded = torch.jit.load('wrapped_rnn.zip')

print(loaded)
print(loaded.code)