In [1]:
TRAIN_DATA_PATH = './../data/train-1.renju'
TRAIN_SIZE = 1984695
TEST_DATA_PATH = './../data/train-2.renju'
TEST_SIZE = 147230
BATCH_SIZE = 1000

POS_TO_LETTER = 'abcdefghijklmno'
LETTER_TO_POS = dict((let, pos) for pos, let in enumerate(POS_TO_LETTER))

In [2]:
def parse_data_file(data_path, ind):
    def convert_move_to_pos(move):
        move = move.strip()
        try:
            let, n = move[0], int(move[1:]) - 1
            return n, LETTER_TO_POS[let]
        except:
            return None
    
    def normalize_game_info(game_info):
        return game_info[0], [convert_move_to_pos(move) for move in game_info[1:] if convert_move_to_pos(move) is not None]
    
    raw_data = []
    cnt = -1
    print(data_path)
    with open(data_path) as ff:
        for line in ff:
            cnt += 1
            if cnt < ind * BATCH_SIZE:
                continue
            if cnt >= (ind + 1) * BATCH_SIZE:
                continue
            raw_data.append(line.strip('\n'))
            if cnt % 100 == 0:
                print('>({})'.format(str(cnt)), end='')
    print()
    return [normalize_game_info(game_info.split()) for game_info  in raw_data]

In [3]:
from board import Board
from config import WHITE, BLACK, change_color

import numpy as np
from copy import deepcopy
from random import shuffle
import json
import os

In [4]:
# Build list of data for training of testing
# One value of output array is a tuple containing
# a board state - array 15x15
# and an expert's move - np.array 225 with one value set to 1, corresponding to the chosen move
def generate_data(games):
    def generate_boards(game_info):
        winner = game_info[0]
        game = game_info[1]
        boards = list()
        board = Board()
        label = np.zeros((15, 15))
        color = BLACK
        is_win_state = 0
        if winner == 'black':
            is_win_state = 1
        elif winner == 'white':
            is_win_state = -1
        for move in game:
            label[move] = 1
            boards.append((
                deepcopy(board.get_board().tolist()),
                deepcopy(label.ravel().tolist()),
                [np.float64(is_win_state)]
            ))
            boards.append((
                deepcopy(np.rot90(board.get_board()).tolist()),
                deepcopy(np.rot90(label).ravel().tolist()),
                [np.float64(is_win_state)]
            ))
            boards.append((
                deepcopy(np.rot90(board.get_board(), 2).tolist()), 
                deepcopy(np.rot90(label, 2).ravel().tolist()),
                [np.float64(is_win_state)]
            ))
            boards.append((
                deepcopy(np.rot90(board.get_board(), 3).tolist()),
                deepcopy(np.rot90(label, 3).ravel().tolist()),
                [np.float64(is_win_state)]
            ))
            try:
                board.execute_move(move, color)
            except:
                boards.pop()
                boards.pop()
                boards.pop()
                boards.pop()
                break
            label[move] = 0
            color = change_color(color)
            is_win_state = -is_win_state
        return boards
            
    data = []
    cnt = 0
    for game in games:
        data.extend(generate_boards(game))
        cnt += 1
        if cnt % 100 == 0:
            print('-', end='')
    print()
    shuffle(data)
    return data

