<a href="https://colab.research.google.com/github/finardi/tutos/blob/master/Jit_Example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

ref: https://spell.ml/blog/pytorch-jit-YBmYuBEAACgAiv71

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import time


# ====================================================
# Poderia usar torch.nn.Conv2d, mas é fins didáticos
# ====================================================

class Conv2d(nn.Module):
    def __init__(
        self, n_channels, out_channels, kernel_size, dilation=1, padding=0, stride=1
    ):
        super().__init__()

        self.kernel_size = kernel_size
        self.kernel_size_number = kernel_size * kernel_size
        self.out_channels = out_channels
        self.padding = padding
        self.dilation = dilation
        self.stride = stride
        self.n_channels = n_channels
        self.weights = nn.Parameter(
            torch.Tensor(self.out_channels, self.n_channels, self.kernel_size**2)
        )

    def __repr__(self):
        return (
            f"Conv2d(n_channels={self.n_channels}, out_channels={self.out_channels}, "
            f"kernel_size={self.kernel_size})"
        )
    
    def forward(self, x):
        width = self.calculate_new_width(x)
        height = self.calculate_new_height(x)
        windows = self.calculate_windows(x)
        
        result = torch.zeros(
            [x.shape[0] * self.out_channels, width, height],
            dtype=torch.float32, device=x.device
        )

        for channel in range(x.shape[1]):
            for i_conv_n in range(self.out_channels):
                xx = torch.matmul(windows[channel], self.weights[i_conv_n][channel]) 
                xx = xx.view((-1, width, height))
                
                xx_stride = slice(i_conv_n * xx.shape[0], (i_conv_n + 1) * xx.shape[0])
                result[xx_stride] += xx

        result = result.view((x.shape[0], self.out_channels, width, height))
        return result  

    def calculate_windows(self, x):
        windows = F.unfold(
            x,
            kernel_size=(self.kernel_size, self.kernel_size),
            padding=(self.padding, self.padding),
            dilation=(self.dilation, self.dilation),
            stride=(self.stride, self.stride)
        )

        windows = (windows
            .transpose(1, 2)
            .contiguous().view((-1, x.shape[1], int(self.kernel_size**2)))
            .transpose(0, 1)
        )
        return windows

    def calculate_new_width(self, x):
        return (
            (x.shape[2] + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1)
            // self.stride
        ) + 1

    def calculate_new_height(self, x):
        return (
            (x.shape[3] + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1)
            // self.stride
        ) + 1

In [18]:
start = time.time()
x = torch.randint(0, 255, (1, 3, 512, 512), device='cuda') / 255
conv = Conv2d(3, 16, 3)
conv.cuda()

out = conv(x)
out.mean().backward()

end = time.time()
elapsed = end-start

print(f'elapsed time: {elapsed:.3}s')

elapsed time: 0.0259s


