In [1]:
import gymnasium as gym 
import numpy as np
import pandas as pd
import typing as tt
import torch  
import torch.nn as nn 
import torch.nn.functional as F
import wandb
from collections import deque
import os

from gymnasium.wrappers import NormalizeObservation, NormalizeReward



In [2]:
# # GAE already computed
# advantages = gae_advantages.detach()

# # Value target
# value_target = advantages + values.detach()

# # Policy loss
# policy_loss = -(log_probs * advantages).mean()

# # Value loss
# value_loss = 0.5 * (values - value_target).pow(2).mean()

# # Entropy bonus
# entropy_loss = -entropy.mean()

# # Total loss
# loss = policy_loss + value_coef * value_loss + entropy_coef * entropy_loss

In [3]:
HIDDEN_LAYER1  = 128
# ALPHA = 0.95
GAMMA = 0.9 # DISCOUNT FACTOR
LAMBDA = 0.95 # FOR GAE
LR = 1e-3
# N_STEPS = 20
ENV_ID = 'InvertedPendulum-v5'
N_ENV = 1
BATCH_SIZE = 64

ENTROPY_BETA = 0.01
ENTROPY_BETA_MIN = 0.001
entropy_smoothing_factor = 0.05

if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu' 
print(f'Using device : {device}')

Using device : mps


In [4]:
env = gym.make(ENV_ID)

In [5]:
class PolicyNet(nn.Module):
    def __init__(self, input_size, fc, action_dim, log_std_min, log_std_max):
        super().__init__()
        self.input_size = input_size
        self.fc = fc
        self.action_dim = action_dim
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max
        
        self.net = nn.Sequential(
            nn.Linear(self.input_size, self.fc), 
            nn.ReLU(), 
            nn.Linear(self.fc, self.fc), 
            nn.ReLU()
        )
        
        self.mu = nn.Linear(self.fc, self.action_dim)
        
        self.log_std = nn.Parameter(torch.zeros(self.action_dim))
        
        self.critic_head = nn.Linear(self.fc, 1)
        
    def forward(self, x):
        x = self.net(x)
        
        mu = self.mu(x)
        std = torch.exp(torch.clamp(self.log_std, self.log_std_min, self.log_std_max))
        std = std.expand_as(mu)
        
        val = self.critic_head(x)
        return mu, std, val
    
class BetaScheduler:
    def __init__(self, target_reward, beta_start, beta_min=1e-4, smoothing_factor=0.01):
        self.target = target_reward
        self.start = beta_start
        self.min = beta_min
        self.alpha = smoothing_factor
        self.ema_reward = None  # Exponential Moving Average of Reward
        self.current_beta = beta_start

    def update(self, reward):
        # 1. Update EMA of Reward
        if self.ema_reward is None:
            self.ema_reward = reward
        else:
            self.ema_reward = (self.ema_reward * (1 - self.alpha)) + (reward * self.alpha)
        
        # 2. Calculate Progress (0.0 to 1.0) based on EMA
        # If ema_reward is negative, treat progress as 0
        progress = max(0.0, min(1.0, self.ema_reward / self.target))
        
        # 3. Decay Beta linearly with progress
        self.current_beta = self.start * (1.0 - progress)
        
        # 4. Clamp to minimum
        self.current_beta = max(self.current_beta, self.min)
        
        return self.current_beta

In [6]:
def compute_gae(deltas, dones, gamma, lam):
    deltas_t = torch.tensor(deltas, dtype=torch.float32)
    dones_t = torch.tensor(dones, dtype=torch.float32)

    mask = 1.0 - dones_t

    T = deltas_t.shape[0]
    adv = torch.zeros_like(deltas_t)
    gae = 0.0

    for t in reversed(range(T)):
        gae = deltas_t[t] + gamma * lam * mask[t] * gae
        adv[t] = gae

    return adv

In [14]:
from numpy import dtype


