<a href="https://colab.research.google.com/github/manikanta-eng/Reinforcement-learning/blob/main/lab_09_rml.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque

env_name = "CartPole-v1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ActorCritic(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(obs_dim, 64), nn.Tanh(),
            nn.Linear(64, 64), nn.Tanh()
        )
        self.policy = nn.Sequential(nn.Linear(64, act_dim), nn.Softmax(dim=-1))
        self.value = nn.Linear(64, 1)

    def forward(self, x):
        h = self.shared(x)
        return self.policy(h), self.value(h)


def collect_trajectories(env, net, steps, gamma, lam):
    obs, _ = env.reset()
    obs_buf, act_buf, rew_buf, val_buf, logp_buf = [], [], [], [], []
    ep_rews, ep_len = [], []

    for _ in range(steps):
        obs_t = torch.as_tensor(obs, dtype=torch.float32).to(device)
        pi, v = net(obs_t)
        dist = torch.distributions.Categorical(pi)
        a = dist.sample().cpu().numpy()
        logp = dist.log_prob(torch.as_tensor(a)).cpu().item()

        obs_buf.append(obs.copy())
        act_buf.append(a)
        val_buf.append(v.cpu().item())
        logp_buf.append(logp)

        next_obs, r, terminated, truncated, _ = env.step(int(a))
        done = terminated or truncated
        rew_buf.append(r)
        ep_rews.append(r)

        obs = next_obs
        if done:
            obs, _ = env.reset()
            ep_len.append(len(ep_rews))
            ep_rews = []

    obs_buf, act_buf = np.array(obs_buf), np.array(act_buf)
    rew_buf, val_buf, logp_buf = np.array(rew_buf), np.array(val_buf), np.array(logp_buf)

    last_val = net(torch.as_tensor(obs, dtype=torch.float32).to(device))[1].cpu().item()
    adv_buf = np.zeros_like(rew_buf)
    lastgaelam = 0

    for t in reversed(range(len(rew_buf))):
        if t == len(rew_buf) - 1:
            nextnonterminal = 1.0
            nextvalues = last_val
        else:
            nextnonterminal = 1.0
            nextvalues = val_buf[t + 1]
        delta = rew_buf[t] + gamma * nextvalues * nextnonterminal - val_buf[t]
        lastgaelam = delta + gamma * lam * nextnonterminal * lastgaelam
        adv_buf[t] = lastgaelam

    ret_buf = adv_buf + val_buf
    return obs_buf, act_buf, logp_buf, adv_buf, ret_buf


def ppo_train(env_name="CartPole-v1", total_steps=20000, batch_steps=1024, epochs=10, minibatch_size=64,
              gamma=0.99, lam=0.95, clip=0.2, pi_lr=3e-4):
    env = gym.make(env_name)
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.n
    net = ActorCritic(obs_dim, act_dim).to(device)
    optimizer = optim.Adam(net.parameters(), lr=pi_lr)
    steps = 0

    while steps < total_steps:
        obs_buf, act_buf, logp_buf, adv_buf, ret_buf = collect_trajectories(env, net, batch_steps, gamma, lam)
        steps += batch_steps
        adv_buf = (adv_buf - adv_buf.mean()) / (adv_buf.std() + 1e-8)
        inds = np.arange(batch_steps)

        for _ in range(epochs):
            np.random.shuffle(inds)
            for start in range(0, batch_steps, minibatch_size):
                mb = inds[start:start + minibatch_size]
                obs_mb = torch.as_tensor(obs_buf[mb], dtype=torch.float32).to(device)
                act_mb = torch.as_tensor(act_buf[mb], dtype=torch.int64).to(device)
                old_logp_mb = torch.as_tensor(logp_buf[mb], dtype=torch.float32).to(device)
                adv_mb = torch.as_tensor(adv_buf[mb], dtype=torch.float32).to(device)
                ret_mb = torch.as_tensor(ret_buf[mb], dtype=torch.float32).to(device)

                pi, v = net(obs_mb)
                dist = torch.distributions.Categorical(pi)
                logp = dist.log_prob(act_mb)
                ratio = torch.exp(logp - old_logp_mb)
                surr1 = ratio * adv_mb
                surr2 = torch.clamp(ratio, 1 - clip, 1 + clip) * adv_mb
                policy_loss = -torch.min(surr1, surr2).mean()
                value_loss = ((v.squeeze(-1) - ret_mb) ** 2).mean()
                entropy = dist.entropy().mean()
                loss = policy_loss + 0.5 * value_loss - 0.01 * entropy

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        print(f"Steps: {steps}\tLoss: {loss.item():.3f}\tPolicy: {policy_loss.item():.3f}\tValue: {value_loss.item():.3f}")
    env.close()


if __name__ == "__main__":
    ppo_train(env_name)


Steps: 1024	Loss: 66.367	Policy: -0.136	Value: 133.018
Steps: 2048	Loss: 49.991	Policy: -0.009	Value: 100.012
Steps: 3072	Loss: 54.685	Policy: 0.300	Value: 108.781
Steps: 4096	Loss: 54.997	Policy: -0.055	Value: 110.114
Steps: 5120	Loss: 49.834	Policy: 0.219	Value: 99.242
Steps: 6144	Loss: 48.390	Policy: -0.165	Value: 97.120
Steps: 7168	Loss: 44.366	Policy: -0.134	Value: 89.010
Steps: 8192	Loss: 39.677	Policy: 0.079	Value: 79.204
Steps: 9216	Loss: 36.186	Policy: 0.111	Value: 72.159
Steps: 10240	Loss: 33.485	Policy: -0.115	Value: 67.208
Steps: 11264	Loss: 30.381	Policy: -0.137	Value: 61.043
Steps: 12288	Loss: 26.940	Policy: -0.010	Value: 53.909
Steps: 13312	Loss: 24.406	Policy: -0.096	Value: 49.012
Steps: 14336	Loss: 21.502	Policy: 0.030	Value: 42.954
Steps: 15360	Loss: 19.208	Policy: 0.043	Value: 38.341
Steps: 16384	Loss: 17.001	Policy: 0.027	Value: 33.960
Steps: 17408	Loss: 15.132	Policy: 0.070	Value: 30.133
Steps: 18432	Loss: 13.254	Policy: -0.034	Value: 26.588
Steps: 19456	Loss: 11.6