In [5]:
import gymnasium as gym 
import numpy as np
import pandas as pd
import typing as tt
import torch  
import torch.nn as nn 
import torch.nn.functional as F
import wandb
from gymnasium.wrappers import NormalizeObservation, NormalizeReward



In [61]:
HIDDEN_LAYER1  = 256

GAMMA = 0.99
LR = 1e-4

N_STEPS = 1
ENTROPY_BETA = 0.01
ENV_ID = 'InvertedPendulum-v5'
N_ENV = 1
BATCH_SIZE = 5


if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu' 
print(f'Using device : {device}')

Using device : mps


In [3]:
env = gym.make(ENV_ID)
eval_env = gym.make(ENV_ID, render_mode='rgb_array')


In [4]:
env.observation_space.shape, env.action_space.shape

((4,), (1,))

In [4]:
env.observation_space.sample(), env.action_space.sample()

(array([ 0.26376065,  1.75201758, -0.5726076 , -1.58489694]),
 array([-0.769586], dtype=float32))

In [7]:
env.observation_space, env.action_space

(Box(-inf, inf, (4,), float32), Box(-3.0, 3.0, (1,), float32))

In [6]:
env = NormalizeObservation(env)
env.observation_space.sample(), env.action_space.sample()

(array([-0.2851903 ,  1.5047672 , -0.98539096, -0.18656994], dtype=float32),
 array([-1.2208982], dtype=float32))

In [8]:
class PolicyNet(nn.Module):
    def __init__(self, input_size, fc, action_dim):
        super().__init__()
        self.input_size = input_size
        self.fc = fc 
        self.action_dim = action_dim
        
        self.net = nn.Sequential(
            nn.Linear(self.input_size, self.fc), 
            nn.ReLU(), 
            
        )
        
        self.mu = nn.Sequential(
            nn.Linear(self.fc, self.action_dim), 
            # nn.Tanh()

        )
        self.log_std = nn.Parameter(torch.zeros(action_dim))

        
        self.critic_head = nn.Linear(self.fc, 1)
        
    def forward(self, x):
        x = self.net(x)
        mu = self.mu(x)
        # Use learned constant std (common in simple continuous control)
        std = torch.exp(self.log_std.clamp(-2, 0.5))  # exp(-2)=0.135, exp(0.5)=1.65
        std = std.expand_as(mu)  # Broadcast to batch size
        
        v = self.critic_head(x)
        return mu, std, v

In [9]:
policy = PolicyNet(
    env.observation_space.shape[0], 
    HIDDEN_LAYER1, 
    env.action_space.shape[0]
).to(device)
policy

PolicyNet(
  (net): Sequential(
    (0): Linear(in_features=4, out_features=256, bias=True)
    (1): ReLU()
  )
  (mu): Sequential(
    (0): Linear(in_features=256, out_features=1, bias=True)
  )
  (critic_head): Linear(in_features=256, out_features=1, bias=True)
)

In [14]:

low = torch.tensor(env.action_space.low, dtype=torch.float32, device=device)
high = torch.tensor(env.action_space.high, dtype=torch.float32, device=device)
state, _ = env.reset()
state_t = torch.tensor(state, dtype=torch.float32, device = device).unsqueeze(0)
mu_v, std_v,  value = policy(state_t)
print('mu', mu_v)
print('var: ', std_v)
# break
# mu = mu_v.data.cpu().numpy()
# print('mu', mu)
# dist = torch.distributions.Normal()
# action = dist.sample().item()
# std = torch.sqrt(var_v)
dist = torch.distributions.Normal(mu_v, std_v)
print(dist)
u = dist.sample()
a = torch.tanh(u)
print(f'a:{a}')
action = low + (a+1) * (high-low)*0.5
print(f'action:{action}')
action_env = action.squeeze(0).detach().cpu().numpy()
action_env


mu tensor([[-0.1050]], device='mps:0', grad_fn=<LinearBackward0>)
var:  tensor([[1.]], device='mps:0', grad_fn=<ExpandBackward0>)
Normal(loc: tensor([[-0.1050]], device='mps:0', grad_fn=<LinearBackward0>), scale: tensor([[1.]], device='mps:0', grad_fn=<ExpandBackward0>))
a:tensor([[-0.2339]], device='mps:0')
action:tensor([[-0.7018]], device='mps:0')


