# Torch.jit.script
Following the examples from https://pytorch.org/docs/stable/generated/torch.jit.script.html#torch.jit.script

## Scripting a function

In [1]:
import torch


def test_sum(a, b):
    return a + b


# Annotate the arguments to be int
scripted_fn = torch.jit.script(test_sum, example_inputs=[(3, 4)])

print(type(scripted_fn))  # torch.jit.ScriptFunction

# See the compiled graph as Python code
print(scripted_fn.code)

# Call the function using the TorchScript interpreter
scripted_fn(20, 100)

<class 'torch.jit.ScriptFunction'>
def test_sum(a: int,
    b: int) -> int:
  return torch.add(a, b)



120

## Scripting a module

In [None]:
import torch
from torch.jit import ScriptModule


# Example (scripting a simple module with a Parameter):
class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super().__init__()
        # This parameter will be copied to the new ScriptModule
        self.weight = torch.nn.Parameter(torch.rand(N, M))

        # When this submodule is used, it will be compiled
        self.linear = torch.nn.Linear(N, M)

    def forward(self, input):
        output = self.weight.mv(input)

        # This calls the `forward` method of the `nn.Linear` module, which will
        # cause the `self.linear` submodule to be compiled to a `ScriptModule` here
        output = self.linear(output)
        return output


model = MyModule(2, 3)
scripted_module: ScriptModule = torch.jit.script(model)

print(scripted_module.code)

def forward(self,
    input: Tensor) -> Tensor:
  weight = self.weight
  output = torch.mv(weight, input)
  linear = self.linear
  return (linear).forward(output, )

