TorchScript is an intermediate representation of a PyTorch model that can be run in a high-performance environment such as C++

In [1]:
import torch
print(torch.__version__)

1.11.0


A module consists of three main components:
1. A constructor
2. A set of Parameters and sub-Modules
3. A forward function

In [2]:
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        
    def forward(self, x, h):
        new_h = torch.tanh(x + h)
        return new_h, new_h
    
my_cell = MyCell()
x = torch.rand(3, 4)
h = torch.rand(3, 4)

print(my_cell(x, h))

(tensor([[0.6459, 0.7476, 0.7187, 0.7204],
        [0.8514, 0.8134, 0.7587, 0.7128],
        [0.7636, 0.8745, 0.5874, 0.7938]]), tensor([[0.6459, 0.7476, 0.7187, 0.7204],
        [0.8514, 0.8134, 0.7587, 0.7128],
        [0.7636, 0.8745, 0.5874, 0.7938]]))


### Building a hierarchy of Modules

By using a `torch.nn.Linear` module within the `MyCell` module, we are building a hierarchy of `Module`s.

In [3]:
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)
        
    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h
    
my_cell = MyCell()
print(my_cell)
print(my_cell(x, h))

MyCell(
  (linear): Linear(in_features=4, out_features=4, bias=True)
)
(tensor([[-0.4871, -0.0639,  0.7874, -0.0934],
        [ 0.5208, -0.5190,  0.7104,  0.2067],
        [ 0.5020, -0.3728,  0.1434, -0.0503]], grad_fn=<TanhBackward0>), tensor([[-0.4871, -0.0639,  0.7874, -0.0934],
        [ 0.5208, -0.5190,  0.7104,  0.2067],
        [ 0.5020, -0.3728,  0.1434, -0.0503]], grad_fn=<TanhBackward0>))


In [4]:
class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        # utilizing control flow (if statement)
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.dg = MyDecisionGate()
        self.linear = torch.nn.Linear(4, 4)
        
    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h
    
my_cell = MyCell()
print(my_cell)
print(my_cell(x, h))

MyCell(
  (dg): MyDecisionGate()
  (linear): Linear(in_features=4, out_features=4, bias=True)
)
(tensor([[ 0.2657,  0.4078,  0.4647,  0.6438],
        [ 0.8248,  0.1254,  0.3423,  0.9201],
        [ 0.8192,  0.2591, -0.5026,  0.9065]], grad_fn=<TanhBackward0>), tensor([[ 0.2657,  0.4078,  0.4647,  0.6438],
        [ 0.8248,  0.1254,  0.3423,  0.9201],
        [ 0.8192,  0.2591, -0.5026,  0.9065]], grad_fn=<TanhBackward0>))


### Gradient tape

PyTorch use a gradient tape to record operations as they occur, and replay them backwards in computing derivatives.

### Basics of TorchScript

TorchScript provides tools to capture the definition of your model using `tracing`

In [5]:
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)
        
    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h

my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)

MyCell(
  original_name=MyCell
  (linear): Linear(original_name=Linear)
)


(tensor([[ 0.1199,  0.4323,  0.4247,  0.3825],
         [-0.0318,  0.4097,  0.8126,  0.8418],
         [-0.6526,  0.7871,  0.8530,  0.5522]], grad_fn=<TanhBackward0>),
 tensor([[ 0.1199,  0.4323,  0.4247,  0.3825],
         [-0.0318,  0.4097,  0.8126,  0.8418],
         [-0.6526,  0.7871,  0.8530,  0.5522]], grad_fn=<TanhBackward0>))

TorchScript records its definitions in an Intermediate Representation (IR), commonly referred to in Deep learning as a `graph`

In [6]:
print(traced_cell.graph)

graph(%self.1 : __torch__.MyCell,
      %x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
      %h : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
  %linear : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)
  %20 : Tensor = prim::CallMethod[name="forward"](%linear, %x)
  %11 : int = prim::Constant[value=1]() # /var/folders/yh/yw23prz55635x1mx37sy_kp40000gn/T/ipykernel_6336/4016154612.py:7:0
  %12 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::add(%20, %h, %11) # /var/folders/yh/yw23prz55635x1mx37sy_kp40000gn/T/ipykernel_6336/4016154612.py:7:0
  %13 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::tanh(%12) # /var/folders/yh/yw23prz55635x1mx37sy_kp40000gn/T/ipykernel_6336/4016154612.py:7:0
  %14 : (Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu), Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu)) = prim::TupleConstruct(%13, %13)
  return (%14)



`graph` is hard to understand. Instead, we can use the `.code` property to give a Python-syntax interpretation of the low-level `graph`

In [7]:
print(traced_cell.code)

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  linear = self.linear
  _0 = torch.tanh(torch.add((linear).forward(x, ), h))
  return (_0, _0)



### Benefits of using TorchScript:
1. TorchScript can be invoked in its own interpreter that does not acquire the Global Interpreter Lock (GIL), and so many requests can be processed on the same instance simultaneously.

2. This format allows us to save the whole model to disk and load it into another environment, such as in a server written in a language other than Python.

3. TorchScript gives us a representation in which we can do compiler optimization.

4. TorchScript allows us to interface with many backend/device runtimes.

In [10]:
# verify that traced cell produces the same results as the Python module
print(my_cell(x, h))
print(traced_cell(x, h))

