# Detect environment

In [None]:
import os

if "COLAB_GPU" in os.environ:
    print("Detected as running on Colab, setting up accordingly")
    running_in_colab = True
else:
    print("Detected as running locally, setting up accordingly")
    running_in_colab = False


# Setup on Colab

In [None]:
if running_in_colab:

    from google.colab import drive
    from pathlib import Path

    # Only setup the environment on the first run
    try:
        runtime_has_been_setup
    
    except NameError:

        # Checkout code and setup Python runtime
        !git clone https://github.com/chris838/alpha-zero-general.git
        %cd '/content/alpha-zero-general'
        !pip install coloredlogs ipympl  # numpy imgaug folium torchvision tqdm scipy tqdm pandas scikit-learn

        # Mount google drive for loading/storing progress
        drive.mount('/content/gdrive')

        # Define load and checkpoint paths        
        checkpoint_folder = Path('/content/gdrive/MyDrive/colab/alpha-zero-general/temp/checkpoints')

        runtime_has_been_setup = True

# Setup locally

In [None]:
if not running_in_colab:

    from pathlib import Path

    # Define load and checkpoint paths
    checkpoint_folder = Path("temp/checkpoints")


# Imports

In [None]:
from datetime import datetime
import shutil
import logging
import coloredlogs
from Coach import Coach
from utils import dotdict
from santorini.keras.NNet import NNetWrapper
from santorini.SantoriniGame import SantoriniGame

In [None]:
import Arena
from MCTS import MCTS

from santorini.SantoriniPlayers import (
    RandomPlayer,
    HumanSantoriniPlayer,
    GreedySantoriniPlayer,
)

import numpy as np
from utils import *

# Initialise & restore progress

In [None]:
log = logging.getLogger(__name__)
coloredlogs.install(level="INFO")  # Change this to DEBUG to see more info.

checkpoint_folder.mkdir(parents=True, exist_ok=True)

args = dotdict(
    {
        "numIters": 5,
        "numEps": 200,  # Number of complete self-play games to simulate during a new iteration.
        "arenaCompare": 30,  # Number of games to play during arena play to determine if new net will be accepted.
        "tempThreshold": 15,  #
        "updateThreshold": 0.6,  # During arena playoff, new neural net will be accepted if threshold or more of games are won.
        "maxlenOfQueue": 200000,  # Number of game examples to train the neural networks.
        "numMCTSSims": 25,  # Number of games moves for MCTS to simulate.
        "cpuct": 1,
        "checkpoint": str(checkpoint_folder),
        "numItersForTrainExamplesHistory": 20,
    }
)

game = SantoriniGame(5)
nnet = NNetWrapper(game)
coach = Coach(game, nnet, args)

In [None]:
def backup_and_resume_progress():

    # Check if there are any checkpoints to load
    if list(checkpoint_folder.glob("*.h5")):

        print("Resuming training from saved model and samples")

        # Find the latest model/examples
        last_checkpoint_name = sorted(checkpoint_folder.glob("*.examples"))[-1].stem

        # .pth.tar substituted for h5 in the load method
        model_file = "temp.pth.tar" 
        model_file_actual = model_file.replace('.pth.tar', '.h5')

        # 'examples' suffix appended in the load method
        checkpoint_file = last_checkpoint_name 
        checkpoint_file_actual = f'{checkpoint_file}.examples'

        # Load model weights
        print(f"Loading model weights from {model_file_actual}")
        nnet.load_checkpoint(folder=str(checkpoint_folder), filename=model_file)

        # Load training examples
        print(f"Loading training samples from {checkpoint_file_actual}")
        args["load_model"] = True
        args["load_folder_file"] = (str(checkpoint_folder), checkpoint_file)
        coach.loadTrainExamples()

        # Clean the entire checkpoint directory, by archiving the previous one and recreating
        timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 
        shutil.move(checkpoint_folder, checkpoint_folder.parent / (checkpoint_folder.name + f'_{timestamp}'))
        checkpoint_folder.mkdir(parents=True, exist_ok=True)

    else:
        args["load_model"] = False
        print("No progress to load, training from scratch")

In [None]:
def assess_performance(num_trials = 200):

    # Test performance against greedy bot
    inference_args = dotdict({'numMCTSSims': 25, 'cpuct': 1.0})
    inference_mcts = MCTS(game, nnet, inference_args)
    alphago_player = lambda x: np.argmax(inference_mcts.getActionProb(x, temp=0))

    greedy_player = GreedySantoriniPlayer(game).play
    arena = Arena.Arena(alphago_player, greedy_player, game, display=game.display)
    oneWon, twoWon, _ = arena.playGames(num_trials, verbose=False)
    
    print("\AlphaGo won {} games, Greedy Player won {} games".format(oneWon, twoWon))

# Training

In [None]:
while True:

    backup_and_resume_progress()

    # Learn
    %time coach.learn()

    assess_performance()