In [7]:
import torch
import numpy as np
from torch import nn, optim
import itertools
import gymnasium as gym

device = "cuda" if torch.cuda.is_available() else "cpu"

In [8]:
from stable_baselines3 import PPO

In [9]:
class Rollout:
    def __init__(self,size) -> None:
        self.obss = None
        self.size = size
    def init(self, obs_size):
        size = self.size
        self.obss = np.zeros((size, obs_size))
        self.rewards = np.zeros(size)
        self.actions = np.zeros(size)
        self.prev_log_prob = np.zeros(size)
        self.advantages = np.zeros(size)
        self.returns = np.zeros(size)
        self.values = np.zeros(size)
        self.episode_starts = np.zeros(size)
        self.i = 0

    def add(self, obs, reward, done, action,log_prob):
        if self.obss is None:
            self.init(obs.shape[0])
        self.obss[self.i] = (obs)
        self.rewards[self.i] = (reward)
        self.actions[self.i] = (action)
        self.episode_starts[self.i] = (done)
        self.prev_log_prob[self.i] = (log_prob)
        self.i += 1

    def get(self, batch_size):
        size = self.i
        indices = np.random.permutation(size)
        obs, returns= self.obss[indices], self.returns[indices]
        actions = self.actions[indices]
        prev_log_prob = self.prev_log_prob[indices]
        advantages = self.advantages[indices]
        rewards = self.rewards[indices]
        i = 0
        def to_batch(arr, dtype=torch.float32):
            return torch.tensor(arr[i:i+batch_size], dtype=dtype,device=device)
        while i < size:
            yield map(to_batch,(obs,returns,actions,prev_log_prob,advantages,rewards))#to_batch(obs), to_batch(rewards), to_batch(dones), to_batch(actions), to_batch(prev_log_prob)
            i += batch_size
    # def calculate_advantages(self):
    #     for t in range(self.obss.shape[0]):
    #         pass
    def calc_advantage(self):
        # https://arxiv.org/pdf/1506.02438
        rewards = self.rewards
        values = self.values
        gamma = 0.99
        decay = .97
        next_values = np.concatenate([values[1:], [0]])
        deltas = [rew+gamma * next_val - val for rew, val,next_val in zip(rewards, values, next_values)]
        adv = [deltas[-1]]
        for i in reversed(range(len(deltas)-1)):
            adv.append(deltas[i] + decay * gamma * adv[-1])
        self.advantages = np.array(adv[::-1])
        self.returns = self.advantages + self.values

    def calculate_advantages(self, last_values: np.ndarray, dones:np.ndarray):
        # https://arxiv.org/pdf/1506.02438
        gamma, decay,gae_lambda = 0.99, 0.97,.95
        last_gae_lam = 0
        buffer_size = self.obss.shape[0]
        for step in reversed(range(buffer_size)):
            if step == buffer_size - 1:
                next_non_terminal = 1.0 - dones.astype(np.float32)
                next_values = last_values
            else:
                next_non_terminal = 1.0 - self.episode_starts[step + 1]
                next_values = self.values[step + 1]
            delta = self.rewards[step] + gamma * next_values * \
                        next_non_terminal - self.values[step]
            last_gae_lam = delta + gamma * gae_lambda * next_non_terminal * last_gae_lam
            self.advantages[step] = last_gae_lam
        self.returns = self.advantages + self.values

