In [17]:
import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal, Categorical

################################## set device ##################################
print("="*92)
# set device to cpu or cuda
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("Device set to :", torch.cuda.get_device_name(device))
else:
    print("Device set to : cpu")
print("="*92)

################################## PPO Policy ##################################
class RolloutBuffer:
    def __init__(self):
        self.actions = []
        self.states = []
        self.logprobs = []
        self.rewards = []
        self.state_values = []
        self.is_terminals = []

    def clear(self):
        self.__init__()


class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, has_continuous_action_space, action_std_init):
        super(ActorCritic, self).__init__()
        self.has_continuous_action_space = has_continuous_action_space
        if has_continuous_action_space:
            self.action_dim = action_dim
            self.action_var = torch.full((action_dim,), action_std_init**2).to(device)

        hidden_dim = 64

        if has_continuous_action_space:
            self.actor = nn.Sequential(
                nn.Linear(state_dim, hidden_dim), nn.Tanh(),
                nn.Linear(hidden_dim, hidden_dim), nn.Tanh(),
                nn.Linear(hidden_dim, action_dim), nn.Tanh()
            )
        else:
            self.actor = nn.Sequential(
                nn.Linear(state_dim, hidden_dim), nn.Tanh(),
                nn.Linear(hidden_dim, hidden_dim), nn.Tanh(),
                nn.Linear(hidden_dim, action_dim), nn.Softmax(dim=-1)
            )

        self.critic = nn.Sequential(
            nn.Linear(state_dim, hidden_dim), nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim), nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )

    def set_action_std(self, new_action_std):
        if self.has_continuous_action_space:
            self.action_var = torch.full((self.action_dim,), new_action_std**2).to(device)

    def act(self, state):
        if self.has_continuous_action_space:
            action_mean = self.actor(state)
            cov_mat = torch.diag(self.action_var).unsqueeze(0)
            dist = MultivariateNormal(action_mean, cov_mat)
        else:
            action_probs = self.actor(state)
            dist = Categorical(action_probs)

        action = dist.sample()
        return action.detach(), dist.log_prob(action).detach(), self.critic(state).detach()

    def evaluate(self, state, action):
        if self.has_continuous_action_space:
            action_mean = self.actor(state)
            action_var = self.action_var.expand_as(action_mean)
            cov_mat = torch.diag_embed(action_var)
            dist = MultivariateNormal(action_mean, cov_mat)
            if self.action_dim == 1:
                action = action.reshape(-1, self.action_dim)
        else:
            action_probs = self.actor(state)
            dist = Categorical(action_probs)

        return dist.log_prob(action), self.critic(state), dist.entropy()


