In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from config import Configuration
from noisy_net import NoisyLinear, NoisyFactorizedLinear
from utils import state_pre_processing
from OneHotEncode import OneHotEncode

class BranchingQNetwork(nn.Module):
    def __init__(self, observation_space, action_space, hidden_dim, exploration_method="Dueling", architecture="DQN"):
        super().__init__()
        self.architecture = architecture
        self.model = nn.ModuleList([nn.Sequential(
            nn.Linear(62, hidden_dim*4),
            nn.ReLU(),
            nn.Linear(hidden_dim*4, hidden_dim*2),
            nn.ReLU(),
            nn.Linear(hidden_dim*2, hidden_dim),
            nn.ReLU()
        ) for i in range(12)])
        if self.architecture == "Dueling":
            self.value_head = nn.ModuleList([nn.Linear(hidden_dim, 1) for i in range(12)])
            self.adv_heads = nn.ModuleList([nn.Linear(hidden_dim, 11) for i in range(12)])
        else:
            self.out = nn.ModuleList([nn.Linear(hidden_dim, 11) for i in range(12)])

    def forward(self, x):
        processed_x = self.state_processing(x)
        #print(processed_x)
        layer1 = torch.stack([self.model[i](processed_x[i]) for i, _ in enumerate(processed_x)])
        if self.architecture == "Dueling":
            value = torch.stack([self.value_head[i](layer1[i]) for i, _ in enumerate(layer1)])
            advs = torch.stack([self.adv_heads[i](layer1[i]) for i, _ in enumerate(layer1)])
            q_val = value + advs - advs.mean()
        else:
            q_val = torch.stack([self.out[i](layer1[i]) for i, _ in enumerate(layer1)])
            
        return q_val
    
    def state_processing(self, obs):
        node_info = obs[:45]
        groups_info = obs[45:]
        partitions = [17 for i in range(12)]
        groups = torch.split(groups_info, partitions)
        groups_final = [torch.cat((node_info, groups[i])) for i in range(len(groups))]
        return groups_final
   

In [3]:
config = Configuration("configs/config.json")
torch.cuda.init()
device = torch.device(
    config.device if torch.cuda.is_available() else "cpu")
torch.autograd.set_detect_anomaly(True)


<torch.autograd.anomaly_mode.set_detect_anomaly at 0x1f9e64d27c0>

In [4]:
obs = [  64.,    0.,    0.,  500.,    0.,    0.,    1.,  100.,    8.,
          0.,    0.,  100.,    0.,    1.,    0.,  100.,    0.,    0.,
          0.,  -88.,    0.,    0.,    0., -100.,   16.,    0.,    0.,
        100.,   12.,    0.,    1., -100.,   16.,    0.,    0.,  -50.,
          7.,    1.,    0.,  -13.,    8.,    0.,    0., -500.,   32.,
          3.,    1.,   94.,    1.,    8.,   10.,    2.,   56.,    1.,
          7.,    8.,    0.,   93.,    0.,    8.,    6.,    1.,   58.,
          0.,    8.,    8.,    2.,    0.,    0.,    0.,    4.,    0.,
        100.,    0.,    8.,    7.,    1.,   90.,    0.,    8.,    2.,
          2.,  100.,    0.,    8.,    7.,    0.,  100.,    0.,    8.,
          2.,    1.,   15.,    0.,    7.,    9.,    2.,   70.,    0.,
          8.,    2.,    0.,   91.,    1.,   12.]

In [5]:
new_obs = OneHotEncode(obs)

In [6]:
new_obs = torch.tensor(new_obs).float().to(device)
new_obs

