In [142]:
import time
import logging
import os
import random
import csv

import numpy as np
import coloredlogs
from FAdo.conversions import *

from utils.data_loader import *
from utils.heuristics import *

from alpha_zero.Coach import Coach
from alpha_zero.MCTS import MCTS
from alpha_zero.utils import *
from alpha_zero.state_elimination.StateEliminationGame import StateEliminationGame as Game
from alpha_zero.state_elimination.pytorch.NNet import NNetWrapper as nn


In [143]:
log = logging.getLogger(__name__)
coloredlogs.install(level='INFO')
args = dotdict({
    'numIters': 1000,
    # Number of complete self-play games to simulate during a new iteration.
    'numEps': 100,
    'tempThreshold': 0,        # temperature hyperparameters
    # During arena playoff, new neural net will be accepted if threshold or more of games are won.
    'updateThreshold': 0.6,
    # Number of game examples to train the neural networks.
    'maxlenOfQueue': 200000,
    'numMCTSSims': 25,          # Number of games moves for MCTS to simulate.
    # Number of games to play during arena play to determine if new net will be accepted.
    'arenaCompare': 40,
    'cpuct': 1,
    'checkpoint': './alpha_zero/models/',
    'load_model': True,
    'load_folder_file': ('./alpha_zero/models/', 'best.pth.tar'),
    'numItersForTrainExamplesHistory': 20,
})
min_n = 3
max_n = 8
n_range = max_n - min_n + 1
alphabet = [2]
density = [0.2]
sample_size = 30


In [144]:
def test_heuristics():
    if not os.path.isfile('./result/heuristics_experiment_result.pkl'):
        with open('./result/heuristics_experiment_result.pkl', 'rb') as fp:
            exp = load(fp)
            return exp
    else:
        data = load_data()
        exp = [[[[[0, 0] for d in range(len(density))] for k in range(
            len(alphabet))] for n in range(n_range)] for c in range(6)]
        for n in range(n_range):
            for k in range(len(alphabet)):
                for d in range(len(density)):
                    for i in range(sample_size):
                        random.seed(i)
                        print('n' + str(n + min_n) + 'k' + ('2' if not k else ('5' if k == 1 else '10')) + (
                            's' if not d else 'd') + '\'s ' + str(i + 1) + ' sample')
                        # eliminate_randomly
                        gfa = data[n][k][d][i].dup()
                        start_time = time.time()
                        result = eliminate_randomly(gfa)
                        end_time = time.time()
                        result_time = end_time - start_time
                        result_size = result.treeLength()
                        exp[0][n][k][d][0] += result_time
                        exp[0][n][k][d][1] += result_size

                        # decompose with eliminate_randomly
                        gfa = data[n][k][d][i].dup()
                        start_time = time.time()
                        result = decompose(gfa, False, False)
                        end_time = time.time()
                        result_time = end_time - start_time
                        result_size = result.treeLength()
                        exp[1][n][k][d][0] += result_time
                        exp[1][n][k][d][1] += result_size

                        # eliminate_by_state_weight_heuristic
                        gfa = data[n][k][d][i].dup()
                        start_time = time.time()
                        result = eliminate_by_state_weight_heuristic(gfa)
                        end_time = time.time()
                        result_time = end_time - start_time
                        result_size = result.treeLength()
                        exp[2][n][k][d][0] += result_time
                        exp[2][n][k][d][1] += result_size

                        # decompose + eliminate_by_state_weight_heuristic
                        gfa = data[n][k][d][i].dup()
                        start_time = time.time()
                        result = decompose(gfa, True, False)
                        end_time = time.time()
                        result_time = end_time - start_time
                        result_size = result.treeLength()
                        exp[3][n][k][d][0] += result_time
                        exp[3][n][k][d][1] += result_size

                        # eliminate_by_repeated_state_weight_heuristic
                        gfa = data[n][k][d][i].dup()
                        start_time = time.time()
                        result = eliminate_by_repeated_state_weight_heuristic(
                            gfa)
                        end_time = time.time()
                        result_time = end_time - start_time
                        result_size = result.treeLength()
                        exp[4][n][k][d][0] += result_time
                        exp[4][n][k][d][1] += result_size

                        # decompose + eliminate_by_repeated_state_weight_heuristic
                        gfa = data[n][k][d][i].dup()
                        start_time = time.time()
                        result = decompose(gfa, True, True)
                        end_time = time.time()
                        result_time = end_time - start_time
                        result_size = result.treeLength()
                        exp[5][n][k][d][0] += result_time
                        exp[5][n][k][d][1] += result_size
        with open('./result/heuristics_experiment_result.pkl', 'wb') as fp:
            dump(exp, fp)


In [145]:
# exp = test_heuristics()

In [146]:
import pandas as pd

