In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from noisy_net import NoisyLinear, NoisyFactorizedLinear
from player import PlayerHelper

In [8]:
class BranchingQNetwork(nn.Module):
    def __init__(self, observation_space, action_space, action_bins, hidden_dim, exploration_method, architecture = "DQN"):
        super().__init__()
        self.exploration_method = exploration_method
        # if self.exploration_method == "Noisy":
        #     self.model = nn.Sequential(
        #         NoisyLinear(observation_space, hidden_dim*4),
        #         nn.ReLU(),
        #         NoisyLinear(hidden_dim*4, hidden_dim*2),
        #         nn.ReLU(),
        #         NoisyLinear(hidden_dim*2, hidden_dim),
        #         nn.ReLU()
        #     )
        #     self.value_head = NoisyLinear(hidden_dim, 1)
        #     self.adv_heads = nn.ModuleList(
        #         [NoisyLinear(hidden_dim, action_bins) for i in range(action_space)])
        # else:
        self.architecture = architecture
        self.model = nn.Sequential(
            nn.Linear(observation_space, hidden_dim*4),
            nn.ReLU(),
            nn.Linear(hidden_dim*4, hidden_dim*2),
            nn.ReLU(),
            nn.Linear(hidden_dim*2, hidden_dim),
            nn.ReLU()
        )
        if architecture == "Dueling":
            self.value_head = nn.Linear(hidden_dim, 1)
            self.adv_heads = nn.Linear(hidden_dim, action_bins)
        else:
            self.out = nn.Linear(hidden_dim, action_bins)
            
    def forward(self, x):
        first_layer = self.model(x)
        out = self.out(first_layer)
        q_val = out
        if self.architecture == "Dueling":
            value = self.value_head(out)
            advs = self.advs_heads(out)
            q_val = value + advs - advs.mean()
        # if value.shape[0] == 1:
        #     advs = torch.stack([l(out) for l in self.adv_heads], dim=0)
        #     q_val = value + advs - advs.mean(1, keepdim=True)
        # else:
        #     advs = torch.stack([l(out) for l in self.adv_heads], dim=1)
        #     q_val = value.unsqueeze(1) + advs - advs.mean(2, keepdim=True)
        return q_val



In [9]:
observation = [  64.,    0.,    0.,  500.,    0.,    0.,    1.,  100.,    8.,
      0.,    0.,  100.,    0.,    1.,    0.,  100.,    0.,    0.,
      0.,  -88.,    0.,    0.,    0., -100.,   16.,    0.,    0.,
    100.,   12.,    0.,    1., -100.,   16.,    0.,    0.,  -50.,
      7.,    1.,    0.,  -13.,    8.,    0.,    0., -500.,   32.,
      3.,    1.,   94.,    1.,    8.,   10.,    2.,   56.,    1.,
      7.,    8.,    0.,   93.,    0.,    8.,    6.,    1.,   58.,
      0.,    8.,    8.,    2.,    0.,    0.,    0.,    4.,    0.,
    100.,    0.,    8.,    7.,    1.,   90.,    0.,    8.,    2.,
      2.,  100.,    0.,    8.,    7.,    0.,  100.,    0.,    8.,
      2.,    1.,   15.,    0.,    7.,    9.,    2.,   70.,    0.,
      8.,    2.,    0.,   91.,    1.,   12.]

In [10]:
player_helper = PlayerHelper(7,1, "../config/DemoMap.json")

In [11]:
legal_moves = player_helper.legal_moves(observation)

In [12]:
legal_moves

