In [3]:
from firm_lib.ppo import PPO
from firm_lib.serverless_env import SimEnvironment
from firm_lib.util import *
from torch import nn as nn
import torch

def get_model():
    class ActorNetworkWrapper(nn.Module):
        def __init__(self, input_size=NUM_STATES, hidden_size=HIDDEN_SIZE, 
                output_size=NUM_ACTIONS, base_model=None):
            super(ActorNetworkWrapper, self).__init__()

            self.fc1 = nn.Linear(input_size, hidden_size)
            self.fc1.weight.data, self.fc1.bias.data = \
                base_model.fc1.weight.data, base_model.fc1.bias.data

            self.fc2 = nn.Linear(hidden_size, hidden_size)
            self.fc2.weight.data, self.fc2.bias.data = \
                base_model.fc2.weight.data, base_model.fc2.bias.data

            self.fc3 = nn.Linear(hidden_size, output_size)
            self.fc3.weight.data, self.fc3.bias.data = \
                base_model.fc3.weight.data, base_model.fc3.bias.data

            self.relu = nn.ReLU()

        def forward(self, input_):
            # input_ = torch.FloatTensor(input_)
            output = self.relu(self.fc1(input_))
            output = self.relu(self.fc2(output))
            output = self.fc3(output)

            return output


    env = SimEnvironment("firm_lib/data/readfile_sleep_imageresize_output.csv")
    function_name = env.get_function_name()
    initial_state = env.reset(function_name)
    folder_path = "firm_lib/model/" + str(function_name)
    agent = PPO(env, function_name, folder_path)
    agent.load_checkpoint("model/ppo.pth.tar")
    return ActorNetworkWrapper(base_model=agent.actor)

def get_params_argmax(input_size):
    
    # Take sum of the input vars
    c01 = torch.zeros([1, 1, 1, input_size+1])
    c01[0][0][0][0] = 1

    c02 = torch.zeros([1, 1, 1, input_size+1])
    c02[0][0][0][0] = 1
    c02[0][0][0][-1] = 1

    return c01, c02

def get_plain_comparative_firm():
    class MyModel(nn.ModuleList):
        def __init__(self, device=torch.device("cpu")):
            super(MyModel, self).__init__()

            self.input_size = NUM_STATES
            c01, c02 = get_params_argmax(self.input_size)
            self.ft = torch.nn.Flatten()
            
            #################
            # Model
            ################# 
            self.base_model = get_model()
            
            #################
            # Input summation
            #################
            self.input_conv1 = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=[1, self.input_size+1])
            self.input_conv1.weight = torch.nn.Parameter(c01, requires_grad=True)
            self.input_conv1.bias = torch.nn.Parameter(torch.zeros_like(self.input_conv1.bias, requires_grad=True))
            
            self.input_conv2 = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=[1, self.input_size+1])
            self.input_conv2.weight = torch.nn.Parameter(c02, requires_grad=True)
            self.input_conv2.bias = torch.nn.Parameter(torch.zeros_like(self.input_conv2.bias, requires_grad=True))            
            
        def forward(self, obs):
            # input processing
            obs = torch.unsqueeze(obs, 0)
            obs = torch.unsqueeze(obs, 0)

            input1 = self.input_conv1(obs)
            input2 = self.input_conv2(obs)
            
            input1 = torch.squeeze(input1, 0)
            input1 = torch.squeeze(input1, 0)
            input2 = torch.squeeze(input2,0)
            input2 = torch.squeeze(input2,0)

            # the model
            copy1_logits = self.base_model(input1)
            copy2_logits = self.base_model(input2)

            return self.ft(torch.concat((copy1_logits, copy2_logits), dim=1))

    return MyModel()


In [6]:
x = torch.tensor(
    [[0.1] * 14]
)

model = get_plain_comparative_firm()
torch.onnx.export(
    model,      # The model being converted
    x,          # A dummy input for tracing the model
    f"model/conv2d_based_onnx/model.onnx", # The output file name for the ONNX model
    input_names=['input'],   # Optional: names for the input nodes
    output_names=['output'], # Optional: names for the output nodes
    opset_version=11         # Optional: specify the ONNX opset version
)


Loading checkpoint...
Checkpoint successfully loaded!


# testing

In [34]:
def get_params_argmax(input_size):
    
    # Take sum of the input vars
    c01 = torch.zeros([1, 1, 1, input_size+1])
    c01[0][0][0][0] = 1

    c02 = torch.zeros([1, 1, 1, input_size+1])
    c02[0][0][0][0] = 1
    c02[0][0][0][-1] = 1

    return c01, c02

c01, c02 = get_params_argmax(7)

x = torch.tensor([[[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]])

layer = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=[1, 8])

layer.weight = torch.nn.Parameter(c02)
layer.bias = torch.nn.Parameter(torch.zeros_like(layer.bias))

print(layer.weight.shape, layer.stride, layer.padding)
print(layer(x))


torch.Size([1, 1, 1, 8]) (1, 1) (0, 0)
tensor([[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
       grad_fn=<SqueezeBackward1>)
