In [1]:
from envs import TicTacToeTrainingEnv
import warnings
warnings.filterwarnings("ignore")
from copy import deepcopy

from stable_baselines3 import PPO
import wandb
from utils.terminal_colors import *
from utils.json_utils import save_opponent_stats, load_opponent_stats
from utils.models_utils import should_save_model, get_models
from utils.evaluator import evaluate_model_by_opponent
from configs.config import *




NameError: name 'DEFAULT_BOARD_LENGTH' is not defined

<h1 style="color:#0b9ed8">TRAINING</h1>

In [None]:

# Boucle principale d'entraînement
for i in range(START_MODEL_INDEX, MAX_MODELS + 1):

    model_name = f"model_{i}.zip"
    model_path = os.path.join(MODELS_DIR, model_name)

    opponent_models = get_models(MODELS_DIR)
    opponent_pool = ["random", "smart_random"] + opponent_models

    improvement = False


    # Initialisation WandB
    wandb.init(
        project=f"{MODELS_DIR}_CONNECTX_KAGGLE",
        name=f"{MODELS_DIR}-run_model_{i}",
        config={
            "model_index": i,
            "gamma": GAMMA,
            "gae_lambda": GAE_LAMBDA,
            "ent_coef_start": START_ENT_COEF,
            "checkpoint_interval": CHECKPOINT_INTERVAL,
            "improvement_threshold": IMPROVEMENT_THRESHOLD
        },
        reinit=True
    )

    best_model = None
    best_stats = {}
    n_checks = TOTAL_STEPS // CHECKPOINT_INTERVAL

    for check in range(n_checks):
        current_progress = (check * CHECKPOINT_INTERVAL) / TOTAL_STEPS

        #Paramètres dynamiques
        n_steps = int(2048 + (4096 - 2048) * current_progress**0.8)  # un peu plus progressif
        batch_size = int(512 + (2048 - 512) * current_progress**1.0) # monte moins vite
        ent_coef = max(0.001, 0.015 * (1 - current_progress**0.6))   # descend moins vite


        opponent_stats = load_opponent_stats(opponent_pool)

        env = TicTacToeTrainingEnv(
            opponent_pool=opponent_pool,
            first_play_rate=0.5,
            lost_games_path="defeated_games.json",
            review_ratio=0.0,
            opponent_statistics_file=STATS_PATH,
        )

        # Initialisation/Chargement du modèle
        if check == 0:
            if i == 1:
                model = PPO(
                    "MultiInputPolicy",
                    env=env,
                    verbose=1,
                    gamma=GAMMA,
                    gae_lambda=GAE_LAMBDA,
                    ent_coef=ent_coef,
                    n_steps=n_steps,
                    batch_size=batch_size,
                    learning_rate=LR_SCHEDULE(current_progress),
                    policy_kwargs=POLICY_KWARGS
                )
            else:
                prev_model_path = get_models(MODELS_DIR)[-1]
                model = PPO.load(prev_model_path, env=env)
                model.ent_coef = ent_coef
                model.n_steps = n_steps
                model.batch_size = batch_size
                model.learning_rate = LR_SCHEDULE(current_progress)

        print(f"\n{YELLOW}=== Training segment {check+1}/{n_checks} ===")
        print(f"Steps: {check*CHECKPOINT_INTERVAL}-{(check+1)*CHECKPOINT_INTERVAL}")
        print(f"Params: n_steps={n_steps}, batch={batch_size}, ent_coef={ent_coef:.4f}")
        print(f"Opponents: {opponent_pool}{RESET}\n")

        # Model training
        model.learn(total_timesteps=CHECKPOINT_INTERVAL)

        # Évaluation
        results = evaluate_model_by_opponent(model, opponent_pool, n_episodes=200, stats_path=STATS_PATH)

        current_stats = {k: {
            "defeat_rate": v["defeat_rate"],
            "victory_rate": v["victory_rate"]
        } for k, v in results.items()}

        # Logging
        wandb.log({
            **{f"metrics/{k}_win_rate": v["victory_rate"] for k, v in results.items()},
            **{f"metrics/{k}_defeat_rate": v["defeat_rate"] for k, v in results.items()},
            "hyperparams/n_steps": n_steps,
            "hyperparams/batch_size": batch_size,
            "hyperparams/ent_coef": ent_coef,
            "progress/current": current_progress,
            "progress/checkpoint": check
        })

        # Sauvegarde
        if should_save_model(current_stats, best_stats, IMPROVEMENT_THRESHOLD):

            print(f"{GREEN}Saved new best model at checkpoint {check}{RESET}")
            print(f"{RED}Old best stats -> {best_stats}{RESET}")
            print(f"{YELLOW}New best stats -> {current_stats}{RESET}")

            improvement = True
            best_model = model
            best_stats = deepcopy(current_stats)
            model.save(model_path)
            save_opponent_stats(best_stats, STATS_PATH)


    if improvement:
        print(f"{GREEN}Training completed for model {i}. Best model saved.{RESET}")
    else:
        print(f"{RED}Warning: No model met improvement criteria{RESET}")


    wandb.finish()