In [None]:
import gym
import numpy as np
import mjx
from mjx.agents import RandomAgent, ShantenAgent
from ppo_agent import PPOAgent, GymEnv
import torch

In [None]:
# Initialize the environment: Play against Shanten agent
opponents = [ShantenAgent() for _ in range(3)]  
env = GymEnv(opponent_agents=opponents)

# obtain the observation and action mask shape
obs, info = env.reset()
obs_shape = obs.flatten().shape[0]
action_dim = len(info["action_mask"])  # action number

# Initialize PPO agent
agent = PPOAgent(
    input_dim=obs_shape,
    hidden_dim=128,
    output_dim=action_dim,
    lr = 1e-4,              # small learning rate
    entropy_coef=0.001      # small entropy coefficient: small curiousity
)


In [5]:
import matplotlib.pyplot as plt

def plot_rewards(rewards, path="logs/reward_curve.png"):
    plt.figure()
    plt.plot(rewards)
    plt.xlabel("Episode")
    plt.ylabel("Reward")
    plt.title("Training Reward Curve")
    plt.savefig(path)
    plt.close()


In [None]:
import json

# Training hyperparameters
num_episodes = 3000
log_interval = 100  # record every 100 episodes
rolling_rewards = []

all_rewards = []
all_actor_loss = []
all_value_loss = []
best_reward = -float("inf")

for episode in range(1, num_episodes + 1):
    obs, info = env.reset()
    total_reward = 0
    done = False

    while not done:
        action = agent.act(obs, info["action_mask"])
        next_obs, reward, done, info = env.step(action)
        agent.store_reward(reward)
        obs = next_obs
        total_reward += reward
    
    stats = agent.update()
    all_rewards.append(total_reward)
    all_actor_loss.append(stats['actor_loss'])
    all_value_loss.append(stats['value_loss'])

    # logs
    if episode % log_interval == 0:
        avg_reward = np.mean(all_rewards[-log_interval:])
        rolling_rewards.append(avg_reward)
        print(f"Episode {episode}/{num_episodes}, avg reward: {avg_reward:.3f}")
        # update the best model
        if avg_reward > best_reward:
            best_reward = avg_reward
            torch.save(agent.model.state_dict(), "logs/ppo/ppo_shanten_model_1e-4_10000.pt")
            print(f"Best model saved with reward: {best_reward:.3f}")

        log_data = {
            'episode': episode,
            'avg_reward': avg_reward,
            'actor_loss': stats['actor_loss'],
            'value_loss': stats['value_loss'],
            'entropy': stats['entropy'],
            'total_loss': stats['total_loss']
        }
        with open('logs/ppo/ppo_shanten_training_log_1e-4_10000.json', 'a') as f:
            json.dump(log_data, f)
            f.write('\n')
# plot the reward curve
plot_rewards(rolling_rewards, path=f"logs/ppo/ppo_shanten_reward_curve_{episode}_1e-4_10000.png")




Episode 100/10000, avg reward: -128.250
Best model saved with reward: -128.250
Episode 200/10000, avg reward: -126.900
Best model saved with reward: -126.900
Episode 300/10000, avg reward: -120.150
Best model saved with reward: -120.150
Episode 400/10000, avg reward: -124.200
Episode 500/10000, avg reward: -128.250
Episode 600/10000, avg reward: -124.200
Episode 700/10000, avg reward: -126.900
Episode 800/10000, avg reward: -126.450
Episode 900/10000, avg reward: -122.850
Episode 1000/10000, avg reward: -121.500
Episode 1100/10000, avg reward: -126.900
Episode 1200/10000, avg reward: -120.150
Episode 1300/10000, avg reward: -132.300
Episode 1400/10000, avg reward: -119.700
Best model saved with reward: -119.700
Episode 1500/10000, avg reward: -119.700
Episode 1600/10000, avg reward: -121.500
Episode 1700/10000, avg reward: -124.200
Episode 1800/10000, avg reward: -121.500
Episode 1900/10000, avg reward: -122.850
Episode 2000/10000, avg reward: -124.650
Episode 2100/10000, avg reward: -