In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
from torch.distributions import Normal
import torch.optim as optim

In [2]:
def initialize_uniformly(layer: nn.Linear, init_w: float = 3e-3):
    """Initialize the weights and bias in [-init_w, init_w]."""
    layer.weight.data.uniform_(-init_w, init_w)
    layer.bias.data.uniform_(-init_w, init_w)


class Actor_a2c(nn.Module):
    def __init__(self, in_dim: int, out_dim: int,seed: int):
        """Initialize."""
        super(Actor_a2c, self).__init__()
        self.seed = torch.manual_seed(seed)
        self.hidden1 = nn.Linear(in_dim, 128)
        self.mu_layer = nn.Linear(128, out_dim)
        self.log_std_layer = nn.Linear(128, out_dim)

        initialize_uniformly(self.mu_layer)
        initialize_uniformly(self.log_std_layer)

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        """Forward method implementation."""
        x = F.relu(self.hidden1(state))

        mu = torch.tanh(self.mu_layer(x))
        log_std = F.softplus(self.log_std_layer(x))
        std = torch.exp(log_std)

        dist = Normal(mu, std)
        action = dist.sample()

        return action, dist


class Critic_a2c(nn.Module):
    def __init__(self, in_dim: int, seed: int):
        """Initialize."""
        super(Critic_a2c, self).__init__()

        self.hidden1 = nn.Linear(in_dim, 128)
        self.out = nn.Linear(128, 1)
        self.seed = torch.manual_seed(seed)

        initialize_uniformly(self.out)

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        """Forward method implementation."""
        x = F.relu(self.hidden1(state))
        value = self.out(x)

        return value

In [3]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cuda', index=0)

In [4]:
state_size= 33
action_size =4
GAMMA = 0.99  # discount factor
TAU = 1e-3  # for soft update of target parameters
LR_ACTOR = 1e-4  # learning rate of the actor
LR_CRITIC = 1e-4
random_seed = 4
WEIGHT_DECAY = 0  
# Actor Network
actor_local = Actor_a2c(state_size, action_size, random_seed).to(DEVICE)
# initialize with its own Learning Rate
actor_optimizer = optim.Adam(actor_local.parameters(), lr=LR_ACTOR)

In [5]:
 # Critic Network
critic_local = Critic_a2c(state_size, random_seed).to(DEVICE)
# initialize with its own Learning Rate
critic_optimizer = optim.Adam(critic_local.parameters(), lr=LR_CRITIC,
                                   weight_decay=WEIGHT_DECAY)

In [6]:
reward=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
GAMMA = 0.99 

