In [2]:
import torch
import torch.nn as nn
from cmars_lib.config_mappol import get_config
import cmars_lib.mdp_config as mdp_config
from cmars_lib.util import *
from cmars_lib.cnn import *
from cmars_lib.act import *
from cmars_lib.mlp import *
from cmars_lib.distributions import *
from gym import spaces
  
parser = get_config()
parser.add_argument("--add_move_state", action='store_true', default=False)
parser.add_argument("--add_local_obs", action='store_true', default=False)
parser.add_argument("--add_distance_state", action='store_true', default=False)
parser.add_argument("--add_enemy_action_state", action='store_true', default=False)
parser.add_argument("--add_agent_id", action='store_true', default=False)
parser.add_argument("--add_visible_state", action='store_true', default=False)
parser.add_argument("--add_xy_state", action='store_true', default=False)
parser.add_argument("--use_state_agent", action='store_true', default=False)
parser.add_argument("--use_mustalive", action='store_false', default=True)
parser.add_argument("--add_center_xy", action='store_true', default=False)
parser.add_argument("--use_single_network", action='store_true', default=False)
all_args = parser.parse_known_args()[0]

class R_Actor(nn.Module):
    def __init__(self, args, obs_space, action_space, device=torch.device("cpu")):
        super(R_Actor, self).__init__()
        self.hidden_size = args.hidden_size
        self._gain = args.gain
        self._use_orthogonal = args.use_orthogonal
        self._use_policy_active_masks = args.use_policy_active_masks
        self._use_naive_recurrent_policy = args.use_naive_recurrent_policy
        self._use_recurrent_policy = args.use_recurrent_policy
        self._recurrent_N = args.recurrent_N
        self.tpdv = dict(dtype=torch.float64, device=device)

        obs_shape = get_shape_from_obs_space(obs_space)
        base = CNNBase if len(obs_shape) == 3 else MLPBase
        self.base = base(args, obs_shape)

        self.act = ACTLayer(action_space, self.hidden_size, self._use_orthogonal, self._gain, args)

        self.to(device)

    def forward(self, obs, rnn_states, masks, available_actions=None, deterministic=False):
        obs = check(obs).to(**self.tpdv)

        actor_features = self.base(obs)
        action_ligits_probs = self.act(actor_features, available_actions, deterministic)

        return action_ligits_probs

def get_model(layer_count, hidden_size, action_count):
    # set configs
    assert layer_count in [1, 2, 3, 4, 5, 6, 7, 8]
    assert hidden_size in [32, 64, 128]
    assert action_count in [15, 30, 80, 140]

    all_args.layer_N = layer_count
    all_args.hidden_size = hidden_size
    n_prbs = action_count
    models_path = f"models/output_{action_count}/h{hidden_size}/N{layer_count}/actor_type_embb.pt"

    act_space = spaces.Discrete(n_prbs)
    obs_space = spaces.Box(low=0, high=10e6, shape=(mdp_config.EMBB_LOCAL_OBS_VAR_COUNT+mdp_config.AUG_LOCAL_STATE_VAR_COUNT,))

    # base policy
    device = torch.device('cpu')
    model = R_Actor(all_args, obs_space, act_space)
    model.load_state_dict(torch.load(models_path, map_location=device))

    # remove softmax
    class CMARS_Actor_Wrapper(nn.ModuleList):
        def __init__(self, model, device=torch.device("cpu")):
            super(CMARS_Actor_Wrapper, self).__init__()        
            self.to(device)

            self.af = nn.ReLU()
            self.lin1 = nn.Linear(mdp_config.AUG_LOCAL_STATE_VAR_COUNT + mdp_config.EMBB_LOCAL_OBS_VAR_COUNT, all_args.hidden_size)
            self.lin1.weight.data = model.base.mlp.fc1[0].weight.data
            self.lin1.bias.data = model.base.mlp.fc1[0].bias.data
            
            self.midlayers = []
            for i in range(all_args.layer_N):
                self.midlayers.append(nn.Linear(all_args.hidden_size, all_args.hidden_size))

            for iter, item in enumerate(self.midlayers):
                item.weight.data = model.base.mlp.fc2[iter][0].weight.data
                item.bias.data = model.base.mlp.fc2[iter][0].bias.data

            for i in range(all_args.layer_N):
                setattr(self, "lin{}".format(i+2), self.midlayers[i])

            self.out = nn.Linear(all_args.hidden_size, n_prbs)
            self.out.weight.data = model.act.action_out.linear.weight.data
            self.out.bias.data = model.act.action_out.linear.bias.data

        def forward(self, obs):
            obs = self.af(self.lin1(obs))

            for item in self.midlayers:
                obs = self.af(item(obs))    

            logits = self.out(obs)

            return logits

    embb_cmars_wrapper = CMARS_Actor_Wrapper(model)

    return embb_cmars_wrapper