array([-0.7017853], dtype=float32)

In [17]:
# sanity check for action range.
with torch.no_grad():
    samples = dist.sample((10000000,))          # [10000, batch, act_dim]
    a = torch.tanh(samples)
    actions = low + (a + 1) * (high - low) * 0.5
    # actions = torch.clamp(u, low, high)
    print(actions.min().item(), actions.max().item())

-3.0 2.9996337890625


In [63]:
def experience_generator(env, policy, gamma, n_steps):
    while True: 
        state_list = []
        raw_action_list = []
        reward_list = []
        return_list = []
        done_list = []
        last_state_list = []
        
        done = False
        ep_rew = 0
        state, _ = env.reset()
        while not done:
            state_t = torch.tensor(state, dtype=torch.float32, device = device).unsqueeze(0)
            with torch.no_grad():
                mu_v, std_v,  value = policy(state_t)
            dist = torch.distributions.Normal(mu_v, std_v)
            # print(dist)
            u = dist.sample()
            a = torch.tanh(u)
            # print(f'a:{a}')
            action = low + (a+1) * (high-low)*0.5
            # print(f'action:{action}')
            # print(action)
            action_env = action.squeeze(0).detach().cpu().numpy()
            new_state, rew, term, trunc, info = env.step(action_env)
            done = term or trunc
            ep_rew += rew
            state_list.append(state_t)
            raw_action_list.append(u)
            reward_list.append(rew)
            done_list.append(done)
            
            last_state_list.append(new_state)
                
            if len(reward_list)>=n_steps:
                ret = sum([reward_list[i]* (gamma**i) for i in range(n_steps)])
                
                yield { 
                    'state':state_list[0], 
                    'raw_action':raw_action_list[0],
                    'ret':ret,
                    'done':done,
                    'last_state':last_state_list[n_steps-1] if not done else None, 
                    'ep_reward': ep_rew if done_list[0] else None, 
                    'reward_list':reward_list,
                }
                
                state_list.pop(0)
                raw_action_list.pop(0)
                reward_list.pop(0)
                done_list.pop(0)
                last_state_list.pop(0)
                
            state = new_state
                
        else:
            while len(reward_list)>0:
                ret = sum([reward_list[i]* (gamma**i) for i in range(len(reward_list))])
                
                yield { 
                    'state':state_list[0], 
                    'raw_action':raw_action_list[0],
                    'ret':ret,
                    'done':done,
                    'last_state': None, 
                    'ep_reward': ep_rew if done_list[0] else None,
                    'reward_list':reward_list,
                }
                
                state_list.pop(0)
                raw_action_list.pop(0)
                reward_list.pop(0)
                done_list.pop(0)
                last_state_list.pop(0)
                
def record_video(env, policy, device, max_steps=500):
    """Record a single episode and return frames + reward"""
    frames = []
    state, _ = env.reset()
    done = False
    total_reward = 0
    steps = 0
    low = torch.tensor(env.action_space.low, dtype=torch.float32, device=device)
    high = torch.tensor(env.action_space.high, dtype=torch.float32, device=device)
    
    while not done and steps < max_steps:
        frame = env.render()
        frames.append(frame)        
        state_tensor = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
        
        with torch.no_grad():
            mu, std, val = policy(state_tensor)
        dist = torch.distributions.Normal(mu, std)
        u = dist.sample()
        a = torch.tanh(u)
        action = low + (a+1) * (high-low)*0.5
            
        action_env = action.squeeze(0).detach().cpu().numpy()
        state, reward, terminated, truncated, _ = env.step(action_env)
        total_reward += reward
        done = terminated or truncated
        steps += 1
        
    return frames, total_reward, steps

def smooth(old: tt.Optional[float], val: float, alpha: float = 0.95) -> float:
    if old is None:
        return val
    return old * alpha + (1-alpha)*val

In [64]:


policy = PolicyNet(
    env.observation_space.shape[0], 
    HIDDEN_LAYER1, 
    env.action_space.shape[0]
).to(device)
optimizer = torch.optim.Adam(policy.parameters(),lr=LR, )



