In [1]:
import time
import logging
import os
import random
import csv

import numpy as np
import coloredlogs
from FAdo.conversions import *

from utils.data_loader import *
from utils.heuristics import *

from alpha_zero.Coach import Coach
from alpha_zero.MCTS import MCTS
from alpha_zero.utils import *
from alpha_zero.state_elimination.StateEliminationGame import StateEliminationGame as Game
from alpha_zero.state_elimination.pytorch.NNet import NNetWrapper as nn


In [2]:
log = logging.getLogger(__name__)
coloredlogs.install(level='INFO')
args = dotdict({
    'numIters': 1000,
    # Number of complete self-play games to simulate during a new iteration.
    'numEps': 100,
    'tempThreshold': 0,        # temperature hyperparameters
    # During arena playoff, new neural net will be accepted if threshold or more of games are won.
    'updateThreshold': 0.6,
    # Number of game examples to train the neural networks.
    'maxlenOfQueue': 200000,
    'numMCTSSims': 25,          # Number of games moves for MCTS to simulate.
    # Number of games to play during arena play to determine if new net will be accepted.
    'arenaCompare': 40,
    'cpuct': 1,
    'checkpoint': './alpha_zero/models/',
    'load_model': True,
    'load_folder_file': ('./alpha_zero/models/', 'best.pth.tar'),
    'numItersForTrainExamplesHistory': 20,
})
min_n = 3
max_n = 6
n_range = max_n - min_n + 1
alphabet = [2]
density = [0.2]
sample_size = 30


In [3]:
def test_heuristics():
    if os.path.isfile('./result/heuristics_experiment_result.pkl'):
        with open('./result/heuristics_experiment_result.pkl', 'rb') as fp:
            exp = load(fp)
            return exp
    else:
        data = load_data()
        exp = [[[[[0, 0] for d in range(len(density))] for k in range(
            len(alphabet))] for n in range(n_range)] for c in range(6)]
        for n in range(n_range):
            for k in range(len(alphabet)):
                for d in range(len(density)):
                    for i in range(sample_size):
                        random.seed(i)
                        print('n' + str(n + min_n) + 'k' + ('2' if not k else ('5' if k == 1 else '10')) + (
                            's' if not d else 'd') + '\'s ' + str(i + 1) + ' sample')
                        # eliminate_randomly
                        gfa = data[n][k][d][i].dup()
                        start_time = time.time()
                        result = eliminate_randomly(gfa)
                        end_time = time.time()
                        result_time = end_time - start_time
                        result_size = result.treeLength()
                        exp[0][n][k][d][0] += result_time
                        exp[0][n][k][d][1] += result_size

                        # decompose with eliminate_randomly
                        gfa = data[n][k][d][i].dup()
                        start_time = time.time()
                        result = decompose(gfa, False, False)
                        end_time = time.time()
                        result_time = end_time - start_time
                        result_size = result.treeLength()
                        exp[1][n][k][d][0] += result_time
                        exp[1][n][k][d][1] += result_size

                        # eliminate_by_state_weight_heuristic
                        gfa = data[n][k][d][i].dup()
                        start_time = time.time()
                        result = eliminate_by_state_weight_heuristic(gfa)
                        end_time = time.time()
                        result_time = end_time - start_time
                        result_size = result.treeLength()
                        exp[2][n][k][d][0] += result_time
                        exp[2][n][k][d][1] += result_size

                        # decompose + eliminate_by_state_weight_heuristic
                        gfa = data[n][k][d][i].dup()
                        start_time = time.time()
                        result = decompose(gfa, True, False)
                        end_time = time.time()
                        result_time = end_time - start_time
                        result_size = result.treeLength()
                        exp[3][n][k][d][0] += result_time
                        exp[3][n][k][d][1] += result_size

                        # eliminate_by_repeated_state_weight_heuristic
                        gfa = data[n][k][d][i].dup()
                        start_time = time.time()
                        result = eliminate_by_repeated_state_weight_heuristic(
                            gfa)
                        end_time = time.time()
                        result_time = end_time - start_time
                        result_size = result.treeLength()
                        exp[4][n][k][d][0] += result_time
                        exp[4][n][k][d][1] += result_size

                        # decompose + eliminate_by_repeated_state_weight_heuristic
                        gfa = data[n][k][d][i].dup()
                        start_time = time.time()
                        result = decompose(gfa, True, True)
                        end_time = time.time()
                        result_time = end_time - start_time
                        result_size = result.treeLength()
                        exp[5][n][k][d][0] += result_time
                        exp[5][n][k][d][1] += result_size
        with open('./result/heuristics_experiment_result.pkl', 'wb') as fp:
            dump(exp, fp)


In [4]:
exp = test_heuristics()

