In [1]:
import torch
import numpy as np

## JIT - Just-in-time compiler

Eager mode (Prototype, debug, train, experiment) -> (Tracing/ Scripting) -> Script mode (Optimization, other languages, deployment)

### Tracing

In [2]:
def my_function(x):
    if x.mean() > 1.0:
        r = torch.tensor(1.0)
    else:
        r = torch.tensor(2.0)
    return r

In [3]:
ftrace = torch.jit.trace(my_function, (torch.ones(2, 2)))

  if x.mean() > 1.0:
  r = torch.tensor(2.0)


In [4]:
ftrace.graph

graph(%x : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu)):
  %5 : Float(requires_grad=0, device=cpu) = prim::Constant[value={2}]() # /tmp/ipykernel_173664/1534898225.py:3:0
  %6 : Device = prim::Constant[value="cpu"]() # /tmp/ipykernel_173664/1534898225.py:3:0
  %7 : int = prim::Constant[value=6]() # /tmp/ipykernel_173664/1534898225.py:3:0
  %8 : bool = prim::Constant[value=0]() # /tmp/ipykernel_173664/1534898225.py:3:0
  %9 : bool = prim::Constant[value=0]() # /tmp/ipykernel_173664/1534898225.py:3:0
  %10 : NoneType = prim::Constant()
  %11 : Float(requires_grad=0, device=cpu) = aten::to(%5, %6, %7, %8, %9, %10) # /tmp/ipykernel_173664/1534898225.py:3:0
  %12 : Float(requires_grad=0, device=cpu) = aten::detach(%11) # /tmp/ipykernel_173664/1534898225.py:3:0
  return (%12)

To call the JIT’ed function, just call the `forward()` method:

In [5]:
x = torch.ones(2, 2)

In [7]:
ftrace.forward(x)

AttributeError: 'torch._C.ScriptFunction' object has no attribute 'forward'

### Scripting

In [8]:
@torch.jit.script
def my_function(x):
    if x.mean() > 1.0:
        r = torch.tensor(1.0)
    else:
        r = torch.tensor(2.0)
    return r

In [9]:
my_function.graph

graph(%x.1 : Tensor):
  %10 : bool = prim::Constant[value=0]()
  %2 : NoneType = prim::Constant()
  %4 : float = prim::Constant[value=1.]() # /tmp/ipykernel_173664/1734323466.py:3:18
  %12 : float = prim::Constant[value=2.]() # /tmp/ipykernel_173664/1734323466.py:6:25
  %3 : Tensor = aten::mean(%x.1, %2) # /tmp/ipykernel_173664/1734323466.py:3:7
  %5 : Tensor = aten::gt(%3, %4) # /tmp/ipykernel_173664/1734323466.py:3:7
  %7 : bool = aten::Bool(%5) # /tmp/ipykernel_173664/1734323466.py:3:7
  %r : Tensor = prim::If(%7) # /tmp/ipykernel_173664/1734323466.py:3:4
    block0():
      %r.1 : Tensor = aten::tensor(%4, %2, %2, %10) # /tmp/ipykernel_173664/1734323466.py:4:12
      -> (%r.1)
    block1():
      %r.3 : Tensor = aten::tensor(%12, %2, %2, %10) # /tmp/ipykernel_173664/1734323466.py:6:12
      -> (%r.3)
  return (%r)

In [10]:
type(my_function)

torch.jit.ScriptFunction

In [11]:
x = torch.ones(2, 2)

In [12]:
my_function(x)

tensor(2.)

In [13]:
x = torch.ones(2, 2).add_(1.0)

In [14]:
my_function(x) # Control-flow logic was preserved!

tensor(1.)