In [48]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        # torch.jit.trace produces a ScriptModule's conv1 and conv2
        self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
        self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))

    def forward(self, input):
        input = F.relu(self.conv1(input))
        input = F.relu(self.conv2(input))
        return input

scripted_module = torch.jit.script(MyModule())
print(scripted_module)

RecursiveScriptModule(
  original_name=MyModule
  (conv1): Conv2d(original_name=Conv2d)
  (conv2): Conv2d(original_name=Conv2d)
)


RecursiveScriptModule(
  original_name=MyModule
  (conv1): Conv2d(original_name=Conv2d)
  (conv2): Conv2d(original_name=Conv2d)
)

In [43]:
import torch
import torch.nn as nn

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()

    @torch.jit.export
    def some_entry_point(self, input):
        return input + 10

    @torch.jit.ignore
    def python_only_fn(self, input):
        # This function won't be compiled, so any
        # Python APIs can be used
        import pdb
        pdb.set_trace()

    def forward(self, input):
        if self.training:
            self.python_only_fn(input)
        return input * 99

scripted_module = torch.jit.script(MyModule())
print(scripted_module.some_entry_point(torch.randn(2, 2)))
print(scripted_module(torch.randn(2, 2)))

tensor([[ 7.5513,  8.1656],
        [ 8.9144, 11.2090]])
--Return--
None
> [1;32mc:\users\zuolu\appdata\local\temp\ipykernel_25144\2145007978.py[0m(17)[0;36mpython_only_fn[1;34m()[0m



RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "C:\Users\zuolu\AppData\Local\Temp\ipykernel_25144\2145007978.py", line 21, in forward
    def forward(self, input):
        if self.training:
            self.python_only_fn(input)
            ~~~~~~~~~~~~~~~~~~~ <--- HERE
        return input * 99
RuntimeError: BdbQuit: <EMPTY MESSAGE>

At:
  d:\miniconda3\envs\mlc\lib\bdb.py(154): dispatch_return
  d:\miniconda3\envs\mlc\lib\bdb.py(92): trace_dispatch
  C:\Users\zuolu\AppData\Local\Temp\ipykernel_25144\2145007978.py(17): python_only_fn
  d:\miniconda3\envs\mlc\lib\site-packages\torch\jit\_recursive.py(1069): lazy_binding_method
  d:\miniconda3\envs\mlc\lib\site-packages\torch\nn\modules\module.py(1541): _call_impl
  d:\miniconda3\envs\mlc\lib\site-packages\torch\nn\modules\module.py(1532): _wrapped_call_impl
  C:\Users\zuolu\AppData\Local\Temp\ipykernel_25144\2145007978.py(26): <module>
  C:\Users\zuolu\AppData\Roaming\Python\Python38\site-packages\IPython\core\interactiveshell.py(3508): run_code
  C:\Users\zuolu\AppData\Roaming\Python\Python38\site-packages\IPython\core\interactiveshell.py(3448): run_ast_nodes
  C:\Users\zuolu\AppData\Roaming\Python\Python38\site-packages\IPython\core\interactiveshell.py(3269): run_cell_async
  C:\Users\zuolu\AppData\Roaming\Python\Python38\site-packages\IPython\core\async_helpers.py(129): _pseudo_sync_runner
  C:\Users\zuolu\AppData\Roaming\Python\Python38\site-packages\IPython\core\interactiveshell.py(3064): _run_cell
  C:\Users\zuolu\AppData\Roaming\Python\Python38\site-packages\IPython\core\interactiveshell.py(3009): run_cell
  C:\Users\zuolu\AppData\Roaming\Python\Python38\site-packages\ipykernel\zmqshell.py(549): run_cell
  C:\Users\zuolu\AppData\Roaming\Python\Python38\site-packages\ipykernel\ipkernel.py(449): do_execute
  C:\Users\zuolu\AppData\Roaming\Python\Python38\site-packages\ipykernel\kernelbase.py(778): execute_request
  C:\Users\zuolu\AppData\Roaming\Python\Python38\site-packages\ipykernel\ipkernel.py(362): execute_request
  C:\Users\zuolu\AppData\Roaming\Python\Python38\site-packages\ipykernel\kernelbase.py(437): dispatch_shell
  C:\Users\zuolu\AppData\Roaming\Python\Python38\site-packages\ipykernel\kernelbase.py(534): process_one
  C:\Users\zuolu\AppData\Roaming\Python\Python38\site-packages\ipykernel\kernelbase.py(545): dispatch_queue
  d:\miniconda3\envs\mlc\lib\asyncio\events.py(81): _run
  d:\miniconda3\envs\mlc\lib\asyncio\base_events.py(1859): _run_once
  d:\miniconda3\envs\mlc\lib\asyncio\base_events.py(570): run_forever
  C:\Users\zuolu\AppData\Roaming\Python\Python38\site-packages\tornado\platform\asyncio.py(205): start
  C:\Users\zuolu\AppData\Roaming\Python\Python38\site-packages\ipykernel\kernelapp.py(739): start
  d:\miniconda3\envs\mlc\lib\site-packages\traitlets\config\application.py(985): launch_instance
  C:\Users\zuolu\AppData\Roaming\Python\Python38\site-packages\ipykernel_launcher.py(18): <module>
  d:\miniconda3\envs\mlc\lib\runpy.py(87): _run_code
  d:\miniconda3\envs\mlc\lib\runpy.py(194): _run_module_as_main



In [44]:
import torch

@torch.jit.script
def foo(x, y):
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r

print(type(foo))  # torch.jit.ScriptFunction

# See the compiled graph as Python code
print(foo.code)

# Call the function using the TorchScript interpreter
foo(torch.ones(2, 2), torch.ones(2, 2))

<class 'torch.jit.ScriptFunction'>
def foo(x: Tensor,
    y: Tensor) -> Tensor:
  _0 = bool(torch.gt(torch.max(x), torch.max(y)))
  if _0:
    r = x
  else:
    r = y
  return r



tensor([[1., 1.],
        [1., 1.]])

In [40]:
import torch

def test_sum(a, b):
    return a + b

# Annotate the arguments to be int
scripted_fn = torch.jit.script(test_sum, example_inputs=[(3, 4)])

print(type(scripted_fn))  # torch.jit.ScriptFunction

# See the compiled graph as Python code
print(scripted_fn.code)

# Call the function using the TorchScript interpreter
a= torch.rand(1,2)
b= torch.rand(1,2)
# scripted_fn(20, 100)
scripted_fn(a,b)

<class 'torch.jit.ScriptFunction'>
def test_sum(a: Tensor,
    b: Tensor) -> Tensor:
  return torch.add(a, b)



tensor([[1.1167, 1.2970]])

In [41]:
import torch

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super().__init__()
        # This parameter will be copied to the new ScriptModule
        self.weight = torch.nn.Parameter(torch.rand(N, M))

        # When this submodule is used, it will be compiled
        self.linear = torch.nn.Linear(N, M)

    def forward(self, input):
        output = self.weight.mv(input)

        # This calls the `forward` method of the `nn.Linear` module, which will
        # cause the `self.linear` submodule to be compiled to a `ScriptModule` here
        output = self.linear(output)
        return output

scripted_module = torch.jit.script(MyModule(2, 3))

In [None]:
print(scripted_module)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        # torch.jit.trace produces a ScriptModule's conv1 and conv2
        self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
        self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))

    def forward(self, input):
        input = F.relu(self.conv1(input))
        input = F.relu(self.conv2(input))
        return input

scripted_module = torch.jit.script(MyModule())

In [None]:
print(scripted_module)

In [None]:
# torch.jit.trace for functions
import torch

def foo(x, y):
    return 2 * x + y

traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

# torch.jit.trace for modules
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x)

n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)
traced_module = torch.jit.trace(n, example_forward_input)

print(traced_module)

In [51]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        # torch.jit.trace produces a ScriptModule's conv1 and conv2
        self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
        self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))

    def forward(self, input):
        input = F.relu(self.conv1(input))
        input = F.relu(self.conv2(input))
        return input