def get_params_argmax(input_size):
    
    # Take sum of the input vars
    c01 = torch.zeros([1, 1, 1, input_size+1])
    c01[0][0][0][0] = 1

    c02 = torch.zeros([1, 1, 1, input_size+1])
    c02[0][0][0][0] = 1
    c02[0][0][0][-1] = 1

    return c01, c02

def get_plain_comparative_cmars(layer_count=1, hidden_size=32, action_count=15):
    class MyModel(nn.ModuleList):
        def __init__(self, device=torch.device("cpu")):
            super(MyModel, self).__init__()

            input_size = 19
            self.input_size = input_size
            c01, c02 = get_params_argmax(input_size)
            
            self.ft = torch.nn.Flatten()

            #################
            # Model
            ################# 
            self.base_model = get_model(layer_count, hidden_size, action_count)
            
            #################
            # Input summation
            #################
            self.input_conv1 = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=[1, input_size+1])
            self.input_conv1.weight = torch.nn.Parameter(c01, requires_grad=True)
            self.input_conv1.bias = torch.nn.Parameter(torch.zeros_like(self.input_conv1.bias, requires_grad=True))
            
            self.input_conv2 = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=[1, input_size+1])
            self.input_conv2.weight = torch.nn.Parameter(c02, requires_grad=True)
            self.input_conv2.bias = torch.nn.Parameter(torch.zeros_like(self.input_conv2.bias, requires_grad=True))            
            
        def forward(self, obs):
            # input processing
            obs = torch.unsqueeze(obs, 0)
            obs = torch.unsqueeze(obs, 0)

            input1 = self.input_conv1(obs)
            input2 = self.input_conv2(obs)
            
            input1 = torch.squeeze(input1, 0)
            input1 = torch.squeeze(input1, 0)
            input2 = torch.squeeze(input2,0)
            input2 = torch.squeeze(input2,0)
                        
            # the model
            copy1_logits = self.base_model(input1)
            copy2_logits = self.base_model(input2)
            
            return torch.concat((copy1_logits, copy2_logits), dim=1)

    return MyModel()

In [3]:
x = torch.tensor(
    [[0.1] * 38]
)

for layer_count in [1, 2]:
    for hidden_size in [32, 64]:
        for action_count in [15, 30]:
            model = get_plain_comparative_cmars(layer_count, hidden_size, action_count)
            torch.onnx.export(
                model,      # The model being converted
                x,          # A dummy input for tracing the model
                f"models/conv2d_based_onnx/model_l{layer_count}_h{hidden_size}_a{action_count}.onnx", # The output file name for the ONNX model
                input_names=['input'],   # Optional: names for the input nodes
                output_names=['output'], # Optional: names for the output nodes
                opset_version=11         # Optional: specify the ONNX opset version
            )


self._use_feature_normalization:  False
self._use_feature_normalization:  False
self._use_feature_normalization:  False
self._use_feature_normalization:  False
self._use_feature_normalization:  False
self._use_feature_normalization:  False
self._use_feature_normalization:  False
self._use_feature_normalization:  False


