In [None]:
import gym
import numpy as np
import mjx
from mjx.agents import RandomAgent, ShantenAgent
from ppo_agent import PPOAgent, GymEnv
import torch

In [None]:
# Initialize the environment: Play against Shanten agent
opponents = [ShantenAgent() for _ in range(3)]  
env = GymEnv(opponent_agents=opponents)

# obtain the observation and action mask shape
obs, info = env.reset()
obs_shape = obs.flatten().shape[0]
action_dim = len(info["action_mask"])  # action number

# Initialize PPO agent
agent = PPOAgent(
    input_dim=obs_shape,
    hidden_dim=128,
    output_dim=action_dim,
    lr = 1e-4,              # small learning rate
    entropy_coef=0.001      # small entropy coefficient: small curiousity
)


In [5]:
import matplotlib.pyplot as plt

def plot_rewards(rewards, path="logs/reward_curve.png"):
    plt.figure()
    plt.plot(rewards)
    plt.xlabel("Episode")
    plt.ylabel("Reward")
    plt.title("Training Reward Curve")
    plt.savefig(path)
    plt.close()


In [None]:
import json

# training parameters
num_episodes = 30000
log_interval = 100  # record every 100 episodes
rolling_rewards = []

all_rewards = []
all_actor_loss = []
all_value_loss = []
best_reward = -float("inf")

for episode in range(1, num_episodes + 1):
    print(f"Episode {episode}/{num_episodes}")
    obs, info = env.reset()
    total_reward = 0
    done = False

    while not done:
        action = agent.act(obs, info["action_mask"])
        next_obs, reward, done, info = env.step(action)
        agent.store_reward(reward)
        obs = next_obs
        total_reward += reward
    
    stats = agent.update()
    all_rewards.append(total_reward)
    all_actor_loss.append(stats['actor_loss'])
    all_value_loss.append(stats['value_loss'])
    print(f"Episode {episode} finished with total reward: {total_reward:.3f}")

    # print logs
    if episode % log_interval == 0:
        avg_reward = np.mean(all_rewards[-log_interval:])
        rolling_rewards.append(avg_reward)
        print(f"Episode {episode}/{num_episodes}, avg reward: {avg_reward:.3f}")
        # Update the best reward
        if avg_reward > best_reward:
            best_reward = avg_reward
            torch.save(agent.model.state_dict(), "logs/ppo_cr/ppo_cr_lr_3e-4_30000.pt")
            print(f"Best model saved with reward: {best_reward:.3f}")

        log_data = {
            'episode': episode,
            'avg_reward': avg_reward,
            'actor_loss': stats['actor_loss'],
            'value_loss': stats['value_loss'],
            'entropy': stats['entropy'],
            'total_loss': stats['total_loss']
        }
        with open('logs/ppo_cr/training_log_lr3e-4_30000.json', 'a') as f:
            json.dump(log_data, f)
            f.write('\n')
# plot the reward curve
plot_rewards(rolling_rewards, path=f"logs/ppo_cr/reward_curve_{episode}_lr3e-4_30000.png")




Episode 1/30000
Episode 1 finished with total reward: 129.000
Episode 2/30000
Episode 2 finished with total reward: -100.000
Episode 3/30000
Episode 3 finished with total reward: 118.000
Episode 4/30000
Episode 4 finished with total reward: -93.000
Episode 5/30000
Episode 5 finished with total reward: 97.000
Episode 6/30000
Episode 6 finished with total reward: 34.000
Episode 7/30000
Episode 7 finished with total reward: 199.000
Episode 8/30000
Episode 8 finished with total reward: 100.000
Episode 9/30000
Episode 9 finished with total reward: 26.000
Episode 10/30000
Episode 10 finished with total reward: 157.000
Episode 11/30000
Episode 11 finished with total reward: 334.000
Episode 12/30000
Episode 12 finished with total reward: -9.000
Episode 13/30000
Episode 13 finished with total reward: -72.000
Episode 14/30000
Episode 14 finished with total reward: 46.000
Episode 15/30000
Episode 15 finished with total reward: 215.000
Episode 16/30000
Episode 16 finished with total reward: 84.000

lr = 0.01, time = 351m 46.9s

In [7]:
best_reward

747.91