scripted_module = torch.jit.script(MyModule())

print(scripted_module)

RecursiveScriptModule(
  original_name=MyModule
  (conv1): Conv2d(original_name=Conv2d)
  (conv2): Conv2d(original_name=Conv2d)
)


In [53]:
!pip  install tabulate


Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple, https://pypi.ngc.nvidia.com
Collecting tabulate
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl (35 kB)
Installing collected packages: tabulate
Successfully installed tabulate-0.9.0


In [54]:
from typing import List
import torch

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print(">>> my_compiler() invoked:")
    print(">>> FX graph:")
    gm.graph.print_tabular()
    print(f">>> Code:\n{gm.code}")
    return gm.forward  # return a python callable

@torch.compile(backend=my_compiler)
def foo(x, y):
    return (x + y) * x

if __name__ == "__main__":
    a, b = torch.randn(10), torch.ones(10)
    foo(a, b)

>>> my_compiler() invoked:
>>> FX graph:
opcode         name    target                   args          kwargs
-------------  ------  -----------------------  ------------  --------
placeholder    l_x_    L_x_                     ()            {}
placeholder    l_y_    L_y_                     ()            {}
call_function  add     <built-in function add>  (l_x_, l_y_)  {}
call_function  mul     <built-in function mul>  (add, l_x_)   {}
output         output  output                   ((mul,),)     {}
>>> Code:



def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
    l_x_ = L_x_
    l_y_ = L_y_
    add = l_x_ + l_y_;  l_y_ = None
    mul = add * l_x_;  add = l_x_ = None
    return (mul,)
    


In [55]:
import dis

def hello():
    print("Hello, world!")

for k in ["co_names", "co_varnames", "co_consts"]:
    print(k, getattr(hello.__code__, k))
print(dis.dis(hello))

co_names ('print',)
co_varnames ()
co_consts (None, 'Hello, world!')
  4           0 LOAD_GLOBAL              0 (print)
              2 LOAD_CONST               1 ('Hello, world!')
              4 CALL_FUNCTION            1
              6 POP_TOP
              8 LOAD_CONST               0 (None)
             10 RETURN_VALUE
None


In [56]:
@torch.compile(backend=my_compiler)
def toy_example(x):
    x = x / (torch.abs(x) + 1)
    return x

def test():
    x = torch.randn(10)
    toy_example(x)
    x = torch.randn(20)
    toy_example(x)

In [62]:
import torch
import torchvision.models as models
import torch._dynamo
torch._dynamo.config.suppress_errors = True
torch.set_float32_matmul_precision('high')

model = models.resnet18().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)


# compiled_model = torch. compile (model)   # 关键一行

# reduce-overhead: optimizes to reduce the framework overhead
#                and uses some extra memory. Helps speed up small models
# torch.compile(model, mode="reduce-overhead")

# max-autotune: optimizes to produce the fastest model,
#               but takes a very long time to compile
compiled_model=torch.compile(model, mode="max-autotune")


x = torch.randn(16, 3, 224, 224).cuda()
optimizer.zero_grad()
out = compiled_model(x)
out.sum().backward()

optimizer.step()



exported_model = torch._dynamo.export(model, input)
torch.save(exported_model, "foo.pt")

In [None]:
torch.save(model.state_dict(), "foo.pt")
# both these lines of code do the same thing
torch.save(model.state_dict(), "foo.pt")