class NStepCollector:
    def __init__(self, env, policy, gamma, lam, batch_size, device):
        # super().__init__(self,)
        self.env = env
        self.policy = policy
        self.gamma = gamma
        self.lam = lam
        self.batch_size = batch_size
        self.device = device
        
        self.ep_reward = 0
        
        self.state, _ = env.reset()
        action_low = torch.tensor(env.action_space.low, dtype=torch.float32, device=device)
        action_high = torch.tensor(env.action_space.high, dtype=torch.float32, device=device)
        self.action_bias = (action_high + action_low) / 2
        self.action_scale = (action_high - action_low) / 2
                
        self.states = deque(maxlen=batch_size)
        self.rawactions = deque(maxlen=batch_size)
        self.rewards = deque(maxlen=batch_size)
        self.terms = deque(maxlen=batch_size)
        self.next_states = deque(maxlen=batch_size)
        self.deltas = deque(maxlen=batch_size)
        self.values = deque(maxlen=batch_size)
        # sel

            
    def rollout(self):
        while True:
            # print(f"state: {self.state}")
            
            state_t = torch.tensor(self.state, dtype=torch.float32, device=device).unsqueeze(0)
            
            with torch.no_grad():
                mu, std, _ = self.policy(state_t)
            
            # print('mu', mu)
            # print('std', std)
            dist = torch.distributions.Normal(mu,std)
            u = dist.sample()
            a = torch.tanh(u)
            action = a*self.action_scale + self.action_bias
            action_env = action.squeeze(0).detach().cpu().numpy()
            
            next_state, rew, term, trunc, info = self.env.step(action_env)
            # self.next_state_t = torch.Tensor(next_state, dtype=torch.float32, device=device).unsqueeze(0)
            self.ep_reward += rew
            done = term or trunc
            next_state_t = torch.tensor(next_state, dtype=torch.float32, device=device).unsqueeze(0)

            if not term:
                with torch.no_grad():
                    _, _ , v_t = self.policy(state_t)
                    _, _, v_t1 = self.policy(next_state_t)
                v_t = v_t.item()
                v_t1 = v_t1.item()
            else:
                with torch.no_grad():
                    _, _ , v_t = self.policy(state_t)
                v_t = v_t.item()
                v_t1 = 0
            
            
            delta = rew + self.gamma*v_t1 - v_t
            
            self.states.append(state_t)
            self.rewards.append(rew)
            self.terms.append(term)
            self.next_states.append(next_state)
            self.deltas.append(delta)
            self.rawactions.append(u)
            self.values.append(v_t)
            if len(self.states)>=self.batch_size:
                yield {
                        'states':list(self.states), 
                        'actions':list(self.rawactions), 
                        'dones':list(self.terms), 
                        'deltas':list(self.deltas),
                        'ep_reward': self.ep_reward if done else None,
                        'values':list(self.values)
                }
            
            else: 
                yield None
                
            self.state = next_state
            if term or trunc:
                # print("reset")
                self.state, _ = self.env.reset()
                self.ep_reward = 0
            
        
        

In [15]:
# batch_states = []
# batch_actions = []
# batch_gae = []
# batch_values = []


policy = PolicyNet(
    env.observation_space.shape[0], 
    HIDDEN_LAYER1, 
    env.action_space.shape[0], 
    log_std_min=-20, 
    log_std_max=1,
).to(device)

total_updates = 20000
optimizer = torch.optim.Adam(policy.parameters(), lr = LR)
scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer, lr_lambda=lambda upd: 1.0 - min(upd, total_updates) / total_updates
)

current_beta = ENTROPY_BETA
beta_scheduler = BetaScheduler(
    target_reward=950, 
    beta_start=ENTROPY_BETA, 
    beta_min=ENTROPY_BETA_MIN, 
    smoothing_factor=entropy_smoothing_factor
)