In [7]:
next_state = [[ 6.2222e-02, -3.9996e+00,  6.9542e-03,  9.9997e-01,  7.7391e-03,
         -5.6170e-06,  8.6469e-04, -3.4869e-02,  3.1775e-04,  3.0815e-01,
          1.2387e+00,  1.1688e-02,  1.4016e-01, -5.6076e-04, -9.9975e+00,
          5.8327e-02,  9.9980e-01, -1.8202e-02,  2.0578e-04,  7.6992e-03,
         -3.0750e-01, -1.1807e-02, -7.2549e-01, -1.0685e+00,  1.0386e-01,
          1.6131e+00,  7.8506e+00, -1.0000e+00,  1.5387e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00, -5.2221e-01],
        [ 1.3153e-01, -3.9938e+00, -2.0449e-01,  9.9954e-01,  1.6356e-02,
          3.5127e-04, -2.5442e-02,  1.0088e+00, -1.9979e-02,  6.4862e-01,
          2.6086e+00,  1.7644e-01, -4.0517e+00,  7.2647e-02, -9.9762e+00,
         -8.0799e-02,  9.9859e-01, -2.6092e-02,  2.9818e-03,  4.6097e-02,
         -1.8316e+00, -1.7008e-01, -1.0295e+00, -6.6513e-02,  9.0792e-01,
          1.0482e+00,  7.9880e+00, -1.0000e+00, -4.3836e-01,  0.0000e+00,
          1.0000e+00,  0.0000e+00, -7.1473e-01],
        [ 2.8244e-02, -3.9999e+00,  5.1689e-03,  9.9999e-01,  3.5130e-03,
         -1.9085e-06,  6.4285e-04, -2.5168e-02,  1.0930e-04,  1.3947e-01,
          5.6067e-01,  2.4559e-03,  1.0117e-01, -5.4199e-02, -9.9979e+00,
         -5.6906e-02,  9.9979e-01, -1.7253e-02, -2.2676e-04, -1.0991e-02,
          4.3872e-01,  1.3033e-02, -6.8806e-01, -2.0860e+00,  9.3708e-02,
         -1.7760e+00, -7.9914e+00, -1.0000e+00,  3.7097e-01,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  6.6445e-01],
        [ 7.0786e-02, -3.9994e+00, -2.6165e-02,  9.9996e-01,  8.8044e-03,
          2.4116e-05, -3.2545e-03,  1.2939e-01, -1.3774e-03,  3.5111e-01,
          1.4114e+00,  1.6985e-02, -5.2005e-01,  6.5794e-02, -9.9978e+00,
          3.0809e-02,  9.9987e-01, -9.6340e-03,  3.3956e-04,  1.2752e-02,
         -5.0870e-01, -1.9435e-02, -3.8344e-01,  7.5552e-01,  8.4229e-02,
          1.3566e+00, -3.6579e+00, -1.0000e+00,  7.1147e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00, -5.2300e-02],
        [-1.6085e-01, -3.9929e+00,  2.0398e-01,  9.9948e-01, -2.0002e-02,
          4.2910e-04,  2.5383e-02, -1.0050e+00, -2.4352e-02, -7.9112e-01,
         -3.1815e+00,  2.0144e-01,  4.0353e+00, -4.7662e-02, -9.9709e+00,
          8.0528e-02,  9.9818e-01,  3.8763e-02,  3.9107e-03, -4.6042e-02,
          1.8299e+00, -2.2298e-01,  1.5293e+00,  1.2994e+00,  1.1263e+00,
         -1.0431e+00,  7.9999e+00, -1.0000e+00,  2.9213e-02,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  9.4769e-01],
        [ 3.6064e-02, -3.9998e+00,  2.8208e-02,  9.9998e-01,  4.4855e-03,
         -1.3264e-05,  3.5084e-03, -1.3919e-01,  7.5481e-04,  1.7815e-01,
          7.1615e-01,  6.2463e-03,  5.5950e-01, -4.8218e-02, -9.9967e+00,
         -5.4222e-02,  9.9968e-01, -1.8529e-02, -4.2315e-04, -1.7252e-02,
          6.8827e-01,  2.4299e-02, -7.3855e-01, -2.0404e+00,  1.4787e-01,
         -2.0854e+00, -7.9692e+00, -1.0000e+00, -7.0146e-01,  0.0000e+00,
          1.0000e+00,  0.0000e+00, -7.5632e-03],
        [-1.0742e-01, -3.9944e+00, -2.0444e-01,  9.9959e-01, -1.3358e-02,
         -2.8698e-04, -2.5435e-02,  1.0082e+00,  1.6275e-02, -5.2841e-01,
         -2.1253e+00,  1.5917e-01, -4.0499e+00, -6.6528e-03, -9.9762e+00,
         -8.0773e-02,  9.9848e-01,  3.0080e-02, -2.7739e-03,  4.6092e-02,
         -1.8314e+00,  1.5831e-01,  1.1888e+00,  1.6153e+00,  9.2556e-01,
          1.0475e+00,  2.5177e+00, -1.0000e+00, -7.5935e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  1.6410e-01],
        [ 2.0372e-01, -3.9913e+00, -2.0352e-01,  9.9936e-01,  2.5335e-02,
          5.4278e-04, -2.5330e-02,  1.0018e+00, -3.0802e-02,  1.0013e+00,
          4.0263e+00,  2.4760e-01, -4.0205e+00,  8.0639e-02, -9.9653e+00,
         -8.0280e-02,  9.9788e-01, -4.5721e-02,  4.8223e-03,  4.5991e-02,
         -1.8284e+00, -2.7473e-01, -1.8010e+00, -1.0429e+00,  1.3318e+00,
          1.0387e+00, -3.0417e+00, -1.0000e+00,  7.3992e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00, -8.3678e-01],
        [ 2.0501e-01, -3.9955e+00, -2.1870e-02,  9.9967e-01,  2.5501e-02,
          5.8436e-05, -2.7216e-03,  1.0876e-01, -3.3085e-03,  1.0124e+00,
          4.0684e+00,  1.2683e-01, -4.3665e-01,  8.1081e-02, -9.9823e+00,
         -6.0038e-02,  9.9893e-01, -4.6153e-02, -1.8018e-04, -3.6465e-03,
          1.4636e-01,  1.0308e-02, -1.8332e+00, -1.0532e+00,  6.7962e-01,
         -1.4111e+00,  3.4057e+00, -1.0000e+00,  7.2389e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00, -7.0105e-01],
        [ 2.8219e-02, -3.9998e+00, -2.8214e-02,  9.9999e-01,  3.5100e-03,
          1.0382e-05, -3.5092e-03,  1.3924e-01, -5.9013e-04,  1.3928e-01,
          5.5990e-01,  4.7442e-03, -5.5974e-01, -5.4207e-02, -9.9970e+00,
          5.4218e-02,  9.9970e-01, -1.7244e-02,  3.7058e-04,  1.7252e-02,
         -6.8827e-01, -2.1290e-02, -6.8747e-01, -2.0856e+00,  1.3543e-01,
          2.0855e+00, -2.9978e+00, -1.0000e+00, -7.4171e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  1.7878e-03],
        [-2.0400e-01, -3.9924e+00, -1.7599e-01,  9.9944e-01, -2.5371e-02,
         -4.6977e-04, -2.1903e-02,  8.6675e-01,  2.6679e-02, -1.0038e+00,
         -4.0354e+00,  2.1671e-01, -3.4789e+00, -8.0738e-02, -9.9690e+00,
         -5.9113e-02,  9.9808e-01,  4.5811e-02, -4.3146e-03,  4.1492e-02,
         -1.6510e+00,  2.4600e-01,  1.8077e+00,  1.0451e+00,  1.1954e+00,
          1.2047e+00, -7.9985e+00, -1.0000e+00,  1.5353e-01,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  2.7490e-01],
        [-1.3566e-01, -3.9981e+00,  1.3542e-03,  9.9986e-01, -1.6874e-02,
          2.4475e-06,  1.6862e-04, -6.1603e-03, -1.4658e-04, -6.7110e-01,
         -2.6974e+00,  5.4900e-02,  2.4749e-02, -3.6972e-02, -9.9911e+00,
         -5.3452e-02,  9.9940e-01,  3.3321e-02,  5.2862e-04, -9.3145e-03,
          3.7215e-01, -3.0278e-02,  1.3260e+00,  1.1944e+00,  3.5167e-01,
         -1.6081e+00,  6.7237e-01, -1.0000e+00, -7.9717e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00, -2.0530e-01],
        [-1.0857e-01, -3.9987e+00,  2.8101e-02,  9.9990e-01, -1.3505e-02,
          3.9813e-05,  3.4957e-03, -1.3842e-01, -2.2725e-03, -5.3697e-01,
         -2.1584e+00,  3.7511e-02,  5.5623e-01, -7.2060e-03, -9.9921e+00,
         -5.4274e-02,  9.9939e-01,  3.0391e-02,  9.0904e-04, -1.7242e-02,
          6.8819e-01, -5.2080e-02,  1.2096e+00,  1.6203e+00,  3.2642e-01,
         -2.0838e+00,  2.0698e+00, -1.0000e+00,  7.7276e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  2.4865e-01],
        [ 8.8901e-02, -3.9987e+00, -6.8527e-02,  9.9990e-01,  1.1057e-02,
          7.9297e-05, -8.5244e-03,  3.3989e-01, -4.5194e-03,  4.4069e-01,
          1.7716e+00,  3.7581e-02, -1.3660e+00,  6.7897e-02, -9.9964e+00,
         -6.5510e-02,  9.9985e-01, -1.4555e-02,  4.3640e-04,  9.0306e-03,
         -3.5959e-01, -2.4971e-02, -5.7938e-01,  5.1186e-01,  1.2829e-01,
         -7.8565e-01,  7.6161e+00, -1.0000e+00, -2.4484e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00, -5.4470e-01],
        [ 3.5040e-03, -3.9999e+00, -2.8240e-02,  9.9999e-01,  4.3582e-04,
          1.2959e-06, -3.5125e-03,  1.3945e-01, -7.1696e-05,  1.6916e-02,
          6.8003e-02,  2.4124e-03, -5.6060e-01, -5.7096e-02, -9.9980e+00,
          5.4205e-02,  9.9980e-01, -1.0534e-02,  1.6667e-04,  1.7255e-02,
         -6.8831e-01, -9.6019e-03, -4.2023e-01, -1.7536e+00,  9.1545e-02,
          2.0859e+00,  4.1844e+00, -1.0000e+00,  6.8184e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00, -8.6566e-01],
        [-1.3314e-01, -3.9938e+00,  2.0424e-01,  9.9954e-01, -1.6556e-02,
          3.5544e-04,  2.5411e-02, -1.0067e+00, -2.0167e-02, -6.5496e-01,
         -2.6341e+00,  1.7741e-01,  4.0434e+00, -2.6382e-02, -9.9739e+00,
          8.0664e-02,  9.9835e-01,  3.4261e-02,  3.3212e-03, -4.6070e-02,
          1.8307e+00, -1.8946e-01,  1.3529e+00,  1.4637e+00,  1.0145e+00,
         -1.0455e+00, -4.1927e+00, -1.0000e+00, -6.8133e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00, -8.4818e-01],
        [-1.3215e-01, -3.9981e+00, -2.8118e-02,  9.9986e-01, -1.6438e-02,
         -4.8478e-05, -3.4983e-03,  1.3859e-01,  2.7700e-03, -6.5436e-01,
         -2.6302e+00,  5.4454e-02, -5.5681e-01, -7.2889e-02, -9.9921e+00,
          5.4264e-02,  9.9950e-01,  2.6306e-02, -9.5290e-04,  1.7244e-02,
         -6.8821e-01,  5.4546e-02,  1.0463e+00,  7.0399e-02,  3.0649e-01,
          2.0841e+00,  2.4005e+00, -1.0000e+00,  7.6313e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00, -6.3462e-01],
        [-1.9245e-01, -3.9941e+00,  1.3745e-01,  9.9957e-01, -2.3936e-02,
          3.4555e-04,  1.7105e-02, -6.7879e-01, -1.9636e-02, -9.4909e-01,
         -3.8152e+00,  1.6692e-01,  2.7251e+00, -7.9578e-02, -9.9778e+00,
          7.3139e-02,  9.9869e-01,  4.2701e-02,  2.8163e-03, -2.7887e-02,
          1.1106e+00, -1.6084e-01,  1.6913e+00,  8.8614e-01,  8.4748e-01,
         -1.4505e-01, -3.7147e+00, -1.0000e+00,  7.0853e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  3.3321e-01],
        [-2.8104e-02, -3.9981e+00, -1.3008e-01,  9.9986e-01, -3.4951e-03,
         -4.7685e-05, -1.6181e-02,  6.4407e-01,  2.6959e-03, -1.3826e-01,
         -5.5594e-01,  5.2829e-02, -2.5887e+00,  5.4243e-02, -9.9924e+00,
         -7.2620e-02,  9.9952e-01,  1.7211e-02, -6.1695e-04,  2.5764e-02,
         -1.0256e+00,  3.5391e-02,  6.8504e-01,  2.0841e+00,  2.9736e-01,
          4.2520e-02,  5.3760e+00, -1.0000e+00, -5.9244e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00, -5.5375e-02],
        [-1.3222e-01, -3.9938e+00,  2.0448e-01,  9.9954e-01, -1.6441e-02,
          3.5310e-04,  2.5442e-02, -1.0087e+00, -2.0083e-02, -6.5199e-01,
         -2.6221e+00,  1.7698e-01,  4.0515e+00, -7.2720e-02, -9.9761e+00,
          8.0795e-02,  9.9859e-01,  2.6278e-02,  2.9993e-03, -4.6097e-02,
          1.8315e+00, -1.7107e-01,  1.0368e+00,  7.5784e-02,  9.1093e-01,
         -1.0482e+00,  7.5233e+00, -1.0000e+00,  2.7204e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00, -2.1990e-01]]

