In [1]:
import gymnasium as gym 
import torch 
import torch.nn as nn
env  = gym.make("CartPole-v1", render_mode="human")
obs , info = env.reset(seed=42)
obs = torch.tensor(obs) 

class Policy(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.core = nn.Sequential(
            nn.Linear(4,256),
            nn.ReLU(),
            nn.Linear(256,256),
            nn.ReLU(),
            nn.Linear(256,128),
            nn.ReLU(),
            nn.Linear(128,128),
            nn.ReLU(),
            nn.Linear(128,2),
           
            
        )
    def forward(self, x):
        return self.core(x)
    
policy = Policy()

episodes = 1000
gamma = 0.99
optimizer = torch.optim.Adam(policy.parameters(), lr=0.0001)



for episode in range(episodes):
    obs , info = env.reset()
    obs = torch.tensor(obs,dtype=torch.float32)
    done = False 
    observations = []
    actions = []
    rewards = []
    
    while not done :
        logits = policy(obs)
        probs = torch.softmax(logits , dim=-1)
        action = torch.multinomial(probs,num_samples=1).item()
        observations.append(obs)
        actions.append(action)
        
        obs , reward , terminated , truncated , info = env.step (action)
        obs = torch.tensor(obs,dtype=torch.float32)
        rewards.append(reward)
        done = terminated or truncated
    Gs = []
    G = 0
    for i in range(len(rewards) - 1 , -1 , -1):
        G = rewards[i] + gamma * G
        Gs.append(G)
    Gs.reverse()
    Gs = torch.tensor(Gs,dtype=torch.float32)
    Gs = (Gs - Gs.mean()) / (Gs.std(unbiased=False) + 1e-8)

    loss = 0 
    for obs, action , G in zip(observations , actions , Gs):
        logits = policy(obs)
        probs = torch.softmax(logits, dim =-1)

        max_prob = probs[action]
      
        log_prob = torch.log(max_prob)
        loss += -log_prob * G 
    print(f"Episode {episode} , Loss: {loss.item()}")
    loss /= len(observations)
    optimizer.zero_grad()
    
    loss.backward()
    
    torch.nn.utils.clip_grad_norm_(policy.parameters(), max_norm=1.0)
    optimizer.step()
   
        
        




        

Episode 0 , Loss: -0.046221017837524414
Episode 1 , Loss: 0.0021431446075439453
Episode 2 , Loss: -0.135903000831604
Episode 3 , Loss: -0.04564619064331055
Episode 4 , Loss: 0.03138303756713867
Episode 5 , Loss: -0.0523756742477417
Episode 6 , Loss: 0.07188558578491211
Episode 7 , Loss: 0.08414971828460693
Episode 8 , Loss: -0.0907902717590332
Episode 9 , Loss: 0.12313699722290039
Episode 10 , Loss: -0.04556405544281006
Episode 11 , Loss: 0.08943355083465576
Episode 12 , Loss: -0.07942783832550049
Episode 13 , Loss: 0.09399676322937012
Episode 14 , Loss: 0.03040444850921631
Episode 15 , Loss: -0.28500354290008545
Episode 16 , Loss: 0.06311023235321045
Episode 17 , Loss: -0.023213982582092285
Episode 18 , Loss: 0.04444706439971924
Episode 19 , Loss: 0.007896065711975098
Episode 20 , Loss: -0.14670145511627197
Episode 21 , Loss: -0.07906579971313477
Episode 22 , Loss: -0.11578667163848877
Episode 23 , Loss: -0.005194425582885742
Episode 24 , Loss: 0.0058133602142333984
Episode 25 , Loss:

KeyboardInterrupt: 