## Model definition

In [1]:
import math
import numpy as np
import os
import torch
import torch.nn as nn
from torch.nn.modules.module import Module

class testLinear(Module):
    def __init__(self, in_feat, out_feat):
        super(testLinear, self).__init__()
        self.weight = nn.Linear(in_feat, out_feat, bias=False)
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.weight.weight)
        
    def forward(self, input):
        x = self.weight(input) 
        return x

In [6]:
lin_model = testLinear(in_feat=2, out_feat=3).cuda()
sample_input = torch.rand(1, 17, 2).cuda()

print("Pytorch output: \n", lin_model(sample_input))

Pytorch output: 
 tensor([[[ 0.7506, -0.6191, -0.9389],
         [ 0.7067, -0.4748, -0.9219],
         [ 0.6418, -0.5606, -0.7919],
         [ 0.1280, -0.1105, -0.1584],
         [ 1.0030, -0.3163, -1.4341],
         [ 0.5764,  0.0532, -0.9067],
         [ 1.0830, -0.5608, -1.4715],
         [ 0.7507, -0.5899, -0.9494],
         [ 1.0004, -0.4530, -1.3820],
         [ 0.3665,  0.0194, -0.5714],
         [ 0.4789, -0.3924, -0.5999],
         [ 0.5356,  0.0552, -0.8445],
         [ 0.7626, -0.3181, -1.0630],
         [ 0.6264, -0.3697, -0.8351],
         [ 0.8573, -0.1959, -1.2520],
         [ 0.4973, -0.2085, -0.6929],
         [ 1.0588, -0.4674, -1.4670]]], device='cuda:0',
       grad_fn=<UnsafeViewBackward>)


## Convert to ScriptModule
Read more about torch.jit.scriptmodule <a link=https://pytorch.org/docs/stable/generated/torch.jit.ScriptModule.html#scriptmodule>here</a>.

In [8]:
# for single input
module_ts = torch.jit.script(lin_model, sample_input)

# for multiple inputs
# module_ts = torch.jit.script(lin_model, (input1, input2))

# pretty-print the forward function code
print(module_ts.code)

# save scriptmodule
model_path = os.path.join('out/','sample_model_traced.pt')
module_ts.save(model_path)

# compare scriptmodule output with that of pytorch module
output = module_ts(sample_input)
print("torchscript output: ", output.size())


import __torch__.___torch_mangle_3
import __torch__.torch.nn.modules.linear.___torch_mangle_4
def forward(self,
    input: Tensor) -> Tensor:
  _0 = self.weight.weight
  if torch.eq(torch.dim(input), 2):
    _1 = torch.__isnot__(None, None)
  else:
    _1 = False
  if _1:
    bias = ops.prim.unchecked_unwrap_optional(annotate(Optional[Tensor], None))
    x = torch.addmm(bias, input, torch.t(_0), beta=1, alpha=1)
  else:
    output = torch.matmul(input, torch.t(_0))
    if torch.__isnot__(None, None):
      bias0 = ops.prim.unchecked_unwrap_optional(annotate(Optional[Tensor], None))
      output0 = torch.add_(output, bias0, alpha=1)
    else:
      output0 = output
    x = output0
  return x

torchscript output:  torch.Size([1, 17, 3])
