In [1]:
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical

In [2]:
from collections import namedtuple
import numpy as np

StepExp = namedtuple('StepExperience', ['obs', 'obs_n', 'act', 'logp', 'val', 'rew', 'done'])
Experience = namedtuple('Experience', ['obs', 'r2g', 'adv', 'logp', 'act'])

class ExpBuffer:
    def __init__(self, env, capacity):
        obs_dim = env.observation_space.shape
        act_dim = int(env.action_space.n)
        self.start = 0
        self.idx = 0
        self.size = 0
        self.capacity = capacity

        self.obs = torch.zeros((self.capacity, *obs_dim), requires_grad=False).float()
        self.obs_n = torch.zeros((self.capacity, *obs_dim), requires_grad=False).float()
        self.act = torch.zeros((self.capacity, act_dim), requires_grad=False).long()
        self.logp = torch.zeros((self.capacity, 1), requires_grad=False).float()
        self.val = torch.zeros((self.capacity, 1), requires_grad=False).float()
        self.rew = torch.zeros((self.capacity, 1), requires_grad=False).float()
        self.done = torch.zeros((self.capacity, 1), requires_grad=False).float()
        self.r2g = torch.zeros((self.capacity, 1), requires_grad=False).float()
        self.adv = torch.zeros((self.capacity, 1), requires_grad=False).float()

    def store(self, exp):
        self.obs[self.idx] = exp.obs
        self.obs_n[self.idx] = exp.obs_n
        self.act[self.idx] = exp.act
        self.logp[self.idx] = exp.logp
        self.val[self.idx] = exp.val
        self.done[self.idx] = exp.done
        self.rew[self.idx] = exp.rew

        self.size = min(self.size+1, self.capacity)
        self.idx = (self.idx+1) % self.capacity

    def store_finished(self, r2g, adv):
        self.r2g[self.start:self.size] = r2g[:, None]
        self.adv[self.start:self.size] = adv[:, None]
        self.start = self.idx

    def reset(self):
        self.idx = 0
        self.size = 0
        self.start = 0

    def sample(self, bs):
        assert self.size >= bs
        idxs = np.random.choice(self.size, bs, replace=False)
        adv = self.adv[idxs]
        ''' normalize advantage
        std = adv.std()
        mean = adv.mean()
        adv = (adv - mean) / std
        '''
        return Experience(
            self.obs[idxs],
            self.r2g[idxs],
            adv,
            self.logp[idxs],
            self.act[idxs]
        )


In [9]:
def get_action(policy, obs):
    logits = policy(obs)
    dist = Categorical(logits=logits)
    action = dist.sample()
    log_prob = dist.log_prob(action)
    if action.shape == ():
        return action.item(), log_prob.item()
    return action, log_prob

def reward_to_go(rewards, gamma=0.99):
    n = len(rewards)
    r2g = torch.zeros(n, requires_grad=False)
    for i in reversed(range(n)):
        r2g[i] = rewards[i] + gamma*(0 if i+1 >= n else r2g[i+1])

    return r2g

def advantage(rewards, vals, gamma=0.99, lam=0.95):
   n = len(rewards)
   diff = [rewards[i] + gamma*vals[i+1] - vals[i] for i in range(n-1)]
   adv = torch.zeros(n-1, requires_grad=False)
   for i in reversed(range(n-1)):
      adv[i] = diff[i] + gamma*lam*(0 if i+1 >= n-1 else adv[i+1])

   return adv
   
def step(policy, value, obs, env):
    act, logp = get_action(policy, obs)
    val = value(obs)
    obs_n, rew, term, trunc, _ = env.step(act)
    obs_n = torch.from_numpy(obs_n)
    done = term or trunc

    return StepExp(
        obs,
        obs_n,
        act, 
        logp,
        val,
        rew,
        done,
    )

def policy_loss(policy, exp, eps):
    act, logp = get_action(policy, exp.obs)
    ratio = torch.exp(logp[:,None] - exp.logp)
    clip_adv = torch.clamp(ratio, 1-eps, 1+eps)*exp.adv
    loss = -(torch.min(ratio*exp.adv, clip_adv)).mean()

    return loss

def value_loss(value, exp):
    diff = (value(exp.obs) - exp.r2g)**2
    mse = diff.mean()

    return mse

def fit_policy(policy, exp, opt, epochs):
    losses = []
    for i in range(epochs):
        opt.zero_grad()
        loss = policy_loss(policy, exp, eps=0.2)
        loss.backward()
        opt.step()
        losses.append(loss.item())

    return losses

def fit_value(value, exp, opt, epochs):
    losses = []
    for i in range(epochs):
        opt.zero_grad()
        loss = value_loss(value, exp)
        loss.backward()
        opt.step()
        losses.append(loss.item())

    return losses

def average(arr):
    return (sum(arr) / len(arr))

In [53]:
from torch.optim import Adam
import gymnasium as gym

env = gym.make('CartPole-v1')

obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.n
hid_dim = 256

env = gym.make('CartPole-v1')

policy = nn.Sequential(*[
    nn.Linear(obs_dim, hid_dim),
    nn.ReLU(),
    nn.Linear(hid_dim, hid_dim),
    nn.ReLU(),
    nn.Linear(hid_dim, act_dim)
])

value = nn.Sequential(*[
    nn.Linear(obs_dim, hid_dim),
    nn.ReLU(),
    nn.Linear(hid_dim, hid_dim),
    nn.ReLU(),
    nn.Linear(hid_dim, 1)
])