In [16]:
class Conv2d(torch.jit.ScriptModule): # <- OLD was class Conv2d(nn.Module)
    def __init__(
        self, n_channels, out_channels, kernel_size, dilation=1, padding=0, stride=1
    ):
        super().__init__()

        self.kernel_size = kernel_size
        self.kernel_size_number = kernel_size * kernel_size
        self.out_channels = out_channels
        self.padding = padding
        self.dilation = dilation
        self.stride = stride
        self.n_channels = n_channels
        self.weights = nn.Parameter(
            torch.Tensor(self.out_channels, self.n_channels, self.kernel_size**2)
        )

    def __repr__(self):
        return (
            f"Conv2d(n_channels={self.n_channels}, out_channels={self.out_channels}, "
            f"kernel_size={self.kernel_size})"
        )
    
    @torch.jit.script_method # <- insert decorator
    def forward(self, x):
        width = self.calculate_new_width(x)
        height = self.calculate_new_height(x)
        windows = self.calculate_windows(x)
        
        result = torch.zeros(
            [x.shape[0] * self.out_channels, width, height],
            dtype=torch.float32, device=x.device
        )

        for channel in range(x.shape[1]):
            for i_conv_n in range(self.out_channels):
                xx = torch.matmul(windows[channel], self.weights[i_conv_n][channel]) 
                xx = xx.view((-1, width, height))
                
                xx_stride = slice(i_conv_n * xx.shape[0], (i_conv_n + 1) * xx.shape[0])
                result[xx_stride] += xx

        result = result.view((x.shape[0], self.out_channels, width, height))
        return result  

    def calculate_windows(self, x):
        windows = F.unfold(
            x,
            kernel_size=(self.kernel_size, self.kernel_size),
            padding=(self.padding, self.padding),
            dilation=(self.dilation, self.dilation),
            stride=(self.stride, self.stride)
        )

        windows = (windows
            .transpose(1, 2)
            .contiguous().view((-1, x.shape[1], int(self.kernel_size**2)))
            .transpose(0, 1)
        )
        return windows

    def calculate_new_width(self, x):
        return (
            (x.shape[2] + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1)
            // self.stride
        ) + 1

    def calculate_new_height(self, x):
        return (
            (x.shape[3] + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1)
            // self.stride
        ) + 1

In [20]:
start = time.time()
x = torch.randint(0, 255, (1, 3, 512, 512), device='cuda') / 255
conv = Conv2d(3, 16, 3)
conv.cuda()

out = conv(x)
out.mean().backward()

end = time.time()
elapsed_jit = end-start

print(f'elapsed time: {elapsed_jit:.3}s')

elapsed time: 0.0172s


In [23]:
print(f'SpeedUp: {elapsed/elapsed_jit:.3}x')

SpeedUp: 1.51x


# Pytorch IR Graph

In [24]:
@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv

print(foo.graph)

graph(%len.1 : int):
  %21 : int = prim::Constant[value=1]()
  %13 : bool = prim::Constant[value=1]() # <ipython-input-24-8a34e03747f9>:5:4
  %5 : NoneType = prim::Constant()
  %1 : int = prim::Constant[value=3]() # <ipython-input-24-8a34e03747f9>:4:21
  %2 : int = prim::Constant[value=4]() # <ipython-input-24-8a34e03747f9>:4:24
  %16 : int = prim::Constant[value=10]() # <ipython-input-24-8a34e03747f9>:6:15
  %20 : float = prim::Constant[value=1.]() # <ipython-input-24-8a34e03747f9>:7:22
  %4 : int[] = prim::ListConstruct(%1, %2)
  %rv.1 : Tensor = aten::zeros(%4, %5, %5, %5, %5) # <ipython-input-24-8a34e03747f9>:4:9
  %rv : Tensor = prim::Loop(%len.1, %13, %rv.1) # <ipython-input-24-8a34e03747f9>:5:4
    block0(%i.1 : int, %rv.29 : Tensor):
      %17 : bool = aten::lt(%i.1, %16) # <ipython-input-24-8a34e03747f9>:6:11
      %rv.27 : Tensor = prim::If(%17) # <ipython-input-24-8a34e03747f9>:6:8
        block0():
          %rv.5 : Tensor = aten::sub(%rv.29, %20, %21) # <ipython-input-24-8

#### %rv.1 : Tensor means we assign the output to a (unique) value named rv.1, that value is of Tensor type and that we do not know its concrete shape.

#### aten::zeros is the operator (equivalent to torch.zeros) and the input list (%4, %6, %6, %10, %12) specifies which values in scope should be passed as inputs. The schema for built-in functions like aten::zeros can be found at Builtin Functions.



# Trace Example

In [35]:
def fill_row_zero(x):
    x[0] = torch.rand(*x.shape[1:2])
    return x

traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)