In [None]:
for ind in range((TRAIN_SIZE + BATCH_SIZE - 1) // BATCH_SIZE):
    if ind < 147:
        continue
    if os.path.isfile('./../renju_nn_train_data/renju_nn_train_data_{}.json'.format(str(ind))):
        continue
    print('STAGE: parse data file {}'.format(ind))
    train_game_data = parse_data_file(TRAIN_DATA_PATH, ind)
    print('STAGE: generate data')
    train_data = generate_data(train_game_data)
    with open('./../renju_nn_train_data/renju_nn_train_data_{}.json'.format(str(ind)), 'w') as f:
        json.dump(train_data, f)
    print('_____________________________________________________________________')
    print('Loaded batch {}'.format(str(ind)))

STAGE: parse data file 308
./../data/train-1.renju
>(308000)>(308100)>(308200)>(308300)>(308400)>(308500)>(308600)>(308700)>(308800)>(308900)
STAGE: generate data
----------
_____________________________________________________________________
Loaded batch 308
STAGE: parse data file 309
./../data/train-1.renju
>(309000)>(309100)>(309200)>(309300)>(309400)>(309500)>(309600)>(309700)>(309800)>(309900)
STAGE: generate data
----------
_____________________________________________________________________
Loaded batch 309
STAGE: parse data file 310
./../data/train-1.renju
>(310000)>(310100)>(310200)>(310300)>(310400)>(310500)>(310600)>(310700)>(310800)>(310900)
STAGE: generate data
----------
_____________________________________________________________________
Loaded batch 310
STAGE: parse data file 311
./../data/train-1.renju
>(311000)>(311100)>(311200)>(311300)>(311400)>(311500)>(311600)>(311700)>(311800)>(311900)
STAGE: generate data
----------
___________________________________________

>(339000)>(339100)>(339200)>(339300)>(339400)>(339500)>(339600)>(339700)>(339800)>(339900)
STAGE: generate data
----------
_____________________________________________________________________
Loaded batch 339
STAGE: parse data file 340
./../data/train-1.renju
>(340000)>(340100)>(340200)>(340300)>(340400)>(340500)>(340600)>(340700)>(340800)>(340900)
STAGE: generate data
----------
_____________________________________________________________________
Loaded batch 340
STAGE: parse data file 341
./../data/train-1.renju
>(341000)>(341100)>(341200)>(341300)>(341400)>(341500)>(341600)>(341700)>(341800)>(341900)
STAGE: generate data
----------
_____________________________________________________________________
Loaded batch 341
STAGE: parse data file 342
./../data/train-1.renju
>(342000)>(342100)>(342200)>(342300)>(342400)>(342500)>(342600)>(342700)>(342800)>(342900)
STAGE: generate data
----------
_____________________________________________________________________
Loaded batch 342
STAGE: 

----------
_____________________________________________________________________
Loaded batch 370
STAGE: parse data file 371
./../data/train-1.renju
>(371000)>(371100)>(371200)>(371300)>(371400)>(371500)>(371600)>(371700)>(371800)>(371900)
STAGE: generate data
----------
_____________________________________________________________________
Loaded batch 371
STAGE: parse data file 372
./../data/train-1.renju
>(372000)>(372100)>(372200)>(372300)>(372400)>(372500)>(372600)>(372700)>(372800)>(372900)
STAGE: generate data
----------
_____________________________________________________________________
Loaded batch 372
STAGE: parse data file 373
./../data/train-1.renju
>(373000)>(373100)>(373200)>(373300)>(373400)>(373500)>(373600)>(373700)>(373800)>(373900)
STAGE: generate data
----------
_____________________________________________________________________
Loaded batch 373
STAGE: parse data file 374
./../data/train-1.renju
>(374000)>(374100)>(374200)>(374300)>(374400)>(374500)>(374600)>(374

>(402000)>(402100)>(402200)>(402300)>(402400)>(402500)>(402600)>(402700)>(402800)>(402900)
STAGE: generate data
----------
_____________________________________________________________________
Loaded batch 402
STAGE: parse data file 403
./../data/train-1.renju
>(403000)>(403100)>(403200)>(403300)>(403400)>(403500)>(403600)>(403700)>(403800)>(403900)
STAGE: generate data
----------
_____________________________________________________________________
Loaded batch 403
STAGE: parse data file 404
./../data/train-1.renju
>(404000)>(404100)>(404200)>(404300)>(404400)>(404500)>(404600)>(404700)>(404800)>(404900)
STAGE: generate data
----------
_____________________________________________________________________
Loaded batch 404
STAGE: parse data file 405
./../data/train-1.renju
>(405000)>(405100)>(405200)>(405300)>(405400)>(405500)>(405600)>(405700)>(405800)>(405900)
STAGE: generate data
----------
_____________________________________________________________________
Loaded batch 405
STAGE: 