In [17]:
exp_collector = NStepCollector(env, policy, GAMMA, LAMBDA, BATCH_SIZE, device)
total_rewards = []
episode_idx = 0
for step_idx, exp in enumerate(exp_collector.rollout()):
    # exp = exp_collector.rollout()
    # print(exp)
    if exp is None:
        continue
    
    if exp['ep_reward'] is not None:
        # --- NEW: Update Beta when episode finishes ---
        current_beta = beta_scheduler.update(exp['ep_reward'])
        episode_reward = exp['ep_reward']
        total_rewards.append(episode_reward)
        mean_reward = float(np.mean(total_rewards[-100:]))
        print(f"episode : {episode_idx} | step: {step_idx} | episode reward : {episode_reward} | mean reward/100 eps : {mean_reward}")
        # wandb.log({
        #     "episode_reward": episode_reward, 
        #     "mean_reward_100": mean_reward,  
        #     "entropy_beta": current_beta,  # Log this to track decay!
        #     'episode_number': episode_idx,   
        #     "steps_per_episode": exp['ep_steps']
        # }, step=step_idx)
        episode_idx += 1
        
        if mean_reward>950:
            # save_path = os.path.join(wandb.run.dir, "policy_best.pt")
            # torch.save(policy.state_dict(), save_path)
            # wandb.log({"best_policy_path": save_path}, step=step_idx)
            print(f"Solved! Mean reward > 450 at episode {episode_idx}")
            break
    
    states_list = exp['states']
    rawactions_list = exp['actions']
    dones_list = exp['dones']
    deltas_list = exp['deltas']
    values_list = exp['values']
    adv_list = compute_gae(deltas_list, dones_list, GAMMA, LAMBDA)
    
    batch_states = torch.cat(states_list, dim =0)
    batch_actions = torch.cat(rawactions_list, dim=0)
    batch_adv = torch.tensor(adv_list, dtype = torch.float32, device=device)
    batch_value = torch.tensor(values_list,dtype = torch.float32, device=device)
    
    
    mu, std, value = policy(batch_states)
    value_t = value.squeeze(dim=1)

    returns = batch_adv + batch_value
    loss_critic = F.mse_loss(value_t, returns.detach())
    
    dist = torch.distributions.Normal(mu, std)
    logp_u = dist.log_prob(batch_actions).sum(dim=-1)
    a_t = torch.tanh(batch_actions)
    logp_correction = torch.log(( 1 - a_t.pow(2))+1e-6).sum(dim=-1)
    logp = logp_u - logp_correction
    
    
    # logp_u1 = dist.log_prob(batch_actions).sum(dim=-1)
    # a_t = torch.tanh(batch_actions)
    # logp_correction1 = torch.log(( 1 - a_t.pow(2))+1e-6).sum(dim=-1)
    # logp1 = logp_u1 - logp_correction1
    
    loss_policy = -(logp * batch_adv.detach()).mean()
    
    entropy = dist.entropy().sum(dim=-1).mean()
    
    total_loss = loss_critic + loss_policy - ENTROPY_BETA*entropy
    
    optimizer.zero_grad()
    total_loss.backward()
    torch.nn.utils.clip_grad_norm_(policy.parameters(), max_norm=0.5)
    optimizer.step()
    scheduler.step()
    
    # print(f"batch_states: {batch_states}")
    # print(f'batch_rawactions:{batch_actions}')
    # print(f"dones: {dones_list}")
    # print(f"deltas: {deltas_list}")
    # print(f"batch_adv: {batch_adv}")
    # print(f"batch_values: {value}, {value.shape}")
    # # print(f"batch_values, dim0: {value.squeeze(dim=0)}")
    # # print(f"batch_values, dim1: {value.squeeze(dim=1)}")
    # print(f"value_t :{value_t}")
    # print(f"dist_t: {dist}")
    # print(f"logp_u: {logp_u}")
    
    # print(f"logp_correction: {logp_correction}")
    # print(f"logp: {logp}")
    # print(f'loss_critic: {loss_critic}')
    # print(f'loss policy : {loss_policy}')
    # print(f'entropy: {entropy}')
    # print(f"entropy mean:{entropy.mean()}")
    # print(f"entropy sum:{entropy.sum(dim=-1)}")
    # print(f"entropy sum mean (used):{entropy.sum(dim=-1).mean()}")
    # print(f"log_p: {logp_correction.sum(dim=-1)}")
    # print(f"total_loss: {total_loss}")
    
    
    # returns = 
    
    
    # loss_critic = 
    
    # break
    if step_idx > 200000:
        break
    


episode : 0 | step: 63 | episode reward : 6 | mean reward/100 eps : 6.0
episode : 1 | step: 68 | episode reward : 4 | mean reward/100 eps : 5.0
episode : 2 | step: 72 | episode reward : 3 | mean reward/100 eps : 4.333333333333333
episode : 3 | step: 75 | episode reward : 2 | mean reward/100 eps : 3.75
episode : 4 | step: 80 | episode reward : 4 | mean reward/100 eps : 3.8
episode : 5 | step: 83 | episode reward : 2 | mean reward/100 eps : 3.5


  batch_adv = torch.tensor(adv_list, dtype = torch.float32, device=device)