In [147]:
def test_alpha_zero():
    if not os.path.isfile('./result/alpha_zero_experiment_result.pkl'):
        with open('./result/alpha_zero_experiment_result.pkl', 'rb') as fp:
            exp = load(fp)
        with open('./result/c7.csv', 'w', newline='') as fp:
            writer = csv.writer(fp)
            for n in range(5 - 3, 11 - 3):
                size_value = exp[n][1][0][1] / 100
                writer.writerow([size_value])
    else:
        data = load_data()
        exp = [[[[0, 0] for d in range(len(density))] for k in range(
            len(alphabet))] for n in range(n_range)]
        g = Game()
        nnet = nn(g)
        mcts = MCTS(g, nnet, args)
        def player(x): return np.argmax(mcts.getActionProb(x, temp=0))
        curPlayer = 1
        if args.load_model:
            nnet.load_checkpoint(args.checkpoint, args.load_folder_file[1])
        else:
            print("Can't test without pre-trained model")
            exit()
        for n in range(n_range):
            for k in range(len(alphabet)):
                for d in range(len(density)):
                    for i in range(sample_size):
                        #print('n' + str(n + min_n) + 'k' + ('2' if not k else ('5' if k == 1 else '10')) + (
                        #    's' if not d else 'd') + '\'s ' + str(i + 1) + ' sample')
                        gfa = data[n][k][d][i].dup()
                        board = g.getInitBoard(
                            gfa, n + min_n, alphabet[k], density[d])
                        order = []
                        start_time = time.time()
                        while g.getGameEnded(board, curPlayer) == -1:
                            action = player(
                                g.getCanonicalForm(board, curPlayer))
                            valids = g.getValidMoves(
                                g.getCanonicalForm(board, curPlayer), 1)
                            if valids[action] == 0:
                                assert valids[action] > 0
                            board, curPlayer = g.getNextState(
                                board, curPlayer, action)
                            order.append(action)
                            
                        result = g.gfaToBoard(board)[0][n + min_n + 1].treeLength()
                        end_time = time.time()
                        gfa.eliminateAll(order)
                        '''
                        if (result != gfa.delta[0][n + min_n + 1].treeLength()):
                            print('order', order)
                            print('result length', result)
                            print('valid length',
                                  gfa.delta[0][n + min_n + 1].treeLength())
                            print('Something is wrong')
                            exit()
                        '''
                        result_time = end_time - start_time
                        exp[n][k][d][0] += result_time
                        exp[n][k][d][1] += result
        with open('./result/alpha_zero_experiment_result.pkl', 'wb') as fp:
            dump(exp, fp)

In [148]:
test_alpha_zero()

In [None]:
with open('./result/alpha_zero_experiment_result.pkl', 'rb') as fp:
    exp_alpha = load(fp)

with open('./result/heuristics_experiment_result.pkl', 'rb') as fp:
    exp_heuristic = load(fp)

In [None]:
np.array(exp_heuristic[5][:6])[:, 0, 0].tolist()

[[0.015770673751831055, 372.0],
 [0.02125382423400879, 752.0],
 [0.025895357131958008, 1839.0],
 [0.0336301326751709, 3093.0],
 [0.05150318145751953, 8606.0],
 [0.08024716377258301, 19352.0]]

In [None]:
exp_alpha

[[[[0.546905517578125, 74]]],
 [[[1.4073143005371094, 177]]],
 [[[3.773301839828491, 595]]],
 [[[9.00179934501648, 1385]]],
 [[[28.578489065170288, 5166]]],
 [[[120.01312184333801, 21606]]]]

In [105]:
exp_alpha


[[[[0.5762872695922852, 76]]],
 [[[1.3665142059326172, 173]]],
 [[[3.897165060043335, 595]]],
 [[[10.418294429779053, 1636]]]]

In [22]:
exp_heuristic[5]

[[[[0.015770673751831055, 372], [0.020870208740234375, 1408]],
  [[0.025353431701660156, 1350], [0.04005026817321777, 3863]],
  [[0.03676748275756836, 3326], [0.08379054069519043, 8032]]],
 [[[0.02125382423400879, 752], [0.044689178466796875, 5316]],
  [[0.04623866081237793, 4206], [0.08870840072631836, 17065]],
  [[0.0786600112915039, 12673], [0.2022247314453125, 38010]]],
 [[[0.025895357131958008, 1839], [0.07112789154052734, 18355]],
  [[0.06750774383544922, 13243], [1.5291893482208252, 68571]],
  [[0.1848142147064209, 47462], [1.8073883056640625, 136901]]],
 [[[0.0336301326751709, 3093], [0.16398978233337402, 60595]],
  [[0.13657927513122559, 40205], [3.4006781578063965, 293525]],
  [[0.4982781410217285, 152565], [4.538684368133545, 639854]]],
 [[[0.05150318145751953, 8606], [3.042722463607788, 217070]],
  [[2.996006965637207, 139377], [12.025941610336304, 943944]],
  [[7.127958536148071, 632047], [24.530405282974243, 2350448]]],
 [[[0.08024716377258301, 19352], [5.416334390640259,