In [5]:
import mjx
from mjx.agents import RandomAgent, ShantenAgent

from ppo_agent import PPOAgent, GymEnv
from tqdm import tqdm



In [6]:
RANK_DICT = {
    90 : 1,
    45 : 2,
    0 : 3,
    -135 : 4
}

def eval(records):
    avg_score = sum([record["score"] for record in records]) / len(records)
    avg_rank = sum([record["rank"] for record in records]) / len(records)
    
    print(f"Average score: {avg_score:.2f}")
    print(f"Average rank: {avg_rank:.2f}")
    

## Test Base Model: Base Model vs. Base Model

In [None]:
base_agent = RandomAgent()
env = mjx.MjxEnv()

N = 1
results = []

for _ in tqdm(range(N)):
    obs_dict = env.reset()
    while not env.done():
        actions = {
            player_id: base_agent.act(obs_dict[player_id])
            for player_id in obs_dict.keys()
        }

        obs_dict = env.step(actions)

    my_index = obs_dict["player_0"].who()
    score = obs_dict["player_0"].tens()[my_index]
    my_reward = env.rewards()["player_0"]
    my_rank = RANK_DICT[my_reward]

    
    




## Test PPO Model: PPO Model vs. Base Model

In [3]:

ppo_agent = PPOAgent(
    input_dim=544,
    hidden_dim=128,
    output_dim=181,
    pretrained_model="rl_models/ppo_random_opponent_model_3e-4.pt", # If have a pretrained model, load it
)


  state_dict = torch.load(pretrained_model)
  state_dict = torch.load("logs/ppo_cr_cl/best_model_ppo5_stage_1.pt")


Loaded pretrained model from rl_models/ppo_random_opponent_model_3e-4.pt


In [7]:
env = GymEnv(opponent_agents=[RandomAgent(), RandomAgent(), RandomAgent()], info_type="default")
N = 1000

records = []
for _ in tqdm(range(N)):
    obs, info = env.reset()
    done = False
    while not done:
        action_mask = info["action_mask"]
        action = ppo_agent.act(obs, action_mask)

        # env.step 只需要 PPO 的动作，内部会处理其他 agent
        obs, reward, done, info = env.step(action)

    my_index = env.curr_obs_dict['player_0'].who()
    my_score = env.curr_obs_dict['player_0'].tens()[my_index] - 25000
    my_reward = env.mjx_env.rewards()['player_0']
    my_rank = RANK_DICT[my_reward]


    records.append({
        "my_score": my_score,
        "my_rank": my_rank,
    })

eval(records)

  0%|          | 0/1000 [00:00<?, ?it/s]

 91%|█████████▏| 914/1000 [53:51<05:25,  3.79s/it]  

: 