graph(%x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
  %4 : int = prim::Constant[value=1]() # <ipython-input-35-9083ec5950df>:2:0
  %5 : int = aten::size(%x, %4) # <ipython-input-35-9083ec5950df>:2:0
  %6 : Long(device=cpu) = prim::NumToTensor(%5)
  %7 : int = aten::Int(%6)
  %8 : int[] = prim::ListConstruct(%7)
  %9 : int = prim::Constant[value=6]() # <ipython-input-35-9083ec5950df>:2:0
  %10 : NoneType = prim::Constant()
  %11 : Device = prim::Constant[value="cpu"]() # <ipython-input-35-9083ec5950df>:2:0
  %12 : bool = prim::Constant[value=0]() # <ipython-input-35-9083ec5950df>:2:0
  %13 : Float(4, strides=[1], requires_grad=0, device=cpu) = aten::rand(%8, %9, %10, %11, %12) # <ipython-input-35-9083ec5950df>:2:0
  %14 : int = prim::Constant[value=0]() # <ipython-input-35-9083ec5950df>:2:0
  %15 : int = prim::Constant[value=0]() # <ipython-input-35-9083ec5950df>:2:0
  %16 : Float(4, strides=[1], requires_grad=0, device=cpu) = aten::select(%x, %14, %15) # <ipython-inpu

	%13 : Float(4, strides=[1], requires_grad=0, device=cpu) = aten::rand(%8, %9, %10, %11, %12) # <ipython-input-35-9083ec5950df>:2:0
This may cause errors in trace checking. To disable trace checking, pass check_trace=False to torch.jit.trace()
  _module_class,
Tensor-likes are not close!

Mismatched elements: 4 / 12 (33.3%)
Greatest absolute difference: 0.8664248585700989 at index (0, 0) (up to 1e-05 allowed)
Greatest relative difference: 23.759565154853515 at index (0, 0) (up to 1e-05 allowed)
  _module_class,


## Fixing Warnings

In [37]:
def fill_row_zero(x):
    x = torch.cat((torch.rand(1, *x.shape[1:2]), x[1:2]), dim=0)
    return x

traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)

graph(%x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
  %4 : int = prim::Constant[value=1]() # <ipython-input-37-a0ec8869dac9>:2:0
  %5 : int = aten::size(%x, %4) # <ipython-input-37-a0ec8869dac9>:2:0
  %6 : Long(device=cpu) = prim::NumToTensor(%5)
  %7 : int = aten::Int(%6)
  %8 : int = prim::Constant[value=1]() # <ipython-input-37-a0ec8869dac9>:2:0
  %9 : int[] = prim::ListConstruct(%8, %7)
  %10 : int = prim::Constant[value=6]() # <ipython-input-37-a0ec8869dac9>:2:0
  %11 : NoneType = prim::Constant()
  %12 : Device = prim::Constant[value="cpu"]() # <ipython-input-37-a0ec8869dac9>:2:0
  %13 : bool = prim::Constant[value=0]() # <ipython-input-37-a0ec8869dac9>:2:0
  %14 : Float(1, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::rand(%9, %10, %11, %12, %13) # <ipython-input-37-a0ec8869dac9>:2:0
  %15 : int = prim::Constant[value=0]() # <ipython-input-37-a0ec8869dac9>:2:0
  %16 : int = prim::Constant[value=1]() # <ipython-input-37-a0ec8869dac9>:2:0
  %17 : int = 

	%14 : Float(1, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::rand(%9, %10, %11, %12, %13) # <ipython-input-37-a0ec8869dac9>:2:0
This may cause errors in trace checking. To disable trace checking, pass check_trace=False to torch.jit.trace()
  _module_class,
Tensor-likes are not close!

Mismatched elements: 4 / 8 (50.0%)
Greatest absolute difference: 0.7539458274841309 at index (0, 0) (up to 1e-05 allowed)
Greatest relative difference: 15.360134189773595 at index (0, 2) (up to 1e-05 allowed)
  _module_class,



## [Jit documentation](https://pytorch.org/docs/stable/jit.html)


