In [462]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from config import Configuration
from noisy_net import NoisyLinear, NoisyFactorizedLinear
from utils import state_pre_processing
from OneHotEncode import OneHotEncode

class BranchingQNetwork(nn.Module):
    def __init__(self, observation_space, action_space, hidden_dim, exploration_method="Dueling", architecture="Dueling"):
        super().__init__()
        self.architecture = architecture
        self.model = nn.ModuleList([nn.Sequential(
            nn.Linear(62, hidden_dim*4),
            nn.ReLU(),
            nn.Linear(hidden_dim*4, hidden_dim*2),
            nn.ReLU(),
            nn.Linear(hidden_dim*2, hidden_dim),
            nn.ReLU()
        ) for i in range(12)])
        if self.architecture == "Dueling":
            self.value_head = nn.ModuleList([nn.Linear(hidden_dim, 1) for i in range(12)])
            self.adv_heads = nn.ModuleList([nn.Linear(hidden_dim, 11) for i in range(12)])
        else:
            self.out = nn.ModuleList([nn.Linear(hidden_dim, 11) for i in range(12)])

    def forward(self, x):
        processed_x = self.state_processing(x)
        layer1 = torch.stack([self.model[i](processed_x[i]) for i, _ in enumerate(processed_x)])
        if self.architecture == "Dueling":
            value = torch.stack([self.value_head[i](layer1[i]) for i, _ in enumerate(layer1)])
            advs = torch.stack([self.adv_heads[i](layer1[i]) for i, _ in enumerate(layer1)])
            q_val = value + advs - advs.mean()
        else:
            q_val = torch.stack([self.out[i](layer1[i]) for i, _ in enumerate(layer1)])
            
        return q_val
    
    def state_processing(self, obs):
        node_info = obs[:45]
        groups_info = obs[45:]
        partitions = [17 for i in range(12)]
        groups = torch.split(groups_info, partitions)
        groups_final = [torch.cat((node_info, groups[i])) for i in range(len(groups))]
        return groups_final
   

In [463]:
config = Configuration("configs/config.json")
torch.cuda.init()
device = torch.device(
    config.device if torch.cuda.is_available() else "cpu")
torch.autograd.set_detect_anomaly(True)


<torch.autograd.anomaly_mode.set_detect_anomaly at 0x1c6396d3e50>

In [464]:
obs = [  64.,    0.,    0.,  500.,    0.,    0.,    1.,  100.,    8.,
          0.,    0.,  100.,    0.,    1.,    0.,  100.,    0.,    0.,
          0.,  -88.,    0.,    0.,    0., -100.,   16.,    0.,    0.,
        100.,   12.,    0.,    1., -100.,   16.,    0.,    0.,  -50.,
          7.,    1.,    0.,  -13.,    8.,    0.,    0., -500.,   32.,
          3.,    1.,   94.,    1.,    8.,   10.,    2.,   56.,    1.,
          7.,    8.,    0.,   93.,    0.,    8.,    6.,    1.,   58.,
          0.,    8.,    8.,    2.,    0.,    0.,    0.,    4.,    0.,
        100.,    0.,    8.,    7.,    1.,   90.,    0.,    8.,    2.,
          2.,  100.,    0.,    8.,    7.,    0.,  100.,    0.,    8.,
          2.,    1.,   15.,    0.,    7.,    9.,    2.,   70.,    0.,
          8.,    2.,    0.,   91.,    1.,   12.]

In [465]:
new_obs = OneHotEncode(obs)

In [466]:
new_obs = torch.tensor(new_obs).float()


In [467]:
model = BranchingQNetwork(249,11,128)


In [481]:
out = model(new_obs)
out_max = out.max(1)
out_max_7 = out_max.values
out_max_7.sort(reverse=True)

TypeError: sort() received an invalid combination of arguments - got (reverse=bool, ), but expected one of:
 * (name dim, bool descending)
 * (int dim, bool descending)


In [473]:
torch.argmax(out, dim=1)

tensor([ 6,  0,  8,  1, 10,  7,  6, 10,  2,  5,  5,  0])