## Model definition

In [1]:
import math
import numpy as np
import os
import torch
import torch.nn as nn
from torch.nn.modules.module import Module

class testLinear(Module):
    def __init__(self, in_feat, out_feat):
        super(testLinear, self).__init__()
        self.weight = nn.Linear(in_feat, out_feat, bias=False)
        self.reset_parameters()
        self.var = 10

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.weight.weight)
        
    def forward(self, input):
        x = self.weight(input) 
        return x

In [2]:
is_cuda = False 
lin_model = testLinear(in_feat=2, out_feat=3)
lin_model = lin_model.cuda() if is_cuda else lin_model
sample_input = torch.rand(1, 17, 2)
sample_input = sample_input.cuda() if is_cuda else sample_input

print("Pytorch output: \n", lin_model(sample_input))

Pytorch output: 
 tensor([[[ 0.4949,  0.0679, -0.6061],
         [ 0.5487,  0.1210, -0.6902],
         [ 0.1842, -0.0236, -0.2061],
         [ 0.1809,  0.0805, -0.2437],
         [ 0.2799,  0.1334, -0.3807],
         [ 0.5195, -0.0482, -0.5886],
         [ 0.2026,  0.0080, -0.2402],
         [ 0.3953,  0.0241, -0.4721],
         [ 0.2655,  0.1053, -0.3527],
         [ 0.4343,  0.1261, -0.5584],
         [ 0.0960,  0.0283, -0.1237],
         [ 0.3603,  0.1267, -0.4721],
         [ 0.3559, -0.0649, -0.3905],
         [ 0.4506, -0.0422, -0.5104],
         [ 0.2300,  0.2480, -0.3680],
         [ 0.1335,  0.0672, -0.1830],
         [ 0.3961, -0.0230, -0.4542]]], grad_fn=<UnsafeViewBackward>)


In [3]:
# without input
# module_ts = torch.jit.script(lin_model)

# for single input
module_ts = torch.jit.script(lin_model, sample_input)

# for multiple inputs
# module_ts = torch.jit.script(lin_model, (input1, input2))

# pretty-print the forward function code
print("Code: ")
print(module_ts.code)


Code: 
def forward(self,
    input: Tensor) -> Tensor:
  return (self.weight).forward(input, )



In [4]:

# save scriptmodule
model_path = os.path.join('out/','sample_model_traced.pt')
module_ts.save(model_path)

# compare scriptmodule output with that of pytorch module
output = module_ts(sample_input)
print("torchscript output: ", output.size())


torchscript output:  torch.Size([1, 17, 3])