array([False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False,  True,
       False, False,  True,  True, False,  True, False, False,  True,
       False, False,  True, False, False,  True, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False,  True, False,  True,  True, False, False,  True, False,
       False, False, False, False, False,  True,  True, False, False,
        True, False,  True,  True, False,  True,  True,  True, False,
        True, False, False, False, False, False, False, False, False,
        True,  True, False, False,  True, False,  True,  True, False,
        True,  True,  True, False,  True, False, False, False, False,
       False, False, False, False, False, False,  True,  True,  True,
        True,  True,  True, False, False, False, False, False, False,
       False, False,

In [13]:
model= BranchingQNetwork(105, 7, 132, 128, "eps")

In [14]:
model

BranchingQNetwork(
  (model): Sequential(
    (0): Linear(in_features=105, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=128, bias=True)
    (5): ReLU()
  )
  (out): Linear(in_features=128, out_features=132, bias=True)
)

In [15]:
obs = torch.tensor(observation).float()

In [16]:
preds = model(obs)

In [127]:
legal = [i for i, x in enumerate(legal_moves) if x]

In [128]:
legal

[26,
 29,
 30,
 32,
 35,
 38,
 41,
 55,
 57,
 58,
 61,
 68,
 69,
 72,
 74,
 75,
 77,
 78,
 79,
 81,
 90,
 91,
 94,
 96,
 97,
 99,
 100,
 101,
 103,
 114,
 115,
 116,
 117,
 118,
 119]

In [19]:
preds

tensor([ 0.7539,  2.1589, -0.4548,  0.2044,  1.6336,  0.3973, -3.4324, -1.2586,
        -0.2446,  1.3898,  6.5970,  0.9921,  2.6295,  0.1013,  1.9477,  1.0572,
        -2.1258, -2.5155, -1.0820, -2.6839,  0.3637, -0.0635, -2.8531, -3.8592,
        -5.1137, -4.7280, -3.6250,  6.9713,  0.0910,  0.2531,  3.2860,  4.6918,
         3.9574,  1.8980,  2.8215,  1.4723, -2.8976, -4.6101,  2.5463, -2.0740,
         2.8114, -0.8338,  5.6938,  0.2623, -0.8113,  2.7417,  1.3199, -6.5692,
         2.4388,  2.6458, -2.4127,  4.7794,  2.7811,  4.1658, -2.6666, -0.7878,
        -3.8838,  0.6434,  0.9333, -1.7638, -0.7097, -1.2664,  3.9241, -1.6029,
         1.3296,  0.0436,  1.9175,  1.3615, -5.7806,  3.4057, -1.1605, -4.2725,
        -2.6051,  3.5573,  1.1878,  4.9655, -2.5979,  4.6477,  4.1446, -0.8245,
         0.9166,  0.3600,  1.5055, -1.0515,  2.1722,  5.9279,  7.2527, -1.9157,
         3.7786, 10.2110,  5.8694, -4.6711, -0.3526, -2.3502,  2.9701, -1.1664,
         1.1294, -0.6310, -1.8825,  1.24

In [20]:
legal_pred = preds.gather(0, torch.LongTensor(legal))

In [21]:
preds[118]

tensor(-4.9212, grad_fn=<SelectBackward>)

In [22]:
legal_pred

tensor([-3.6250,  0.2531,  3.2860,  3.9574,  1.4723,  2.5463, -0.8338, -0.7878,
         0.6434,  0.9333, -1.2664, -5.7806,  3.4057, -2.6051,  1.1878,  4.9655,
         4.6477,  4.1446, -0.8245,  0.3600,  5.8694, -4.6711,  2.9701,  1.1294,
        -0.6310,  1.2451, -1.6302, -0.3674,  0.4117,  2.5364,  0.2867, -5.0402,
         4.9207, -4.9212,  3.6582], grad_fn=<GatherBackward>)

In [23]:
act_q = legal_pred.sort(descending=True)

In [24]:
act_q

torch.return_types.sort(
values=tensor([ 5.8694,  4.9655,  4.9207,  4.6477,  4.1446,  3.9574,  3.6582,  3.4057,
         3.2860,  2.9701,  2.5463,  2.5364,  1.4723,  1.2451,  1.1878,  1.1294,
         0.9333,  0.6434,  0.4117,  0.3600,  0.2867,  0.2531, -0.3674, -0.6310,
        -0.7878, -0.8245, -0.8338, -1.2664, -1.6302, -2.6051, -3.6250, -4.6711,
        -4.9212, -5.0402, -5.7806], grad_fn=<SortBackward>),
indices=tensor([20, 15, 32, 16, 17,  3, 34, 12,  2, 22,  5, 29,  4, 25, 14, 23,  9,  8,
        28, 19, 30,  1, 27, 24,  7, 18,  6, 10, 26, 13,  0, 21, 33, 31, 11]))

In [25]:
chosen_q = act_q.values[:7]
chosen_q_idx = act_q.indices[:7]

In [26]:
chosen_q

tensor([5.8694, 4.9655, 4.9207, 4.6477, 4.1446, 3.9574, 3.6582],
       grad_fn=<SliceBackward>)

In [27]:
chosen_q_idx

tensor([20, 15, 32, 16, 17,  3, 34])

In [28]:
legal = torch.LongTensor(legal)

In [29]:
legal

tensor([ 26,  29,  30,  32,  35,  38,  41,  55,  57,  58,  61,  68,  69,  72,
         74,  75,  77,  78,  79,  81,  90,  91,  94,  96,  97,  99, 100, 101,
        103, 114, 115, 116, 117, 118, 119])

In [30]:
legal.gather(0, chosen_q_idx)

tensor([ 90,  75, 117,  77,  78,  32, 119])

In [31]:
torch.cuda.init()
device = torch.device(
    "cuda:0" if torch.cuda.is_available() else "cpu")
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x1344f2ec820>

In [36]:
player_helper.get_action_choices(132)

ValueError: setting an array element with a sequence.

In [140]:
l = np.array([[2,5],[3,7],[1,5],[3,12]])
k = np.array([[3,7],[3,12]])
a = []

for i in k:
    a.append(np.where(l == i))
a

[(array([1, 1, 3], dtype=int64), array([0, 1, 0], dtype=int64)),
 (array([1, 3, 3], dtype=int64), array([0, 0, 1], dtype=int64))]

In [41]:
from utils import build_action_table

In [53]:
actions = build_action_table()

In [54]:
actions

array([[ 0.,  1.],
       [ 0.,  2.],
       [ 0.,  3.],
       [ 0.,  4.],
       [ 0.,  5.],
       [ 0.,  6.],
       [ 0.,  7.],
       [ 0.,  8.],
       [ 0.,  9.],
       [ 0., 10.],
       [ 0., 11.],
       [ 1.,  1.],
       [ 1.,  2.],
       [ 1.,  3.],
       [ 1.,  4.],
       [ 1.,  5.],
       [ 1.,  6.],
       [ 1.,  7.],
       [ 1.,  8.],
       [ 1.,  9.],
       [ 1., 10.],
       [ 1., 11.],
       [ 2.,  1.],
       [ 2.,  2.],
       [ 2.,  3.],
       [ 2.,  4.],
       [ 2.,  5.],
       [ 2.,  6.],
       [ 2.,  7.],
       [ 2.,  8.],
       [ 2.,  9.],
       [ 2., 10.],
       [ 2., 11.],
       [ 3.,  1.],
       [ 3.,  2.],
       [ 3.,  3.],
       [ 3.,  4.],
       [ 3.,  5.],
       [ 3.,  6.],
       [ 3.,  7.],
       [ 3.,  8.],
       [ 3.,  9.],
       [ 3., 10.],
       [ 3., 11.],
       [ 4.,  1.],
       [ 4.,  2.],
       [ 4.,  3.],
       [ 4.,  4.],
       [ 4.,  5.],
       [ 4.,  6.],
       [ 4.,  7.],
       [ 4.,  8.],
       [ 4.,

In [55]:
action_idx = legal.gather(0, chosen_q_idx)

In [60]:
action_idx.numpy()

array([ 90,  75, 117,  77,  78,  32, 119], dtype=int64)

In [64]:
np.take(actions, action_idx, 0)

array([[ 8.,  3.],
       [ 6., 10.],
       [10.,  8.],
       [ 7.,  1.],
       [ 7.,  2.],
       [ 2., 11.],
       [10., 10.]])

In [68]:
actions[119]

array([10., 10.])

array([ 6., 10.])

In [82]:
q = torch.rand(16, 132)

In [83]:
q

tensor([[0.2337, 0.2023, 0.6979,  ..., 0.9916, 0.7788, 0.1082],
        [0.1911, 0.5513, 0.4918,  ..., 0.6103, 0.6163, 0.7330],
        [0.1390, 0.3027, 0.9134,  ..., 0.3006, 0.0406, 0.7331],
        ...,
        [0.7446, 0.2522, 0.2925,  ..., 0.4003, 0.0278, 0.6248],
        [0.4931, 0.7993, 0.7787,  ..., 0.5293, 0.7227, 0.3850],
        [0.2986, 0.5375, 0.0621,  ..., 0.4885, 0.8534, 0.7899]])

In [87]:
torch.rand(16,7,2, dtype=torch.long)

RuntimeError: "check_uniform_bounds" not implemented for 'Long'

In [99]:
g = torch.tensor([[7,5,1,3,4,6,4], [4,5,4,3,2,1,2]])

In [100]:
t = torch.rand(2,132)

In [101]:
t

tensor([[0.8769, 0.7713, 0.5623, 0.6054, 0.9913, 0.3439, 0.8936, 0.4777, 0.4417,
         0.6755, 0.0067, 0.8930, 0.3669, 0.8032, 0.5370, 0.5559, 0.4430, 0.3516,
         0.1542, 0.4206, 0.5969, 0.8665, 0.3750, 0.2019, 0.5501, 0.3274, 0.8157,
         0.0607, 0.5639, 0.9297, 0.4288, 0.9111, 0.0692, 0.1258, 0.1742, 0.2223,
         0.7769, 0.9685, 0.9321, 0.1690, 0.1611, 0.3566, 0.2821, 0.3076, 0.2235,
         0.9809, 0.1576, 0.6205, 0.0492, 0.5403, 0.8366, 0.9951, 0.5693, 0.9298,
         0.9073, 0.8854, 0.8172, 0.5441, 0.2569, 0.0388, 0.4351, 0.4532, 0.4570,
         0.1793, 0.3458, 0.5107, 0.0199, 0.8273, 0.7753, 0.7950, 0.8339, 0.8221,
         0.1465, 0.8349, 0.1712, 0.5736, 0.1137, 0.1839, 0.2120, 0.6139, 0.9545,
         0.3113, 0.7454, 0.5375, 0.2901, 0.3855, 0.9291, 0.7756, 0.9309, 0.3437,
         0.3393, 0.7878, 0.9354, 0.4164, 0.0565, 0.0440, 0.8819, 0.0226, 0.0838,
         0.9989, 0.0185, 0.8449, 0.4004, 0.7864, 0.8366, 0.5975, 0.1008, 0.3554,
         0.4819, 0.9273, 0.7

In [104]:
t.gather(1, g)

tensor([[0.4777, 0.3439, 0.7713, 0.6054, 0.9913, 0.8936, 0.9913],
        [0.1804, 0.0768, 0.1804, 0.6486, 0.5583, 0.1285, 0.5583]])

In [105]:
t

tensor([[0.8769, 0.7713, 0.5623, 0.6054, 0.9913, 0.3439, 0.8936, 0.4777, 0.4417,
         0.6755, 0.0067, 0.8930, 0.3669, 0.8032, 0.5370, 0.5559, 0.4430, 0.3516,
         0.1542, 0.4206, 0.5969, 0.8665, 0.3750, 0.2019, 0.5501, 0.3274, 0.8157,
         0.0607, 0.5639, 0.9297, 0.4288, 0.9111, 0.0692, 0.1258, 0.1742, 0.2223,
         0.7769, 0.9685, 0.9321, 0.1690, 0.1611, 0.3566, 0.2821, 0.3076, 0.2235,
         0.9809, 0.1576, 0.6205, 0.0492, 0.5403, 0.8366, 0.9951, 0.5693, 0.9298,
         0.9073, 0.8854, 0.8172, 0.5441, 0.2569, 0.0388, 0.4351, 0.4532, 0.4570,
         0.1793, 0.3458, 0.5107, 0.0199, 0.8273, 0.7753, 0.7950, 0.8339, 0.8221,
         0.1465, 0.8349, 0.1712, 0.5736, 0.1137, 0.1839, 0.2120, 0.6139, 0.9545,
         0.3113, 0.7454, 0.5375, 0.2901, 0.3855, 0.9291, 0.7756, 0.9309, 0.3437,
         0.3393, 0.7878, 0.9354, 0.4164, 0.0565, 0.0440, 0.8819, 0.0226, 0.0838,
         0.9989, 0.0185, 0.8449, 0.4004, 0.7864, 0.8366, 0.5975, 0.1008, 0.3554,
         0.4819, 0.9273, 0.7

In [114]:
y = t.sort(1, descending=True)

In [115]:
nextq = []
for x in y.values:
    nextq.append(x[:7])
nextq

[tensor([0.9989, 0.9951, 0.9913, 0.9809, 0.9685, 0.9557, 0.9545]),
 tensor([0.9995, 0.9871, 0.9864, 0.9840, 0.9683, 0.9655, 0.9469])]

In [116]:
torch.stack(nextq)

tensor([[0.9989, 0.9951, 0.9913, 0.9809, 0.9685, 0.9557, 0.9545],
        [0.9995, 0.9871, 0.9864, 0.9840, 0.9683, 0.9655, 0.9469]])

In [119]:
legal = legal.numpy()

In [120]:
import random

In [129]:
legal = random.sample(legal, 7)

In [130]:
legal

[32, 79, 30, 90, 99, 117, 116]

In [131]:
actions

array([[ 0.,  1.],
       [ 0.,  2.],
       [ 0.,  3.],
       [ 0.,  4.],
       [ 0.,  5.],
       [ 0.,  6.],
       [ 0.,  7.],
       [ 0.,  8.],
       [ 0.,  9.],
       [ 0., 10.],
       [ 0., 11.],
       [ 1.,  1.],
       [ 1.,  2.],
       [ 1.,  3.],
       [ 1.,  4.],
       [ 1.,  5.],
       [ 1.,  6.],
       [ 1.,  7.],
       [ 1.,  8.],
       [ 1.,  9.],
       [ 1., 10.],
       [ 1., 11.],
       [ 2.,  1.],
       [ 2.,  2.],
       [ 2.,  3.],
       [ 2.,  4.],
       [ 2.,  5.],
       [ 2.,  6.],
       [ 2.,  7.],
       [ 2.,  8.],
       [ 2.,  9.],
       [ 2., 10.],
       [ 2., 11.],
       [ 3.,  1.],
       [ 3.,  2.],
       [ 3.,  3.],
       [ 3.,  4.],
       [ 3.,  5.],
       [ 3.,  6.],
       [ 3.,  7.],
       [ 3.,  8.],
       [ 3.,  9.],
       [ 3., 10.],
       [ 3., 11.],
       [ 4.,  1.],
       [ 4.,  2.],
       [ 4.,  3.],
       [ 4.,  4.],
       [ 4.,  5.],
       [ 4.,  6.],
       [ 4.,  7.],
       [ 4.,  8.],
       [ 4.,

In [133]:
np.take(actions, legal, 0)

array([[ 2., 11.],
       [ 7.,  3.],
       [ 2.,  9.],
       [ 8.,  3.],
       [ 9.,  1.],
       [10.,  8.],
       [10.,  7.]])

In [134]:
actions[32]

array([ 2., 11.])

In [135]:
predict = torch.tensor([ 0.0000,  2.2634,  0.0000,  0.0000, 14.5552,  0.0000,  4.0787, 22.1282,
         8.6315,  0.0000,  0.0000,  2.2967,  0.0000, 11.2417,  0.0000,  0.0000,
         2.6133,  0.0000,  0.0000,  0.0000,  0.0000, 10.4586,  5.7030,  0.0000,
         0.0000, 12.8066,  0.0000,  0.0000,  0.0000,  3.9369,  0.0000,  2.3094,
         0.0000,  2.2453,  0.0000, 10.3566,  0.0000,  0.0000,  0.0000,  8.4870,
         0.0000,  0.0000,  1.9676,  7.8532,  4.9158, 12.1242,  0.0000,  0.0000,
         5.5951,  0.0000, 11.2550,  0.0000, 18.0321,  0.0000,  0.0000,  0.0000,
         7.2463,  0.0000,  0.0000,  0.1663,  0.0000,  0.0000,  0.0000,  4.2565,
         0.0000,  8.6758,  7.2351,  0.0000, 13.7532,  0.5862,  0.0000, 10.1941,
         0.0000,  8.6476,  7.2139,  2.8377,  0.0000,  0.0000,  3.3754,  0.0000,
         1.9132,  0.5925,  5.3133,  0.0000,  4.2074,  0.0000,  3.5816,  0.0000,
         1.8329, 18.0533,  4.7783,  0.0000,  8.2180,  8.0269,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  2.1318,  7.4333,  0.0000,  5.6233, 10.2923,  7.0338,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  7.2451,  0.3393,
         0.0000, 14.0444,  0.2286,  0.0000,  0.0000,  0.0000,  4.0939,  0.0000],
       device='cuda:0')

In [136]:
legal_id = torch.tensor([  2,   5,   8,  35,  36,  39,  41,  42,  55,  56,  57,  59,  70,  71,
         72,  73,  74,  75,  81,  84,  85,  87,  88,  90,  91,  94, 123, 126,
        129], device='cuda:0')

In [137]:
predict.gather(0, legal_id)

RuntimeError: CUDA error: device-side assert triggered

In [142]:
torch.rand(10).unsqueeze(1)

tensor([[0.5594],
        [0.8244],
        [0.5368],
        [0.4656],
        [0.1332],
        [0.6059],
        [0.6987],
        [0.1715],
        [0.4309],
        [0.9509]])

In [145]:
actions[45]

array([4., 2.])

In [146]:
actions[47]

array([4., 4.])

In [147]:
actions[65]

array([ 5., 11.])

In [148]:
actions[70]

array([6., 5.])

In [150]:
legal_moves(observation)

TypeError: 'numpy.ndarray' object is not callable

In [161]:
actions = [[68, 16, 121, 3, 41, 105, 89], [22, 23, 25, -1, -1, -1, -1], [18, 56, 112, 6, 33, 30, 98], [7, 63, 126, 88, 35, 108, 48], [38, 103, 120, 22, 46, 91, 122], [19, 68, 81, 98, 40, 116, 130], [69, 41, 99, 116, 4, 49, 32], [111, 29, 15, 68, 107, 91, 5], [13, 38, 102, 0, 30, 61, 131], [90, 37, 78, 130, 14, 72, 32]]

In [162]:
preds

tensor([[0.6413, 0.4148, 0.7435,  ..., 0.0866, 0.2352, 0.9963],
        [0.1508, 0.5930, 0.3212,  ..., 0.1832, 0.0636, 0.1680],
        [0.8078, 0.7295, 0.4701,  ..., 0.6412, 0.2963, 0.0509],
        ...,
        [0.9721, 0.3197, 0.7892,  ..., 0.5335, 0.1455, 0.0773],
        [0.8056, 0.4343, 0.9186,  ..., 0.5955, 0.6109, 0.3155],
        [0.4338, 0.2608, 0.6878,  ..., 0.0492, 0.5824, 0.8812]])

In [163]:
preds = torch.rand(10,132)

In [177]:
new_current_Q = []
for idx, index_group in enumerate(actions):
    x = []
    for i in index_group:
        if i > -1:
            x.append(preds[idx][i])
        else:
            x.append(torch.tensor(0))
    new_current_Q.append(torch.stack(x))
new_current_Q

    

[tensor([0.8963, 0.1302, 0.1358, 0.7783, 0.0163, 0.8492, 0.6282]),
 tensor([0.2719, 0.2574, 0.4111, 0.0000, 0.0000, 0.0000, 0.0000]),
 tensor([0.3639, 0.1463, 0.6224, 0.7311, 0.8257, 0.5302, 0.8953]),
 tensor([0.7730, 0.5497, 0.8220, 0.0928, 0.4209, 0.7917, 0.1120]),
 tensor([0.7865, 0.4682, 0.4233, 0.0581, 0.5598, 0.6818, 0.8211]),
 tensor([0.4206, 0.4238, 0.0345, 0.4168, 0.0360, 0.5273, 0.3891]),
 tensor([0.4979, 0.2559, 0.3107, 0.0808, 0.3374, 0.6515, 0.6854]),
 tensor([0.9183, 0.6732, 0.5718, 0.1405, 0.3952, 0.8155, 0.8593]),
 tensor([0.3791, 0.9320, 0.8433, 0.9964, 0.2736, 0.5116, 0.4739]),
 tensor([0.9892, 0.3174, 0.7957, 0.8887, 0.2609, 0.0754, 0.9617])]

In [180]:
obs = torch.rand(10,105)

In [181]:
preds = model(obs)

In [183]:
preds.shape

torch.Size([10, 132])

In [192]:
new_current_Q = []
for idx, index_group in enumerate(actions):
    x = []
    for i in index_group:
        if i > -1:
            x.append(preds[idx][i])
        else:
            x.append(torch.tensor(0))
    new_current_Q.append(torch.stack(x))
new_current_Q = torch.stack(new_current_Q)
new_current_Q

tensor([[-0.1499, -0.0017, -0.0488,  0.0565,  0.0969, -0.0118,  0.0965],
        [-0.0175, -0.0862, -0.1245,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.0292,  0.0371, -0.0444,  0.0344, -0.0383, -0.1055, -0.0031],
        [-0.0654,  0.0490, -0.0196, -0.0004,  0.0155,  0.0110,  0.0505],
        [-0.0663, -0.0183, -0.0440, -0.0229,  0.0291, -0.1092, -0.0227],
        [ 0.0147, -0.1338, -0.0519, -0.0036, -0.0708, -0.0262, -0.0893],
        [ 0.0266,  0.0864,  0.0687, -0.0225, -0.0595,  0.0372,  0.0373],
        [ 0.0331,  0.0439,  0.0463, -0.1258, -0.0217, -0.0968,  0.0181],
        [-0.0286, -0.0662,  0.0425,  0.0499, -0.0778, -0.0522, -0.0493],
        [ 0.0258, -0.0171,  0.0629, -0.0868, -0.0998,  0.0228,  0.0634]],
       grad_fn=<StackBackward>)

In [190]:
new_current_Q[1][3]

tensor(0., grad_fn=<SelectBackward>)

In [195]:
a = torch.tensor([4,5,7])
b = torch.tensor([0])
for i in range(2):
    a = torch.cat((a,b))
a

tensor([4, 5, 7, 0, 0])

In [196]:
list1 = np.array([])

In [198]:
list1.shape

(0,)

In [199]:
list1.reshape(0,2)

array([], shape=(0, 2), dtype=float64)