n3k2s's 1 sample
n3k2s's 2 sample
n3k2s's 3 sample
n3k2s's 4 sample
n3k2s's 5 sample
n3k2s's 6 sample
n3k2s's 7 sample
n3k2s's 8 sample
n3k2s's 9 sample
n3k2s's 10 sample
n3k2s's 11 sample
n3k2s's 12 sample
n3k2s's 13 sample
n3k2s's 14 sample
n3k2s's 15 sample
n3k2s's 16 sample
n3k2s's 17 sample
n3k2s's 18 sample
n3k2s's 19 sample
n3k2s's 20 sample
n3k2s's 21 sample
n3k2s's 22 sample
n3k2s's 23 sample
n3k2s's 24 sample
n3k2s's 25 sample
n3k2s's 26 sample
n3k2s's 27 sample
n3k2s's 28 sample
n3k2s's 29 sample
n3k2s's 30 sample
n4k2s's 1 sample
n4k2s's 2 sample
n4k2s's 3 sample
n4k2s's 4 sample
n4k2s's 5 sample
n4k2s's 6 sample
n4k2s's 7 sample
n4k2s's 8 sample
n4k2s's 9 sample
n4k2s's 10 sample
n4k2s's 11 sample
n4k2s's 12 sample
n4k2s's 13 sample
n4k2s's 14 sample
n4k2s's 15 sample
n4k2s's 16 sample
n4k2s's 17 sample
n4k2s's 18 sample
n4k2s's 19 sample
n4k2s's 20 sample
n4k2s's 21 sample
n4k2s's 22 sample
n4k2s's 23 sample
n4k2s's 24 sample
n4k2s's 25 sample
n4k2s's 26 sample
n4k2s's 27

In [5]:
import pandas as pd

2023-03-15 19:39:49 ksk numexpr.utils[1335862] INFO Note: NumExpr detected 16 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2023-03-15 19:39:49 ksk numexpr.utils[1335862] INFO NumExpr defaulting to 8 threads.


In [6]:
def test_alpha_zero():
    if not os.path.isfile('./result/alpha_zero_experiment_result.pkl'):
        with open('./result/alpha_zero_experiment_result.pkl', 'rb') as fp:
            exp = load(fp)
        with open('./result/c7.csv', 'w', newline='') as fp:
            writer = csv.writer(fp)
            for n in range(5 - 3, 11 - 3):
                size_value = exp[n][1][0][1] / 100
                writer.writerow([size_value])
    else:
        data = load_data()
        exp = [[[[0, 0] for d in range(len(density))] for k in range(
            len(alphabet))] for n in range(n_range)]
        g = Game()
        nnet = nn(g)
        mcts = MCTS(g, nnet, args)
        def player(x): return np.argmax(mcts.getActionProb(x, temp=0))
        curPlayer = 1
        if args.load_model:
            nnet.load_checkpoint(args.checkpoint, args.load_folder_file[1])
        else:
            print("Can't test without pre-trained model")
            exit()
        for n in range(n_range):
            for k in range(len(alphabet)):
                for d in range(len(density)):
                    for i in range(sample_size):
                        #print('n' + str(n + min_n) + 'k' + ('2' if not k else ('5' if k == 1 else '10')) + (
                        #    's' if not d else 'd') + '\'s ' + str(i + 1) + ' sample')
                        gfa = data[n][k][d][i].dup()
                        board = g.getInitBoard(
                            gfa, n + min_n, alphabet[k], density[d])
                        order = []
                        start_time = time.time()
                        while g.getGameEnded(board, curPlayer) == -1:
                            action = player(
                                g.getCanonicalForm(board, curPlayer))
                            valids = g.getValidMoves(
                                g.getCanonicalForm(board, curPlayer), 1)
                            if valids[action] == 0:
                                assert valids[action] > 0
                            board, curPlayer = g.getNextState(
                                board, curPlayer, action)
                            order.append(action)
                            
                        result = g.gfaToBoard(board)[0][n + min_n + 1].treeLength()
                        end_time = time.time()
                        gfa.eliminateAll(order)
                        '''
                        if (result != gfa.delta[0][n + min_n + 1].treeLength()):
                            print('order', order)
                            print('result length', result)
                            print('valid length',
                                  gfa.delta[0][n + min_n + 1].treeLength())
                            print('Something is wrong')
                            exit()
                        '''
                        result_time = end_time - start_time
                        exp[n][k][d][0] += result_time
                        exp[n][k][d][1] += result
        with open('./result/alpha_zero_experiment_result.pkl', 'wb') as fp:
            dump(exp, fp)

In [7]:
test_alpha_zero()

In [None]:
with open('./result/alpha_zero_experiment_result.pkl', 'rb') as fp:
    exp_alpha = load(fp)

with open('./result/heuristics_experiment_result.pkl', 'rb') as fp:
    exp_heuristic = load(fp)

In [None]:
np.array(exp_heuristic[5][:4])[:, 0, 0].tolist(), np.array(exp_alpha[:4])[:, 0, 0].tolist()


([[0.016735553741455078, 372.0],
  [0.018989086151123047, 752.0],
  [0.026637554168701172, 1839.0],
  [0.0336604118347168, 3093.0]],
 [[1.4725992679595947, 217.0],
  [4.289000034332275, 587.0],
  [13.468000173568726, 2179.0],
  [35.0204656124115, 5617.0]])