In [1]:
import numpy as np
import torch
from torch import nn
import torch.autograd.functional as F
STATELEN = 10
ACTLEN = 10
STEP_SIZE = 4

In [2]:
class Dynamics(nn.Module):

    def __init__(self, input_size, hidden_size, output_size):
        super(Dynamics, self).__init__()
        # self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
            nn.Tanh()
        )
    def forward(self, input_element):
        output = self.linear_relu_stack(input_element)
        return output

In [3]:
class reward(nn.Module):

    def __init__(self, input_size, hidden_size, output_size = 1):
        super(reward, self).__init__()
        # self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, output_size),
            nn.Tanh()
        )

    def forward(self, input_element):
        output = self.linear_relu_stack(input_element)
        return output

In [4]:
def forward_step(state_col, actions, K_arr, k_arr, time_step):
    state = state_col[0]
    state_collection = torch.zeros((time_step, 1, STATELEN))
    reward_collection = torch.zeros((time_step, 1, 1))
    act_collection = torch.zeros((time_step, 1, ACTLEN))
    
    i = 0
    while i < time_step:
        state_collection[i] = state
        actions[i] = (torch.matmul(state_col[i] - state,torch.transpose((K_arr[i]),0,1)) + 
                      k_arr[i] + actions[i]
                     )
        
        act_collection[i] = actions[i]
        sa_input = torch.cat((state,actions[i]),dim = 1)
        #sa_input shape = [1,state_size + action_size]
        state = my_Dyna(sa_input)
        #state shape = [1,state_size]
        reward = my_reward(sa_input)
        reward_collection[i] = reward
        i = i + 1
    
    return state_collection, act_collection, reward_collection

def backward(state_col,action_col,reward_col,step_size,ifconv):

    i = step_size -1
    K_arr = torch.zeros(step_size,ACTLEN,STATELEN)
    k_arr = torch.zeros(step_size,1,ACTLEN)
    
    while i > -1:
        temp_input = torch.cat((state_col[i],act_col[i]),dim = 1)
        C_t = F.hessian(my_reward, temp_input.view(-1))
        #shape = [state+action, state+action]
        #print(torch.sum(C_t))
        F_t = F.jacobian(my_Dyna, temp_input.view(-1))
        #shape = [state, state+action]
        #print(torch.sum(F_t))
        c_t = F.jacobian(my_reward, temp_input.view(-1))
        #shape = [1, state+action]
        #print(torch.sum(c_t))
        
        if i == step_size-1:
            Q_t = C_t
            q_t = c_t
        else:
            Q_t = C_t + torch.matmul(torch.matmul(torch.transpose(F_t,0,1),V_t),F_t)
            q_t = c_t + torch.matmul(v_t,F_t)
            
        Q_pre1 = torch.split(Q_t,[STATELEN,ACTLEN])[0]
        Q_pre2 = torch.split(Q_t,[STATELEN,ACTLEN])[1]
        Q_xx = torch.split(Q_pre1,[STATELEN,ACTLEN],dim = 1)[0]
        Q_xu = torch.split(Q_pre1,[STATELEN,ACTLEN],dim = 1)[1]
        Q_uu = torch.split(Q_pre2,[STATELEN,ACTLEN],dim = 1)[1]
        
        try:
            invQuu = torch.linalg.inv(Q_uu - torch.eye(ACTLEN)) #regularize term
        except:
            invQuu = torch.linalg.inv(Q_uu + torch.eye(ACTLEN)*0.01)
            ifconv = 1
            
        K_t = -torch.matmul(invQuu, torch.transpose(Q_xu,0,1))
        transK_t = torch.transpose(K_t,0,1)
        #K_t shape = [actlen, statelen]
        q_t = torch.split(q_t, [STATELEN, ACTLEN], dim = 1)
        k_t = -torch.matmul(q_t[1],invQuu)
        #k_t shape = [1,actlen]
        
        V_t = (Q_xx + torch.matmul(Q_xu,K_t)*2 + 
               torch.matmul(torch.matmul(transK_t,Q_uu),K_t)
              )
        #V_t shape = [statelen, statelen]

        v_t = (q_t[0] + torch.matmul(k_t,torch.transpose(Q_xu,0,1)) + 
               torch.matmul(q_t[1],K_t) + 
               torch.matmul(k_t, torch.matmul(Q_uu,K_t)) 
              )
        #v_t shape = [1, statelen]
        K_arr[i] = K_t
        k_arr[i] = k_t
        i = i - 1
        
    return K_arr, k_arr, ifconv

