# Setup on Colab

In [None]:
!git clone https://github.com/chris838/alpha-zero-general.git

In [None]:
%cd '/content/alpha-zero-general'

In [None]:
!pip install -r docker/requirements.txt

In [None]:
!pip install --upgrade numpy # imgaug folium torchvision tqdm scipy tqdm pandas scikit-learn

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
checkpoint_folder = Path('/content/gdrive/MyDrive/colab/alpha-zero-general/temp')
load_folder = Path('/content/gdrive/MyDrive/colab/alpha-zero-general/pretrained_models/santorini/keras/5x5')

# Setup locally

In [None]:
checkpoint_folder = Path('temp')
load_folder = Path('pretrained_models/santorini/keras/5x5')

# Train AlphaZero

In [None]:
from datetime import datetime
from pathlib import Path
import shutil
import logging
import coloredlogs
from Coach import Coach
from utils import dotdict
from santorini.keras.NNet import NNetWrapper
from santorini.SantoriniGame import SantoriniGame

In [None]:
import Arena
from MCTS import MCTS

from santorini.SantoriniPlayers import (
    RandomPlayer,
    HumanSantoriniPlayer,
    GreedySantoriniPlayer,
)

import numpy as np
from utils import *

In [None]:
log = logging.getLogger(__name__)
coloredlogs.install(level='INFO')  # Change this to DEBUG to see more info.

In [None]:
checkpoint_folder.mkdir(parents=True, exist_ok=True)
load_folder.mkdir(parents=True, exist_ok=True)

In [None]:
args = dotdict({
    'numIters': 1000,
    'numEps': 100,              # Number of complete self-play games to simulate during a new iteration.
    'tempThreshold': 15,        #
    'updateThreshold': 0.6,     # During arena playoff, new neural net will be accepted if threshold or more of games are won.
    'maxlenOfQueue': 200000,    # Number of game examples to train the neural networks.
    'numMCTSSims': 25,          # Number of games moves for MCTS to simulate.
    'arenaCompare': 40,         # Number of games to play during arena play to determine if new net will be accepted.
    'cpuct': 1,
    'checkpoint': str(checkpoint_folder),
    'numItersForTrainExamplesHistory': 20,
})

In [None]:
game = SantoriniGame(5)
nnet = NNetWrapper(game)

# Set very low iterations to let this notebook run in its entirety.
# In reality, training a model, even as simple as the one for Dots and Boxes, can take several hours or days.
args['numIters'] = 10
args['numEps'] = 10
args['arenaCompare'] = 3

coach = Coach(game, nnet, args)

In [None]:
# Check if there are any checkpoints to load
if list(load_folder.glob('*.h5')):
  
  print ("Resuming training from saved model and samples")

  # Find the latest model/examples
  last_checkpoint_name = sorted(load_folder.glob('*.examples'))[-1].stem
  model_file = last_checkpoint_name # .pth.tar substituted for h5 in the load method
  checkpoint_file = last_checkpoint_name # 'examples' suffix appended in the load method

  # Load model weights
  print (f"Loading model weights from {model_file.replace('.pth.tar', '.h5')}")
  nnet.load_checkpoint(folder=str(load_folder), filename=model_file)

  # Load training examples
  print (f"Loading training samples from {checkpoint_file}.examples")
  args['load_model'] = True
  args['load_folder_file'] = (str(load_folder), checkpoint_file)
  coach.loadTrainExamples()

else:
  args['load_model'] = False
  print ("No progress to load, training from scratch")

In [None]:
# Learn
%time coach.learn()

In [None]:
# Save progress by backing up the latest checkpoint, to be loaded on next run
latest_samples = sorted(checkpoint_folder.glob('*.examples'))[-1]
best_model = checkpoint_folder / 'best.h5'

timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 

shutil.copy(latest_samples, load_folder / f'{timestamp}.pth.tar.examples')
shutil.copy(best_model, load_folder / f'{timestamp}.h5')

In [None]:
# Test performance against random bots

%matplotlib widget

game = SantoriniGame(5)

random_player = RandomPlayer(game).play
greedy_player = GreedySantoriniPlayer(game).play

args = dotdict({'numMCTSSims': 25, 'cpuct': 1.0})
mcts = MCTS(game, nnet, args)
alphago_player = lambda x: np.argmax(mcts.getActionProb(x, temp=0))

arena = Arena.Arena(alphago_player, greedy_player, game, display=game.display_3d)

%time oneWon, twoWon, draws = arena.playGames(5, verbose=False)
print("\AlphaGo won {} games, Greedy Player won {} games".format(oneWon, twoWon))