In [660]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from config import Configuration
from noisy_net import NoisyLinear, NoisyFactorizedLinear
from utils import state_pre_processing
from OneHotEncode import OneHotEncode

class BranchingQNetwork(nn.Module):
    def __init__(self, observation_space, action_space, hidden_dim, exploration_method="Dueling", architecture="DQN"):
        super().__init__()
        self.architecture = architecture
        self.model = nn.ModuleList([nn.Sequential(
            nn.Linear(62, hidden_dim*4),
            nn.ReLU(),
            nn.Linear(hidden_dim*4, hidden_dim*2),
            nn.ReLU(),
            nn.Linear(hidden_dim*2, hidden_dim),
            nn.ReLU()
        ) for i in range(12)])
        if self.architecture == "Dueling":
            self.value_head = nn.ModuleList([nn.Linear(hidden_dim, 1) for i in range(12)])
            self.adv_heads = nn.ModuleList([nn.Linear(hidden_dim, 11) for i in range(12)])
        else:
            self.out = nn.ModuleList([nn.Linear(hidden_dim, 11) for i in range(12)])

    def forward(self, x):
        processed_x = self.state_processing(x)
        print(processed_x)
        layer1 = torch.stack([self.model[i](processed_x[i]) for i, _ in enumerate(processed_x)])
        if self.architecture == "Dueling":
            value = torch.stack([self.value_head[i](layer1[i]) for i, _ in enumerate(layer1)])
            advs = torch.stack([self.adv_heads[i](layer1[i]) for i, _ in enumerate(layer1)])
            q_val = value + advs - advs.mean()
        else:
            q_val = torch.stack([self.out[i](layer1[i]) for i, _ in enumerate(layer1)])
            
        return q_val
    
    def state_processing(self, obs):
        node_info = obs[:45]
        groups_info = obs[45:]
        partitions = [17 for i in range(12)]
        groups = torch.split(groups_info, partitions)
        groups_final = [torch.cat((node_info, groups[i])) for i in range(len(groups))]
        return groups_final
   

In [661]:
config = Configuration("configs/config.json")
torch.cuda.init()
device = torch.device(
    config.device if torch.cuda.is_available() else "cpu")
torch.autograd.set_detect_anomaly(True)


<torch.autograd.anomaly_mode.set_detect_anomaly at 0x1c626dca190>

In [662]:
obs = [  64.,    0.,    0.,  500.,    0.,    0.,    1.,  100.,    8.,
          0.,    0.,  100.,    0.,    1.,    0.,  100.,    0.,    0.,
          0.,  -88.,    0.,    0.,    0., -100.,   16.,    0.,    0.,
        100.,   12.,    0.,    1., -100.,   16.,    0.,    0.,  -50.,
          7.,    1.,    0.,  -13.,    8.,    0.,    0., -500.,   32.,
          3.,    1.,   94.,    1.,    8.,   10.,    2.,   56.,    1.,
          7.,    8.,    0.,   93.,    0.,    8.,    6.,    1.,   58.,
          0.,    8.,    8.,    2.,    0.,    0.,    0.,    4.,    0.,
        100.,    0.,    8.,    7.,    1.,   90.,    0.,    8.,    2.,
          2.,  100.,    0.,    8.,    7.,    0.,  100.,    0.,    8.,
          2.,    1.,   15.,    0.,    7.,    9.,    2.,   70.,    0.,
          8.,    2.,    0.,   91.,    1.,   12.]

In [663]:
new_obs = OneHotEncode(obs)

In [664]:
new_obs = torch.tensor(new_obs).float().to(device)
new_obs