In [5]:
my_Dyna = Dynamics(STATELEN + ACTLEN, STATELEN, STATELEN)
my_reward = reward(STATELEN + ACTLEN, STATELEN , 1)

In [6]:
state_col = torch.zeros((STEP_SIZE, 1, STATELEN))
state_col[0] = torch.rand((1, STATELEN))
act_col = torch.rand((STEP_SIZE, 1, ACTLEN))
K_arr = torch.zeros(STEP_SIZE, ACTLEN, STATELEN)
k_arr = torch.zeros(STEP_SIZE, 1, ACTLEN)
ifconv = 0

i = 0
while i < 1000:
    print(i)
    state_col, act_col ,reward_col = forward_step(state_col, act_col, K_arr, k_arr, STEP_SIZE)
    K_arr, k_arr, ifconv = backward(state_col, act_col, reward_col, STEP_SIZE, ifconv)
    print(act_col[0])
    if ifconv == 1:
        break
    
    i = i + 1


0
tensor([[0.0623, 0.8762, 0.3597, 0.4830, 0.6265, 0.7074, 0.1983, 0.2911, 0.4916,
         0.7870]], grad_fn=<SelectBackward0>)
1
tensor([[0.1125, 0.8789, 0.3399, 0.4532, 0.6035, 0.7323, 0.1395, 0.3091, 0.5292,
         0.8205]], grad_fn=<SelectBackward0>)
2
tensor([[0.1619, 0.8806, 0.3199, 0.4239, 0.5800, 0.7568, 0.0817, 0.3261, 0.5664,
         0.8539]], grad_fn=<SelectBackward0>)
3
tensor([[0.2105, 0.8813, 0.2998, 0.3950, 0.5561, 0.7810, 0.0250, 0.3421, 0.6030,
         0.8872]], grad_fn=<SelectBackward0>)
4
tensor([[ 0.2582,  0.8810,  0.2796,  0.3667,  0.5319,  0.8049, -0.0306,  0.3570,
          0.6390,  0.9203]], grad_fn=<SelectBackward0>)
5
tensor([[ 0.3050,  0.8798,  0.2594,  0.3390,  0.5073,  0.8284, -0.0850,  0.3708,
          0.6743,  0.9532]], grad_fn=<SelectBackward0>)
6
tensor([[ 0.3503,  0.8773,  0.2390,  0.3120,  0.4825,  0.8510, -0.1380,  0.3837,
          0.7089,  0.9853]], grad_fn=<SelectBackward0>)