episode : 6 | step: 94 | episode reward : 10 | mean reward/100 eps : 4.428571428571429
episode : 7 | step: 124 | episode reward : 29 | mean reward/100 eps : 7.5
episode : 8 | step: 158 | episode reward : 33 | mean reward/100 eps : 10.333333333333334
episode : 9 | step: 194 | episode reward : 35 | mean reward/100 eps : 12.8
episode : 10 | step: 220 | episode reward : 25 | mean reward/100 eps : 13.909090909090908
episode : 11 | step: 225 | episode reward : 4 | mean reward/100 eps : 13.083333333333334
episode : 12 | step: 230 | episode reward : 4 | mean reward/100 eps : 12.384615384615385
episode : 13 | step: 234 | episode reward : 3 | mean reward/100 eps : 11.714285714285714
episode : 14 | step: 237 | episode reward : 2 | mean reward/100 eps : 11.066666666666666
episode : 15 | step: 240 | episode reward : 2 | mean reward/100 eps : 10.5
episode : 16 | step: 243 | episode reward : 2 | mean reward/100 eps : 10.0
episode : 17 | step: 246 | episode reward : 2 | mean reward/100 eps : 9.5555555

KeyboardInterrupt: 

In [None]:
env.close()

In [None]:
# state: [-0.00128835  0.00534891 -0.31165499  0.69753022]
# mu tensor([[-0.1096]], device='mps:0')
# std tensor([[1.]], device='mps:0')
# {'states': [tensor([[ 0.0089, -0.0047,  0.0056,  0.0059]], device='mps:0'), tensor([[-0.0085,  0.0360, -0.8698,  2.0097]], device='mps:0'), tensor([[-0.0615,  0.1570, -1.7760,  4.0253]], device='mps:0'), tensor([[ 0.0050, -0.0086, -0.0011, -0.0086]], device='mps:0'), tensor([[-0.0013,  0.0053, -0.3117,  0.6975]], device='mps:0')], 'actions': [tensor([[-1.3766]], device='mps:0'), tensor([[-1.6407]], device='mps:0'), tensor([[-1.1694]], device='mps:0'), tensor([[-0.3236]], device='mps:0'), tensor([[0.6305]], device='mps:0')], 'dones': [False, False, True, False, False], 'deltas': [1.0049174800515175, 1.0309211984276772, 0.08663583546876907, 1.0279532670974731, 0.93100406229496]}
# batch_states: tensor([[ 8.8630e-03, -4.7227e-03,  5.5838e-03,  5.9092e-03],
#         [-8.4699e-03,  3.6001e-02, -8.6976e-01,  2.0097e+00],
#         [-6.1460e-02,  1.5701e-01, -1.7760e+00,  4.0253e+00],
#         [ 4.9838e-03, -8.5719e-03, -1.1198e-03, -8.6451e-03],
#         [-1.2883e-03,  5.3489e-03, -3.1165e-01,  6.9753e-01]], device='mps:0')
# batch_rawactions:tensor([[-1.3766],
#         [-1.6407],
#         [-1.1694],
#         [-0.3236],
#         [ 0.6305]], device='mps:0')
# dones: [False, False, True, False, False]
# deltas: [1.0049174800515175, 1.0309211984276772, 0.08663583546876907, 1.0279532670974731, 0.93100406229496]
# batch_adv: tensor([1.0049, 1.0309, 0.9655, 1.0280, 0.9310], device='mps:0')
# batch_values: tensor([[-0.1225],
#         [-0.1176],
#         [-0.0866],
#         [-0.1232],
#         [-0.0953]], device='mps:0', grad_fn=<LinearBackward0>), torch.Size([5, 1])
# value_t :tensor([-0.1225, -0.1176, -0.0866, -0.1232, -0.0953], device='mps:0',
#        grad_fn=<SqueezeBackward1>)
# dist_t: Normal(loc: torch.Size([5, 1]), scale: torch.Size([5, 1]))
# logp_u: tensor([-1.7717, -1.9905, -1.3410, -0.9512, -1.1928], device='mps:0',
#        grad_fn=<SumBackward1>)
# logp_correction: tensor([-1.4904, -1.9689, -1.1366, -0.1030, -0.3737], device='mps:0')
# logp: tensor([-0.2813, -0.0216, -0.2044, -0.8482, -0.8192], device='mps:0',
#        grad_fn=<SubBackward0>)
# loss_critic: 0.012116288766264915
# loss policy : 0.42738208174705505
# entropy: tensor([[1.4189],
#         [1.4189],
#         [1.4189],
#         [1.4189],
#         [1.4189]], device='mps:0', grad_fn=<AddBackward0>)
# entropy mean:1.4189385175704956
# entropy sum:tensor([1.4189, 1.4189, 1.4189, 1.4189, 1.4189], device='mps:0',
#        grad_fn=<SumBackward1>)
# entropy sum mean (used):1.418938398361206