In [1]:
import torch.nn as nn 
import torch 
from copy import deepcopy
class Policy(nn.Module):
    def __init__(self):
        super().__init__()
        self.core = nn.Sequential(
            nn.Linear(4,256),
            nn.ReLU(),
            nn.Linear(256,256),
            nn.ReLU(),
            nn.Linear(256,128),
            nn.ReLU(),
            nn.Linear(128,128),
            nn.ReLU(),
            nn.Linear(128,2)
        )
        
    def forward(self,x):
        return self.core(x)


class Value(nn.Module):
    def __init__(self):
        super().__init__()
        self.core = nn.Sequential(
            nn.Linear(4,256),
            nn.ReLU(),
            nn.Linear(256,256),
            nn.ReLU(),
            nn.Linear(256,128),
            nn.ReLU(),
            nn.Linear(128,128),
            nn.ReLU(),
            nn.Linear(128,1)
        )
        
    def forward(self,x):
        return self.core(x)


policyModel = Policy()

oldPolicyModel = deepcopy(policyModel)
def update_old_policy():
    global oldPolicyModel
    oldPolicyModel = deepcopy(policyModel)
    for param in oldPolicyModel.parameters():
        param.requires_grad = False

valueModel = Value()


In [2]:
import gymnasium as gym 



env = gym.make("CartPole-v1")
ppo_epochs = 4

In [3]:
def is_terminal(observation):
    x,v,theta , omega = observation 
    x_terminal = x <= -4.8 or x>=4.8 
    theta_terminal = theta <= -0.209*7.5 or theta >= 0.209*7.5
    return x_terminal or theta_terminal


episodes = 10000

value_optimizer = torch.optim.Adam(valueModel.parameters(), lr=1e-3)
policy_optimizer = torch.optim.Adam(policyModel.parameters(),lr=1e-3)
valueLoss = torch.nn.functional.mse_loss
gamma = 0.99
epsilon =0.3
update_old_step = 10
update_old_policy() 
for episode in range(episodes):
    done = False
    obs, info = env.reset()
    obs = torch.tensor(obs)
    observations = []
    rewards = []
    actions = []
    confidences= []
    
    with torch.no_grad():
        while not done:
            
            observations.append(obs)
            logits = policyModel(obs)
            probs = torch.softmax(logits,dim=-1)

            action = torch.multinomial(probs, num_samples=1).item()
            confidences.append(probs[action]) 
            actions.append(action) 
            obs, reward, terminated, truncated , info = env.step(action)
            rewards.append(reward)
            obs = torch.tensor(obs)
            done = is_terminal(obs)
        
        
        
        discontinued_rewards = []
        reward = 0 
        observation_tensor = torch.stack(observations) 
        for i in range(-1,-len(rewards) -1,-1):
            reward = rewards[i] + gamma * reward
            discontinued_rewards.append(reward)


        discontinued_rewards.reverse()
        discontinued_rewards = torch.tensor(discontinued_rewards,dtype=torch.float32)
    
    
    reward_preds = valueModel(observation_tensor)   
    loss = valueLoss(reward_preds.squeeze(-1),discontinued_rewards)
    value_optimizer.zero_grad()
    loss.backward() 
    value_optimizer.step()
    reward_preds = reward_preds.detach()
    
    total_loss = 0
    advantages = []
    for i in range(len(observations)):
        predicted_reward = reward_preds[i]
        reward = discontinued_rewards[i] 
        advantage = reward - predicted_reward  
        advantages.append(advantage)
    advantages = torch.stack(advantages) 
    advantages = (advantages - advantages.mean()) / (advantages.std(unbiased=False) + 1e-8)

    for i in range(len(observations)):
        advantage = advantages[i]  
        current_logits = policyModel(observations[i])
        current_log_probs = torch.log_softmax(current_logits,dim=-1)

        current_confidence = current_log_probs[actions[i]] 

        old_logits = oldPolicyModel(observations[i])
       
        
        old_log_probs = torch.log_softmax(old_logits,dim=-1)
        old_confidence = old_log_probs[actions[i]]

        confidence_ratio = torch.exp(current_confidence - old_confidence)
        loss = confidence_ratio*advantage   
        loss2 =torch.clip( (confidence_ratio) , 1-epsilon , 1+epsilon)*advantage 
        loss = -torch.min(loss, loss2)
        
        total_loss += loss 
    total_loss /= len(observations)
    print('episode: ',episode + 1, ' loss: ',total_loss)
    policy_optimizer.zero_grad()
    total_loss.backward()
    policy_optimizer.step()
    if  (episode + 1)  % update_old_step == 0:
        update_old_policy()
    







    

    

  logger.warn(


episode:  1  loss:  tensor([2.1674e-08], grad_fn=<DivBackward0>)
episode:  2  loss:  tensor([-0.0029], grad_fn=<DivBackward0>)
episode:  3  loss:  tensor([0.0117], grad_fn=<DivBackward0>)
episode:  4  loss:  tensor([-0.0011], grad_fn=<DivBackward0>)
episode:  5  loss:  tensor([-4.1977e-05], grad_fn=<DivBackward0>)
episode:  6  loss:  tensor([-0.0043], grad_fn=<DivBackward0>)
episode:  7  loss:  tensor([-0.0055], grad_fn=<DivBackward0>)
episode:  8  loss:  tensor([0.0235], grad_fn=<DivBackward0>)
episode:  9  loss:  tensor([0.0057], grad_fn=<DivBackward0>)
episode:  10  loss:  tensor([0.0232], grad_fn=<DivBackward0>)
episode:  11  loss:  tensor([1.9463e-08], grad_fn=<DivBackward0>)
episode:  12  loss:  tensor([0.0152], grad_fn=<DivBackward0>)
episode:  13  loss:  tensor([0.0014], grad_fn=<DivBackward0>)
episode:  14  loss:  tensor([-0.0010], grad_fn=<DivBackward0>)
episode:  15  loss:  tensor([0.0006], grad_fn=<DivBackward0>)
episode:  16  loss:  tensor([-0.0009], grad_fn=<DivBackward0>

KeyboardInterrupt: 

In [4]:
import gymnasium as gym 



env = gym.make("CartPole-v1",render_mode='human')

In [5]:


for episode in range(episodes):
    done = False
    obs, info = env.reset(seed=42)
    obs = torch.tensor(obs)
    observations = []
    rewards = []
    actions = []
    confidences= []
    with torch.no_grad():
        while not done:
            
            observations.append(obs)
            logits = policyModel(obs)
            probs = torch.softmax(logits,dim=-1)

            action = torch.multinomial(probs, num_samples=1).item()
            confidences.append(probs[action])
            actions.append(action)
            obs, reward, terminated, truncated , info = env.step(action)
            rewards.append(reward)
            obs = torch.tensor(obs)
            done = is_terminal(obs)
        print('episode: ',episode + 1)
        

    

    

    

  logger.warn(


episode:  1


KeyboardInterrupt: 