In [6]:
from aurora_lib.network_simulator.pcc.aurora.schedulers import TestScheduler
from aurora_lib.network_simulator.pcc.aurora.aurora import Aurora
from aurora_lib.network_simulator.pcc.aurora.aurora_environment import AuroraEnvironment
from aurora_lib.trace import generate_trace
from aurora_lib.ppo import PPO
import gym
from torch import nn as nn
import torch

NUM_STATES = None

def get_model():
    dummy_trace = generate_trace((10, 10), (2, 2), (2, 2), (50, 50), (0, 0), (1, 1), (0, 0), (0, 0))
    test_scheduler = TestScheduler(dummy_trace)
    env = AuroraEnvironment(trace_scheduler=test_scheduler)
    obs = env.reset()
    global NUM_STATES
    NUM_STATES = env.observation_space.shape[0]
    model = PPO(env, verbose=False)
    model.load_checkpoint("models/model")

    return model.actor


def get_params_argmax(input_size):
    
    # Take sum of the input vars
    c01 = torch.zeros([1, 1, 1, input_size+1])
    c01[0][0][0][0] = 1

    c02 = torch.zeros([1, 1, 1, input_size+1])
    c02[0][0][0][0] = 1
    c02[0][0][0][-1] = 1

    return c01, c02


def get_plain_comparative_aurora():
    class MyModel(nn.ModuleList):
        def __init__(self, device=torch.device("cpu")):
            super(MyModel, self).__init__()
            self.base_model = get_model()
            self.input_size = NUM_STATES
            self.ft = torch.nn.Flatten()
            c01, c02 = get_params_argmax(self.input_size)
            
            #################
            # Input summation
            #################
            self.input_conv1 = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=[1, self.input_size+1])
            self.input_conv1.weight = torch.nn.Parameter(c01, requires_grad=True)
            self.input_conv1.bias = torch.nn.Parameter(torch.zeros_like(self.input_conv1.bias, requires_grad=True))
            
            self.input_conv2 = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=[1, self.input_size+1])
            self.input_conv2.weight = torch.nn.Parameter(c02, requires_grad=True)
            self.input_conv2.bias = torch.nn.Parameter(torch.zeros_like(self.input_conv2.bias, requires_grad=True))
            
        def forward(self, obs):
            # input processing
            obs = torch.unsqueeze(obs, 0)
            obs = torch.unsqueeze(obs, 0)

            input1 = self.input_conv1(obs)
            input2 = self.input_conv2(obs)
            
            input1 = torch.squeeze(input1, 0)
            input1 = torch.squeeze(input1, 0)
            input2 = torch.squeeze(input2,0)
            input2 = torch.squeeze(input2,0)

            # the model
            copy1_logits = self.base_model(input1)
            copy2_logits = self.base_model(input2)

            return self.ft(torch.concat((copy1_logits, copy2_logits), dim=1))

    return MyModel()
    

In [8]:
get_model()
x = torch.tensor(
    [[0.1] * NUM_STATES * 2]
)

model = get_plain_comparative_aurora()
torch.onnx.export(
    model,      # The model being converted
    x,          # A dummy input for tracing the model
    f"models/conv2d_based_onnx/model_conv2d.onnx", # The output file name for the ONNX model
    input_names=['input'],   # Optional: names for the input nodes
    output_names=['output'], # Optional: names for the output nodes
    opset_version=11         # Optional: specify the ONNX opset version
)
