In [None]:
import rlcard
import torch
import os
from rlcard.agents import RandomAgent, DQNAgent
from rlcard.utils import tournament, Logger

# 1. Configuration
ENV_NAME = 'limit-holdem' # Make sure this matches your training env!
NUM_PLAYERS = 2
NUM_HANDS = 10000         # How many hands to play for the test
MODEL_PATH = 'checkpoints/dqn_2players.pt'

# 1. Setup the environment to get the correct dimensions
env = rlcard.make('limit-holdem', config={'num_players': NUM_PLAYERS})

# 2. Read the .pt file as a dictionary
# map_location='cpu' ensures it loads safely even if trained on 'mps' or 'cuda'
checkpoint_dict = torch.load('checkpoints/dqn_2players.pt', map_location='cpu', weights_only=False)

# 3. Initialize the agent with the same architecture it was trained on
dqn_agent = DQNAgent(
    num_actions=env.num_actions,
    state_shape=env.state_shape[0],
    mlp_layers=[512, 256, 128], # This MUST match the saved model
    device='cpu'
)

# 4. Restore the weights and states from the dictionary
dqn_agent.from_checkpoint(checkpoint_dict)

print("Agent successfully restored from dictionary checkpoint!")

# 4. Create Random Opponents
# We need 3 random agents to fill the table
random_agents = [RandomAgent(num_actions=env.num_actions) for _ in range(NUM_PLAYERS - 1)]

# 5. Bind Agents to Environment
# DQN is Player 0, Randoms are Players 1, 2, 3
agents = [dqn_agent] + random_agents
env.set_agents(agents)

print(f"Starting tournament: DQN vs {NUM_PLAYERS-1} Random Bots")
print(f"Playing {NUM_HANDS} hands...")

# 6. Run Tournament
# tournament() returns a list of average payoffs for each agent
payoffs = tournament(env, NUM_HANDS)[0]

# 7. Print Results
dqn_payoff = payoffs
print("\n--------------------------------")
print(f"Final Results ({NUM_HANDS} hands)")
print("--------------------------------")
print(f"DQN Agent Payoff (Avg): {dqn_payoff:.4f} big blinds/hand")

if NUM_PLAYERS > 2:
    for i in range(NUM_PLAYERS):
        print(f"Random Bot {i} Payoff: {payoffs[i]:.4f}")
else:
    print(f"Random Bot Payoff: {payoffs:.4f}")

print("--------------------------------")

if dqn_payoff > 0:
    print("✅ SUCCESS: The DQN is beating the random players!")
else:
    print("❌ WARNING: The DQN is losing to random players.")


INFO - Restoring model from checkpoint...
Agent successfully restored from dictionary checkpoint!
Starting tournament: DQN vs 1 Random Bots
Playing 10000 hands...

--------------------------------
Final Results (10000 hands)
--------------------------------
DQN Agent Payoff (Avg): 1.6314 big blinds/hand
Random Bot 0 Payoff: 1.6314
--------------------------------
✅ SUCCESS: The DQN is beating the random players!