(tensor([[ 0.1199,  0.4323,  0.4247,  0.3825],
        [-0.0318,  0.4097,  0.8126,  0.8418],
        [-0.6526,  0.7871,  0.8530,  0.5522]], grad_fn=<TanhBackward0>), tensor([[ 0.1199,  0.4323,  0.4247,  0.3825],
        [-0.0318,  0.4097,  0.8126,  0.8418],
        [-0.6526,  0.7871,  0.8530,  0.5522]], grad_fn=<TanhBackward0>))
(tensor([[ 0.1199,  0.4323,  0.4247,  0.3825],
        [-0.0318,  0.4097,  0.8126,  0.8418],
        [-0.6526,  0.7871,  0.8530,  0.5522]], grad_fn=<TanhBackward0>), tensor([[ 0.1199,  0.4323,  0.4247,  0.3825],
        [-0.0318,  0.4097,  0.8126,  0.8418],
        [-0.6526,  0.7871,  0.8530,  0.5522]], grad_fn=<TanhBackward0>))


### Script compiler

Tracing does not include `control flows` such as `if-else statements` and `for-loop`

Use `script compiler` to analyze Python source code to transform it into TorchScript

In [13]:
class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x
        
class MyCell(torch.nn.Module):
    def __init__(self, dg):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)
        
    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h
    
my_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))

print(traced_cell.dg.code)
print(traced_cell.code)

def forward(self,
    argument_1: Tensor) -> NoneType:
  return None

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = (linear).forward(x, )
  _1 = (dg).forward(_0, )
  _2 = torch.tanh(torch.add(_0, h))
  return (_2, _2)



  if x.sum() > 0:


Use `script compiler` to convert `MyDecisionGate`

In [19]:
scripted_gate = torch.jit.script(MyDecisionGate())

my_cell = MyCell(scripted_gate)
scripted_cell = torch.jit.script(my_cell)

print(scripted_gate.code)
print(scripted_cell.code)

def forward(self,
    x: Tensor) -> Tensor:
  if bool(torch.gt(torch.sum(x), 0)):
    _0 = x
  else:
    _0 = torch.neg(x)
  return _0

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = torch.add((dg).forward((linear).forward(x, ), ), h)
  new_h = torch.tanh(_0)
  return (new_h, new_h)



In [15]:
# Test the TorchScript
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell(x, h)

(tensor([[-3.8196e-01,  7.5139e-01,  6.2731e-01,  8.3418e-01],
         [ 8.4364e-04,  6.2280e-01,  6.5460e-01,  8.2321e-01],
         [-3.1970e-01,  6.8387e-01,  8.3628e-01,  8.7024e-01]],
        grad_fn=<TanhBackward0>),
 tensor([[-3.8196e-01,  7.5139e-01,  6.2731e-01,  8.3418e-01],
         [ 8.4364e-04,  6.2280e-01,  6.5460e-01,  8.2321e-01],
         [-3.1970e-01,  6.8387e-01,  8.3628e-01,  8.7024e-01]],
        grad_fn=<TanhBackward0>))

### Mixing Scripting and Tracing

Scripting can be composed with tracing as the following two cases:
- `torch.jit.script` will inline the code for a traced module
- `torch.jit.trace` will inline the code for a scripted module

In [24]:
# first case
class MyRNNLoop(torch.nn.Module):
    def __init__(self):
        super(MyRNNLoop, self).__init__()
        self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))
        
    def forward(self, xs):
        h, y = torch.zeros(3, 4), torch.zeros(3, 4)
        for i in range(xs.size(0)):
            y, h = self.cell(xs[i], h)
        return y, h

rnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)

def forward(self,
    xs: Tensor) -> Tuple[Tensor, Tensor]:
  h = torch.zeros([3, 4])
  y = torch.zeros([3, 4])
  y0 = y
  h0 = h
  for i in range(torch.size(xs, 0)):
    cell = self.cell
    _0 = (cell).forward(torch.select(xs, 0, i), h0, )
    y1, h1, = _0
    y0, h0 = y1, h1
  return (y0, h0)



In [25]:
# second case
class WrapRNN(torch.nn.Module):
    def __init__(self):
        super(WrapRNN, self).__init__()
        self.loop = torch.jit.script(MyRNNLoop())

    def forward(self, xs):
        y, h = self.loop(xs)
        return torch.relu(y)
    
traced = torch.jit.trace(
    WrapRNN(),
    (torch.rand(10, 3, 4))
)

print(traced.code)

def forward(self,
    xs: Tensor) -> Tensor:
  loop = self.loop
  _0, y, = (loop).forward(xs, )
  return torch.relu(y)



### Saving and Loading models

The TorchScript format includes code, parameters, attributes, and debug information. It can be loaded into an entirely separate process.

In [26]:
traced.save('wrapped_rnn.pt')

loaded = torch.jit.load('wrapped_rnn.pt')

print(loaded)
print(loaded.code)

RecursiveScriptModule(
  original_name=WrapRNN
  (loop): RecursiveScriptModule(
    original_name=MyRNNLoop
    (cell): RecursiveScriptModule(
      original_name=MyCell
      (dg): RecursiveScriptModule(original_name=MyDecisionGate)
      (linear): RecursiveScriptModule(original_name=Linear)
    )
  )
)
def forward(self,
    xs: Tensor) -> Tensor:
  loop = self.loop
  _0, y, = (loop).forward(xs, )
  return torch.relu(y)