tensor([  64.,    0.,    0.,  500.,    0.,    0.,    1.,  100.,    8.,    0.,
           0.,  100.,    0.,    1.,    0.,  100.,    0.,    0.,    0.,  -88.,
           0.,    0.,    0., -100.,   16.,    0.,    0.,  100.,   12.,    0.,
           1., -100.,   16.,    0.,    0.,  -50.,    7.,    1.,    0.,  -13.,
           8.,    0.,    0., -500.,   32.,    0.,    0.,    0.,    1.,    0.,
           0.,    0.,    0.,    0.,    0.,    0.,    0.,    1.,    0.,   94.,
           1.,    8.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
           0.,    0.,    1.,    0.,    0.,    1.,   56.,    1.,    7.,    0.,
           0.,    0.,    0.,    0.,    0.,    0.,    0.,    1.,    0.,    0.,
           1.,    0.,    0.,   93.,    0.,    8.,    0.,    0.,    0.,    0.,
           0.,    0.,    1.,    0.,    0.,    0.,    0.,    0.,    1.,    0.,
          58.,    0.,    8.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
           0.,    1.,    0.,    0.,    0.,    0.,    1.,    0., 

In [7]:
model = BranchingQNetwork(249,11,128)
model.to(device)

BranchingQNetwork(
  (model): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=62, out_features=512, bias=True)
      (1): ReLU()
      (2): Linear(in_features=512, out_features=256, bias=True)
      (3): ReLU()
      (4): Linear(in_features=256, out_features=128, bias=True)
      (5): ReLU()
    )
    (1): Sequential(
      (0): Linear(in_features=62, out_features=512, bias=True)
      (1): ReLU()
      (2): Linear(in_features=512, out_features=256, bias=True)
      (3): ReLU()
      (4): Linear(in_features=256, out_features=128, bias=True)
      (5): ReLU()
    )
    (2): Sequential(
      (0): Linear(in_features=62, out_features=512, bias=True)
      (1): ReLU()
      (2): Linear(in_features=512, out_features=256, bias=True)
      (3): ReLU()
      (4): Linear(in_features=256, out_features=128, bias=True)
      (5): ReLU()
    )
    (3): Sequential(
      (0): Linear(in_features=62, out_features=512, bias=True)
      (1): ReLU()
      (2): Linear(in_features=512, out_f

In [8]:
out = model(new_obs)
print(out)
out_max = out.max(1)
print(out_max)
out_max_sorted = out_max.values.sort(descending=True)
print(out_max_sorted)
chosen_group = out_max_sorted.indices[:7]
chosen_location = torch.stack([out_max.indices[i] for i in chosen_group]) + 1
action = torch.stack([chosen_group, chosen_location], dim = 1)
action.detach().cpu().numpy()

tensor([[ 2.9020, -2.6777, -0.8610, -1.3483, -0.7833, -8.3360, -4.6109,  2.2436,
          1.3109,  5.1803, -0.7092],
        [ 1.4493,  3.9153,  6.1884, -1.7723,  1.5198,  6.0655,  4.6400,  2.4675,
          2.3097,  3.5910, -6.4122],
        [ 0.5694,  0.8096, -0.0370,  1.8791,  3.9980, -4.0703, -1.9475, -2.6174,
          0.2954,  0.2050,  0.1179],
        [ 7.0691,  1.5045,  1.5875,  0.6541,  5.8861,  0.5254,  3.2723,  2.5873,
          4.3878,  5.8561,  1.8414],
        [-2.5419, -4.6274, -5.4703,  3.0516, -4.0267,  1.8106, -5.2737, -1.8081,
          1.4377,  0.1265, -4.3325],
        [-0.6096,  4.3651, -0.2477,  2.2767,  2.0451, -3.4724, -2.3883,  1.2757,
          5.6066, -2.4234, -2.2584],
        [ 2.8746,  5.3597,  4.5418,  5.9028, -3.7635, -0.2702, -1.5322,  0.7782,
          2.7711,  2.9967,  3.3838],
        [-5.4440, -2.2268, -3.6689, -0.3324, -3.8233,  2.7718, -2.4710,  1.5707,
          2.9060,  0.8373, -3.5100],
        [ 5.0995, -3.0545, -0.7478,  1.5264,  2.2042,  2

array([[ 9,  4],
       [10,  8],
       [ 3,  1],
       [11,  7],
       [ 1,  3],
       [ 6,  4],
       [ 5,  9]], dtype=int64)

In [9]:
torch.argmax(out,1)

tensor([9, 2, 4, 0, 3, 8, 3, 8, 0, 3, 7, 6], device='cuda:0')

In [10]:
q = out.max(1, keepdim=True)
q 

torch.return_types.max(
values=tensor([[5.1803],
        [6.1884],
        [3.9980],
        [7.0691],
        [3.0516],
        [5.6066],
        [5.9028],
        [2.9060],
        [5.0995],
        [9.7555],
        [8.6733],
        [6.3694]], device='cuda:0', grad_fn=<MaxBackward0>),
indices=tensor([[9],
        [2],
        [4],
        [0],
        [3],
        [8],
        [3],
        [8],
        [0],
        [3],
        [7],
        [6]], device='cuda:0'))

In [11]:
g = torch.tensor([5])
for i in range(5):
    g = torch.stack((g,torch.tensor([5])),1)
g

RuntimeError: stack expects each tensor to be equal size, but got [1, 2] at entry 0 and [1] at entry 1

In [12]:
a = torch.rand(16)
a = a.reshape(16,1,1)
a.shape

torch.Size([16, 1, 1])

In [13]:
a

tensor([[[0.2276]],

        [[0.4942]],

        [[0.9266]],

        [[0.0940]],

        [[0.2169]],

        [[0.5311]],

        [[0.6452]],

        [[0.2567]],

        [[0.4689]],

        [[0.6318]],

        [[0.2457]],

        [[0.1645]],

        [[0.1003]],

        [[0.1968]],

        [[0.0496]],

        [[0.2462]]])

In [14]:
b = torch.rand(16,1,11)
b

tensor([[[0.1409, 0.9135, 0.1956, 0.5351, 0.8767, 0.9059, 0.9668, 0.5254,
          0.8901, 0.8311, 0.1479]],

        [[0.5783, 0.3645, 0.5010, 0.9720, 0.0453, 0.0646, 0.7823, 0.3192,
          0.6586, 0.6066, 0.0733]],

        [[0.0016, 0.4700, 0.8568, 0.4734, 0.2790, 0.8545, 0.7776, 0.6796,
          0.0137, 0.6813, 0.0604]],

        [[0.1228, 0.6575, 0.6607, 0.7180, 0.0462, 0.6034, 0.6933, 0.0272,
          0.0079, 0.1705, 0.3592]],

        [[0.1513, 0.1571, 0.9631, 0.4855, 0.3229, 0.9497, 0.4082, 0.8003,
          0.3726, 0.6891, 0.5113]],

        [[0.3264, 0.3929, 0.2047, 0.5997, 0.6581, 0.2868, 0.6920, 0.2254,
          0.6807, 0.5510, 0.1217]],

        [[0.1221, 0.7400, 0.6622, 0.1505, 0.9949, 0.0253, 0.0817, 0.2848,
          0.8047, 0.3968, 0.5452]],

        [[0.8734, 0.0869, 0.5105, 0.0908, 0.6852, 0.2858, 0.8747, 0.4453,
          0.8012, 0.3103, 0.6997]],

        [[0.6785, 0.7149, 0.3171, 0.4268, 0.2233, 0.7391, 0.7505, 0.2933,
          0.1239, 0.4385, 0.5474]],

 

In [15]:
a+b

tensor([[[0.3685, 1.1410, 0.4232, 0.7627, 1.1043, 1.1335, 1.1944, 0.7529,
          1.1176, 1.0586, 0.3754]],

        [[1.0725, 0.8587, 0.9952, 1.4662, 0.5395, 0.5588, 1.2765, 0.8133,
          1.1527, 1.1007, 0.5675]],

        [[0.9282, 1.3966, 1.7835, 1.4000, 1.2056, 1.7812, 1.7042, 1.6062,
          0.9403, 1.6079, 0.9871]],

        [[0.2168, 0.7515, 0.7546, 0.8120, 0.1401, 0.6974, 0.7873, 0.1212,
          0.1019, 0.2645, 0.4532]],

        [[0.3683, 0.3740, 1.1800, 0.7024, 0.5399, 1.1666, 0.6251, 1.0172,
          0.5895, 0.9061, 0.7282]],

        [[0.8575, 0.9240, 0.7359, 1.1309, 1.1892, 0.8179, 1.2232, 0.7565,
          1.2118, 1.0821, 0.6529]],

        [[0.7672, 1.3851, 1.3073, 0.7957, 1.6401, 0.6705, 0.7269, 0.9299,
          1.4499, 1.0419, 1.1904]],

        [[1.1301, 0.3437, 0.7672, 0.3475, 0.9420, 0.5425, 1.1314, 0.7020,
          1.0579, 0.5670, 0.9564]],

        [[1.1473, 1.1837, 0.7860, 0.8956, 0.6922, 1.2079, 1.2194, 0.7622,
          0.5927, 0.9074, 1.0163]],

 

In [16]:
for p in model.parameters():
    print(p.shape)

torch.Size([512, 62])
torch.Size([512])
torch.Size([256, 512])
torch.Size([256])
torch.Size([128, 256])
torch.Size([128])
torch.Size([512, 62])
torch.Size([512])
torch.Size([256, 512])
torch.Size([256])
torch.Size([128, 256])
torch.Size([128])
torch.Size([512, 62])
torch.Size([512])
torch.Size([256, 512])
torch.Size([256])
torch.Size([128, 256])
torch.Size([128])
torch.Size([512, 62])
torch.Size([512])
torch.Size([256, 512])
torch.Size([256])
torch.Size([128, 256])
torch.Size([128])
torch.Size([512, 62])
torch.Size([512])
torch.Size([256, 512])
torch.Size([256])
torch.Size([128, 256])
torch.Size([128])
torch.Size([512, 62])
torch.Size([512])
torch.Size([256, 512])
torch.Size([256])
torch.Size([128, 256])
torch.Size([128])
torch.Size([512, 62])
torch.Size([512])
torch.Size([256, 512])
torch.Size([256])
torch.Size([128, 256])
torch.Size([128])
torch.Size([512, 62])
torch.Size([512])
torch.Size([256, 512])
torch.Size([256])
torch.Size([128, 256])
torch.Size([128])
torch.Size([512, 62])
to

In [17]:
st = torch.rand(16,12,62)
st

tensor([[[9.0634e-01, 3.6819e-01, 1.6183e-01,  ..., 7.5799e-01,
          8.4424e-01, 4.2412e-01],
         [6.0448e-01, 5.4335e-01, 7.4203e-01,  ..., 7.0699e-01,
          2.7163e-01, 5.6318e-01],
         [4.8601e-01, 8.6404e-01, 8.9060e-01,  ..., 1.6247e-01,
          4.6965e-01, 7.0303e-01],
         ...,
         [4.9919e-02, 9.8319e-01, 8.8738e-01,  ..., 4.9448e-01,
          7.5336e-01, 4.9672e-01],
         [1.3585e-01, 7.6376e-01, 7.1237e-01,  ..., 5.7086e-01,
          1.8680e-01, 8.0998e-01],
         [7.3872e-01, 9.8578e-01, 8.4271e-01,  ..., 9.4244e-01,
          4.2971e-01, 3.4733e-01]],

        [[9.4890e-01, 8.7925e-01, 7.4226e-01,  ..., 2.3118e-01,
          6.8572e-01, 8.9810e-01],
         [1.3254e-01, 9.7431e-01, 3.0210e-01,  ..., 6.2608e-01,
          9.6138e-01, 4.9180e-01],
         [7.0091e-01, 2.5027e-01, 2.3231e-01,  ..., 5.8622e-01,
          8.2803e-01, 1.6477e-01],
         ...,
         [6.6064e-01, 5.3339e-01, 7.8149e-01,  ..., 9.4813e-01,
          2.132

In [18]:
idx = [[n for x in range(12)]
# for n in range(1):
#     idx.append([[x for x in range(62)] for k in range(12)])
# idx = torch.LongTensor(idx)
# idx

SyntaxError: unexpected EOF while parsing (<ipython-input-18-d425bf9ff9b9>, line 5)

In [19]:
torch.stack([st[0][0], st[1][0]])

tensor([[9.0634e-01, 3.6819e-01, 1.6183e-01, 6.1783e-02, 7.5748e-01, 5.3435e-01,
         9.0489e-01, 4.2322e-01, 2.1452e-02, 3.4708e-01, 4.6065e-01, 9.6804e-04,
         4.4811e-01, 2.5523e-01, 5.1219e-01, 6.0646e-01, 8.9461e-01, 3.6300e-01,
         5.1894e-01, 1.5556e-01, 3.5789e-01, 4.6786e-01, 9.4625e-02, 4.5425e-01,
         3.3776e-01, 9.0104e-01, 4.4940e-01, 4.1516e-01, 3.5882e-01, 2.1876e-01,
         8.6680e-01, 8.4617e-01, 2.0557e-01, 5.8066e-01, 5.9418e-01, 3.7971e-02,
         3.7494e-01, 6.6955e-01, 1.2190e-01, 7.4618e-01, 4.1675e-01, 6.7977e-01,
         3.6369e-01, 7.9110e-02, 2.5585e-01, 6.1226e-01, 5.1925e-01, 3.1409e-02,
         5.4210e-02, 8.5982e-01, 8.9860e-01, 3.9133e-01, 8.1209e-01, 2.8309e-01,
         6.5712e-01, 2.8152e-01, 4.7462e-01, 8.4409e-01, 4.1831e-01, 7.5799e-01,
         8.4424e-01, 4.2412e-01],
        [9.4890e-01, 8.7925e-01, 7.4226e-01, 5.0210e-02, 7.8545e-01, 4.0067e-01,
         6.7424e-01, 4.2624e-01, 9.8032e-02, 2.4611e-01, 2.4971e-01, 3.1845

In [20]:
st.transpose(0,1).shape

torch.Size([12, 16, 62])

In [21]:
st[0]

tensor([[9.0634e-01, 3.6819e-01, 1.6183e-01, 6.1783e-02, 7.5748e-01, 5.3435e-01,
         9.0489e-01, 4.2322e-01, 2.1452e-02, 3.4708e-01, 4.6065e-01, 9.6804e-04,
         4.4811e-01, 2.5523e-01, 5.1219e-01, 6.0646e-01, 8.9461e-01, 3.6300e-01,
         5.1894e-01, 1.5556e-01, 3.5789e-01, 4.6786e-01, 9.4625e-02, 4.5425e-01,
         3.3776e-01, 9.0104e-01, 4.4940e-01, 4.1516e-01, 3.5882e-01, 2.1876e-01,
         8.6680e-01, 8.4617e-01, 2.0557e-01, 5.8066e-01, 5.9418e-01, 3.7971e-02,
         3.7494e-01, 6.6955e-01, 1.2190e-01, 7.4618e-01, 4.1675e-01, 6.7977e-01,
         3.6369e-01, 7.9110e-02, 2.5585e-01, 6.1226e-01, 5.1925e-01, 3.1409e-02,
         5.4210e-02, 8.5982e-01, 8.9860e-01, 3.9133e-01, 8.1209e-01, 2.8309e-01,
         6.5712e-01, 2.8152e-01, 4.7462e-01, 8.4409e-01, 4.1831e-01, 7.5799e-01,
         8.4424e-01, 4.2412e-01],
        [6.0448e-01, 5.4335e-01, 7.4203e-01, 5.7446e-01, 6.5888e-01, 6.4708e-01,
         5.6514e-01, 8.0336e-01, 6.5059e-01, 3.9673e-01, 6.3339e-01, 5.4990

In [22]:
st[15][0]

tensor([0.5965, 0.3781, 0.1513, 0.1158, 0.6886, 0.8067, 0.6663, 0.5890, 0.6116,
        0.2014, 0.5506, 0.6796, 0.4135, 0.6969, 0.2634, 0.0686, 0.5530, 0.6165,
        0.3308, 0.3186, 0.9145, 0.2548, 0.1336, 0.4622, 0.1730, 0.8474, 0.4329,
        0.7840, 0.1798, 0.1501, 0.1951, 0.3263, 0.5000, 0.9071, 0.5895, 0.7884,
        0.9110, 0.7678, 0.2392, 0.0749, 0.3178, 0.3750, 0.0865, 0.0119, 0.7775,
        0.8067, 0.7773, 0.1547, 0.7555, 0.2847, 0.8508, 0.3132, 0.1638, 0.2782,
        0.1863, 0.8209, 0.3369, 0.6736, 0.1242, 0.7732, 0.4156, 0.9000])

In [23]:
r = torch.rand(12,16,11)
r

tensor([[[0.9104, 0.3321, 0.1306,  ..., 0.6653, 0.6279, 0.8036],
         [0.8869, 0.3657, 0.0304,  ..., 0.4979, 0.1830, 0.4993],
         [0.6165, 0.2715, 0.8923,  ..., 0.6384, 0.1819, 0.5126],
         ...,
         [0.2963, 0.0012, 0.9134,  ..., 0.9108, 0.8808, 0.3757],
         [0.0814, 0.1206, 0.6598,  ..., 0.7502, 0.9866, 0.4247],
         [0.6967, 0.8306, 0.0088,  ..., 0.7491, 0.5935, 0.8041]],

        [[0.2160, 0.3457, 0.9077,  ..., 0.9212, 0.5426, 0.7474],
         [0.8171, 0.7326, 0.5667,  ..., 0.3648, 0.7438, 0.2154],
         [0.5241, 0.8059, 0.3340,  ..., 0.0410, 0.5450, 0.7202],
         ...,
         [0.5238, 0.0643, 0.9787,  ..., 0.8499, 0.6198, 0.0088],
         [0.7577, 0.6400, 0.7444,  ..., 0.6953, 0.2089, 0.6021],
         [0.7468, 0.4580, 0.5644,  ..., 0.5501, 0.0760, 0.1754]],

        [[0.4732, 0.0541, 0.8589,  ..., 0.3244, 0.0657, 0.6186],
         [0.5070, 0.3726, 0.9182,  ..., 0.3969, 0.9317, 0.8615],
         [0.9636, 0.5885, 0.7955,  ..., 0.2294, 0.0740, 0.

In [24]:
a = torch.rand(16,7,2).long()

In [25]:
a.gather(0, torch.LongTensor([16,7,0]))

RuntimeError: Index tensor must have the same number of dimensions as input tensor

In [147]:
a = torch.tensor([[[ 2,  5],
         [11, 11],
         [ 8,  2],
         [10,  6],
         [ 3,  9],
         [ 4,  1],
         [ 5,  8]],

        [[ 1,  5],
         [ 5,  7],
         [ 8,  2],
         [ 0, 10],
         [10,  6],
         [11,  1],
         [ 7,  4]],

        [[ 0,  7],
         [10,  2],
         [ 7,  9],
         [ 6,  5],
         [ 1, 11],
         [ 3, 10],
         [ 8,  6]],

        [[ 1,  9],
         [ 5,  3],
         [ 4, 10],
         [ 3, 11],
         [ 8,  6],
         [11,  5],
         [ 7,  8]],

        [[ 7,  2],
         [11,  7],
         [ 3,  8],
         [10, 10],
         [ 6,  9],
         [ 4,  6],
         [ 8,  1]],

        [[ 9,  5],
         [10,  5],
         [11,  5],
         [ 0,  6],
         [ 1,  6],
         [ 2,  6],
         [ 3,  6]],

        [[ 0,  1],
         [11,  4],
         [ 2,  8],
         [ 5, 11],
         [ 9, 10],
         [ 4,  7],
         [ 3,  3]],

        [[ 8, 10],
         [ 0,  1],
         [ 3,  6],
         [10,  5],
         [ 6, 11],
         [11,  3],
         [ 7,  4]],

        [[10,  2],
         [ 1,  5],
         [11,  9],
         [ 4,  1],
         [ 3,  4],
         [ 5,  6],
         [ 7,  3]],

        [[ 5,  1],
         [ 1,  9],
         [ 2, 11],
         [ 3,  2],
         [ 6,  8],
         [11,  7],
         [ 9,  4]],

        [[10,  5],
         [ 1,  4],
         [11,  3],
         [ 4,  2],
         [ 9,  9],
         [ 3,  8],
         [ 8,  6]],

        [[ 3,  5],
         [ 7,  6],
         [ 4,  9],
         [ 1,  7],
         [ 5,  1],
         [ 6,  8],
         [ 0,  4]],

        [[ 7,  6],
         [ 9,  4],
         [10, 10],
         [ 1,  9],
         [ 5,  3],
         [ 2, 11],
         [ 4,  8]],

        [[10,  6],
         [ 2,  7],
         [ 4, 11],
         [ 9,  8],
         [ 1,  3],
         [ 7,  5],
         [ 0,  2]],

        [[ 4,  2],
         [ 6,  1],
         [ 2,  6],
         [ 1, 10],
         [ 8,  5],
         [ 0, 11],
         [ 7,  4]],

        [[ 5,  7],
         [ 2,  7],
         [ 8,  7],
         [ 1, 10],
         [10, 10],
         [11,  7],
         [ 4, 10]]])

In [148]:
a.shape

torch.Size([16, 7, 2])

In [149]:
a = a.transpose(1,2)
a

tensor([[[ 2, 11,  8, 10,  3,  4,  5],
         [ 5, 11,  2,  6,  9,  1,  8]],

        [[ 1,  5,  8,  0, 10, 11,  7],
         [ 5,  7,  2, 10,  6,  1,  4]],

        [[ 0, 10,  7,  6,  1,  3,  8],
         [ 7,  2,  9,  5, 11, 10,  6]],

        [[ 1,  5,  4,  3,  8, 11,  7],
         [ 9,  3, 10, 11,  6,  5,  8]],

        [[ 7, 11,  3, 10,  6,  4,  8],
         [ 2,  7,  8, 10,  9,  6,  1]],

        [[ 9, 10, 11,  0,  1,  2,  3],
         [ 5,  5,  5,  6,  6,  6,  6]],

        [[ 0, 11,  2,  5,  9,  4,  3],
         [ 1,  4,  8, 11, 10,  7,  3]],

        [[ 8,  0,  3, 10,  6, 11,  7],
         [10,  1,  6,  5, 11,  3,  4]],

        [[10,  1, 11,  4,  3,  5,  7],
         [ 2,  5,  9,  1,  4,  6,  3]],

        [[ 5,  1,  2,  3,  6, 11,  9],
         [ 1,  9, 11,  2,  8,  7,  4]],

        [[10,  1, 11,  4,  9,  3,  8],
         [ 5,  4,  3,  2,  9,  8,  6]],

        [[ 3,  7,  4,  1,  5,  6,  0],
         [ 5,  6,  9,  7,  1,  8,  4]],

        [[ 7,  9, 10,  1,  5,  2,  4],
 

In [150]:
n = a[:,0]
l = a[:,1]

In [151]:
n

tensor([[ 2, 11,  8, 10,  3,  4,  5],
        [ 1,  5,  8,  0, 10, 11,  7],
        [ 0, 10,  7,  6,  1,  3,  8],
        [ 1,  5,  4,  3,  8, 11,  7],
        [ 7, 11,  3, 10,  6,  4,  8],
        [ 9, 10, 11,  0,  1,  2,  3],
        [ 0, 11,  2,  5,  9,  4,  3],
        [ 8,  0,  3, 10,  6, 11,  7],
        [10,  1, 11,  4,  3,  5,  7],
        [ 5,  1,  2,  3,  6, 11,  9],
        [10,  1, 11,  4,  9,  3,  8],
        [ 3,  7,  4,  1,  5,  6,  0],
        [ 7,  9, 10,  1,  5,  2,  4],
        [10,  2,  4,  9,  1,  7,  0],
        [ 4,  6,  2,  1,  8,  0,  7],
        [ 5,  2,  8,  1, 10, 11,  4]])

In [152]:
l 

tensor([[ 5, 11,  2,  6,  9,  1,  8],
        [ 5,  7,  2, 10,  6,  1,  4],
        [ 7,  2,  9,  5, 11, 10,  6],
        [ 9,  3, 10, 11,  6,  5,  8],
        [ 2,  7,  8, 10,  9,  6,  1],
        [ 5,  5,  5,  6,  6,  6,  6],
        [ 1,  4,  8, 11, 10,  7,  3],
        [10,  1,  6,  5, 11,  3,  4],
        [ 2,  5,  9,  1,  4,  6,  3],
        [ 1,  9, 11,  2,  8,  7,  4],
        [ 5,  4,  3,  2,  9,  8,  6],
        [ 5,  6,  9,  7,  1,  8,  4],
        [ 6,  4, 10,  9,  3, 11,  8],
        [ 6,  7, 11,  8,  3,  5,  2],
        [ 2,  1,  6, 10,  5, 11,  4],
        [ 7,  7,  7, 10, 10,  7, 10]])

In [153]:
l = l - 1

In [154]:
a[:,1] = a[:,1] - 1


In [155]:
a

tensor([[[ 2, 11,  8, 10,  3,  4,  5],
         [ 4, 10,  1,  5,  8,  0,  7]],

        [[ 1,  5,  8,  0, 10, 11,  7],
         [ 4,  6,  1,  9,  5,  0,  3]],

        [[ 0, 10,  7,  6,  1,  3,  8],
         [ 6,  1,  8,  4, 10,  9,  5]],

        [[ 1,  5,  4,  3,  8, 11,  7],
         [ 8,  2,  9, 10,  5,  4,  7]],

        [[ 7, 11,  3, 10,  6,  4,  8],
         [ 1,  6,  7,  9,  8,  5,  0]],

        [[ 9, 10, 11,  0,  1,  2,  3],
         [ 4,  4,  4,  5,  5,  5,  5]],

        [[ 0, 11,  2,  5,  9,  4,  3],
         [ 0,  3,  7, 10,  9,  6,  2]],

        [[ 8,  0,  3, 10,  6, 11,  7],
         [ 9,  0,  5,  4, 10,  2,  3]],

        [[10,  1, 11,  4,  3,  5,  7],
         [ 1,  4,  8,  0,  3,  5,  2]],

        [[ 5,  1,  2,  3,  6, 11,  9],
         [ 0,  8, 10,  1,  7,  6,  3]],

        [[10,  1, 11,  4,  9,  3,  8],
         [ 4,  3,  2,  1,  8,  7,  5]],

        [[ 3,  7,  4,  1,  5,  6,  0],
         [ 4,  5,  8,  6,  0,  7,  3]],

        [[ 7,  9, 10,  1,  5,  2,  4],
 

In [156]:
r[2].shape

torch.Size([16, 11])

In [157]:
r[2].transpose(0,1)[7]

tensor([0.4852, 0.6873, 0.7841, 0.8713, 0.2902, 0.7676, 0.0334, 0.0519, 0.3088,
        0.8429, 0.1501, 0.0591, 0.9100, 0.6772, 0.8635, 0.9863])

In [158]:
r[2]

tensor([[0.4732, 0.0541, 0.8589, 0.2897, 0.1120, 0.5677, 0.7250, 0.4852, 0.3244,
         0.0657, 0.6186],
        [0.5070, 0.3726, 0.9182, 0.7436, 0.3511, 0.2861, 0.5697, 0.6873, 0.3969,
         0.9317, 0.8615],
        [0.9636, 0.5885, 0.7955, 0.8832, 0.2987, 0.6666, 0.0843, 0.7841, 0.2294,
         0.0740, 0.4479],
        [0.2490, 0.2594, 0.6460, 0.0783, 0.9437, 0.7117, 0.3293, 0.8713, 0.4490,
         0.7669, 0.6650],
        [0.6362, 0.2256, 0.5302, 0.9728, 0.6515, 0.0619, 0.2034, 0.2902, 0.8987,
         0.6412, 0.5201],
        [0.7204, 0.4928, 0.9359, 0.6027, 0.6712, 0.2136, 0.0554, 0.7676, 0.6041,
         0.8705, 0.3043],
        [0.3892, 0.6285, 0.3813, 0.3496, 0.9412, 0.4949, 0.0850, 0.0334, 0.6270,
         0.9636, 0.7937],
        [0.8418, 0.8817, 0.8408, 0.0415, 0.3436, 0.7089, 0.1910, 0.0519, 0.5899,
         0.3183, 0.8764],
        [0.6120, 0.3214, 0.3186, 0.1770, 0.1759, 0.0188, 0.9392, 0.3088, 0.3144,
         0.2818, 0.4104],
        [0.4419, 0.0804, 0.5903, 0.82

In [159]:
r

tensor([[[0.9104, 0.3321, 0.1306,  ..., 0.6653, 0.6279, 0.8036],
         [0.8869, 0.3657, 0.0304,  ..., 0.4979, 0.1830, 0.4993],
         [0.6165, 0.2715, 0.8923,  ..., 0.6384, 0.1819, 0.5126],
         ...,
         [0.2963, 0.0012, 0.9134,  ..., 0.9108, 0.8808, 0.3757],
         [0.0814, 0.1206, 0.6598,  ..., 0.7502, 0.9866, 0.4247],
         [0.6967, 0.8306, 0.0088,  ..., 0.7491, 0.5935, 0.8041]],

        [[0.2160, 0.3457, 0.9077,  ..., 0.9212, 0.5426, 0.7474],
         [0.8171, 0.7326, 0.5667,  ..., 0.3648, 0.7438, 0.2154],
         [0.5241, 0.8059, 0.3340,  ..., 0.0410, 0.5450, 0.7202],
         ...,
         [0.5238, 0.0643, 0.9787,  ..., 0.8499, 0.6198, 0.0088],
         [0.7577, 0.6400, 0.7444,  ..., 0.6953, 0.2089, 0.6021],
         [0.7468, 0.4580, 0.5644,  ..., 0.5501, 0.0760, 0.1754]],

        [[0.4732, 0.0541, 0.8589,  ..., 0.3244, 0.0657, 0.6186],
         [0.5070, 0.3726, 0.9182,  ..., 0.3969, 0.9317, 0.8615],
         [0.9636, 0.5885, 0.7955,  ..., 0.2294, 0.0740, 0.

In [160]:
a

tensor([[[ 2, 11,  8, 10,  3,  4,  5],
         [ 4, 10,  1,  5,  8,  0,  7]],

        [[ 1,  5,  8,  0, 10, 11,  7],
         [ 4,  6,  1,  9,  5,  0,  3]],

        [[ 0, 10,  7,  6,  1,  3,  8],
         [ 6,  1,  8,  4, 10,  9,  5]],

        [[ 1,  5,  4,  3,  8, 11,  7],
         [ 8,  2,  9, 10,  5,  4,  7]],

        [[ 7, 11,  3, 10,  6,  4,  8],
         [ 1,  6,  7,  9,  8,  5,  0]],

        [[ 9, 10, 11,  0,  1,  2,  3],
         [ 4,  4,  4,  5,  5,  5,  5]],

        [[ 0, 11,  2,  5,  9,  4,  3],
         [ 0,  3,  7, 10,  9,  6,  2]],

        [[ 8,  0,  3, 10,  6, 11,  7],
         [ 9,  0,  5,  4, 10,  2,  3]],

        [[10,  1, 11,  4,  3,  5,  7],
         [ 1,  4,  8,  0,  3,  5,  2]],

        [[ 5,  1,  2,  3,  6, 11,  9],
         [ 0,  8, 10,  1,  7,  6,  3]],

        [[10,  1, 11,  4,  9,  3,  8],
         [ 4,  3,  2,  1,  8,  7,  5]],

        [[ 3,  7,  4,  1,  5,  6,  0],
         [ 4,  5,  8,  6,  0,  7,  3]],

        [[ 7,  9, 10,  1,  5,  2,  4],
 

In [165]:
a = a.transpose(1,2)

In [166]:
a

tensor([[[ 2,  4],
         [11, 10],
         [ 8,  1],
         [10,  5],
         [ 3,  8],
         [ 4,  0],
         [ 5,  7]],

        [[ 1,  4],
         [ 5,  6],
         [ 8,  1],
         [ 0,  9],
         [10,  5],
         [11,  0],
         [ 7,  3]],

        [[ 0,  6],
         [10,  1],
         [ 7,  8],
         [ 6,  4],
         [ 1, 10],
         [ 3,  9],
         [ 8,  5]],

        [[ 1,  8],
         [ 5,  2],
         [ 4,  9],
         [ 3, 10],
         [ 8,  5],
         [11,  4],
         [ 7,  7]],

        [[ 7,  1],
         [11,  6],
         [ 3,  7],
         [10,  9],
         [ 6,  8],
         [ 4,  5],
         [ 8,  0]],

        [[ 9,  4],
         [10,  4],
         [11,  4],
         [ 0,  5],
         [ 1,  5],
         [ 2,  5],
         [ 3,  5]],

        [[ 0,  0],
         [11,  3],
         [ 2,  7],
         [ 5, 10],
         [ 9,  9],
         [ 4,  6],
         [ 3,  2]],

        [[ 8,  9],
         [ 0,  0],
         [ 3,  5]

In [170]:
r.shape

torch.Size([12, 16, 11])

In [171]:
r = r.transpose(1,2)

In [173]:
r.shape

torch.Size([12, 11, 16])

In [186]:
q = []
for act in a:
    batch = []
    for sub_act in act:
        batch.append(r[sub_act[0]][sub_act[1]])
    q.append(torch.stack(batch))
torch.stack(q).shape

torch.Size([16, 7, 16])

In [188]:
q[0][2]

tensor([0.8126, 0.4204, 0.9744, 0.4650, 0.3760, 0.8659, 0.1643, 0.4925, 0.5820,
        0.6886, 0.9918, 0.9626, 0.0045, 0.0376, 0.3990, 0.2539])

In [191]:
r[2]

tensor([[0.4732, 0.5070, 0.9636, 0.2490, 0.6362, 0.7204, 0.3892, 0.8418, 0.6120,
         0.4419, 0.5984, 0.1233, 0.2805, 0.3527, 0.4064, 0.0729],
        [0.0541, 0.3726, 0.5885, 0.2594, 0.2256, 0.4928, 0.6285, 0.8817, 0.3214,
         0.0804, 0.2417, 0.2847, 0.5816, 0.0747, 0.6608, 0.4212],
        [0.8589, 0.9182, 0.7955, 0.6460, 0.5302, 0.9359, 0.3813, 0.8408, 0.3186,
         0.5903, 0.0779, 0.4079, 0.4995, 0.8152, 0.2005, 0.7901],
        [0.2897, 0.7436, 0.8832, 0.0783, 0.9728, 0.6027, 0.3496, 0.0415, 0.1770,
         0.8253, 0.6632, 0.2593, 0.8494, 0.3972, 0.7324, 0.7599],
        [0.1120, 0.3511, 0.2987, 0.9437, 0.6515, 0.6712, 0.9412, 0.3436, 0.1759,
         0.0791, 0.5274, 0.3213, 0.8699, 0.8899, 0.3556, 0.2853],
        [0.5677, 0.2861, 0.6666, 0.7117, 0.0619, 0.2136, 0.4949, 0.7089, 0.0188,
         0.1645, 0.7114, 0.1970, 0.7854, 0.2025, 0.7806, 0.8581],
        [0.7250, 0.5697, 0.0843, 0.3293, 0.2034, 0.0554, 0.0850, 0.1910, 0.9392,
         0.7560, 0.6709, 0.1958, 0.24

In [198]:
q[2]

tensor([[0.1284, 0.4609, 0.6474, 0.1225, 0.7377, 0.7278, 0.6058, 0.1690, 0.9432,
         0.1028, 0.1685, 0.1053, 0.1504, 0.7632, 0.9397, 0.9874],
        [0.1142, 0.4097, 0.0415, 0.0466, 0.4397, 0.4605, 0.6003, 0.3397, 0.7229,
         0.0350, 0.3080, 0.3023, 0.6102, 0.4022, 0.7245, 0.6958],
        [0.8432, 0.9468, 0.5246, 0.0426, 0.2209, 0.9575, 0.4756, 0.2982, 0.8288,
         0.4203, 0.2345, 0.0801, 0.6127, 0.2470, 0.2947, 0.7514],
        [0.3764, 0.4112, 0.0183, 0.0548, 0.0918, 0.3340, 0.6501, 0.0254, 0.7362,
         0.1834, 0.0801, 0.4031, 0.4761, 0.8685, 0.9882, 0.4145],
        [0.7474, 0.2154, 0.7202, 0.3769, 0.0413, 0.0203, 0.8938, 0.5025, 0.4482,
         0.7700, 0.7667, 0.3090, 0.0185, 0.0088, 0.6021, 0.1754],
        [0.3486, 0.0917, 0.7565, 0.7979, 0.3406, 0.5527, 0.6249, 0.2442, 0.4527,
         0.9543, 0.6371, 0.0965, 0.8205, 0.0014, 0.2247, 0.5695],
        [0.4242, 0.5441, 0.8830, 0.0074, 0.8514, 0.8322, 0.9511, 0.3351, 0.2896,
         0.6086, 0.1001, 0.6268, 0.46

In [195]:
a

tensor([[[ 2,  4],
         [11, 10],
         [ 8,  1],
         [10,  5],
         [ 3,  8],
         [ 4,  0],
         [ 5,  7]],

        [[ 1,  4],
         [ 5,  6],
         [ 8,  1],
         [ 0,  9],
         [10,  5],
         [11,  0],
         [ 7,  3]],

        [[ 0,  6],
         [10,  1],
         [ 7,  8],
         [ 6,  4],
         [ 1, 10],
         [ 3,  9],
         [ 8,  5]],

        [[ 1,  8],
         [ 5,  2],
         [ 4,  9],
         [ 3, 10],
         [ 8,  5],
         [11,  4],
         [ 7,  7]],

        [[ 7,  1],
         [11,  6],
         [ 3,  7],
         [10,  9],
         [ 6,  8],
         [ 4,  5],
         [ 8,  0]],

        [[ 9,  4],
         [10,  4],
         [11,  4],
         [ 0,  5],
         [ 1,  5],
         [ 2,  5],
         [ 3,  5]],

        [[ 0,  0],
         [11,  3],
         [ 2,  7],
         [ 5, 10],
         [ 9,  9],
         [ 4,  6],
         [ 3,  2]],

        [[ 8,  9],
         [ 0,  0],
         [ 3,  5]

In [199]:
r[0][6]

tensor([0.1284, 0.4609, 0.6474, 0.1225, 0.7377, 0.7278, 0.6058, 0.1690, 0.9432,
        0.1028, 0.1685, 0.1053, 0.1504, 0.7632, 0.9397, 0.9874])

In [204]:
q = torch.stack(q)

In [205]:
q 

tensor([[[0.1120, 0.3511, 0.2987,  ..., 0.8899, 0.3556, 0.2853],
         [0.0671, 0.3123, 0.4280,  ..., 0.8033, 0.1922, 0.2846],
         [0.8126, 0.4204, 0.9744,  ..., 0.0376, 0.3990, 0.2539],
         ...,
         [0.8382, 0.1435, 0.4875,  ..., 0.9609, 0.0997, 0.2042],
         [0.4452, 0.7048, 0.0387,  ..., 0.2980, 0.6012, 0.6348],
         [0.2810, 0.8880, 0.5160,  ..., 0.6879, 0.0503, 0.8594]],

        [[0.0492, 0.4358, 0.4979,  ..., 0.7077, 0.5163, 0.7262],
         [0.3250, 0.5925, 0.6502,  ..., 0.0340, 0.1978, 0.5048],
         [0.8126, 0.4204, 0.9744,  ..., 0.0376, 0.3990, 0.2539],
         ...,
         [0.3723, 0.8729, 0.9098,  ..., 0.0397, 0.0191, 0.9223],
         [0.3865, 0.7580, 0.3982,  ..., 0.5463, 0.5886, 0.2097],
         [0.9739, 0.3567, 0.9348,  ..., 0.2981, 0.7180, 0.9402]],

        [[0.1284, 0.4609, 0.6474,  ..., 0.7632, 0.9397, 0.9874],
         [0.1142, 0.4097, 0.0415,  ..., 0.4022, 0.7245, 0.6958],
         [0.8432, 0.9468, 0.5246,  ..., 0.2470, 0.2947, 0.

In [207]:
r = r.transpose(1,2)

In [217]:
r

tensor([[[0.9104, 0.3321, 0.1306,  ..., 0.6653, 0.6279, 0.8036],
         [0.8869, 0.3657, 0.0304,  ..., 0.4979, 0.1830, 0.4993],
         [0.6165, 0.2715, 0.8923,  ..., 0.6384, 0.1819, 0.5126],
         ...,
         [0.2963, 0.0012, 0.9134,  ..., 0.9108, 0.8808, 0.3757],
         [0.0814, 0.1206, 0.6598,  ..., 0.7502, 0.9866, 0.4247],
         [0.6967, 0.8306, 0.0088,  ..., 0.7491, 0.5935, 0.8041]],

        [[0.2160, 0.3457, 0.9077,  ..., 0.9212, 0.5426, 0.7474],
         [0.8171, 0.7326, 0.5667,  ..., 0.3648, 0.7438, 0.2154],
         [0.5241, 0.8059, 0.3340,  ..., 0.0410, 0.5450, 0.7202],
         ...,
         [0.5238, 0.0643, 0.9787,  ..., 0.8499, 0.6198, 0.0088],
         [0.7577, 0.6400, 0.7444,  ..., 0.6953, 0.2089, 0.6021],
         [0.7468, 0.4580, 0.5644,  ..., 0.5501, 0.0760, 0.1754]],

        [[0.4732, 0.0541, 0.8589,  ..., 0.3244, 0.0657, 0.6186],
         [0.5070, 0.3726, 0.9182,  ..., 0.3969, 0.9317, 0.8615],
         [0.9636, 0.5885, 0.7955,  ..., 0.2294, 0.0740, 0.

In [225]:
r.max(2).values.shape

torch.Size([12, 16])

In [222]:
r[0]

tensor([[0.9104, 0.3321, 0.1306, 0.1701, 0.0246, 0.2860, 0.1284, 0.9790, 0.6653,
         0.6279, 0.8036],
        [0.8869, 0.3657, 0.0304, 0.3983, 0.1003, 0.3791, 0.4609, 0.6941, 0.4979,
         0.1830, 0.4993],
        [0.6165, 0.2715, 0.8923, 0.9748, 0.0347, 0.0860, 0.6474, 0.8152, 0.6384,
         0.1819, 0.5126],
        [0.5080, 0.0428, 0.8240, 0.4820, 0.9513, 0.5244, 0.1225, 0.8729, 0.7272,
         0.8866, 0.9049],
        [0.8807, 0.8260, 0.1791, 0.1689, 0.6944, 0.1191, 0.7377, 0.9903, 0.5398,
         0.0991, 0.3536],
        [0.5095, 0.7136, 0.1346, 0.9863, 0.4646, 0.0387, 0.7278, 0.3336, 0.8501,
         0.8577, 0.1676],
        [0.5220, 0.3972, 0.2466, 0.6627, 0.2329, 0.1582, 0.6058, 0.8346, 0.5003,
         0.5355, 0.6593],
        [0.7413, 0.6915, 0.5133, 0.1083, 0.1208, 0.0350, 0.1690, 0.8238, 0.2513,
         0.9098, 0.6602],
        [0.7625, 0.0722, 0.0677, 0.8797, 0.5389, 0.5140, 0.9432, 0.4410, 0.8175,
         0.9784, 0.2953],
        [0.0304, 0.9568, 0.7030, 0.98

In [229]:
r.max(2).values[0]

tensor([0.9790, 0.8869, 0.9748, 0.9513, 0.9903, 0.9863, 0.8346, 0.9098, 0.9784,
        0.9898, 0.9880, 0.9895, 0.8012, 0.9134, 0.9866, 0.9874])

In [230]:
q.shape

torch.Size([16, 7, 16])

In [233]:
nextq = r.max(2).values

In [234]:
nextq.shape

torch.Size([12, 16])

In [254]:
r = torch.rand(12,16,11)

In [255]:
r_t = r.transpose(0,1)

In [258]:
r[0].shape

torch.Size([16, 11])

In [259]:
r_t[0].shape

torch.Size([12, 11])

In [276]:
m = r_t.max(2).values

In [277]:
m.shape

torch.Size([16, 12])

In [278]:
curQ = torch.rand(12,16,11)

In [279]:
nextQ = torch.rand(12,16,11)

In [282]:
curQT = curQ.transpose(0,1)
nextQT = nextQ.transpose(0,1)

In [305]:
a

tensor([[[ 2,  4],
         [11, 10],
         [ 8,  1],
         [10,  5],
         [ 3,  8],
         [ 4,  0],
         [ 5,  7]],

        [[ 1,  4],
         [ 5,  6],
         [ 8,  1],
         [ 0,  9],
         [10,  5],
         [11,  0],
         [ 7,  3]],

        [[ 0,  6],
         [10,  1],
         [ 7,  8],
         [ 6,  4],
         [ 1, 10],
         [ 3,  9],
         [ 8,  5]],

        [[ 1,  8],
         [ 5,  2],
         [ 4,  9],
         [ 3, 10],
         [ 8,  5],
         [11,  4],
         [ 7,  7]],

        [[ 7,  1],
         [11,  6],
         [ 3,  7],
         [10,  9],
         [ 6,  8],
         [ 4,  5],
         [ 8,  0]],

        [[ 9,  4],
         [10,  4],
         [11,  4],
         [ 0,  5],
         [ 1,  5],
         [ 2,  5],
         [ 3,  5]],

        [[ 0,  0],
         [11,  3],
         [ 2,  7],
         [ 5, 10],
         [ 9,  9],
         [ 4,  6],
         [ 3,  2]],

        [[ 8,  9],
         [ 0,  0],
         [ 3,  5]

In [327]:
for idx, group in enumerate(curQT):
    new_curQT = []
    for sub_act in a[idx]:
        new_curQT.append(group[sub_act[0]][sub_act[1]])
    print(idx, new_curQT)
    

tensor([2, 4])
tensor([11, 10])
tensor([8, 1])
tensor([10,  5])
tensor([3, 8])
tensor([4, 0])
tensor([5, 7])
0 [tensor(0.7771), tensor(0.5944), tensor(0.4784), tensor(0.7221), tensor(0.8025), tensor(0.3783), tensor(0.8618)]
tensor([1, 4])
tensor([5, 6])
tensor([8, 1])
tensor([0, 9])
tensor([10,  5])
tensor([11,  0])
tensor([7, 3])
1 [tensor(0.8534), tensor(0.9557), tensor(0.2724), tensor(0.2299), tensor(0.9522), tensor(0.4979), tensor(0.2348)]
tensor([0, 6])
tensor([10,  1])
tensor([7, 8])
tensor([6, 4])
tensor([ 1, 10])
tensor([3, 9])
tensor([8, 5])
2 [tensor(0.5527), tensor(0.3458), tensor(0.1226), tensor(0.6754), tensor(0.6724), tensor(0.6883), tensor(0.8497)]
tensor([1, 8])
tensor([5, 2])
tensor([4, 9])
tensor([ 3, 10])
tensor([8, 5])
tensor([11,  4])
tensor([7, 7])
3 [tensor(0.8964), tensor(0.8530), tensor(0.2013), tensor(0.3569), tensor(0.5445), tensor(0.3820), tensor(0.1739)]
tensor([7, 1])
tensor([11,  6])
tensor([3, 7])
tensor([10,  9])
tensor([6, 8])
tensor([4, 5])
tensor([8,