In [11]:
class PPO:
    # https://arxiv.org/pdf/1707.06347

    def __init__(self, env: gym.Env,outc) -> None:
        self.policy = nn.Sequential(
            nn.Linear(np.prod(env.observation_space.shape), 64),nn.Tanh(),
            nn.Linear(64, outc),nn.Tanh()
        ).to(device)
        self.value = nn.Sequential(
            nn.Linear(np.prod(env.observation_space.shape), 64),nn.Tanh(),
            nn.Linear(64, outc),nn.Tanh()
        ).to(device)
        self.value_net = nn.Linear(outc,1).to(device)
        self.action_net = nn.Linear(outc,env.action_space.n).to(device)
        
        self.optimizer = optim.Adam(itertools.chain(
            self.policy.parameters(),
            self.value.parameters(),
            self.value_net.parameters(),
            self.action_net.parameters(),
        ),lr=4e-3)
        self.env = env
        # self.init_weight()

    def init_weight(self):
        module_gains = {
            self.policy: np.sqrt(2),
            self.value: np.sqrt(2),
            self.action_net: 0.01,
            self.value_net: 1,
        }
        import functools
        def init(module, gain):
            if isinstance(module, nn.Linear):
                nn.init.orthogonal_(module.weight, gain=gain)
                if module.bias is not None:
                    module.bias.data.fill_(0.0)

        for module, gain in module_gains.items():
            module.apply(functools.partial(init, gain=gain))



    def train(self, timesteps, timesteps_per_rollout, epochs=2, batch_size=64):
        try:
            import tqdm
            with tqdm.tqdm(total=timesteps) as pbar:
                i = 0
                while i < timesteps:
                    roullout = self.collect_rollout(timesteps_per_rollout)
                    for e in range(epochs):
                        stop = False
                        for batch in roullout.get(batch_size):
                            stop = self.train_once(batch)
                            if stop: break
                        if stop: break
                    size = roullout.i
                    i += size
                    pbar.update(size)
        except KeyboardInterrupt as e:
            self.env.close()
        except Exception as e:
            self.env.close()
            raise e
    def train_once(self,batch:list[torch.Tensor]):
        c1 = .5
        e = .2
        obs,returns,actions,prev_log_prob,A,rewards = batch
        # evaluation the past actions
        _,log_prob,values = self.forward(obs,actions=actions)
        ratio = torch.exp(log_prob - prev_log_prob)
        lvf = torch.nn.functional.mse_loss(returns,values)
        policy_loss = -torch.min(
            ratio*A,
            torch.clamp(ratio,1-e,1+e)*A
        ).mean()
        loss = policy_loss + c1 * lvf # to minimize
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        # print(policy_loss.item(),lvf.item(),rewards.cpu().numpy().sum())
        should_stop = (prev_log_prob - log_prob).mean() >= .02
        return should_stop

    def collect_rollout(self, size):
        env = self.env
        obs, _ = env.reset()
        rollout = Rollout(size)
        for _ in range(size):
            action,log_prob,values = self.take_action(obs)
            obs, reward, terminated, truncated, _= env.step(action.item())
            done = terminated or truncated
            if done:
                obs, _ = env.reset()
            rollout.add(obs, reward, done,action,log_prob)
            if done:
                break
        # rollout.calculate_advantages(values,np.array(done))
        rollout.calc_advantage()
        return rollout
        
    def forward(self,obs,actions=None)->tuple[torch.Tensor,torch.Tensor,torch.Tensor]:
        if obs.ndim == 1:
            obs = obs[None]
        latent_act = self.policy.forward(obs)
        latent_val = self.value.forward(obs)
        action_logits = self.action_net.forward(latent_act)
        values = self.value_net.forward(latent_val)
        dist = torch.distributions.Categorical(logits=action_logits)
        if actions is None:
            actions = dist.sample((obs.shape[0],))[0]
        return actions,dist.log_prob(actions),values.squeeze()
        
    def take_action(self, obs):
        with torch.no_grad():
            obs = torch.tensor(obs, dtype=torch.float32, device=device)
            a,l,v = self.forward(obs)
        return a,l,v
        
    def play(self,env:gym.Env):
        import time
        try:
            while True:
                obs,_ = env.reset()
                done = False
                i = 0
                while not done:
                    i+=1
                    action,*_ = self.take_action(obs)
                    obs, _, terminated, truncated, _= env.step(action.item())
                    done = terminated or truncated
                    env.render()
                    time.sleep(1/30)
                print(f"Episode length: {i}")
        except KeyboardInterrupt as e:
            env.close()
            pass
        
# game_name = "Acrobot-v1"
game_name = "CartPole-v1"
# ppo = PPO(gym.make(), 64)
ppo = PPO(gym.make(game_name), 64)
ppo.train(10_000, 2048, epochs=2)

  7%|▋         | 689/10000 [00:02<00:29, 315.93it/s]


In [None]:
ppo.play(gym.make(game_name, render_mode="human"))