batch_states = []
batch_returns = []
batch_raw_actions = []
batch_values = []
done_list = []
last_state_list = []
total_rewards = []
adv_smoothed = l_entropy = l_policy = l_value = l_total = None
episode_idx = 0
# BATCH_SIZE = 1 * REWARD_STEP  # n_env * reward_steps

for step_idx, exp in enumerate(experience_generator(env, policy, GAMMA, N_STEPS)):
    batch_states.append(exp['state']) 
    batch_raw_actions.append(exp['raw_action'])

    print(f'\nstate:{exp['state']}')
    print(f'actions : {exp['raw_action']}')
    
    print(f'done:{exp['done']}')
    print(f'last_state:{exp['last_state']}')
    ## bootstrapping if the episode is not completed withing REWARD_STEP
    if exp['last_state'] is not None:
        last_state = exp['last_state']
        last_state_t = torch.tensor(last_state, dtype=torch.float32, device=device).unsqueeze(0)
        with torch.no_grad():
            _, _, bs_val = policy(last_state_t)
        bs_val = bs_val.item()
        ret = exp['ret'] +  (bs_val) * (GAMMA**N_STEPS) 
        batch_returns.append(ret)
    else:
        batch_returns.append(exp['ret'])
        ret =exp['ret']
        bs_val = 0
        
    print(f'bs_val;{bs_val}')
    print(f'ret:{exp['ret']}')
    print(f'ep_reward:{exp['ep_reward']}')
    print(f'reward_list:{exp['reward_list']}')
    print(f"return after bootstrap: {ret}")
        
    if exp['ep_reward'] is not None:
        episode_reward = exp['ep_reward']
        total_rewards.append(episode_reward)
        mean_reward = float(np.mean(total_rewards[-100:]))
        print(f"episode : {episode_idx} | step: {step_idx} | episode reward : {episode_reward} | mean reward/100 eps : {mean_reward}")
        # wandb.log({
        #     "episode_reward": episode_reward, 
        #     "mean_reward_100": mean_reward,  
        #     'episode_number': episode_idx,   
        #     "steps_per_episode": step_idx / max(episode_idx, 1)
        # }, step=step_idx)
        episode_idx += 1
        break
        if mean_reward>950:
            print(f"Solved! Mean reward > 450 at episode {episode_idx}")
            break
        
        
        
    if len(batch_states) < BATCH_SIZE:
        continue
    print(f"batch_actions: {batch_raw_actions}")
    batch_states_t = torch.cat(batch_states, dim=0)
    batch_actions_t = torch.cat(batch_raw_actions, dim=0).to(device).float()  # each element in batch_raw_actions is [1, act_dim]
    batch_returns_t = torch.tensor(batch_returns, dtype=torch.float32, device=device)
    print(f"batch_states_t: {batch_states_t}")
    print(f"batch_actions_t: {batch_actions_t}")
    # print(f"batch_actions_t sum: {batch_actions_t.sum(dim=-1)}")
    print(f"batch_returns_t: {batch_returns_t}")
    mu, std, value_t = policy(batch_states_t)
    value_t = value_t.squeeze(-1)
    # std = std.clamp(1e-3, 2.0)

    
    
    dist_t = torch.distributions.Normal(mu, std)
    
    # u_t = batch_actions_t                         # pre-tanh actions, [B, act_dim]
    logp_u = dist_t.log_prob(batch_actions_t).sum(dim=-1)     # [B]

    a_t = torch.tanh(batch_actions_t)
    log_prob_correction = torch.log(1.0 - a_t.pow(2) + 1e-6).sum(dim=-1)  # [B]
    logp = logp_u - log_prob_correction                      # [B]

    print(f"logp_u:{logp_u}")
    print(f"log_prob_correction:{log_prob_correction}")
    print(f"logp(after correction):{logp}")
    
    
    
    adv_t = (batch_returns_t - value_t).detach()
    loss_policy = - (logp * adv_t).mean()

    returns = adv_t + value_t.detach()
    print(f"Returns:{returns}")
    print(f"value_t:{value_t}")
    loss_value = F.mse_loss(value_t, batch_returns_t.detach())
    
    entropy = dist_t.entropy().mean()
    loss_entropy = -ENTROPY_BETA*entropy
    
    
    print(f"dis_t: {dist_t}")
    print(f'actions_prob_t:{logp}')
    
    print(f"loss_value: {loss_value}")

    print(f"adv_t: {adv_t}")
    print(f"value_t: {value_t}")
    print(f"loss_policy: {loss_policy}")
    print(f"entropy: {entropy}")
    print(f"loss_entropy: {loss_entropy}")
    # break

    loss_total = loss_value + loss_policy + loss_entropy
    
    optimizer.zero_grad()
    loss_total.backward()
    optimizer.step()
    
    
    # with torch.no_grad():
    #     mu_t, std_t, v_t = policy(batch_states_t)
    #     new_dist_t = torch.distributions.Normal(mu_t, std_t)
        
    #     kl_div = torch.distributions.kl_divergence(dist_t, new_dist_t).mean()
        
    # grad_max = 0.0
    # grad_means = 0.0
    # grad_count = 0
    # for p in policy.parameters():
    #     grad_max = max(grad_max, p.grad.abs().max().item())
    #     grad_means += (p.grad ** 2).mean().sqrt().item()
    #     grad_count += 1
        
        
    # adv_smoothed = smooth(
    #                 adv_smoothed,
    #                 float(np.mean(adv_t.mean().item()))
    #             )
    # l_entropy = smooth(l_entropy, loss_entropy.item())
    # l_policy = smooth(l_policy, loss_policy.item())
    # l_value = smooth(l_value, loss_value.item())
    # l_total = smooth(l_total, loss_total.item())
    
    
    
    # # break

    # # wandb.log({
    # #     # 'baseline':baseline,
    # #     'advantage':adv_smoothed,
    # #     'entropy':entropy,
    # #     'loss_policy':l_policy,
    # #     'loss_value':l_value,
    # #     'loss_entropy': l_entropy, 
    # #     'loss_total': l_total,
    # #     'kl div': kl_div.item(),
    # #     'grad_l2':grad_means/grad_count,
    # #     'grad_max':grad_max,
    # #     'batch_scales': batch_returns,
    # #     "current_episode": episode_idx
    # # }, step = step_idx)
    
    batch_raw_actions.clear()
    batch_returns.clear()
    batch_states.clear()
    