In [44]:
next_state.shape

torch.Size([20, 33])

In [8]:
state=[[ 0.0000e+00, -4.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00,
         -0.0000e+00, -4.3711e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+01,
          0.0000e+00,  1.0000e+00, -0.0000e+00, -0.0000e+00, -4.3711e-08,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  7.9015e+00, -1.0000e+00,  1.2515e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00, -5.2221e-01],
        [ 0.0000e+00, -4.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00,
         -0.0000e+00, -4.3711e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+01,
          0.0000e+00,  1.0000e+00, -0.0000e+00, -0.0000e+00, -4.3711e-08,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  7.9562e+00, -1.0000e+00, -8.3623e-01,  0.0000e+00,
          1.0000e+00,  0.0000e+00, -7.1473e-01],
        [ 0.0000e+00, -4.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00,
         -0.0000e+00, -4.3711e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+01,
          0.0000e+00,  1.0000e+00, -0.0000e+00, -0.0000e+00, -4.3711e-08,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00, -8.0000e+00, -1.0000e+00,  0.0000e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  6.6445e-01],
        [ 0.0000e+00, -4.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00,
         -0.0000e+00, -4.3711e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+01,
          0.0000e+00,  1.0000e+00, -0.0000e+00, -0.0000e+00, -4.3711e-08,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00, -3.6319e+00, -1.0000e+00,  7.1281e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00, -5.2300e-02],
        [ 0.0000e+00, -4.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00,
         -0.0000e+00, -4.3711e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+01,
          0.0000e+00,  1.0000e+00, -0.0000e+00, -0.0000e+00, -4.3711e-08,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  7.9805e+00, -1.0000e+00,  5.5805e-01,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  9.4769e-01],
        [ 0.0000e+00, -4.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00,
         -0.0000e+00, -4.3711e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+01,
          0.0000e+00,  1.0000e+00, -0.0000e+00, -0.0000e+00, -4.3711e-08,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00, -7.9696e+00, -1.0000e+00, -6.9725e-01,  0.0000e+00,
          1.0000e+00,  0.0000e+00, -7.5632e-03],
        [ 0.0000e+00, -4.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00,
         -0.0000e+00, -4.3711e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+01,
          0.0000e+00,  1.0000e+00, -0.0000e+00, -0.0000e+00, -4.3711e-08,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  2.6045e+00, -1.0000e+00, -7.5641e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  1.6410e-01],
        [ 0.0000e+00, -4.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00,
         -0.0000e+00, -4.3711e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+01,
          0.0000e+00,  1.0000e+00, -0.0000e+00, -0.0000e+00, -4.3711e-08,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00, -2.6045e+00, -1.0000e+00,  7.5641e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00, -8.3678e-01],
        [ 0.0000e+00, -4.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00,
         -0.0000e+00, -4.3711e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+01,
          0.0000e+00,  1.0000e+00, -0.0000e+00, -0.0000e+00, -4.3711e-08,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  3.7558e+00, -1.0000e+00,  7.0636e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00, -7.0105e-01],
        [ 0.0000e+00, -4.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00,
         -0.0000e+00, -4.3711e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+01,
          0.0000e+00,  1.0000e+00, -0.0000e+00, -0.0000e+00, -4.3711e-08,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00, -2.9969e+00, -1.0000e+00, -7.4175e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  1.7878e-03],
        [ 0.0000e+00, -4.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00,
         -0.0000e+00, -4.3711e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+01,
          0.0000e+00,  1.0000e+00, -0.0000e+00, -0.0000e+00, -4.3711e-08,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00, -8.0000e+00, -1.0000e+00,  0.0000e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  2.7490e-01],
        [ 0.0000e+00, -4.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00,
         -0.0000e+00, -4.3711e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+01,
          0.0000e+00,  1.0000e+00, -0.0000e+00, -0.0000e+00, -4.3711e-08,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  5.5805e-01, -1.0000e+00, -7.9805e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00, -2.0530e-01],
        [ 0.0000e+00, -4.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00,
         -0.0000e+00, -4.3711e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+01,
          0.0000e+00,  1.0000e+00, -0.0000e+00, -0.0000e+00, -4.3711e-08,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  1.9354e+00, -1.0000e+00,  7.7624e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  2.4865e-01],
        [ 0.0000e+00, -4.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00,
         -0.0000e+00, -4.3711e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+01,
          0.0000e+00,  1.0000e+00, -0.0000e+00, -0.0000e+00, -4.3711e-08,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  7.5175e+00, -1.0000e+00, -2.7362e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00, -5.4470e-01],
        [ 0.0000e+00, -4.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00,
         -0.0000e+00, -4.3711e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+01,
          0.0000e+00,  1.0000e+00, -0.0000e+00, -0.0000e+00, -4.3711e-08,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  4.5886e+00, -1.0000e+00,  6.5532e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00, -8.6566e-01],
        [ 0.0000e+00, -4.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00,
         -0.0000e+00, -4.3711e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+01,
          0.0000e+00,  1.0000e+00, -0.0000e+00, -0.0000e+00, -4.3711e-08,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00, -4.5886e+00, -1.0000e+00, -6.5532e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00, -8.4818e-01],
        [ 0.0000e+00, -4.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00,
         -0.0000e+00, -4.3711e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+01,
          0.0000e+00,  1.0000e+00, -0.0000e+00, -0.0000e+00, -4.3711e-08,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  2.7362e+00, -1.0000e+00,  7.5175e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00, -6.3462e-01],
        [ 0.0000e+00, -4.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00,
         -0.0000e+00, -4.3711e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+01,
          0.0000e+00,  1.0000e+00, -0.0000e+00, -0.0000e+00, -4.3711e-08,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00, -3.8785e+00, -1.0000e+00,  6.9970e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  3.3321e-01],
        [ 0.0000e+00, -4.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00,
         -0.0000e+00, -4.3711e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+01,
          0.0000e+00,  1.0000e+00, -0.0000e+00, -0.0000e+00, -4.3711e-08,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  5.3530e+00, -1.0000e+00, -5.9452e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00, -5.5375e-02],
        [ 0.0000e+00, -4.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00,
         -0.0000e+00, -4.3711e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+01,
          0.0000e+00,  1.0000e+00, -0.0000e+00, -0.0000e+00, -4.3711e-08,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  7.5641e+00, -1.0000e+00,  2.6045e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00, -2.1990e-01]]

