In [1]:
%reload_ext dotenv
%dotenv

# ML libraries
import torch

# Local imports
from env import Env
from agents.random_agent import RandomAgent
from agents.minimax_agent import MinimaxAgent
from agents.minimax_agent_old import OldMinimaxAgent
from agents.deep_q_agent import DeepQAgent
from agents.cql_agent import CQLAgent
import utils
from trainer import Trainer

In [3]:
##################
# INITIALIZATION #
##################

# Fix random seed
utils.seed_everything(42, deterministic=False)

# Use GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define player and opponent IDs
AGENT = 1
OPPONENT = 2

# Define agent and opponent
agent = CQLAgent(epsilon_max=1, epsilon_min=0.1, epsilon_decay=0.9999, device=device)
opponent = MinimaxAgent(depth=3, epsilon=0.5, player=OPPONENT)
replacement_agent = MinimaxAgent(depth=3, epsilon=0.5, player=OPPONENT)

# Define options for training
options = {
           'UPDATE_OPPONENT': True,                         # Whether to enable self-play or not
           'OPPONENT_UPDATE_FREQUENCY': 100,                # After how many episodes the opponent will be replaced by the current agent
           'BOOTSTRAP_EPISODES': 7500,                      # During this time, the agent will not be replaced by itself
           'DECAY_RANDOMNESS_OPPONENT': True,               # Decay randomness of the opponent. Use only if the opponent acts with some randomness
           'DECAY_RANDOMNESS_FREQUENCY': 1000,              # Frequency of randomness decay
           'REPLACE_FOR_EVALUATION': True,                  # Whether to replace the training model at the end with another evaluation model
           'REPLACE_FOR_EVALUATION_BY': replacement_agent,  # Evalutation model to replace training model by
           'AUTOSAVE': True,                                # Whether to save the model at certain intervals
           'AUTOSAVE_TYPE': 'NUM_EPISODES',                 # One of ["NUM_OPTIMIZATIONS", "NUM_EPISODES"]
           'AUTOSAVE_PERIOD': 1000,                         # After how many _ to save the model
           }

# Initialize trainer
trainer = Trainer(env=Env(), agent=agent, opponent=opponent, num_episodes={'TRAIN': 10000, 'EVAL': 100}, agent_id=AGENT, opponent_id=OPPONENT, device=device, verbose=True, options=options)

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
type(device)


torch.device

In [22]:
############
# TRAINING #
############

# Train agent
trainer.train()

TRAIN: Running episode 34 of 10000. Current win ratio of AGENT is 5.88%.

KeyboardInterrupt: 

In [8]:
####################
# QUICK EVALUATION #
####################

# Fix random seed
utils.seed_everything(42, deterministic=False)

# Load agent from save and set to eval mode
new_agent = CQLAgent()
new_agent.load_model('./saved_models/CQLAgent_46515_no_replacement')
new_agent.eval_mode()

# Define opponent
new_opponent = MinimaxAgent(depth=3, epsilon=0.5, player=OPPONENT)

# Run evaluation
trainer.eval(trainer.agent, new_opponent, agent_start=None, print_last_n_games=5)

EVAL: Running episode 96 of 100. Ratios are [WINS: 21.88% | LOSSES: 77.08% | TIES: 1.04%]]

_______________
|0|0|0|0|0|0|0|
|0|0|0|0|0|0|0|
|0|0|0|0|0|0|0|
|0|0|0|0|0|0|0|
|0|0|0|0|0|0|0|
|1|0|0|0|0|0|0|
AGENT action was 0


_______________
|0|0|0|0|0|0|0|
|0|0|0|0|0|0|0|
|0|0|0|0|0|0|0|
|0|0|0|0|0|0|0|
|0|0|0|0|0|0|0|
|1|2|0|0|0|0|0|
OPPONENT action was 1


_______________
|0|0|0|0|0|0|0|
|0|0|0|0|0|0|0|
|0|0|0|0|0|0|0|
|0|0|0|0|0|0|0|
|0|1|0|0|0|0|0|
|1|2|0|0|0|0|0|
AGENT action was 1


_______________
|0|0|0|0|0|0|0|
|0|0|0|0|0|0|0|
|0|0|0|0|0|0|0|
|0|0|0|0|0|0|0|
|0|1|0|0|0|0|0|
|1|2|2|0|0|0|0|
OPPONENT action was 2


_______________
|0|0|0|0|0|0|0|
|0|0|0|0|0|0|0|
|0|0|0|0|0|0|0|
|0|1|0|0|0|0|0|
|0|1|0|0|0|0|0|
|1|2|2|0|0|0|0|
AGENT action was 1


_______________
|0|0|0|0|0|0|0|
|0|0|0|0|0|0|0|
|0|0|0|0|0|0|0|
|0|1|0|0|0|0|0|
|0|1|0|0|0|0|0|
|1|2|2|0|0|2|0|
OPPONENT action was 5


_______________
|0|0|0|0|0|0|0|
|0|0|0|0|0|0|0|
|0|1|0|0|0|0|0|
|0|1|0|0|0|0|0|
|0|1|0|0|0|0|0|
|1|2|