num_epochs = 100
num_steps = 4096

policy_opt = Adam(policy.parameters(), lr=3e-4)
policy_epochs = 15

value_opt = Adam(value.parameters(), lr=3e-4)
value_epochs = 15

experience = ExpBuffer(env, num_steps)

for i in range(num_epochs):
    experience.reset()

    obs, _ = env.reset()
    obs = torch.from_numpy(obs)

    eps_rews = []
    rewards = []
    vals = []

    # collect experience
    policy.eval()
    value.eval()
    with torch.no_grad():
        for j in range(num_steps):
            step_exp = step(policy, value, obs, env)

            experience.store(step_exp)
            rewards.append(step_exp.rew)
            vals.append(step_exp.val.item())
            obs = step_exp.obs_n
            if step_exp.done or j == num_steps-1:
                r2g = reward_to_go(rewards)
                if not step_exp.done:
                    pred = step(policy, value, obs, env)
                    rewards.append(pred.val.item())
                    vals.append(pred.val.item())
                else:
                    rewards.append(0)
                    vals.append(0)
                adv = advantage(rewards, vals).detach()
                experience.store_finished(r2g, adv)

                # reset
                eps_rews.append(sum(rewards))
                rewards = []
                vals = []
                obs, _ = env.reset()
                obs = torch.from_numpy(obs)


    # update policy
    policy.train()
    value.train()
    sampled_exp = experience.sample(32)
    p_losses = fit_policy(policy, sampled_exp, policy_opt, policy_epochs)
    v_losses = fit_value(value, sampled_exp, value_opt, value_epochs)

    print(i, '--', f'rew: {sum(eps_rews) / len(eps_rews)}', f'policy_loss: {sum(p_losses) / len(p_losses)}', f'value_loss: {sum(v_losses) / len(v_losses)}')

0 -- rew: 23.011850717362393 policy_loss: -8.358571370442709 value_loss: 282.9984598795573
1 -- rew: 18.45877056562149 policy_loss: -6.445404084523519 value_loss: 159.71985880533853
2 -- rew: 19.32873544265639 policy_loss: -6.508675829569499 value_loss: 97.06512908935547
3 -- rew: 20.701354135166515 policy_loss: -5.657056458791097 value_loss: 235.46366373697916
4 -- rew: 20.5018705868721 policy_loss: -5.321012051900228 value_loss: 160.00516764322916
5 -- rew: 20.827799978595095 policy_loss: -4.540285237630209 value_loss: 93.8729258219401
6 -- rew: 21.860456192747076 policy_loss: -2.2959343592325845 value_loss: 83.25439198811848
7 -- rew: 21.492119639331758 policy_loss: 0.09864840656518936 value_loss: 46.71092300415039
8 -- rew: 21.275663222673643 policy_loss: -1.0858686208724975 value_loss: 64.11043217976888
9 -- rew: 21.160258607766064 policy_loss: -2.4761810938517255 value_loss: 83.94046020507812
10 -- rew: 19.552003515334356 policy_loss: -0.985777489344279 value_loss: 22.92707227071

In [5]:
rews = []

for i in range(1000):
    obs, _ = env.reset()
    obs = torch.from_numpy(obs)
    done = False
    rewards = []
    while not done:
        exp = step(policy, value, obs, env)
        rewards.append(exp.rew)
        done = exp.done
        obs = exp.obs_n
    
    rews.append(sum(rewards))

print(sum(rews) / len(rews))

21.616


In [6]:
rand_policy = nn.Sequential(*[
    nn.Linear(obs_dim, hid_dim),
    nn.ReLU(),
    nn.Linear(hid_dim, act_dim)
])

rand_value = nn.Sequential(*[
    nn.Linear(obs_dim, hid_dim),
    nn.ReLU(),
    nn.Linear(hid_dim, 1)
])

rews = []

for i in range(1000):
    obs, _ = env.reset()
    obs = torch.from_numpy(obs)
    done = False
    rewards = []
    while not done:
        exp = step(rand_policy, rand_value, obs, env)
        rewards.append(exp.rew)
        done = exp.done
        obs = exp.obs_n
    
    rews.append(sum(rewards))

print(sum(rews) / len(rews))

19.908


In [47]:
ret = experience.sample(32)

policy(ret.obs).shape

act, logp = get_action(policy, ret.obs)

print(ret.adv.shape)

ratio_1 = torch.exp(logp[:] - ret.logp)
print('r1',ratio_1.shape)
clip_adv = torch.clamp(ratio_1, 1-0.2, 1+0.2)*ret.adv
print('c1', clip_adv.shape)
loss = -(torch.min(ratio_1*ret.adv, clip_adv)).mean()
print(loss.item())

ratio_1 = torch.exp(logp[:,None] - ret.logp)
print('r2', ratio_1.shape)
clip_adv = torch.clamp(ratio_1, 1-0.2, 1+0.2)*ret.adv
print('c2', clip_adv.shape)
loss = -(torch.min(ratio_1*ret.adv, clip_adv)).mean()
print(loss.item())



torch.Size([32, 1])
r1 torch.Size([32, 32])
c1 torch.Size([32, 32])
-0.7566856145858765
r2 torch.Size([32, 1])
c2 torch.Size([32, 1])
-0.6808624863624573


In [50]:
value(ret.obs).shape

torch.Size([32, 1])