In [None]:
from envs import TicTacToeTrainingEnv
import warnings
warnings.filterwarnings("ignore")
from copy import deepcopy
from sb3_contrib.ppo_mask import MaskablePPO
from utils.terminal_colors import *
from utils.json_utils import save_opponent_stats, load_opponent_stats
from utils.models_utils import should_save_model, get_models, get_last_model_number
from utils.evaluator import evaluate_model_by_opponent
from utils.visualize import defeat_rate_plot
from training.config import *
from sb3_contrib.common.wrappers import ActionMasker
from test.action_mask_ import mask_fn
import json

<h1 style="color:#0b9ed8">TRAINING</h1>

In [None]:
def create_env(opponent_pool):
    """
    Create and wrap the TicTacToe training environment once per training session.

    Parameters:
    -----------
    opponent_pool : list of opponent names or models against which the agent will train.

    Returns:
    --------
    env : ActionMasker
        Wrapped TicTacToe environment ready for training.
    """
    env_init = TicTacToeTrainingEnv(
        board_length=TRAINING_DEFAULT_BOARD_LENGTH,
        pattern_victory_length=TRAINING_DEFAULT_PATTERN_VICTORY_LENGTH,
        opponent_pool=opponent_pool,
        first_play_rate=TRAINING_DEFAULT_FIRST_PLAY_RATE,
        lost_games_path=DEFEAT_PATH,
        review_ratio=TRAINING_DEFAULT_REVIEW_RATIO,
        opponent_statistics_file=BEST_STATS_PATH,
    )
    env = ActionMasker(env_init, mask_fn)
    env.reset()
    return env

In [None]:
def initialize_model(env, last_model_num, ent_coef, n_steps, batch_size, learning_rate):
    """
    Initialize or load the PPO model. Updates dynamic training parameters if model exists.

    Parameters:
    -----------
    env : ActionMasker
        The training environment.
    last_model_num : int
        The number of the last saved model.
    ent_coef : float
        Entropy coefficient for exploration.
    n_steps : int
        Number of steps to run for each environment per update.
    batch_size : int
        Size of minibatches for training.
    learning_rate : float
        Learning rate for training.

    Returns:
    --------
    model : MaskablePPO
        Initialized or loaded PPO model.
    """
    checkpoint_path = os.path.join(MODELS_DIR, "last_checkpoint.zip")

    if os.path.exists(checkpoint_path):
        model = MaskablePPO.load(checkpoint_path, env=env)
        print("✅ Loaded model from last checkpoint.")
    elif last_model_num == 0:
        model = MaskablePPO(
            "MultiInputPolicy",
            env=env,
            verbose=1,
            gamma=GAMMA,
            gae_lambda=GAE_LAMBDA,
            ent_coef=ent_coef,
            n_steps=n_steps,
            batch_size=batch_size,
            learning_rate=learning_rate,
            policy_kwargs=policy_kwargs
        )
    else:
        prev_model_path = get_models(MODELS_DIR)[-1]
        model = MaskablePPO.load(prev_model_path, env=env)

    # Update dynamic parameters
    model.ent_coef = ent_coef
    model.n_steps = n_steps
    model.batch_size = batch_size
    model.learning_rate = learning_rate
    return model

In [None]:
def train_one_model():
    """
    Train a single PPO model, evaluate against opponents, and save stats continuously.

    Returns:
    --------
    improvement : bool
        True if model shows improvement over previous best.
    """
    last_model_num = get_last_model_number(MODELS_DIR)
    next_model_num = last_model_num + 1
    model_name = f"{BASE_MODELS_NAME}_{next_model_num}.zip"
    model_path = os.path.join(MODELS_DIR, model_name)

    opponent_models = get_models(MODELS_DIR)
    opponent_pool = ["random", "smart_random"] + opponent_models
    improvement = False
    best_stats = load_opponent_stats(opponent_pool)
    n_checks = TOTAL_STEPS // CHECKPOINT_INTERVAL

    env = create_env(opponent_pool)

    for check in range(n_checks):
        current_progress = (check * CHECKPOINT_INTERVAL) / TOTAL_STEPS

        # Dynamic training parameters
        n_steps = int(2048 + (4096 - 2048) * current_progress**0.8)
        batch_size = min(1024, int(512 + (2048 - 512) * current_progress**1.0))
        ent_coef = 0.001
        learning_rate = LR_SCHEDULE(current_progress)

        if check == 0:
            model = initialize_model(env, last_model_num, ent_coef, n_steps, batch_size, learning_rate)

        print(f"\n{YELLOW}=== Training segment {check+1}/{n_checks} ===")
        print(f"Steps: {check*CHECKPOINT_INTERVAL}-{(check+1)*CHECKPOINT_INTERVAL}")
        print(f"Params: n_steps={n_steps}, batch={batch_size}, ent_coef={ent_coef:.4f}")
        print(f"Opponents: {opponent_pool}{RESET}\n")

        # Train model
        model.learn(total_timesteps=CHECKPOINT_INTERVAL)

        # Evaluate model
        results = evaluate_model_by_opponent(model, opponent_pool, n_episodes=2000)
        current_stats = {k: {"defeat_rate": v["defeat_rate"], "victory_rate": v["victory_rate"]} for k, v in results.items()}

        # Save improvement if criteria met
        if should_save_model(current_stats, best_stats, IMPROVEMENT_THRESHOLD):
            print(f"{GREEN}Saved new best model at checkpoint {check}{RESET}")
            improvement = True
            best_stats = deepcopy(current_stats)
            model.save(model_path)
            save_opponent_stats(best_stats, BEST_STATS_PATH)

        # Save all stats to JSON file continuously
        all_stats_data = {}
        if os.path.exists(ALL_STATS_PATH):
            with open(ALL_STATS_PATH, "r") as f:
                all_stats_data = json.load(f)

        # Determine the next checkpoint index based on existing entries
        next_checkpoint = len(all_stats_data) + 1

        # Store current evaluation for this checkpoint
        all_stats_data[f"checkpoint_{next_checkpoint}"] = {
            opp: {
                "overall_defeat_rate": results[opp]["defeat_rate"],
                "first_player_defeat_rate": results[opp]["losses_play_first"] / 1000,
                "second_player_defeat_rate": results[opp]["losses_play_second"] / 1000,
            }
            for opp in opponent_pool
        }


        with open(ALL_STATS_PATH, "w") as f:
            json.dump(all_stats_data, f, indent=4)

        # Early stopping if all defeat rates are zero
        all_defeat_zero = all(stats["defeat_rate"] == 0.0 for stats in current_stats.values())
        if all_defeat_zero:
            print(f"{GREEN}=== All defeat rates are 0. Early stopping triggered. ==={RESET}")
            break

    model.save(os.path.join(MODELS_DIR, "last_checkpoint.zip"))
    return improvement

In [None]:
def main_training_loop(nb_models_to_train=2):
    """
    Main training loop to train multiple PPO models sequentially.

    Parameters:
    -----------
    nb_models_to_train : int
        Number of models to train in this session.

    Returns:
    --------
    None
    """
    trained_count = 0
    while trained_count < nb_models_to_train:
        improvement = train_one_model()  # now returns only improvement
        if improvement:
            print(f"{GREEN}Training completed for model {trained_count+1}. Best model saved.{RESET}")
        else:
            print(f"{RED}Warning: No model met improvement criteria{RESET}")
        trained_count += 1  # continue to next model regardless


In [None]:
 # Run the training loop
main_training_loop(2)

In [None]:
defeat_rate_plot()