In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from noisy_net import NoisyLinear, NoisyFactorizedLinear
from player import PlayerHelper

In [3]:
class BranchingQNetwork(nn.Module):
    def __init__(self, observation_space, action_space, action_bins, hidden_dim, exploration_method, architecture = "DQN"):
        super().__init__()
        self.exploration_method = exploration_method
        # if self.exploration_method == "Noisy":
        #     self.model = nn.Sequential(
        #         NoisyLinear(observation_space, hidden_dim*4),
        #         nn.ReLU(),
        #         NoisyLinear(hidden_dim*4, hidden_dim*2),
        #         nn.ReLU(),
        #         NoisyLinear(hidden_dim*2, hidden_dim),
        #         nn.ReLU()
        #     )
        #     self.value_head = NoisyLinear(hidden_dim, 1)
        #     self.adv_heads = nn.ModuleList(
        #         [NoisyLinear(hidden_dim, action_bins) for i in range(action_space)])
        # else:
        self.architecture = architecture
        self.model = nn.Sequential(
            nn.Linear(observation_space, hidden_dim*4),
            nn.ReLU(),
            nn.Linear(hidden_dim*4, hidden_dim*2),
            nn.ReLU(),
            nn.Linear(hidden_dim*2, hidden_dim),
            nn.ReLU()
        )
        if architecture == "Dueling":
            self.value_head = nn.Linear(hidden_dim, 1)
            self.adv_heads = nn.Linear(hidden_dim, action_bins)
        else:
            self.out = nn.Linear(hidden_dim, action_bins)
            
    def forward(self, x):
        first_layer = self.model(x)
        out = self.out(first_layer)
        q_val = out
        if self.architecture == "Dueling":
            value = self.value_head(out)
            advs = self.advs_heads(out)
            q_val = value + advs - advs.mean()
        # if value.shape[0] == 1:
        #     advs = torch.stack([l(out) for l in self.adv_heads], dim=0)
        #     q_val = value + advs - advs.mean(1, keepdim=True)
        # else:
        #     advs = torch.stack([l(out) for l in self.adv_heads], dim=1)
        #     q_val = value.unsqueeze(1) + advs - advs.mean(2, keepdim=True)
        return q_val



In [4]:
observation = [  64.,    0.,    0.,  500.,    0.,    0.,    1.,  100.,    8.,
      0.,    0.,  100.,    0.,    1.,    0.,  100.,    0.,    0.,
      0.,  -88.,    0.,    0.,    0., -100.,   16.,    0.,    0.,
    100.,   12.,    0.,    1., -100.,   16.,    0.,    0.,  -50.,
      7.,    1.,    0.,  -13.,    8.,    0.,    0., -500.,   32.,
      3.,    1.,   94.,    1.,    8.,   10.,    2.,   56.,    1.,
      7.,    8.,    0.,   93.,    0.,    8.,    6.,    1.,   58.,
      0.,    8.,    8.,    2.,    0.,    0.,    0.,    4.,    0.,
    100.,    0.,    8.,    7.,    1.,   90.,    0.,    8.,    2.,
      2.,  100.,    0.,    8.,    7.,    0.,  100.,    0.,    8.,
      2.,    1.,   15.,    0.,    7.,    9.,    2.,   70.,    0.,
      8.,    2.,    0.,   91.,    1.,   12.]

In [5]:
player_helper = PlayerHelper(7,1, "../config/DemoMap.json")

In [6]:
player_helper.legal_moves(observation)