In [9]:
mask= [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

In [10]:
next_state = torch.FloatTensor(next_state).to(DEVICE)

In [11]:
state = torch.FloatTensor(state).to(DEVICE)

In [12]:
mask = torch.from_numpy(np.array(mask)).to(DEVICE)

In [13]:
reward = torch.from_numpy(np.array(reward)).to(DEVICE)

In [14]:
nnext_state = critic_local(next_state)

In [15]:
pred_value = critic_local(state)
pred_value2 = torch.clone(pred_value)

In [16]:
pred_value.size()

torch.Size([20, 1])

In [17]:
targ_value =(nnext_state * GAMMA * mask).mean(dim=1) + reward
targ_value2 =torch.clone(targ_value)

In [18]:
targ_value = targ_value.resize(20,1)



In [19]:
targ_value

tensor([[ 0.0126],
        [ 0.0116],
        [ 0.0104],
        [ 0.0166],
        [ 0.0336],
        [ 0.0111],
        [ 0.0093],
        [ 0.0156],
        [ 0.0134],
        [-0.0035],
        [ 0.0151],
        [ 0.0131],
        [ 0.0235],
        [ 0.0165],
        [ 0.0144],
        [ 0.0142],
        [ 0.0179],
        [ 0.0232],
        [ 0.0157],
        [ 0.0306]], device='cuda:0', dtype=torch.float64,
       grad_fn=<ResizeBackward>)

In [20]:
targ_value.size()

torch.Size([20, 1])

In [21]:
value_loss = F.smooth_l1_loss(pred_value, targ_value)

In [22]:
print(value_loss)

tensor(2.0816e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<SmoothL1LossBackward0>)


In [23]:
# update value
critic_optimizer.zero_grad()
value_loss.backward(retain_graph=True)
critic_optimizer.step()

In [24]:
pred_value2 = pred_value2.resize(20)

In [25]:
targ_value2.size(), pred_value2.size()

(torch.Size([20]), torch.Size([20]))

In [26]:
torch.sub(targ_value2 , pred_value2)

tensor([-0.0062, -0.0074, -0.0014, -0.0029,  0.0142, -0.0006, -0.0014, -0.0042,
        -0.0029, -0.0057,  0.0026,  0.0055,  0.0090, -0.0019, -0.0027,  0.0134,
         0.0020,  0.0041,  0.0002,  0.0118], device='cuda:0',
       dtype=torch.float64, grad_fn=<SubBackward0>)

In [27]:
advantage = (targ_value2 - pred_value2).detach()
advantage

tensor([-0.0062, -0.0074, -0.0014, -0.0029,  0.0142, -0.0006, -0.0014, -0.0042,
        -0.0029, -0.0057,  0.0026,  0.0055,  0.0090, -0.0019, -0.0027,  0.0134,
         0.0020,  0.0041,  0.0002,  0.0118], device='cuda:0',
       dtype=torch.float64)

In [28]:
advantage.size()

torch.Size([20])

In [29]:
log_prob=[ -8.4512,  -8.7970,  -7.6534,  -7.4301,  -8.7383,  -7.5909,  -7.4831,
        -10.3112,  -8.1226,  -9.3186, -10.3247,  -6.7616,  -9.1123,  -8.9978,
        -11.7907,  -9.8782,  -9.3485, -10.6365, -10.4545,  -8.5417]

In [30]:
log_prob = torch.from_numpy(np.array(log_prob)).to(DEVICE)

In [31]:
from torch.autograd import Variable

In [32]:
policy_loss = Variable((-advantage * log_prob), requires_grad=True)
policy_loss

tensor([-0.0527, -0.0649, -0.0108, -0.0215,  0.1237, -0.0046, -0.0105, -0.0436,
        -0.0237, -0.0528,  0.0270,  0.0372,  0.0816, -0.0171, -0.0317,  0.1320,
         0.0186,  0.0440,  0.0024,  0.1010], device='cuda:0',
       dtype=torch.float64, requires_grad=True)

In [33]:
entropy_weight = 1e-2

In [34]:
policy_loss = Variable(policy_loss+ (entropy_weight * -log_prob), requires_grad=True) 
policy_loss

tensor([0.0318, 0.0230, 0.0658, 0.0528, 0.2111, 0.0713, 0.0643, 0.0595, 0.0575,
        0.0404, 0.1302, 0.1048, 0.1727, 0.0728, 0.0862, 0.2307, 0.1121, 0.1504,
        0.1069, 0.1864], device='cuda:0', dtype=torch.float64,
       requires_grad=True)

In [35]:
policy_loss2=Variable(policy_loss.mean(), requires_grad=True) 

In [41]:
policy_loss2

tensor(0.1016, device='cuda:0', dtype=torch.float64, requires_grad=True)

In [37]:
actor_optimizer.zero_grad()

In [38]:
policy_loss2.backward()

In [40]:
actor_optimizer.step()

In [43]:
policy_loss2.item()

0.10155591568686749