state:tensor([[-0.1718,  0.2611, -0.3080,  0.2984]], device='mps:0')
actions : tensor([[-0.5496]], device='mps:0')
done:False
last_state:[-0.34773558  0.44511846 -0.9187471   0.92713106]
bs_val;0.18055394291877747
ret:1.0
ep_reward:None
reward_list:[1]
return after bootstrap: 1.1787484034895896

state:tensor([[-0.3477,  0.4451, -0.9187,  0.9271]], device='mps:0')
actions : tensor([[0.3344]], device='mps:0')
done:False
last_state:[-0.58407205  0.68568045 -0.51969343  0.5064329 ]
bs_val;0.11198025941848755
ret:1.0
ep_reward:None
reward_list:[1]
return after bootstrap: 1.1108604568243026

state:tensor([[-0.5841,  0.6857, -0.5197,  0.5064]], device='mps:0')
actions : tensor([[0.4569]], device='mps:0')
done:False
last_state:[-0.55482924  0.654547   -0.00333064 -0.00120364]
bs_val;0.03157982975244522
ret:1.0
ep_reward:None
reward_list:[1]
return after bootstrap: 1.0312640314549208

state:tensor([[-0.5548,  0.6545, -0.0033, -0.0012]], device='mps:0')
actions : tensor([[-0.0277]], device='mps

In [71]:

batch_states_t = torch.cat(batch_states, dim=0)
batch_actions_t = torch.cat(batch_raw_actions).float()
batch_returns_t = torch.tensor(batch_returns, dtype=torch.float32, device=device)

mu, std, value_t = policy(batch_states_t)
value_t = value_t.squeeze(-1)
dist_t = torch.distributions.Normal(mu, std)
actions_prob_t = dist_t.log_prob(batch_actions_t).sum(dim=-1)

loss_value = F.mse_loss(value_t, batch_returns_t.detach())


adv_t = (batch_returns_t - value_t).detach()
loss_policy = - (actions_prob_t * adv_t).mean()



entropy = dist_t.entropy().mean()
loss_entropy = -ENTROPY_BETA*entropy
print(f"batch_states_t: {batch_states_t}")
print(f"batch_actions_t: {batch_actions_t}")
# print(f"batch_actions_t sum: {batch_actions_t.sum(dim=-1)}")
print(f"batch_returns_t: {batch_returns_t}")

print(f"dis_t: {dist_t}")
print(f'actions_prob_t:{actions_prob_t}')

print(f"loss_value: {loss_value}")

print(f"adv_t: {adv_t}")
print(f"value_t: {value_t}")
print(f"loss_policy: {loss_policy}")
print(f"entropy: {entropy}")
print(f"loss_entropy: {loss_entropy}")

batch_states_t: tensor([[-8.6029e-03, -8.4917e-03, -1.2738e-04, -1.9363e-03],
        [-1.2055e-02, -7.2918e-04, -1.7203e-01,  3.8616e-01],
        [-1.5743e-02,  6.9793e-03, -1.3122e-02,  6.1377e-03],
        [-3.1883e-02,  4.3436e-02, -7.9160e-01,  1.7981e+00],
        [-5.8900e-02,  1.0370e-01, -5.6137e-01,  1.2355e+00],
        [-9.5712e-02,  1.8669e-01, -1.2763e+00,  2.8993e+00],
        [-9.0386e-03, -1.0421e-03,  6.5123e-03, -9.4464e-03],
        [-7.6496e-03, -4.0059e-03,  6.2797e-02, -1.3754e-01],
        [-1.4319e-02,  1.1792e-02, -3.9493e-01,  9.1576e-01],
        [-3.4847e-02,  5.8686e-02, -6.3142e-01,  1.4306e+00],
        [-5.0815e-02,  9.4200e-02, -1.6951e-01,  3.6809e-01],
        [-5.9134e-02,  1.1383e-01, -2.4642e-01,  6.1380e-01],
        [-8.3554e-02,  1.7322e-01, -9.7174e-01,  2.3377e+00],
        [ 7.0420e-03,  6.5816e-03, -1.0610e-03, -5.2243e-03],
        [ 2.2962e-02, -3.0341e-02,  7.9487e-01, -1.8224e+00],
        [ 3.8894e-02, -6.5366e-02,  5.5118e-03,  3.702

In [None]:
# state:tensor([[-0.0619,  0.1597, -0.2782,  0.7257]], device='mps:0')
# actions : tensor([[0.6015]], device='mps:0')


#single loop outputs ( batchsize = 32)

# batch_actions: [tensor([[1.2787]], device='mps:0'), tensor([[-0.5051]], device='mps:0'), tensor([[0.9226]], device='mps:0'), tensor([[-0.4800]], device='mps:0'), tensor([[-0.1513]], device='mps:0'), tensor([[-0.3652]], device='mps:0'), tensor([[-0.6029]], device='mps:0'), tensor([[0.1439]], device='mps:0'), tensor([[-0.1041]], device='mps:0'), tensor([[-0.4831]], device='mps:0'), tensor([[-0.1976]], device='mps:0'), tensor([[-0.3689]], device='mps:0'), tensor([[1.4857]], device='mps:0'), tensor([[-0.0416]], device='mps:0'), tensor([[0.3282]], device='mps:0'), tensor([[-0.3393]], device='mps:0'), tensor([[0.2818]], device='mps:0'), tensor([[-0.7668]], device='mps:0'), tensor([[1.3477]], device='mps:0'), tensor([[-0.0385]], device='mps:0'), tensor([[0.2682]], device='mps:0'), tensor([[0.4369]], device='mps:0'), tensor([[-0.4063]], device='mps:0'), tensor([[0.4942]], device='mps:0'), tensor([[0.1169]], device='mps:0'), tensor([[-1.5994]], device='mps:0'), tensor([[0.8888]], device='mps:0'), tensor([[1.2380]], device='mps:0'), tensor([[0.5268]], device='mps:0'), tensor([[0.9113]], device='mps:0'), tensor([[-0.2788]], device='mps:0'), tensor([[-0.7627]], device='mps:0'), tensor([[0.2360]], device='mps:0'), tensor([[1.4677]], device='mps:0'), tensor([[0.4888]], device='mps:0'), tensor([[0.0271]], device='mps:0'), tensor([[-0.5957]], device='mps:0'), tensor([[0.0911]], device='mps:0'), tensor([[1.0240]], device='mps:0'), tensor([[1.6481]], device='mps:0'), tensor([[-0.3191]], device='mps:0'), tensor([[0.0115]], device='mps:0'), tensor([[-1.0483]], device='mps:0'), tensor([[-0.6342]], device='mps:0'), tensor([[1.1283]], device='mps:0'), tensor([[-0.3149]], device='mps:0'), tensor([[-0.8459]], device='mps:0'), tensor([[0.2149]], device='mps:0'), tensor([[0.2619]], device='mps:0'), tensor([[0.6015]], device='mps:0')]
# batch_states_t: tensor([[ 4.5612e-03, -5.9440e-03, -4.8425e-03,  5.3283e-03],
#         [ 2.1461e-02, -4.5292e-02,  8.4738e-01, -1.9524e+00],
#         [ 4.5854e-02, -1.0036e-01,  3.7533e-01, -8.2890e-01],
#         [ 7.5390e-02, -1.6780e-01,  1.0988e+00, -2.5269e+00],
#         [-3.5143e-03, -5.7780e-03,  8.2755e-03, -5.4098e-03],
#         [-6.1724e-03,  8.4030e-04, -1.4078e-01,  3.3289e-01],
#         [-1.8738e-02,  2.9888e-02, -4.8678e-01,  1.1141e+00],
#         [-4.8847e-02,  9.8466e-02, -1.0175e+00,  2.3092e+00],
#         [-8.6524e-02,  1.8344e-01, -8.6848e-01,  1.9619e+00],
#         [-9.9881e-03,  8.2343e-03,  4.9069e-03,  1.2848e-03],
#         [-1.8769e-02,  2.9196e-02, -4.4270e-01,  1.0363e+00],
#         [-4.0273e-02,  7.8945e-02, -6.3264e-01,  1.4545e+00],
#         [-7.2532e-02,  1.5312e-01, -9.7968e-01,  2.2549e+00],
#         [ 8.1176e-04, -5.4549e-03, -2.0898e-03,  9.0584e-03],
#         [-9.1878e-05, -3.2760e-03, -4.2993e-02,  9.9050e-02],
#         [ 4.5314e-03, -1.4091e-02,  2.7323e-01, -6.3168e-01],
#         [ 8.8704e-03, -2.3743e-02, -5.4890e-02,  1.3633e-01],
#         [ 1.2224e-02, -3.1638e-02,  2.2169e-01, -5.2333e-01],
#         [ 8.2108e-03, -2.2864e-02, -4.2026e-01,  9.4330e-01],
#         [ 9.0227e-03, -2.7142e-02,  4.5771e-01, -1.1289e+00],
#         [ 2.6448e-02, -6.9611e-02,  4.1456e-01, -1.0047e+00],
#         [ 4.8237e-02, -1.2202e-01,  6.7462e-01, -1.6164e+00],
#         [-3.7639e-03,  8.2444e-03, -5.4942e-03, -2.0568e-03],
#         [-1.1695e-02,  2.6153e-02, -3.8999e-01,  8.8843e-01],
#         [-1.8094e-02,  4.0083e-02,  6.8090e-02, -1.7408e-01],
#         [-1.3169e-02,  2.8968e-02,  1.7802e-01, -3.8096e-01],
#         [-2.4560e-02,  5.7393e-02, -7.4454e-01,  1.7766e+00],
#         [-4.0037e-02,  9.4836e-02, -3.2908e-02,  1.2778e-01],
#         [-2.4730e-02,  6.3336e-02,  7.9642e-01, -1.6853e+00],
#         [ 1.6366e-02, -2.2557e-02,  1.2584e+00, -2.6133e+00],
#         [ 8.0660e-02, -1.5706e-01,  1.9538e+00, -4.1080e+00],
#         [ 2.8991e-03,  1.2733e-03, -8.0541e-04,  7.9827e-03],
#         [-9.9628e-03,  3.1270e-02, -6.4051e-01,  1.4769e+00],
#         [-3.0804e-02,  7.8464e-02, -4.0343e-01,  9.0071e-01],
#         [-2.9077e-02,  7.3944e-02,  4.8683e-01, -1.0998e+00],
#         [-8.9494e-04,  1.2327e-02,  9.2215e-01, -1.9816e+00],
#         [ 3.6189e-02, -6.5191e-02,  9.3342e-01, -1.9101e+00],
#         [ 6.2722e-02, -1.1600e-01,  3.9649e-01, -6.5983e-01],
#         [ 8.0529e-02, -1.4827e-01,  4.9392e-01, -9.5565e-01],
#         [-2.4337e-03, -8.6594e-03,  6.9423e-04,  3.7023e-03],
#         [ 1.6138e-02, -5.1468e-02,  9.2517e-01, -2.1218e+00],
#         [ 4.6761e-02, -1.2045e-01,  6.0870e-01, -1.3530e+00],
#         [ 7.1371e-02, -1.7598e-01,  6.2259e-01, -1.4331e+00],
#         [ 8.6990e-03, -6.3233e-03,  4.1285e-04,  7.0600e-03],
#         [-2.4667e-03,  1.9695e-02, -5.5719e-01,  1.2810e+00],
#         [-8.4385e-03,  3.2398e-02,  2.5537e-01, -6.1651e-01],
#         [-4.4765e-03,  2.3398e-02, -5.5927e-02,  1.5399e-01],
#         [-2.0489e-02,  6.1659e-02, -7.4270e-01,  1.7433e+00],
#         [-4.5841e-02,  1.2080e-01, -5.2696e-01,  1.2333e+00],
#         [-6.1909e-02,  1.5966e-01, -2.7817e-01,  7.2573e-01]], device='mps:0')
# batch_actions_t: tensor([[ 1.2787],
#         [-0.5051],
#         [ 0.9226],
#         [-0.4800],
#         [-0.1513],
#         [-0.3652],
#         [-0.6029],
#         [ 0.1439],
#         [-0.1041],
#         [-0.4831],
#         [-0.1976],
#         [-0.3689],
#         [ 1.4857],
#         [-0.0416],
#         [ 0.3282],
#         [-0.3393],
#         [ 0.2818],
#         [-0.7668],
#         [ 1.3477],
#         [-0.0385],
#         [ 0.2682],
#         [ 0.4369],
#         [-0.4063],
#         [ 0.4942],
#         [ 0.1169],
#         [-1.5994],
#         [ 0.8888],
#         [ 1.2380],
#         [ 0.5268],
#         [ 0.9113],
#         [-0.2788],
#         [-0.7627],
#         [ 0.2360],
#         [ 1.4677],
#         [ 0.4888],
#         [ 0.0271],
#         [-0.5957],
#         [ 0.0911],
#         [ 1.0240],
#         [ 1.6481],
#         [-0.3191],
#         [ 0.0115],
#         [-1.0483],
#         [-0.6342],
#         [ 1.1283],
#         [-0.3149],
#         [-0.8459],
#         [ 0.2149],
#         [ 0.2619],
#         [ 0.6015]], device='mps:0')
# batch_actions_t sum: tensor([ 1.2787, -0.5051,  0.9226, -0.4800, -0.1513, -0.3652, -0.6029,  0.1439,
#         -0.1041, -0.4831, -0.1976, -0.3689,  1.4857, -0.0416,  0.3282, -0.3393,
#          0.2818, -0.7668,  1.3477, -0.0385,  0.2682,  0.4369, -0.4063,  0.4942,
#          0.1169, -1.5994,  0.8888,  1.2380,  0.5268,  0.9113, -0.2788, -0.7627,
#          0.2360,  1.4677,  0.4888,  0.0271, -0.5957,  0.0911,  1.0240,  1.6481,
#         -0.3191,  0.0115, -1.0483, -0.6342,  1.1283, -0.3149, -0.8459,  0.2149,
#          0.2619,  0.6015], device='mps:0')
# batch_returns_t: tensor([ 2.9701,  1.9900,  1.0000,  0.0000,  3.9404,  2.9701,  1.9900,  1.0000,
#          0.0000,  2.9701,  1.9900,  1.0000,  0.0000,  7.7255,  6.7935,  5.8520,
#          4.9010,  3.9404,  2.9701,  1.9900,  1.0000,  0.0000,  7.7255,  6.7935,
#          5.8520,  4.9010,  3.9404,  2.9701,  1.9900,  1.0000,  0.0000,  6.7935,
#          5.8520,  4.9010,  3.9404,  2.9701,  1.9900,  1.0000,  0.0000,  2.9701,
#          1.9900,  1.0000,  0.0000, 18.2093, 17.3831, 16.5486, 15.7057, 14.8542,
#         13.9942, 13.1254], device='mps:0')
# dis_t: Normal(loc: torch.Size([50, 1]), scale: torch.Size([50, 1]))
# actions_prob_t:tensor([-0.4574, -0.7943, -0.4546, -0.7709, -0.6558, -0.8042, -1.0410, -0.7419,
#         -0.8810, -0.7865, -0.8240, -0.9640,  0.1293, -0.6181, -0.4856, -0.6520,
#         -0.5033, -0.8499, -0.1618, -0.5619, -0.4774, -0.4338, -0.7566, -0.4353,
#         -0.5437, -1.8201, -0.1695, -0.4007, -0.3870, -0.3706, -0.6982, -0.9345,
#         -0.6105, -0.2013, -0.4091, -0.5548, -0.8383, -0.5283, -0.4903, -0.7511,
#         -0.7024, -0.5511, -1.1298, -0.8620, -0.1132, -0.6405, -1.0252, -0.6416,
#         -0.5891, -0.3968], device='mps:0', grad_fn=<SubBackward0>)
# loss_value: 48.21304702758789
# adv_t: tensor([ 3.1246e+00,  2.3859e+00,  1.2986e+00,  4.4293e-01,  4.0973e+00,
#          3.1171e+00,  2.0845e+00,  9.4994e-01, -5.3820e-03,  3.1267e+00,
#          2.0919e+00,  1.0555e+00, -4.6113e-02,  7.8807e+00,  6.9455e+00,
#          6.1175e+00,  5.0521e+00,  4.1826e+00,  3.0799e+00,  2.3132e+00,
#          1.3146e+00,  3.7432e-01,  7.8808e+00,  6.9078e+00,  6.0290e+00,
#          5.1133e+00,  3.9548e+00,  3.1183e+00,  2.3629e+00,  1.4416e+00,
#          5.4254e-01,  6.9476e+00,  5.9029e+00,  5.0145e+00,  4.2586e+00,
#          3.3650e+00,  2.3852e+00,  1.2755e+00,  3.1532e-01,  3.1266e+00,
#          2.3989e+00,  1.3518e+00,  3.6137e-01,  1.8363e+01,  1.7457e+01,
#          1.6806e+01,  1.5854e+01,  1.4873e+01,  1.4075e+01,  1.3249e+01],
#        device='mps:0')
# value_t: tensor([-0.1545, -0.3959, -0.2986, -0.4429, -0.1569, -0.1470, -0.0945,  0.0501,
#          0.0054, -0.1566, -0.1019, -0.0555,  0.0461, -0.1552, -0.1521, -0.2655,
#         -0.1511, -0.2422, -0.1098, -0.3232, -0.3146, -0.3743, -0.1552, -0.1143,
#         -0.1771, -0.2123, -0.0144, -0.1482, -0.3729, -0.4416, -0.5425, -0.1541,
#         -0.0509, -0.1135, -0.3182, -0.3949, -0.3952, -0.2755, -0.3153, -0.1565,
#         -0.4089, -0.3518, -0.3614, -0.1536, -0.0741, -0.2573, -0.1479, -0.0183,
#         -0.0805, -0.1232], device='mps:0', grad_fn=<SqueezeBackward1>)
# loss_policy: 3.0877444744110107
# entropy: 1.0807201862335205
# loss_entropy: -0.010807201266288757

In [None]:
# bs val: tensor([[0.1607]], device='mps:0', grad_fn=<LinearBackward0>)
# ret tensor([[9.7072]], device='mps:0', grad_fn=<AddBackward0>)

# # with .item(). ( correct)
# bs val: 0.12476468086242676
# ret 9.78975201174128

