In [1]:
import torch
from games import TicTacToe, CartPole
from models import MuZero
from utils import KaggleAgent, GymAgent, evaluateKaggle, evaluateGym

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
game = TicTacToe()
model = MuZero(game, args={
    'dynamicsFunction': {
        'num_resBlocks': 4,
        'hidden_planes': 128
    },
    'predictionFunction': {
        'num_resBlocks': 4,
        'hidden_planes': 128
    },
    'representationFunction': {
        'num_resBlocks': 3,
        'hidden_planes': 64
    },
    'cheatDynamicsFunction': True,
    'cheatRepresentationFunction': True
}).to(device)
model.load_state_dict(torch.load(f'Models/{game}/model.pt', map_location=device))
model.eval()
player = KaggleAgent(model, game, args={
    'search': True,
    'temperature': 1,
    'num_mcts_runs': 100,
    'c': 1,
    'dirichlet_alpha': 0.3,
    'dirichlet_epsilon': 0.0,
    'cheatDynamicsFunction': True,
    'cheatRepresentationFunction': True,
    'cheatAvailableActions': False,
    'cheatTerminalState': False
})

evaluateKaggle("tictactoe", ["reaction", player.run], num_iterations=100)
evaluateKaggle("tictactoe", [player.run, "reaction"], num_iterations=100)


Player 1 | Wins: 47 | Draws: 44 | Losses: 9
Player 2 | Wins: 9 | Draws: 44 | Losses: 47
    

Player 1 | Wins: 30 | Draws: 53 | Losses: 17
Player 2 | Wins: 17 | Draws: 53 | Losses: 30
    