class PPO:
    def __init__(self, state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip,
                 has_continuous_action_space, action_std_init=0.6):
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        self.has_continuous_action_space = has_continuous_action_space
        self.action_std = action_std_init

        self.buffer = RolloutBuffer()

        self.policy = ActorCritic(state_dim, action_dim, has_continuous_action_space, action_std_init).to(device)
        self.policy_old = ActorCritic(state_dim, action_dim, has_continuous_action_space, action_std_init).to(device)
        self.policy_old.load_state_dict(self.policy.state_dict())

        self.optimizer = torch.optim.Adam([
            {'params': self.policy.actor.parameters(), 'lr': lr_actor},
            {'params': self.policy.critic.parameters(), 'lr': lr_critic}
        ])

        self.MseLoss = nn.MSELoss()

    def set_action_std(self, new_action_std):
        if self.has_continuous_action_space:
            self.action_std = new_action_std
            self.policy.set_action_std(new_action_std)
            self.policy_old.set_action_std(new_action_std)

    def decay_action_std(self, action_std_decay_rate, min_action_std):
        self.action_std = max(min_action_std, round(self.action_std - action_std_decay_rate, 4))
        self.set_action_std(self.action_std)

    def select_action(self, state):
        with torch.no_grad():
            state = torch.FloatTensor(state).to(device)
            action, logprob, value = self.policy_old.act(state)

        self.buffer.states.append(state)
        self.buffer.actions.append(action)
        self.buffer.logprobs.append(logprob)
        self.buffer.state_values.append(value)

        return action.cpu().numpy().flatten() if self.has_continuous_action_space else action.item()

    def update(self, time_step):
        rewards, discounted_reward = [], 0
        for reward, is_terminal in zip(reversed(self.buffer.rewards), reversed(self.buffer.is_terminals)):
            discounted_reward = reward + self.gamma * discounted_reward * (1 - int(is_terminal))
            rewards.insert(0, discounted_reward)

        rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)

        old_states = torch.stack(self.buffer.states).detach().to(device)
        old_actions = torch.stack(self.buffer.actions).detach().to(device)
        old_logprobs = torch.stack(self.buffer.logprobs).detach().to(device)
        old_state_values = torch.stack(self.buffer.state_values).detach().squeeze().to(device)

        advantages = rewards - old_state_values

        for _ in range(self.K_epochs):
            logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)
            state_values = state_values.squeeze()
            ratios = torch.exp(logprobs - old_logprobs)
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages
            loss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, rewards) - 0.01 * dist_entropy

            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()

        if time_step % 10000 == 0:
            print(f"Actor loss: {torch.min(surr1, surr2).mean().item()}")
            print(f"Critic loss: {0.5 * self.MseLoss(state_values, rewards).item()}")

        self.policy_old.load_state_dict(self.policy.state_dict())
        self.buffer.clear()

    def save(self, path):
        torch.save(self.policy_old.state_dict(), path)

    def load(self, path):
        self.policy_old.load_state_dict(torch.load(path, map_location=device))
        self.policy.load_state_dict(torch.load(path, map_location=device))


Device set to : Tesla T4


In [None]:
import os
import time
from datetime import datetime

import torch
import numpy as np

import gymnasium as gym  # Updated from gym