In [108]:
def get_params_argmax(input_size):
    
    # Take sum of the input vars
    c01 = torch.zeros([1, 1, 1, input_size+1])
    c01[0][0][0][0] = 1

    c02 = torch.zeros([1, 1, 1, input_size+1])
    c02[0][0][0][0] = 1
    c02[0][0][0][-1] = 1

    return c01, c02


class MyModel(nn.ModuleList):
        def __init__(self, device=torch.device("cpu")):
            super(MyModel, self).__init__()

            input_size = 3
            self.input_size = input_size
            c01, c02 = get_params_argmax(input_size)
            
            self.ft = torch.nn.Flatten()

            self.input_conv1 = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=[1, input_size+1])
            self.input_conv1.weight = torch.nn.Parameter(c01, requires_grad=True)
            self.input_conv1.bias = torch.nn.Parameter(torch.zeros_like(self.input_conv1.bias))


            self.input_conv2 = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=[1, input_size+1])
            self.input_conv2.weight = torch.nn.Parameter(c02, requires_grad=True)
            self.input_conv2.bias = torch.nn.Parameter(torch.zeros_like(self.input_conv2.bias))
            
            self.model = nn.Linear(3, 3)
            self.model.weight = torch.nn.Parameter(torch.eye(3))
            self.model.bias = torch.nn.Parameter(torch.zeros_like(self.model.bias))

        def forward(self, obs):
            # input processing
            obs = torch.unsqueeze(obs, 0)
            obs = torch.unsqueeze(obs, 0)

            input1 = self.input_conv1(obs)
            input2 = self.input_conv2(obs)
            
            input1 = torch.squeeze(input1, 0)
            input1 = torch.squeeze(input1, 0)
            input2 = torch.squeeze(input2,0)
            input2 = torch.squeeze(input2,0)

            out1 = self.model(input1)
            out2 = self.model(input2)
            
            return torch.concat((out1, out2), dim=1)


In [110]:
x = torch.tensor(
    [[0.1] * 6]
)

model = MyModel()
print(model(x))

torch.onnx.export(
    model,
    x,
    "/home/mzi/sys-rl-verif/model_2.onnx",
    input_names=['input'],
    output_names=['output'],
    opset_version=11
)


tensor([[0.1000, 0.1000, 0.1000, 0.2000, 0.2000, 0.2000]],
       grad_fn=<CatBackward0>)


In [32]:
def get_params_argmax(input_size):
    
    # Take sum of the input vars
    c01 = torch.zeros([1, 1, input_size+1])
    c01[0][0][0] = 1

    c02 = torch.zeros([1, 1, input_size+1])
    c02[0][0][0] = 1
    c02[0][0][-1] = 1

    return c01, c02

c01, c02 = get_params_argmax(7)

x = torch.tensor([[[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]])

layer = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=8)

layer.weight = torch.nn.Parameter(c01)
layer.bias = torch.nn.Parameter(torch.zeros_like(layer.bias))

print(layer.weight.shape, layer.stride, layer.padding)
print(layer(x))

torch.Size([1, 1, 8]) (1,) (0,)
tensor([[[0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000]]],
       grad_fn=<ConvolutionBackward0>)


In [34]:
def get_params_argmax(input_size):
    
    # Take sum of the input vars
    c01 = torch.zeros([1, 1, 1, input_size+1])
    c01[0][0][0][0] = 1

    c02 = torch.zeros([1, 1, 1, input_size+1])
    c02[0][0][0][0] = 1
    c02[0][0][0][-1] = 1

    return c01, c02

c01, c02 = get_params_argmax(7)

x = torch.tensor([[[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]])

layer = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=[1, 8])

layer.weight = torch.nn.Parameter(c02)
layer.bias = torch.nn.Parameter(torch.zeros_like(layer.bias))

print(layer.weight.shape, layer.stride, layer.padding)
print(layer(x))


torch.Size([1, 1, 1, 8]) (1, 1) (0, 0)
tensor([[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
       grad_fn=<SqueezeBackward1>)