tensor([  64.,    0.,    0.,  500.,    0.,    0.,    1.,  100.,    8.,    0.,
           0.,  100.,    0.,    1.,    0.,  100.,    0.,    0.,    0.,  -88.,
           0.,    0.,    0., -100.,   16.,    0.,    0.,  100.,   12.,    0.,
           1., -100.,   16.,    0.,    0.,  -50.,    7.,    1.,    0.,  -13.,
           8.,    0.,    0., -500.,   32.,    0.,    0.,    0.,    1.,    0.,
           0.,    0.,    0.,    0.,    0.,    0.,    0.,    1.,    0.,   94.,
           1.,    8.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
           0.,    0.,    1.,    0.,    0.,    1.,   56.,    1.,    7.,    0.,
           0.,    0.,    0.,    0.,    0.,    0.,    0.,    1.,    0.,    0.,
           1.,    0.,    0.,   93.,    0.,    8.,    0.,    0.,    0.,    0.,
           0.,    0.,    1.,    0.,    0.,    0.,    0.,    0.,    1.,    0.,
          58.,    0.,    8.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
           0.,    1.,    0.,    0.,    0.,    0.,    1.,    0., 

In [665]:
model = BranchingQNetwork(249,11,128)
model.to(device)

BranchingQNetwork(
  (model): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=62, out_features=512, bias=True)
      (1): ReLU()
      (2): Linear(in_features=512, out_features=256, bias=True)
      (3): ReLU()
      (4): Linear(in_features=256, out_features=128, bias=True)
      (5): ReLU()
    )
    (1): Sequential(
      (0): Linear(in_features=62, out_features=512, bias=True)
      (1): ReLU()
      (2): Linear(in_features=512, out_features=256, bias=True)
      (3): ReLU()
      (4): Linear(in_features=256, out_features=128, bias=True)
      (5): ReLU()
    )
    (2): Sequential(
      (0): Linear(in_features=62, out_features=512, bias=True)
      (1): ReLU()
      (2): Linear(in_features=512, out_features=256, bias=True)
      (3): ReLU()
      (4): Linear(in_features=256, out_features=128, bias=True)
      (5): ReLU()
    )
    (3): Sequential(
      (0): Linear(in_features=62, out_features=512, bias=True)
      (1): ReLU()
      (2): Linear(in_features=512, out_f

In [666]:
out = model(new_obs)
print(out)
out_max = out.max(1)
print(out_max)
out_max_sorted = out_max.values.sort(descending=True)
print(out_max_sorted)
chosen_group = out_max_sorted.indices[:7]
chosen_location = torch.stack([out_max.indices[i] for i in chosen_group]) + 1
action = torch.stack([chosen_group, chosen_location], dim = 1)
action.detach().cpu().numpy()

[tensor([  64.,    0.,    0.,  500.,    0.,    0.,    1.,  100.,    8.,    0.,
           0.,  100.,    0.,    1.,    0.,  100.,    0.,    0.,    0.,  -88.,
           0.,    0.,    0., -100.,   16.,    0.,    0.,  100.,   12.,    0.,
           1., -100.,   16.,    0.,    0.,  -50.,    7.,    1.,    0.,  -13.,
           8.,    0.,    0., -500.,   32.,    0.,    0.,    0.,    1.,    0.,
           0.,    0.,    0.,    0.,    0.,    0.,    0.,    1.,    0.,   94.,
           1.,    8.], device='cuda:0'), tensor([  64.,    0.,    0.,  500.,    0.,    0.,    1.,  100.,    8.,    0.,
           0.,  100.,    0.,    1.,    0.,  100.,    0.,    0.,    0.,  -88.,
           0.,    0.,    0., -100.,   16.,    0.,    0.,  100.,   12.,    0.,
           1., -100.,   16.,    0.,    0.,  -50.,    7.,    1.,    0.,  -13.,
           8.,    0.,    0., -500.,   32.,    0.,    0.,    0.,    0.,    0.,
           0.,    0.,    0.,    0.,    0.,    1.,    0.,    0.,    1.,   56.,
           1.,    7.],

array([[10,  3],
       [ 5,  3],
       [ 7,  7],
       [ 3,  6],
       [ 8,  2],
       [ 4,  2],
       [11,  8]], dtype=int64)

In [667]:
torch.argmax(out,1)

tensor([8, 2, 5, 5, 1, 2, 5, 6, 1, 3, 2, 7], device='cuda:0')

In [649]:
q = out.max(1, keepdim=True)
q 

torch.return_types.max(
values=tensor([[7.6681],
        [4.3280],
        [3.6627],
        [3.3111],
        [3.3614],
        [6.7938],
        [5.9988],
        [7.9785],
        [1.4474],
        [3.6140],
        [4.0060],
        [4.4806]], device='cuda:0', grad_fn=<MaxBackward0>),
indices=tensor([[ 4],
        [ 8],
        [10],
        [ 3],
        [ 7],
        [ 2],
        [ 4],
        [ 2],
        [ 9],
        [ 0],
        [ 0],
        [ 2]], device='cuda:0'))

In [659]:
g = torch.tensor([5])
for i in range(5):
    g = torch.stack((g,torch.tensor([5])),1)
g

RuntimeError: stack expects each tensor to be equal size, but got [1, 2] at entry 0 and [1] at entry 1

In [842]:
a = torch.rand(16)
a = a.reshape(16,1,1)
a.shape

torch.Size([16, 1, 1])

In [843]:
a

tensor([[[0.7653]],

        [[0.8968]],

        [[0.1491]],

        [[0.5569]],

        [[0.8454]],

        [[0.4360]],

        [[0.4331]],

        [[0.7424]],

        [[0.9555]],

        [[0.2917]],

        [[0.2802]],

        [[0.6783]],

        [[0.3563]],

        [[0.4468]],

        [[0.7109]],

        [[0.7604]]])

In [844]:
b = torch.rand(16,1,11)
b

tensor([[[0.1998, 0.5693, 0.4208, 0.8383, 0.5744, 0.5706, 0.3723, 0.5833,
          0.7852, 0.2789, 0.1731]],

        [[0.4356, 0.5773, 0.2484, 0.5187, 0.4292, 0.9952, 0.2208, 0.2657,
          0.8074, 0.4296, 0.1393]],

        [[0.2098, 0.0830, 0.8153, 0.7137, 0.1827, 0.0033, 0.0401, 0.6855,
          0.0067, 0.5045, 0.0874]],

        [[0.3700, 0.1646, 0.5977, 0.2931, 0.1988, 0.7953, 0.9809, 0.4848,
          0.4476, 0.6978, 0.7244]],

        [[0.0876, 0.1468, 0.3667, 0.8249, 0.3232, 0.0641, 0.0850, 0.4794,
          0.7879, 0.5578, 0.1258]],

        [[0.0020, 0.6110, 0.0785, 0.9579, 0.8489, 0.5749, 0.9852, 0.1104,
          0.4834, 0.4429, 0.8309]],

        [[0.4971, 0.1100, 0.4329, 0.3472, 0.5093, 0.8260, 0.9937, 0.5619,
          0.5753, 0.2851, 0.7011]],

        [[0.5411, 0.7743, 0.8090, 0.8835, 0.7705, 0.5820, 0.1611, 0.2945,
          0.5745, 0.2372, 0.4310]],

        [[0.1132, 0.9176, 0.7924, 0.5598, 0.9643, 0.6029, 0.2676, 0.3994,
          0.9076, 0.3200, 0.2417]],

 

In [845]:
a+b

tensor([[[0.9651, 1.3346, 1.1861, 1.6036, 1.3397, 1.3359, 1.1375, 1.3486,
          1.5505, 1.0442, 0.9384]],

        [[1.3324, 1.4741, 1.1452, 1.4155, 1.3259, 1.8920, 1.1175, 1.1625,
          1.7041, 1.3264, 1.0360]],

        [[0.3588, 0.2320, 0.9644, 0.8628, 0.3318, 0.1524, 0.1891, 0.8345,
          0.1558, 0.6536, 0.2364]],

        [[0.9269, 0.7215, 1.1545, 0.8500, 0.7557, 1.3522, 1.5378, 1.0417,
          1.0044, 1.2546, 1.2813]],

        [[0.9329, 0.9922, 1.2121, 1.6703, 1.1686, 0.9095, 0.9304, 1.3248,
          1.6333, 1.4032, 0.9711]],

        [[0.4381, 1.0471, 0.5145, 1.3939, 1.2849, 1.0109, 1.4212, 0.5464,
          0.9194, 0.8789, 1.2669]],

        [[0.9301, 0.5431, 0.8659, 0.7803, 0.9424, 1.2591, 1.4267, 0.9950,
          1.0083, 0.7181, 1.1342]],

        [[1.2835, 1.5167, 1.5514, 1.6259, 1.5128, 1.3244, 0.9035, 1.0369,
          1.3169, 0.9796, 1.1733]],

        [[1.0686, 1.8731, 1.7478, 1.5152, 1.9198, 1.5584, 1.2231, 1.3549,
          1.8630, 1.2755, 1.1972]],

 

In [849]:
for p in model.parameters():
    print(p.shape)

torch.Size([512, 62])
torch.Size([512])
torch.Size([256, 512])
torch.Size([256])
torch.Size([128, 256])
torch.Size([128])
torch.Size([512, 62])
torch.Size([512])
torch.Size([256, 512])
torch.Size([256])
torch.Size([128, 256])
torch.Size([128])
torch.Size([512, 62])
torch.Size([512])
torch.Size([256, 512])
torch.Size([256])
torch.Size([128, 256])
torch.Size([128])
torch.Size([512, 62])
torch.Size([512])
torch.Size([256, 512])
torch.Size([256])
torch.Size([128, 256])
torch.Size([128])
torch.Size([512, 62])
torch.Size([512])
torch.Size([256, 512])
torch.Size([256])
torch.Size([128, 256])
torch.Size([128])
torch.Size([512, 62])
torch.Size([512])
torch.Size([256, 512])
torch.Size([256])
torch.Size([128, 256])
torch.Size([128])
torch.Size([512, 62])
torch.Size([512])
torch.Size([256, 512])
torch.Size([256])
torch.Size([128, 256])
torch.Size([128])
torch.Size([512, 62])
torch.Size([512])
torch.Size([256, 512])
torch.Size([256])
torch.Size([128, 256])
torch.Size([128])
torch.Size([512, 62])
to

In [850]:
st = torch.rand(16,12,62)
st

tensor([[[0.2847, 0.9113, 0.0964,  ..., 0.1876, 0.4767, 0.4001],
         [0.1018, 0.7022, 0.9712,  ..., 0.7624, 0.3157, 0.5787],
         [0.2565, 0.5069, 0.9032,  ..., 0.7857, 0.7244, 0.2263],
         ...,
         [0.6802, 0.4281, 0.1326,  ..., 0.0887, 0.3073, 0.8021],
         [0.3660, 0.0319, 0.3379,  ..., 0.7866, 0.9094, 0.3821],
         [0.5083, 0.8241, 0.8084,  ..., 0.0287, 0.8379, 0.8938]],

        [[0.4168, 0.7862, 0.3454,  ..., 0.0983, 0.8514, 0.7829],
         [0.3767, 0.4316, 0.2368,  ..., 0.3526, 0.8806, 0.5231],
         [0.9336, 0.6520, 0.1217,  ..., 0.2220, 0.9908, 0.3037],
         ...,
         [0.7812, 0.5533, 0.1702,  ..., 0.4085, 0.2313, 0.0341],
         [0.8024, 0.1569, 0.9195,  ..., 0.6120, 0.5205, 0.9044],
         [0.9678, 0.4548, 0.7806,  ..., 0.4290, 0.3402, 0.7311]],

        [[0.6870, 0.9280, 0.4295,  ..., 0.6931, 0.3435, 0.3322],
         [0.8803, 0.5843, 0.6914,  ..., 0.1593, 0.0413, 0.3944],
         [0.9179, 0.0275, 0.4062,  ..., 0.6539, 0.0363, 0.

In [865]:
torch.

TypeError: stack(): argument 'tensors' (position 1) must be tuple of Tensors, not Tensor