7
tensor([[ 0.3946,  0.8739,  0.2186,  0.2856,  0.4575,  0.8732, -0

tensor([[ 1.7399,  0.2665, -0.7065, -0.3776, -0.6973,  1.6064, -1.7270,  0.3956,
          1.8312,  1.9456]], grad_fn=<SelectBackward0>)
61
tensor([[ 1.7542,  0.2531, -0.7202, -0.3824, -0.7129,  1.6166, -1.7427,  0.3933,
          1.8415,  1.9540]], grad_fn=<SelectBackward0>)
62
tensor([[ 1.7682,  0.2398, -0.7339, -0.3870, -0.7284,  1.6267, -1.7581,  0.3910,
          1.8515,  1.9621]], grad_fn=<SelectBackward0>)
63
tensor([[ 1.7821,  0.2264, -0.7475, -0.3914, -0.7437,  1.6367, -1.7732,  0.3888,
          1.8614,  1.9701]], grad_fn=<SelectBackward0>)
64
tensor([[ 1.7957,  0.2131, -0.7611, -0.3958, -0.7588,  1.6468, -1.7881,  0.3866,
          1.8711,  1.9779]], grad_fn=<SelectBackward0>)
65
tensor([[ 1.8083,  0.1997, -0.7753, -0.3999, -0.7744,  1.6569, -1.8028,  0.3835,
          1.8806,  1.9856]], grad_fn=<SelectBackward0>)
66
tensor([[ 1.8206,  0.1863, -0.7895, -0.4040, -0.7898,  1.6671, -1.8174,  0.3805,
          1.8900,  1.9930]], grad_fn=<SelectBackward0>)
67
tensor([[ 1.8328,  0

tensor([[ 2.2581, -0.4807, -1.4302, -0.5105, -1.3832,  2.1617, -2.3792,  0.2874,
          2.1827,  2.2431]], grad_fn=<SelectBackward0>)
120
tensor([[ 2.2635, -0.4924, -1.4406, -0.5113, -1.3917,  2.1705, -2.3872,  0.2868,
          2.1858,  2.2460]], grad_fn=<SelectBackward0>)
121
tensor([[ 2.2690, -0.5042, -1.4510, -0.5122, -1.4001,  2.1792, -2.3952,  0.2862,
          2.1888,  2.2489]], grad_fn=<SelectBackward0>)
122
tensor([[ 2.2743, -0.5158, -1.4613, -0.5130, -1.4084,  2.1879, -2.4031,  0.2857,
          2.1918,  2.2517]], grad_fn=<SelectBackward0>)
123
tensor([[ 2.2796, -0.5275, -1.4715, -0.5138, -1.4166,  2.1966, -2.4109,  0.2852,
          2.1948,  2.2544]], grad_fn=<SelectBackward0>)
124
tensor([[ 2.2848, -0.5391, -1.4817, -0.5146, -1.4248,  2.2053, -2.4187,  0.2847,
          2.1977,  2.2572]], grad_fn=<SelectBackward0>)
125
tensor([[ 2.2899, -0.5507, -1.4919, -0.5154, -1.4330,  2.2139, -2.4264,  0.2843,
          2.2005,  2.2599]], grad_fn=<SelectBackward0>)
126
tensor([[ 2.2

tensor([[ 2.4882, -1.1421, -1.9549, -0.5395, -1.7992,  2.6884, -2.7977,  0.3188,
          2.2857,  2.3731]], grad_fn=<SelectBackward0>)
182
tensor([[ 2.4905, -1.1519, -1.9619, -0.5400, -1.8050,  2.6963, -2.8033,  0.3203,
          2.2867,  2.3743]], grad_fn=<SelectBackward0>)
183
tensor([[ 2.4928, -1.1617, -1.9687, -0.5405, -1.8108,  2.7041, -2.8089,  0.3218,
          2.2878,  2.3756]], grad_fn=<SelectBackward0>)
184
tensor([[ 2.4951, -1.1715, -1.9756, -0.5410, -1.8166,  2.7120, -2.8145,  0.3234,
          2.2888,  2.3768]], grad_fn=<SelectBackward0>)
185
tensor([[ 2.4974, -1.1812, -1.9824, -0.5415, -1.8224,  2.7198, -2.8200,  0.3249,
          2.2898,  2.3780]], grad_fn=<SelectBackward0>)
186
tensor([[ 2.4996, -1.1910, -1.9891, -0.5420, -1.8281,  2.7276, -2.8255,  0.3266,
          2.2909,  2.3791]], grad_fn=<SelectBackward0>)
187
tensor([[ 2.5018, -1.2007, -1.9959, -0.5425, -1.8339,  2.7354, -2.8310,  0.3282,
          2.2919,  2.3803]], grad_fn=<SelectBackward0>)
188
tensor([[ 2.5

tensor([[ 2.5794, -1.7076, -2.3089, -0.5844, -2.1380,  3.1322, -3.0925,  0.4563,
          2.3502,  2.4066]], grad_fn=<SelectBackward0>)
243
tensor([[ 2.5802, -1.7164, -2.3136, -0.5854, -2.1434,  3.1389, -3.0966,  0.4592,
          2.3513,  2.4065]], grad_fn=<SelectBackward0>)
244
tensor([[ 2.5809, -1.7252, -2.3184, -0.5864, -2.1488,  3.1456, -3.1007,  0.4622,
          2.3525,  2.4063]], grad_fn=<SelectBackward0>)
245
tensor([[ 2.5816, -1.7339, -2.3231, -0.5874, -2.1542,  3.1523, -3.1048,  0.4652,
          2.3536,  2.4061]], grad_fn=<SelectBackward0>)
246
tensor([[ 2.5822, -1.7427, -2.3278, -0.5884, -2.1596,  3.1590, -3.1088,  0.4682,
          2.3548,  2.4059]], grad_fn=<SelectBackward0>)
247
tensor([[ 2.5829, -1.7514, -2.3324, -0.5894, -2.1650,  3.1656, -3.1129,  0.4712,
          2.3560,  2.4057]], grad_fn=<SelectBackward0>)
248
tensor([[ 2.5835, -1.7602, -2.3370, -0.5904, -2.1704,  3.1723, -3.1169,  0.4742,
          2.3572,  2.4054]], grad_fn=<SelectBackward0>)
249
tensor([[ 2.5

tensor([[ 2.6064, -2.2662, -2.5820, -0.6604, -2.4948,  3.4902, -3.2887,  0.6981,
          2.4677,  2.3235]], grad_fn=<SelectBackward0>)
305
tensor([[ 2.6063, -2.2750, -2.5859, -0.6617, -2.5006,  3.4955, -3.2914,  0.7025,
          2.4699,  2.3215]], grad_fn=<SelectBackward0>)
306
tensor([[ 2.6063, -2.2839, -2.5897, -0.6631, -2.5063,  3.5008, -3.2940,  0.7070,
          2.4720,  2.3195]], grad_fn=<SelectBackward0>)
307
tensor([[ 2.6062, -2.2927, -2.5935, -0.6644, -2.5121,  3.5060, -3.2967,  0.7114,
          2.4741,  2.3175]], grad_fn=<SelectBackward0>)
308
tensor([[ 2.6061, -2.3015, -2.5974, -0.6658, -2.5178,  3.5113, -3.2993,  0.7159,
          2.4763,  2.3155]], grad_fn=<SelectBackward0>)
309
tensor([[ 2.6060, -2.3103, -2.6011, -0.6672, -2.5235,  3.5166, -3.3019,  0.7204,
          2.4784,  2.3134]], grad_fn=<SelectBackward0>)
310
tensor([[ 2.6059, -2.3192, -2.6049, -0.6685, -2.5293,  3.5218, -3.3045,  0.7249,
          2.4805,  2.3114]], grad_fn=<SelectBackward0>)
311
tensor([[ 2.6

tensor([[ 2.5841, -2.7783, -2.7897, -0.7433, -2.8318,  3.7935, -3.4362,  0.9765,
          2.5954,  2.1880]], grad_fn=<SelectBackward0>)
365
tensor([[ 2.5835, -2.7864, -2.7929, -0.7447, -2.8372,  3.7984, -3.4385,  0.9813,
          2.5974,  2.1856]], grad_fn=<SelectBackward0>)
366
tensor([[ 2.5828, -2.7946, -2.7961, -0.7461, -2.8427,  3.8032, -3.4409,  0.9860,
          2.5995,  2.1831]], grad_fn=<SelectBackward0>)
367
tensor([[ 2.5822, -2.8027, -2.7992, -0.7474, -2.8481,  3.8080, -3.4432,  0.9907,
          2.6016,  2.1807]], grad_fn=<SelectBackward0>)
368
tensor([[ 2.5815, -2.8108, -2.8023, -0.7488, -2.8535,  3.8128, -3.4456,  0.9955,
          2.6037,  2.1783]], grad_fn=<SelectBackward0>)
369
tensor([[ 2.5809, -2.8189, -2.8055, -0.7502, -2.8590,  3.8175, -3.4479,  1.0002,
          2.6057,  2.1758]], grad_fn=<SelectBackward0>)
370
tensor([[ 2.5802, -2.8270, -2.8086, -0.7516, -2.8644,  3.8223, -3.4503,  1.0050,
          2.6078,  2.1734]], grad_fn=<SelectBackward0>)
371
tensor([[ 2.5

tensor([[ 2.5394, -3.2482, -2.9735, -0.8255, -3.1514,  4.0733, -3.5803,  1.2613,
          2.7150,  2.0394]], grad_fn=<SelectBackward0>)
426
tensor([[ 2.5386, -3.2554, -2.9764, -0.8268, -3.1564,  4.0777, -3.5828,  1.2658,
          2.7169,  2.0370]], grad_fn=<SelectBackward0>)
427
tensor([[ 2.5378, -3.2626, -2.9793, -0.8281, -3.1614,  4.0820, -3.5852,  1.2703,
          2.7187,  2.0346]], grad_fn=<SelectBackward0>)
428
tensor([[ 2.5371, -3.2698, -2.9823, -0.8294, -3.1664,  4.0864, -3.5876,  1.2748,
          2.7205,  2.0322]], grad_fn=<SelectBackward0>)
429
tensor([[ 2.5363, -3.2769, -2.9852, -0.8307, -3.1714,  4.0907, -3.5900,  1.2793,
          2.7223,  2.0299]], grad_fn=<SelectBackward0>)
430
tensor([[ 2.5355, -3.2841, -2.9881, -0.8319, -3.1763,  4.0950, -3.5925,  1.2838,
          2.7241,  2.0275]], grad_fn=<SelectBackward0>)
431
tensor([[ 2.5347, -3.2912, -2.9910, -0.8332, -3.1813,  4.0994, -3.5949,  1.2883,
          2.7258,  2.0252]], grad_fn=<SelectBackward0>)
432
tensor([[ 2.5

tensor([[ 2.4922, -3.6563, -3.1505, -0.9035, -3.4421,  4.3360, -3.7440,  1.5232,
          2.8133,  1.9094]], grad_fn=<SelectBackward0>)
488
tensor([[ 2.4915, -3.6621, -3.1533, -0.9048, -3.4462,  4.3404, -3.7471,  1.5269,
          2.8145,  1.9080]], grad_fn=<SelectBackward0>)
489
tensor([[ 2.4907, -3.6678, -3.1560, -0.9061, -3.4504,  4.3447, -3.7502,  1.5306,
          2.8157,  1.9065]], grad_fn=<SelectBackward0>)
490
tensor([[ 2.4899, -3.6735, -3.1587, -0.9073, -3.4545,  4.3491, -3.7534,  1.5343,
          2.8169,  1.9050]], grad_fn=<SelectBackward0>)
491
tensor([[ 2.4891, -3.6791, -3.1614, -0.9086, -3.4587,  4.3534, -3.7565,  1.5380,
          2.8181,  1.9036]], grad_fn=<SelectBackward0>)
492
tensor([[ 2.4884, -3.6848, -3.1641, -0.9099, -3.4628,  4.3577, -3.7597,  1.5416,
          2.8193,  1.9021]], grad_fn=<SelectBackward0>)
493
tensor([[ 2.4876, -3.6905, -3.1668, -0.9112, -3.4669,  4.3620, -3.7628,  1.5453,
          2.8205,  1.9006]], grad_fn=<SelectBackward0>)
494
tensor([[ 2.4

tensor([[ 2.4480, -3.9854, -3.3151, -0.9785, -3.6903,  4.5927, -3.9373,  1.7396,
          2.8814,  1.8207]], grad_fn=<SelectBackward0>)
550
tensor([[ 2.4474, -3.9903, -3.3177, -0.9796, -3.6942,  4.5967, -3.9404,  1.7429,
          2.8823,  1.8194]], grad_fn=<SelectBackward0>)
551
tensor([[ 2.4467, -3.9952, -3.3203, -0.9807, -3.6980,  4.6006, -3.9435,  1.7462,
          2.8833,  1.8180]], grad_fn=<SelectBackward0>)
552
tensor([[ 2.4461, -4.0001, -3.3229, -0.9819, -3.7019,  4.6045, -3.9466,  1.7494,
          2.8843,  1.8166]], grad_fn=<SelectBackward0>)
553
tensor([[ 2.4455, -4.0049, -3.3255, -0.9830, -3.7057,  4.6084, -3.9497,  1.7527,
          2.8853,  1.8152]], grad_fn=<SelectBackward0>)
554
tensor([[ 2.4448, -4.0098, -3.3280, -0.9841, -3.7095,  4.6123, -3.9528,  1.7559,
          2.8863,  1.8139]], grad_fn=<SelectBackward0>)
555
tensor([[ 2.4442, -4.0146, -3.3306, -0.9852, -3.7134,  4.6162, -3.9559,  1.7592,
          2.8872,  1.8125]], grad_fn=<SelectBackward0>)
556
tensor([[ 2.4

tensor([[ 2.4133, -4.2631, -3.4702, -1.0435, -3.9175,  4.8212, -4.1245,  1.9279,
          2.9356,  1.7397]], grad_fn=<SelectBackward0>)
611
tensor([[ 2.4129, -4.2674, -3.4727, -1.0445, -3.9211,  4.8248, -4.1275,  1.9308,
          2.9364,  1.7384]], grad_fn=<SelectBackward0>)
612
tensor([[ 2.4124, -4.2716, -3.4751, -1.0455, -3.9247,  4.8283, -4.1305,  1.9336,
          2.9372,  1.7371]], grad_fn=<SelectBackward0>)
613
tensor([[ 2.4119, -4.2758, -3.4776, -1.0465, -3.9283,  4.8319, -4.1336,  1.9365,
          2.9380,  1.7359]], grad_fn=<SelectBackward0>)
614
tensor([[ 2.4114, -4.2800, -3.4801, -1.0475, -3.9319,  4.8354, -4.1366,  1.9394,
          2.9388,  1.7346]], grad_fn=<SelectBackward0>)
615
tensor([[ 2.4109, -4.2841, -3.4826, -1.0485, -3.9355,  4.8390, -4.1396,  1.9423,
          2.9396,  1.7333]], grad_fn=<SelectBackward0>)
616
tensor([[ 2.4104, -4.2883, -3.4851, -1.0495, -3.9391,  4.8425, -4.1426,  1.9452,
          2.9404,  1.7320]], grad_fn=<SelectBackward0>)
617
tensor([[ 2.4

tensor([[ 2.3879, -4.4961, -3.6134, -1.0990, -4.1238,  5.0229, -4.3014,  2.0896,
          2.9782,  1.6671]], grad_fn=<SelectBackward0>)
670
tensor([[ 2.3876, -4.4997, -3.6158, -1.0999, -4.1272,  5.0262, -4.3044,  2.0922,
          2.9789,  1.6659]], grad_fn=<SelectBackward0>)
671
tensor([[ 2.3872, -4.5034, -3.6182, -1.1007, -4.1306,  5.0295, -4.3073,  2.0948,
          2.9795,  1.6647]], grad_fn=<SelectBackward0>)
672
tensor([[ 2.3868, -4.5071, -3.6205, -1.1016, -4.1340,  5.0328, -4.3103,  2.0974,
          2.9801,  1.6635]], grad_fn=<SelectBackward0>)
673
tensor([[ 2.3865, -4.5107, -3.6229, -1.1025, -4.1374,  5.0360, -4.3132,  2.0999,
          2.9808,  1.6623]], grad_fn=<SelectBackward0>)
674
tensor([[ 2.3861, -4.5144, -3.6253, -1.1034, -4.1408,  5.0393, -4.3162,  2.1025,
          2.9814,  1.6611]], grad_fn=<SelectBackward0>)
675
tensor([[ 2.3858, -4.5180, -3.6276, -1.1042, -4.1441,  5.0425, -4.3192,  2.1051,
          2.9821,  1.6600]], grad_fn=<SelectBackward0>)
676
tensor([[ 2.3

tensor([[ 2.3688, -4.7103, -3.7563, -1.1501, -4.3287,  5.2179, -4.4820,  2.2416,
          3.0144,  1.5960]], grad_fn=<SelectBackward0>)
732
tensor([[ 2.3685, -4.7136, -3.7585, -1.1508, -4.3320,  5.2209, -4.4849,  2.2439,
          3.0150,  1.5949]], grad_fn=<SelectBackward0>)
733
tensor([[ 2.3683, -4.7168, -3.7608, -1.1516, -4.3352,  5.2239, -4.4877,  2.2462,
          3.0155,  1.5938]], grad_fn=<SelectBackward0>)
734
tensor([[ 2.3680, -4.7200, -3.7630, -1.1524, -4.3384,  5.2269, -4.4906,  2.2485,
          3.0160,  1.5927]], grad_fn=<SelectBackward0>)
735
tensor([[ 2.3678, -4.7232, -3.7652, -1.1531, -4.3416,  5.2299, -4.4934,  2.2508,
          3.0165,  1.5916]], grad_fn=<SelectBackward0>)
736
tensor([[ 2.3675, -4.7265, -3.7675, -1.1539, -4.3448,  5.2329, -4.4963,  2.2531,
          3.0170,  1.5905]], grad_fn=<SelectBackward0>)
737
tensor([[ 2.3673, -4.7297, -3.7697, -1.1547, -4.3480,  5.2359, -4.4992,  2.2554,
          3.0175,  1.5894]], grad_fn=<SelectBackward0>)
738
tensor([[ 2.3

tensor([[ 2.3563, -4.8967, -3.8888, -1.1938, -4.5204,  5.3953, -4.6536,  2.3762,
          3.0433,  1.5305]], grad_fn=<SelectBackward0>)
793
tensor([[ 2.3561, -4.8996, -3.8909, -1.1945, -4.5234,  5.3981, -4.6563,  2.3783,
          3.0437,  1.5295]], grad_fn=<SelectBackward0>)
794
tensor([[ 2.3559, -4.9025, -3.8930, -1.1952, -4.5265,  5.4009, -4.6591,  2.3804,
          3.0441,  1.5285]], grad_fn=<SelectBackward0>)
795
tensor([[ 2.3558, -4.9053, -3.8951, -1.1958, -4.5296,  5.4037, -4.6619,  2.3825,
          3.0445,  1.5274]], grad_fn=<SelectBackward0>)
796
tensor([[ 2.3556, -4.9082, -3.8972, -1.1965, -4.5326,  5.4064, -4.6646,  2.3846,
          3.0450,  1.5264]], grad_fn=<SelectBackward0>)
797
tensor([[ 2.3555, -4.9111, -3.8993, -1.1972, -4.5357,  5.4092, -4.6674,  2.3866,
          3.0454,  1.5254]], grad_fn=<SelectBackward0>)
798
tensor([[ 2.3553, -4.9139, -3.9014, -1.1978, -4.5387,  5.4120, -4.6701,  2.3887,
          3.0458,  1.5243]], grad_fn=<SelectBackward0>)
799
tensor([[ 2.3

tensor([[ 2.3490, -5.0610, -4.0113, -1.2312, -4.7005,  5.5580, -4.8162,  2.4967,
          3.0663,  1.4698]], grad_fn=<SelectBackward0>)
853
tensor([[ 2.3490, -5.0636, -4.0133, -1.2318, -4.7035,  5.5606, -4.8188,  2.4986,
          3.0667,  1.4688]], grad_fn=<SelectBackward0>)
854
tensor([[ 2.3489, -5.0661, -4.0153, -1.2323, -4.7064,  5.5632, -4.8215,  2.5005,
          3.0670,  1.4678]], grad_fn=<SelectBackward0>)
855
tensor([[ 2.3488, -5.0687, -4.0172, -1.2329, -4.7093,  5.5659, -4.8242,  2.5024,
          3.0674,  1.4668]], grad_fn=<SelectBackward0>)
856
tensor([[ 2.3487, -5.0713, -4.0192, -1.2335, -4.7123,  5.5685, -4.8268,  2.5044,
          3.0677,  1.4659]], grad_fn=<SelectBackward0>)
857
tensor([[ 2.3486, -5.0739, -4.0212, -1.2341, -4.7152,  5.5711, -4.8295,  2.5063,
          3.0681,  1.4649]], grad_fn=<SelectBackward0>)
858
tensor([[ 2.3486, -5.0765, -4.0232, -1.2346, -4.7181,  5.5737, -4.8321,  2.5082,
          3.0684,  1.4639]], grad_fn=<SelectBackward0>)
859
tensor([[ 2.3

tensor([[ 2.3460, -5.2098, -4.1262, -1.2634, -4.8733,  5.7108, -4.9728,  2.6074,
          3.0851,  1.4123]], grad_fn=<SelectBackward0>)
913
tensor([[ 2.3460, -5.2122, -4.1281, -1.2639, -4.8762,  5.7132, -4.9753,  2.6091,
          3.0854,  1.4113]], grad_fn=<SelectBackward0>)
914
tensor([[ 2.3460, -5.2146, -4.1299, -1.2644, -4.8790,  5.7157, -4.9779,  2.6109,
          3.0857,  1.4104]], grad_fn=<SelectBackward0>)
915
tensor([[ 2.3459, -5.2169, -4.1318, -1.2649, -4.8818,  5.7182, -4.9804,  2.6127,
          3.0860,  1.4095]], grad_fn=<SelectBackward0>)
916
tensor([[ 2.3459, -5.2193, -4.1336, -1.2654, -4.8846,  5.7206, -4.9830,  2.6144,
          3.0863,  1.4085]], grad_fn=<SelectBackward0>)
917
tensor([[ 2.3459, -5.2216, -4.1355, -1.2659, -4.8874,  5.7231, -4.9855,  2.6162,
          3.0865,  1.4076]], grad_fn=<SelectBackward0>)
918
tensor([[ 2.3459, -5.2240, -4.1373, -1.2664, -4.8903,  5.7255, -4.9881,  2.6179,
          3.0868,  1.4067]], grad_fn=<SelectBackward0>)
919
tensor([[ 2.3

tensor([[ 2.3464, -5.3504, -4.2373, -1.2919, -5.0449,  5.8596, -5.1285,  2.7130,
          3.1009,  1.3559]], grad_fn=<SelectBackward0>)
975
tensor([[ 2.3465, -5.3526, -4.2390, -1.2923, -5.0476,  5.8619, -5.1310,  2.7146,
          3.1011,  1.3550]], grad_fn=<SelectBackward0>)
976
tensor([[ 2.3465, -5.3548, -4.2408, -1.2927, -5.0503,  5.8642, -5.1334,  2.7162,
          3.1013,  1.3541]], grad_fn=<SelectBackward0>)
977
tensor([[ 2.3465, -5.3569, -4.2425, -1.2931, -5.0531,  5.8665, -5.1359,  2.7179,
          3.1016,  1.3533]], grad_fn=<SelectBackward0>)
978
tensor([[ 2.3466, -5.3591, -4.2442, -1.2935, -5.0558,  5.8689, -5.1384,  2.7195,
          3.1018,  1.3524]], grad_fn=<SelectBackward0>)
979
tensor([[ 2.3466, -5.3612, -4.2459, -1.2940, -5.0585,  5.8712, -5.1408,  2.7211,
          3.1020,  1.3515]], grad_fn=<SelectBackward0>)
980
tensor([[ 2.3466, -5.3634, -4.2476, -1.2944, -5.0612,  5.8735, -5.1433,  2.7228,
          3.1022,  1.3506]], grad_fn=<SelectBackward0>)
981
tensor([[ 2.3

In [7]:
#for param in rew.parameters():
#    print(param)