def train():
    print("=" * 92)

    env_name = "CartPole-v1"
    has_continuous_action_space = False
    max_ep_len = 1000
    max_training_timesteps = int(3e6)

    print_freq = max_ep_len * 10
    log_freq = max_ep_len * 2
    save_model_freq = int(1e5)

    action_std = 0.6
    action_std_decay_rate = 0.05
    min_action_std = 0.1
    action_std_decay_freq = int(2.5e5)

    update_timestep = max_ep_len * 4
    K_epochs = 80
    eps_clip = 0.2
    gamma = 0.99

    lr_actor = 0.0003
    lr_critic = 0.001

    random_seed = 0

    print("training environment name :", env_name)

    env = gym.make(env_name)

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0] if has_continuous_action_space else env.action_space.n

    log_dir = f"PPO_logs/{env_name}/"
    os.makedirs(log_dir, exist_ok=True)

    run_num = len(next(os.walk(log_dir))[2])
    log_f_name = f"{log_dir}/PPO_{env_name}_log_{run_num}.csv"

    print("current logging run number for", env_name, ":", run_num)
    print("logging at:", log_f_name)

    run_num_pretrained = 0
    checkpoint_dir = f"PPO_preTrained/{env_name}/"
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint_path = f"{checkpoint_dir}/PPO_{env_name}_{random_seed}_{run_num_pretrained}.pth"

    print("save checkpoint path:", checkpoint_path)
    print("-" * 92)
    print("max training timesteps:", max_training_timesteps)
    print("max timesteps per episode:", max_ep_len)
    print("model saving frequency:", save_model_freq)
    print("log frequency:", log_freq)
    print("printing avg reward every:", print_freq)
    print("state space dimension:", state_dim)
    print("action space dimension:", action_dim)
    if has_continuous_action_space:
        print("starting std:", action_std)
        print("decay rate:", action_std_decay_rate)
        print("min std:", min_action_std)
        print("decay freq:", action_std_decay_freq)
    else:
        print("discrete action space")
    print("PPO update freq:", update_timestep)
    print("K epochs:", K_epochs)
    print("eps clip:", eps_clip)
    print("gamma:", gamma)
    print("lr actor:", lr_actor)
    print("lr critic:", lr_critic)
    if random_seed:
        print("setting random seed to", random_seed)
        torch.manual_seed(random_seed)
        np.random.seed(random_seed)
        env.reset(seed=random_seed)
    print("=" * 92)

    ppo_agent = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, has_continuous_action_space, action_std)

    start_time = datetime.now().replace(microsecond=0)
    print("Started training at (GMT):", start_time)
    print("=" * 92)

    log_f = open(log_f_name, "w+")
    log_f.write('episode,timestep,reward\n')

    print_running_reward = 0
    print_running_episodes = 0
    log_running_reward = 0
    log_running_episodes = 0
    time_step = 0
    i_episode = 0

    while time_step <= max_training_timesteps:
        state, _ = env.reset()
        current_ep_reward = 0

        for t in range(1, max_ep_len + 1):
            action = ppo_agent.select_action(state)
            state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            ppo_agent.buffer.rewards.append(reward)
            ppo_agent.buffer.is_terminals.append(done)

            time_step += 1
            current_ep_reward += reward

            if time_step % update_timestep == 0:
                ppo_agent.update(time_step)

            if has_continuous_action_space and time_step % action_std_decay_freq == 0:
                ppo_agent.decay_action_std(action_std_decay_rate, min_action_std)

            if time_step % log_freq == 0:
                log_avg_reward = log_running_reward / log_running_episodes
                log_f.write(f"{i_episode},{time_step},{round(log_avg_reward, 4)}\n")
                log_f.flush()
                log_running_reward = 0
                log_running_episodes = 0

            if time_step % print_freq == 0:
                print_avg_reward = print_running_reward / print_running_episodes
                print(f"Episode: {i_episode}\tTimestep: {time_step}\tAverage Reward: {round(print_avg_reward, 2)}")
                print_running_reward = 0
                print_running_episodes = 0

            if time_step % save_model_freq == 0:
                print("-" * 92)
                print("saving model at:", checkpoint_path)
                ppo_agent.save(checkpoint_path)
                print("Elapsed Time:", datetime.now().replace(microsecond=0) - start_time)
                print("-" * 92)

            if done:
                break

        print_running_reward += current_ep_reward
        print_running_episodes += 1
        log_running_reward += current_ep_reward
        log_running_episodes += 1
        i_episode += 1

    log_f.close()
    env.close()

    end_time = datetime.now().replace(microsecond=0)
    print("=" * 92)
    print("Started training at (GMT):", start_time)
    print("Finished training at (GMT):", end_time)
    print("Total training time:", end_time - start_time)
    print("=" * 92)

if __name__ == '__main__':
    train()


training environment name : CartPole-v1
current logging run number for CartPole-v1 : 9
logging at: PPO_logs/CartPole-v1//PPO_CartPole-v1_log_9.csv
save checkpoint path: PPO_preTrained/CartPole-v1//PPO_CartPole-v1_0_0.pth
--------------------------------------------------------------------------------------------
max training timesteps: 3000000
max timesteps per episode: 1000
model saving frequency: 100000
log frequency: 2000
printing avg reward every: 10000
state space dimension: 4
action space dimension: 2
discrete action space
PPO update freq: 4000
K epochs: 80
eps clip: 0.2
gamma: 0.99
lr actor: 0.0003
lr critic: 0.001
Started training at (GMT): 2025-06-03 14:51:44
Episode: 323	Timestep: 10000	Average Reward: 30.81
Actor loss: 0.026752743870019913
Critic loss: 0.1461910456418991
Episode: 423	Timestep: 20000	Average Reward: 100.13
Episode: 468	Timestep: 30000	Average Reward: 222.67
Actor loss: -0.28854280710220337
Critic loss: 0.4015570282936096
Episode: 494	Timestep: 40000	Average R