array([False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False,  True,
       False, False,  True,  True, False,  True, False, False,  True,
       False, False,  True, False, False,  True, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False,  True, False,  True,  True, False, False,  True, False,
       False, False, False, False, False,  True,  True, False, False,
        True, False,  True,  True, False,  True,  True,  True, False,
        True, False, False, False, False, False, False, False, False,
        True,  True, False, False,  True, False,  True,  True, False,
        True,  True,  True, False,  True, False, False, False, False,
       False, False, False, False, False, False,  True,  True,  True,
        True,  True,  True, False, False, False, False, False, False,
       False, False,

In [7]:
legal_moves

NameError: name 'legal_moves' is not defined

In [8]:
actions = build_action_table().tolist()

NameError: name 'build_action_table' is not defined

In [9]:
actions.index([3,6])

NameError: name 'actions' is not defined

In [10]:
model= BranchingQNetwork(105, 7, 132, 128, "eps")

In [11]:
model

BranchingQNetwork(
  (model): Sequential(
    (0): Linear(in_features=105, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=128, bias=True)
    (5): ReLU()
  )
  (out): Linear(in_features=128, out_features=132, bias=True)
)

In [12]:
obs = torch.tensor(observation).float()

In [13]:
preds = model(obs)

In [14]:
legal = [i for i, x in enumerate(legal_moves) if x]

NameError: name 'legal_moves' is not defined

In [15]:
legal

NameError: name 'legal' is not defined

In [16]:
preds

tensor([ 6.6712, -2.1252,  1.9800, -8.5756, -1.6728,  3.4628,  8.2426,  2.2607,
        -3.0697,  5.8913,  2.9173, -0.6005, -0.8185,  2.5592,  2.0071, -0.4846,
        -1.6680,  0.7731, -2.6719, -2.6306,  2.3806,  0.2284,  0.1485, -2.6063,
        -0.6727, -1.0807, -5.9536, -3.1568,  0.1139, -2.5109, -0.6564, -1.1901,
         1.0462, -1.4271, -0.2501, -3.8880, -0.8491,  2.7232,  3.3158,  0.1621,
        -3.1153,  0.5378, -0.3714, -3.1410,  6.6479, -6.0671,  1.7127, -1.6592,
         9.7451, -0.7874, -0.4999, -1.2743, -3.1192,  0.3740,  0.7207,  2.2177,
         2.8771, -6.4476, -0.7109, -5.2483,  0.7600,  1.9684, -2.8296, -1.3688,
        -8.8673,  1.6873, -5.6360,  1.6794,  6.7481,  4.5762,  3.9550,  0.7512,
        -0.2830,  0.3564,  2.1869,  0.1820,  5.2436,  1.5922, -5.7179,  0.4404,
        -2.8175, -5.3680, -1.9342, -3.1608, -1.7578, -1.9718,  0.1256,  7.7262,
         0.9141, -3.7243,  3.4835,  1.7513, -1.8877,  0.5901, -1.7303,  2.0544,
         2.7319, -5.2023, -2.6598, -5.31

In [17]:
legal_pred = preds.gather(0, torch.LongTensor(legal))

NameError: name 'legal' is not defined

In [18]:
preds[118]

tensor(-3.7390, grad_fn=<SelectBackward>)

In [19]:
legal_pred

NameError: name 'legal_pred' is not defined

In [20]:
act_q = legal_pred.sort(descending=True)

NameError: name 'legal_pred' is not defined

In [21]:
act_q

NameError: name 'act_q' is not defined

In [22]:
chosen_q = act_q.values[:7]
chosen_q_idx = act_q.indices[:7]

NameError: name 'act_q' is not defined

In [23]:
chosen_q

NameError: name 'chosen_q' is not defined

In [24]:
chosen_q_idx

NameError: name 'chosen_q_idx' is not defined

In [25]:
legal = torch.LongTensor(legal)

NameError: name 'legal' is not defined

In [26]:
legal

NameError: name 'legal' is not defined

In [27]:
legal.gather(0, chosen_q_idx)

NameError: name 'legal' is not defined

In [28]:
torch.cuda.init()
device = torch.device(
    "cuda:0" if torch.cuda.is_available() else "cpu")
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x1b126cf63a0>

In [29]:
player_helper.get_action_choices(132)

ValueError: setting an array element with a sequence.

In [30]:
l = np.array([[2,5],[3,7],[1,5],[3,12]])
k = np.array([[3,7],[3,12]])
a = []

for i in k:
    a.append(np.where(l == i))
a

[(array([1, 1, 3], dtype=int64), array([0, 1, 0], dtype=int64)),
 (array([1, 3, 3], dtype=int64), array([0, 0, 1], dtype=int64))]

In [31]:
from utils import build_action_table

In [32]:
actions = build_action_table()

In [33]:
actions

array([[ 0.,  1.],
       [ 0.,  2.],
       [ 0.,  3.],
       [ 0.,  4.],
       [ 0.,  5.],
       [ 0.,  6.],
       [ 0.,  7.],
       [ 0.,  8.],
       [ 0.,  9.],
       [ 0., 10.],
       [ 0., 11.],
       [ 1.,  1.],
       [ 1.,  2.],
       [ 1.,  3.],
       [ 1.,  4.],
       [ 1.,  5.],
       [ 1.,  6.],
       [ 1.,  7.],
       [ 1.,  8.],
       [ 1.,  9.],
       [ 1., 10.],
       [ 1., 11.],
       [ 2.,  1.],
       [ 2.,  2.],
       [ 2.,  3.],
       [ 2.,  4.],
       [ 2.,  5.],
       [ 2.,  6.],
       [ 2.,  7.],
       [ 2.,  8.],
       [ 2.,  9.],
       [ 2., 10.],
       [ 2., 11.],
       [ 3.,  1.],
       [ 3.,  2.],
       [ 3.,  3.],
       [ 3.,  4.],
       [ 3.,  5.],
       [ 3.,  6.],
       [ 3.,  7.],
       [ 3.,  8.],
       [ 3.,  9.],
       [ 3., 10.],
       [ 3., 11.],
       [ 4.,  1.],
       [ 4.,  2.],
       [ 4.,  3.],
       [ 4.,  4.],
       [ 4.,  5.],
       [ 4.,  6.],
       [ 4.,  7.],
       [ 4.,  8.],
       [ 4.,

In [34]:
action_idx = legal.gather(0, chosen_q_idx)

NameError: name 'legal' is not defined

In [None]:
action_idx.numpy()

In [None]:
np.take(actions, action_idx, 0)

In [None]:
actions[119]

In [None]:
q = torch.rand(16, 132)

In [None]:
q

In [None]:
torch.rand(16,7,2, dtype=torch.long)

In [None]:
g = torch.tensor([[7,5,1,3,4,6,4], [4,5,4,3,2,1,2]])

In [None]:
t = torch.rand(2,132)

In [None]:
t

In [None]:
t.gather(1, g)

In [None]:
t

In [None]:
y = t.sort(1, descending=True)

In [None]:
nextq = []
for x in y.values:
    nextq.append(x[:7])
nextq

In [None]:
torch.stack(nextq)

In [None]:
legal = legal.numpy()

In [None]:
import random

In [None]:
legal = random.sample(legal, 7)

In [None]:
legal

In [None]:
actions

In [None]:
np.take(actions, legal, 0)

In [None]:
actions[32]

In [None]:
predict = torch.tensor([ 0.0000,  2.2634,  0.0000,  0.0000, 14.5552,  0.0000,  4.0787, 22.1282,
         8.6315,  0.0000,  0.0000,  2.2967,  0.0000, 11.2417,  0.0000,  0.0000,
         2.6133,  0.0000,  0.0000,  0.0000,  0.0000, 10.4586,  5.7030,  0.0000,
         0.0000, 12.8066,  0.0000,  0.0000,  0.0000,  3.9369,  0.0000,  2.3094,
         0.0000,  2.2453,  0.0000, 10.3566,  0.0000,  0.0000,  0.0000,  8.4870,
         0.0000,  0.0000,  1.9676,  7.8532,  4.9158, 12.1242,  0.0000,  0.0000,
         5.5951,  0.0000, 11.2550,  0.0000, 18.0321,  0.0000,  0.0000,  0.0000,
         7.2463,  0.0000,  0.0000,  0.1663,  0.0000,  0.0000,  0.0000,  4.2565,
         0.0000,  8.6758,  7.2351,  0.0000, 13.7532,  0.5862,  0.0000, 10.1941,
         0.0000,  8.6476,  7.2139,  2.8377,  0.0000,  0.0000,  3.3754,  0.0000,
         1.9132,  0.5925,  5.3133,  0.0000,  4.2074,  0.0000,  3.5816,  0.0000,
         1.8329, 18.0533,  4.7783,  0.0000,  8.2180,  8.0269,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  2.1318,  7.4333,  0.0000,  5.6233, 10.2923,  7.0338,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  7.2451,  0.3393,
         0.0000, 14.0444,  0.2286,  0.0000,  0.0000,  0.0000,  4.0939,  0.0000],
       device='cuda:0')

In [None]:
legal_id = torch.tensor([  2,   5,   8,  35,  36,  39,  41,  42,  55,  56,  57,  59,  70,  71,
         72,  73,  74,  75,  81,  84,  85,  87,  88,  90,  91,  94, 123, 126,
        129], device='cuda:0')

In [35]:
predict.gather(0, legal_id)

NameError: name 'predict' is not defined

In [36]:
torch.rand(10).unsqueeze(1)

tensor([[0.6771],
        [0.6634],
        [0.8403],
        [0.2491],
        [0.7572],
        [0.3958],
        [0.1591],
        [0.2635],
        [0.7663],
        [0.3702]])

In [37]:
actions[45]

array([4., 2.])

In [38]:
actions[47]

array([4., 4.])

In [39]:
actions[65]

array([ 5., 11.])

In [40]:
actions[70]

array([6., 5.])

In [41]:
legal_moves(observation)

NameError: name 'legal_moves' is not defined

In [42]:
actions = [[68, 16, 121, 3, 41, 105, 89], [22, 23, 25, -1, -1, -1, -1], [18, 56, 112, 6, 33, 30, 98], [7, 63, 126, 88, 35, 108, 48], [38, 103, 120, 22, 46, 91, 122], [19, 68, 81, 98, 40, 116, 130], [69, 41, 99, 116, 4, 49, 32], [111, 29, 15, 68, 107, 91, 5], [13, 38, 102, 0, 30, 61, 131], [90, 37, 78, 130, 14, 72, 32]]

In [43]:
preds

tensor([ 6.6712, -2.1252,  1.9800, -8.5756, -1.6728,  3.4628,  8.2426,  2.2607,
        -3.0697,  5.8913,  2.9173, -0.6005, -0.8185,  2.5592,  2.0071, -0.4846,
        -1.6680,  0.7731, -2.6719, -2.6306,  2.3806,  0.2284,  0.1485, -2.6063,
        -0.6727, -1.0807, -5.9536, -3.1568,  0.1139, -2.5109, -0.6564, -1.1901,
         1.0462, -1.4271, -0.2501, -3.8880, -0.8491,  2.7232,  3.3158,  0.1621,
        -3.1153,  0.5378, -0.3714, -3.1410,  6.6479, -6.0671,  1.7127, -1.6592,
         9.7451, -0.7874, -0.4999, -1.2743, -3.1192,  0.3740,  0.7207,  2.2177,
         2.8771, -6.4476, -0.7109, -5.2483,  0.7600,  1.9684, -2.8296, -1.3688,
        -8.8673,  1.6873, -5.6360,  1.6794,  6.7481,  4.5762,  3.9550,  0.7512,
        -0.2830,  0.3564,  2.1869,  0.1820,  5.2436,  1.5922, -5.7179,  0.4404,
        -2.8175, -5.3680, -1.9342, -3.1608, -1.7578, -1.9718,  0.1256,  7.7262,
         0.9141, -3.7243,  3.4835,  1.7513, -1.8877,  0.5901, -1.7303,  2.0544,
         2.7319, -5.2023, -2.6598, -5.31

In [44]:
preds = torch.rand(10,132)

In [45]:
new_current_Q = []
for idx, index_group in enumerate(actions):
    x = []
    for i in index_group:
        if i > -1:
            x.append(preds[idx][i])
        else:
            x.append(torch.tensor(0))
    new_current_Q.append(torch.stack(x))
new_current_Q

    

[tensor([0.3406, 0.7215, 0.7992, 0.3807, 0.4082, 0.5046, 0.3035]),
 tensor([0.3836, 0.2497, 0.3342, 0.0000, 0.0000, 0.0000, 0.0000]),
 tensor([0.0796, 0.6635, 0.0091, 0.4434, 0.7191, 0.1378, 0.9796]),
 tensor([0.3331, 0.2090, 0.2295, 0.8775, 0.7449, 0.0631, 0.9682]),
 tensor([0.8528, 0.8580, 0.8564, 0.7082, 0.9093, 0.4903, 0.2400]),
 tensor([0.1370, 0.9587, 0.0673, 0.6383, 0.7883, 0.6530, 0.6766]),
 tensor([0.4813, 0.4954, 0.4905, 0.2674, 0.0587, 0.9321, 0.6382]),
 tensor([0.1798, 0.9170, 0.4409, 0.7148, 0.1730, 0.8328, 0.3747]),
 tensor([0.8291, 0.0076, 0.5517, 0.3215, 0.0893, 0.9856, 0.5740]),
 tensor([0.1630, 0.1990, 0.9798, 0.6894, 0.0028, 0.6563, 0.9425])]

In [46]:
obs = torch.rand(10,105)

In [47]:
preds = model(obs)

In [48]:
preds.shape

torch.Size([10, 132])

In [49]:
new_current_Q = []
for idx, index_group in enumerate(actions):
    x = []
    for i in index_group:
        if i > -1:
            x.append(preds[idx][i])
        else:
            x.append(torch.tensor(0))
    new_current_Q.append(torch.stack(x))
new_current_Q = torch.stack(new_current_Q)
new_current_Q

tensor([[-0.0257, -0.0249, -0.0507, -0.0445, -0.0090, -0.0403,  0.0145],
        [-0.0699,  0.0003, -0.0562,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.0801, -0.0088, -0.0073,  0.0954,  0.0134,  0.0679, -0.0829],
        [ 0.0807,  0.0630, -0.0118,  0.0202, -0.1011, -0.0443, -0.0605],
        [ 0.0118, -0.1081, -0.0657, -0.0531,  0.0321, -0.0507,  0.0229],
        [-0.0601, -0.0259,  0.0825, -0.1044, -0.0894, -0.1092, -0.1278],
        [ 0.1076, -0.0054,  0.0495, -0.1270, -0.0167, -0.0034,  0.0624],
        [ 0.0470, -0.0101,  0.0081, -0.0373,  0.0103, -0.0480, -0.0738],
        [ 0.0571,  0.0045, -0.0602, -0.0172,  0.0610, -0.0025,  0.0574],
        [-0.0266,  0.0293,  0.0080, -0.1144, -0.0080,  0.1253,  0.0655]],
       grad_fn=<StackBackward>)

In [50]:
new_current_Q[1][3]

tensor(0., grad_fn=<SelectBackward>)

In [51]:
a = torch.tensor([4,5,7])
b = torch.tensor([0])
for i in range(2):
    a = torch.cat((a,b))
a

tensor([4, 5, 7, 0, 0])

In [52]:
list1 = np.array([])

In [53]:
list1.shape

(0,)

In [54]:
list1.reshape(0,2)

array([], shape=(0, 2), dtype=float64)

In [55]:
observation = [  64.,    0.,    0.,  500.,    0.,    0.,    1.,  100.,    8.,
      0.,    0.,  100.,    0.,    1.,    0.,  100.,    0.,    0.,
      0.,  -88.,    0.,    0.,    0., -100.,   16.,    0.,    0.,
    100.,   12.,    0.,    1., -100.,   16.,    0.,    0.,  -50.,
      7.,    1.,    0.,  -13.,    8.,    0.,    0., -500.,   32.,
      3.,    1.,   94.,    1.,    8.,   10.,    2.,   56.,    1.,
      7.,    8.,    0.,   93.,    0.,    8.,    6.,    1.,   58.,
      0.,    8.,    8.,    2.,    0.,    0.,    0.,    4.,    0.,
    100.,    0.,    8.,    7.,    1.,   90.,    0.,    8.,    2.,
      2.,  100.,    0.,    8.,    7.,    0.,  100.,    0.,    8.,
      2.,    1.,   15.,    0.,    7.,    9.,    2.,   70.,    0.,
      8.,    2.,    0.,   91.,    1.,   12.]

In [56]:
player_helper = PlayerHelper(7,1, "../config/DemoMap.json")

In [57]:
legal_obs = player_helper.legal_moves(observation).reshape(12,11)

In [58]:
from utils import build_action_table

In [59]:
table = build_action_table()

In [60]:
table = table.reshape(12,11,2)

In [61]:
legal_obs

array([[False, False, False, False, False, False, False, False, False,
        False, False],
       [False, False, False, False, False, False, False, False, False,
        False, False],
       [False, False, False, False,  True, False, False,  True,  True,
        False,  True],
       [False, False,  True, False, False,  True, False, False,  True,
        False, False],
       [False, False, False, False, False, False, False, False, False,
        False, False],
       [ True, False,  True,  True, False, False,  True, False, False,
        False, False],
       [False, False,  True,  True, False, False,  True, False,  True,
         True, False],
       [ True,  True,  True, False,  True, False, False, False, False,
        False, False],
       [False, False,  True,  True, False, False,  True, False,  True,
         True, False],
       [ True,  True,  True, False,  True, False, False, False, False,
        False, False],
       [False, False, False, False,  True,  True,  True,  Tr

In [62]:
table

array([[[ 0.,  1.],
        [ 0.,  2.],
        [ 0.,  3.],
        [ 0.,  4.],
        [ 0.,  5.],
        [ 0.,  6.],
        [ 0.,  7.],
        [ 0.,  8.],
        [ 0.,  9.],
        [ 0., 10.],
        [ 0., 11.]],

       [[ 1.,  1.],
        [ 1.,  2.],
        [ 1.,  3.],
        [ 1.,  4.],
        [ 1.,  5.],
        [ 1.,  6.],
        [ 1.,  7.],
        [ 1.,  8.],
        [ 1.,  9.],
        [ 1., 10.],
        [ 1., 11.]],

       [[ 2.,  1.],
        [ 2.,  2.],
        [ 2.,  3.],
        [ 2.,  4.],
        [ 2.,  5.],
        [ 2.,  6.],
        [ 2.,  7.],
        [ 2.,  8.],
        [ 2.,  9.],
        [ 2., 10.],
        [ 2., 11.]],

       [[ 3.,  1.],
        [ 3.,  2.],
        [ 3.,  3.],
        [ 3.,  4.],
        [ 3.,  5.],
        [ 3.,  6.],
        [ 3.,  7.],
        [ 3.,  8.],
        [ 3.,  9.],
        [ 3., 10.],
        [ 3., 11.]],

       [[ 4.,  1.],
        [ 4.,  2.],
        [ 4.,  3.],
        [ 4.,  4.],
        [ 4.,  5.],
        [ 4.

In [63]:
model= BranchingQNetwork(105, 7, 132, 128, "eps")

In [64]:
obs = torch.tensor(observation).float()

In [65]:
obs

tensor([  64.,    0.,    0.,  500.,    0.,    0.,    1.,  100.,    8.,    0.,
           0.,  100.,    0.,    1.,    0.,  100.,    0.,    0.,    0.,  -88.,
           0.,    0.,    0., -100.,   16.,    0.,    0.,  100.,   12.,    0.,
           1., -100.,   16.,    0.,    0.,  -50.,    7.,    1.,    0.,  -13.,
           8.,    0.,    0., -500.,   32.,    3.,    1.,   94.,    1.,    8.,
          10.,    2.,   56.,    1.,    7.,    8.,    0.,   93.,    0.,    8.,
           6.,    1.,   58.,    0.,    8.,    8.,    2.,    0.,    0.,    0.,
           4.,    0.,  100.,    0.,    8.,    7.,    1.,   90.,    0.,    8.,
           2.,    2.,  100.,    0.,    8.,    7.,    0.,  100.,    0.,    8.,
           2.,    1.,   15.,    0.,    7.,    9.,    2.,   70.,    0.,    8.,
           2.,    0.,   91.,    1.,   12.])

In [66]:
preds = model(obs)

In [67]:
preds

tensor([-1.1966, -2.6933,  0.6060,  1.6348,  3.6760, -0.6401, -2.4070,  4.0443,
        -1.3758,  3.6266, -2.8390, -3.3912, -0.4062, -0.3077,  3.9133, -4.9548,
        -1.9529,  0.3084,  3.9895, -1.1424, -6.4819,  0.1007,  0.3193, -4.5960,
        -1.8689, -1.6001, -2.2641, -1.5221,  5.1791, -3.8906, -5.1401,  2.5782,
        -1.2943,  0.2346,  0.1916,  5.0937,  5.8389, -0.2643, -2.3624, -5.4610,
        -0.0470, -5.9185,  4.4464, -3.9921, -1.4261, -0.6219,  1.0511,  3.0031,
         3.4777, -1.4364,  0.3361,  0.4129,  0.3662, -3.9045, -4.2403, -7.7173,
        -4.4181,  0.4239, -3.1763,  0.7257, -1.4770,  3.3168,  1.3692, -5.0368,
        -3.4762,  5.1628,  5.0945, -0.3278, -4.9620,  0.4313, -1.1806,  6.5422,
         5.6312, -1.7364,  0.0743,  3.1601, -0.0183,  2.1697, -0.0870,  6.1700,
         0.0304,  0.1304,  0.0123,  1.6201,  0.8869, -0.7071, -2.0701,  0.4372,
         3.1688,  5.6347, -0.8657, -2.3009,  3.4616,  6.3314,  4.1655,  1.8888,
        -2.1762, -0.0448, -4.8028, -1.31

In [68]:
preds = preds.reshape(12, 11)

In [69]:
preds

tensor([[-1.1966, -2.6933,  0.6060,  1.6348,  3.6760, -0.6401, -2.4070,  4.0443,
         -1.3758,  3.6266, -2.8390],
        [-3.3912, -0.4062, -0.3077,  3.9133, -4.9548, -1.9529,  0.3084,  3.9895,
         -1.1424, -6.4819,  0.1007],
        [ 0.3193, -4.5960, -1.8689, -1.6001, -2.2641, -1.5221,  5.1791, -3.8906,
         -5.1401,  2.5782, -1.2943],
        [ 0.2346,  0.1916,  5.0937,  5.8389, -0.2643, -2.3624, -5.4610, -0.0470,
         -5.9185,  4.4464, -3.9921],
        [-1.4261, -0.6219,  1.0511,  3.0031,  3.4777, -1.4364,  0.3361,  0.4129,
          0.3662, -3.9045, -4.2403],
        [-7.7173, -4.4181,  0.4239, -3.1763,  0.7257, -1.4770,  3.3168,  1.3692,
         -5.0368, -3.4762,  5.1628],
        [ 5.0945, -0.3278, -4.9620,  0.4313, -1.1806,  6.5422,  5.6312, -1.7364,
          0.0743,  3.1601, -0.0183],
        [ 2.1697, -0.0870,  6.1700,  0.0304,  0.1304,  0.0123,  1.6201,  0.8869,
         -0.7071, -2.0701,  0.4372],
        [ 3.1688,  5.6347, -0.8657, -2.3009,  3.4616,  6

In [70]:
legal_obs


array([[False, False, False, False, False, False, False, False, False,
        False, False],
       [False, False, False, False, False, False, False, False, False,
        False, False],
       [False, False, False, False,  True, False, False,  True,  True,
        False,  True],
       [False, False,  True, False, False,  True, False, False,  True,
        False, False],
       [False, False, False, False, False, False, False, False, False,
        False, False],
       [ True, False,  True,  True, False, False,  True, False, False,
        False, False],
       [False, False,  True,  True, False, False,  True, False,  True,
         True, False],
       [ True,  True,  True, False,  True, False, False, False, False,
        False, False],
       [False, False,  True,  True, False, False,  True, False,  True,
         True, False],
       [ True,  True,  True, False,  True, False, False, False, False,
        False, False],
       [False, False, False, False,  True,  True,  True,  Tr

In [74]:
for i in range(len(legal_obs)):
    for j in range(len(legal_obs[0])):
        if legal_obs[i][j] == False:
            preds[i][j] = -math.inf

In [73]:
import math

In [75]:
preds

tensor([[   -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf,    -inf],
        [   -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf,    -inf],
        [   -inf,    -inf,    -inf,    -inf, -2.2641,    -inf,    -inf, -3.8906,
         -5.1401,    -inf, -1.2943],
        [   -inf,    -inf,  5.0937,    -inf,    -inf, -2.3624,    -inf,    -inf,
         -5.9185,    -inf,    -inf],
        [   -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf,    -inf],
        [-7.7173,    -inf,  0.4239, -3.1763,    -inf,    -inf,  3.3168,    -inf,
            -inf,    -inf,    -inf],
        [   -inf,    -inf, -4.9620,  0.4313,    -inf,    -inf,  5.6312,    -inf,
          0.0743,  3.1601,    -inf],
        [ 2.1697, -0.0870,  6.1700,    -inf,  0.1304,    -inf,    -inf,    -inf,
            -inf,    -inf,    -inf],
        [   -inf,    -inf, -0.8657, -2.3009,    -inf,   

In [80]:
preds.max(1)

torch.return_types.max(
values=tensor([   -inf,    -inf, -1.2943,  5.0937,    -inf,  3.3168,  5.6312,  6.1700,
         4.1655, -1.3178,  4.9609,    -inf], grad_fn=<MaxBackward0>),
indices=tensor([ 0,  0, 10,  2,  0,  6,  6,  2,  6,  0,  5,  0]))

In [82]:
actions=preds.max(1)

In [91]:
actions_max = actions.values.sort(descending=True)

In [92]:
actions_max_values = actions_max.values[:7]

In [93]:
actions_max_idx = actions_max.indices[:7]

In [102]:
location = actions.indices.gather(0, actions_max_idx)

In [96]:
actions_max

torch.return_types.sort(
values=tensor([ 6.1700,  5.6312,  5.0937,  4.9609,  4.1655,  3.3168, -1.2943, -1.3178,
           -inf,    -inf,    -inf,    -inf], grad_fn=<SortBackward>),
indices=tensor([ 7,  6,  3, 10,  8,  5,  2,  9,  0,  1,  4, 11]))

In [103]:
group = actions_max_idx

In [106]:
location = location + 1

In [118]:
action = [[g, l] for g, l in zip(group.numpy(), location.numpy())]

In [119]:
action

[[7, 3], [6, 7], [3, 3], [10, 6], [8, 7], [5, 7], [2, 11]]

In [113]:
group 


tensor([ 7,  6,  3, 10,  8,  5,  2])

In [114]:
location

tensor([ 3,  7,  3,  6,  7,  7, 11])

In [120]:
infinite = -math.inf

In [152]:
ex = torch.tensor([6.1700,  5.6312,  infinite,  infinite,  4.1655, infinite, infinite])

In [137]:
re = torch.tensor([0,-1,1,0,1,0,-1])

In [139]:
re + ex * 0.92

tensor([5.6764, 4.1807, 5.6862, 4.5640, 4.8323,   -inf,   -inf])

In [143]:
action

[[7, 3], [6, 7], [3, 3], [10, 6], [8, 7], [5, 7], [2, 11]]

In [153]:
for i in range(len(action)):
    if actions_max_values[i] == -math.inf:
        action[i] = [0,0]

In [154]:
action

[[7, 3], [6, 7], [0, 0], [0, 0], [8, 7], [0, 0], [0, 0]]

In [155]:
ex.mean()

tensor(-inf)

In [156]:
action = torch.tensor(action)

In [157]:
action

tensor([[7, 3],
        [6, 7],
        [0, 0],
        [0, 0],
        [8, 7],
        [0, 0],
        [0, 0]])

In [159]:
action.transpose(0,1)

tensor([[7, 6, 0, 0, 8, 0, 0],
        [3, 7, 0, 0, 7, 0, 0]])

In [160]:
from utils import build_action_table

In [161]:
table = build_action_table()

In [165]:
table[94]

array([8., 7.])

In [166]:
action

tensor([[7, 3],
        [6, 7],
        [0, 0],
        [0, 0],
        [8, 7],
        [0, 0],
        [0, 0]])