In [1]:
import torch
import parl
from torch import nn
import torch.nn.functional as F
import numpy as np

class baseModel(nn.Module):
    def __init__(self, obs_shape, act_shape):
        super(baseModel, self).__init__()
        self.fc1 = nn.Linear(obs_shape[0], 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc_pi = nn.ModuleList([nn.Linear(64, act_shape[i]).to(torch.device('cuda')) for i in range(len(act_shape))])
        self.fc_v = nn.Linear(64, 1)
    
    def value(self, obs):
        obs = obs.to(torch.device('cuda')).to(torch.float32)
        obs = F.relu(self.fc1(obs))
        obs = F.relu(self.fc2(obs))
        v = self.fc_v(obs)
        return v.reshape(-1)
    
    def policy(self, obs): # 注意返回的是 (n_action, batch_size, n_act)
        obs = obs.to(torch.device('cuda')).to(torch.float32)
        obs = F.relu(self.fc1(obs))
        obs = F.relu(self.fc2(obs))
        logits = [self.fc_pi[i](obs) for i in range(len(self.fc_pi))]
        return logits

class uavModel(parl.Model):
    def __init__(self, obs_space, act_space, n_clusters):
        """
        obs_space: (obs_n,)
        act_space: (n, n, n, n, ...)
        """
        super(uavModel, self).__init__()
        self.net = nn.ModuleList([baseModel(obs_space, act_space) for i in range(n_clusters)])
        for i in range(n_clusters):
            self.net[i].to(torch.device('cuda'))
        self.n_clusters = n_clusters
        self.n_act = len(act_space)
    
    # 如果是调用下面两个, 那应该是 (n_clusters, xx) 的输入, xx 还需要batch一下
    def value(self, obs):
        return [self.net[i].value(obs[i].reshape(1, -1)) for i in range(len(self.net))]
    
    def policy(self, obs):
        return [self.net[i].policy(obs[i].reshape(1, -1)) for i in range(len(self.net))]

In [2]:
obs = (35,)
act = np.array([6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 4, 4])

In [3]:
model = uavModel(obs, act, 3)

In [6]:
model.net[0].state_dict()

OrderedDict([('fc1.weight',
              tensor([[-0.1077, -0.0089, -0.0905,  ..., -0.0822,  0.1372, -0.1501],
                      [-0.0097, -0.1203,  0.0060,  ..., -0.0916,  0.1094,  0.0533],
                      [ 0.0011, -0.1277,  0.1088,  ..., -0.0065,  0.1509,  0.0495],
                      ...,
                      [ 0.0783, -0.0857,  0.0740,  ..., -0.0723,  0.0905, -0.1087],
                      [-0.1059,  0.1274, -0.1358,  ...,  0.1092,  0.1186, -0.0143],
                      [-0.0900,  0.0162, -0.1343,  ...,  0.1296, -0.0020,  0.1462]],
                     device='cuda:0')),
             ('fc1.bias',
              tensor([ 0.0618, -0.1184, -0.1389,  0.1021,  0.0281, -0.0048, -0.0881, -0.0051,
                      -0.0961, -0.1277, -0.0804,  0.1579, -0.1221,  0.0613,  0.1598, -0.0374,
                       0.0688,  0.1383,  0.0607,  0.0741, -0.1537,  0.1508,  0.0905, -0.0382,
                      -0.0183,  0.0434, -0.0189, -0.0651, -0.1283, -0.0799, -0.1106,  0.1401