In [13]:
from onnx2pytorch import ConvertModel
import onnx 
import torch 
import torch.nn as nn

def get_model(size="small"):
    assert size in ["small", "mid", "big"]
    
    path_to_onnx_model = f"model/onnx/pensieve_{size}_simple.onnx"
    onnx_model = onnx.load(path_to_onnx_model)
    pytorch_model = ConvertModel(onnx_model)

    return pytorch_model

def get_params_argmax(input_size):
    
    # Take sum of the input vars
    c01 = torch.zeros([1, 1, 1, input_size+1])
    c01[0][0][0][0] = 1

    c02 = torch.zeros([1, 1, 1, input_size+1])
    c02[0][0][0][0] = 1
    c02[0][0][0][-1] = 1

    return c01, c02

def get_plain_comparative_pensieve():
    class MyModel(nn.ModuleList):
        def __init__(self, device=torch.device("cpu")):
            super(MyModel, self).__init__()

            input_size = 48
            self.input_size = input_size
            c01, c02 = get_params_argmax(input_size)
            
            self.ft = torch.nn.Flatten()

            #################
            # Model
            ################# 
            self.base_model = get_model()
            
            #################
            # Input summation
            #################
            self.input_conv1 = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=[1, input_size+1])
            self.input_conv1.weight = torch.nn.Parameter(c01, requires_grad=True)
            self.input_conv1.bias = torch.nn.Parameter(torch.zeros_like(self.input_conv1.bias, requires_grad=True))
            
            self.input_conv2 = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=[1, input_size+1])
            self.input_conv2.weight = torch.nn.Parameter(c02, requires_grad=True)
            self.input_conv2.bias = torch.nn.Parameter(torch.zeros_like(self.input_conv2.bias, requires_grad=True))            
            
        def forward(self, obs):
            # input processing
            # obs = torch.unsqueeze(obs, 0)
            # obs = torch.unsqueeze(obs, 0)

            input1 = self.input_conv1(obs)
            input2 = self.input_conv2(obs)
            
            # input1 = torch.squeeze(input1, 0)
            # input1 = torch.squeeze(input1, 0)
            # input2 = torch.squeeze(input2,0)
            # input2 = torch.squeeze(input2,0)
            
            # the model
            copy1_logits = self.base_model(input1)
            copy2_logits = self.base_model(input2)
            
            return torch.concat((copy1_logits, copy2_logits), dim=1)

    return MyModel()

In [15]:
x = torch.tensor(
    [[[0.1] * 96]]
)

model = get_plain_comparative_pensieve()
torch.onnx.export(
    model,      # The model being converted
    x,          # A dummy input for tracing the model
    f"model/conv2d_based_onnx/model.onnx", # The output file name for the ONNX model
    input_names=['input'],   # Optional: names for the input nodes
    output_names=['output'], # Optional: names for the output nodes
    opset_version=11         # Optional: specify the ONNX opset version
)


# testing

In [34]:
def get_params_argmax(input_size):
    
    # Take sum of the input vars
    c01 = torch.zeros([1, 1, 1, input_size+1])
    c01[0][0][0][0] = 1

    c02 = torch.zeros([1, 1, 1, input_size+1])
    c02[0][0][0][0] = 1
    c02[0][0][0][-1] = 1

    return c01, c02

c01, c02 = get_params_argmax(7)

x = torch.tensor([[[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]])

layer = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=[1, 8])

layer.weight = torch.nn.Parameter(c02)
layer.bias = torch.nn.Parameter(torch.zeros_like(layer.bias))

print(layer.weight.shape, layer.stride, layer.padding)
print(layer(x))


torch.Size([1, 1, 1, 8]) (1, 1) (0, 0)
tensor([[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
       grad_fn=<SqueezeBackward1>)
