In [18]:
import numpy as np
from IPython.display import clear_output
import time
import random
import concurrent.futures

In [19]:
def update_board(board_temp,color,column):
    # this is a function that takes the current board status, a color, and a column and outputs the new board status
    # columns 0 - 6 are for putting a checker on the board: if column is full just return the current board...this should be forbidden by the player
    
    # the color input should be either 'plus' or 'minus'
    
    board = board_temp.copy()
    ncol = board.shape[1]
    nrow = board.shape[0]
    
    # this seems silly, but actually faster to run than using sum because of overhead! 
    colsum = abs(board[0,column])+abs(board[1,column])+abs(board[2,column])+abs(board[3,column])+abs(board[4,column])+abs(board[5,column])
    row = int(5-colsum)
    if row > -0.5:
        if color == 'plus':
            board[row,column] = 1
        else:
            board[row,column] = -1
    return board
    
# in this code the board is a 6x7 numpy array.  Each entry is +1, -1 or 0.  You WILL be able to do a better
# job training your neural network if you rearrange this to be a 6x7x2 numpy array.  If the i'th row and j'th
# column is +1, this can be represented by board[i,j,0]=1.  If it is -1, this can be represented by
# board[i,j,1]=1. It's up to you how you represent your board.


In [20]:
def check_for_win(board,col):
    # this code is faster than the above code, but it requires knowing where the last checker was dropped
    # it may seem extreme, but in MCTS this function is called more than anything and actually makes up
    # a large portion of total time spent finding a good move.  So every microsecond is worth saving!
    nrow = 6
    ncol = 7
    # take advantage of knowing what column was last played in...need to check way fewer possibilities
    colsum = abs(board[0,col])+abs(board[1,col])+abs(board[2,col])+abs(board[3,col])+abs(board[4,col])+abs(board[5,col])
    row = int(6-colsum)
    if row+3<6:
        vert = board[row,col] + board[row+1,col] + board[row+2,col] + board[row+3,col]
        if vert == 4:
            return 'v-plus'
        elif vert == -4:
            return 'v-minus'
    if col+3<7:
        hor = board[row,col] + board[row,col+1] + board[row,col+2] + board[row,col+3]
        if hor == 4:
            return 'h-plus'
        elif hor == -4:
            return 'h-minus'
    if col-1>=0 and col+2<7:
        hor = board[row,col-1] + board[row,col] + board[row,col+1] + board[row,col+2]
        if hor == 4:
            return 'h-plus'
        elif hor == -4:
            return 'h-minus'
    if col-2>=0 and col+1<7:
        hor = board[row,col-2] + board[row,col-1] + board[row,col] + board[row,col+1]
        if hor == 4:
            return 'h-plus'
        elif hor == -4:
            return 'h-minus'
    if col-3>=0:
        hor = board[row,col-3] + board[row,col-2] + board[row,col-1] + board[row,col]
        if hor == 4:
            return 'h-plus'
        elif hor == -4:
            return 'h-minus'
    if row < 3 and col < 4:
        DR = board[row,col] + board[row+1,col+1] + board[row+2,col+2] + board[row+3,col+3]
        if DR == 4:
            return 'd-plus'
        elif DR == -4:
            return 'd-minus'
    if row-1>=0 and col-1>=0 and row+2<6 and col+2<7:
        DR = board[row-1,col-1] + board[row,col] + board[row+1,col+1] + board[row+2,col+2]
        if DR == 4:
            return 'd-plus'
        elif DR == -4:
            return 'd-minus'
    if row-2>=0 and col-2>=0 and row+1<6 and col+1<7:
        DR = board[row-2,col-2] + board[row-1,col-1] + board[row,col] + board[row+1,col+1]
        if DR == 4:
            return 'd-plus'
        elif DR == -4:
            return 'd-minus'
    if row-3>=0 and col-3>=0:
        DR = board[row-3,col-3] + board[row-2,col-2] + board[row-1,col-1] + board[row,col]
        if DR == 4:
            return 'd-plus'
        elif DR == -4:
            return 'd-minus'
    if row+3<6 and col-3>=0:
        DL = board[row,col] + board[row+1,col-1] + board[row+2,col-2] + board[row+3,col-3]
        if DL == 4:
            return 'd-plus'
        elif DL == -4:
            return 'd-minus'
    if row-1 >= 0 and col+1 < 7 and row+2<6 and col-2>=0:
        DL = board[row-1,col+1] + board[row,col] + board[row+1,col-1] + board[row+2,col-2]
        if DL == 4:
            return 'd-plus'
        elif DL == -4:
            return 'd-minus'
    if row-2 >=0 and col+2<7 and row+1<6 and col-1>=0:
        DL = board[row-2,col+2] + board[row-1,col+1] + board[row,col] + board[row+1,col-1]
        if DL == 4:
            return 'd-plus'
        elif DL == -4:
            return 'd-minus'
    if row-3>=0 and col+3<7:
        DL = board[row-3,col+3] + board[row-2,col+2] + board[row-1,col+1] + board[row,col]
        if DL == 4:
            return 'd-plus'
        elif DL == -4:
            return 'd-minus'
    if abs(board[0,0]) + abs(board[0,1]) + abs(board[0,2]) + abs(board[0,3]) + abs(board[0,4]) + abs(board[0,5]) + abs(board[0,6]) > 6.5:
        return 'tie'
    return 'nobody'

In [21]:
def find_legal(board):
    legal = [i for i in range(7) if abs(board[0,i]) < 0.1]
    return legal

In [22]:
def look_for_win(board_,color):
    board_ = board_.copy()
    legal = find_legal(board_)
    winner = -1
    for m in legal:
        bt = update_board(board_.copy(),color,m)
        wi = check_for_win(bt,m)
        if wi[2:] == color:
            winner = m
            break
    return winner

In [23]:
def find_all_nonlosers(board,color):
    if color == 'plus':
        opp = 'minus'
    else:
        opp = 'plus'
    legal = find_legal(board)
    poss_boards = [update_board(board,color,l) for l in legal]
    poss_legal = [find_legal(b) for b in poss_boards]
    allowed = []
    for i in range(len(legal)):
        wins = [j for j in poss_legal[i] if check_for_win(update_board(poss_boards[i],opp,j),j) != 'nobody']
        if len(wins) == 0:
            allowed.append(legal[i])
    return allowed

In [24]:
def back_prop(winner,path,color0,md):
    for i in range(len(path)):
        board_temp = path[i]
        
        md[board_temp][0]+=1
        if winner[2]==color0[0]:
            if i % 2 == 1:
                md[board_temp][1] += 1
            else:
                md[board_temp][1] -= 1
        elif winner[2]=='e': # tie
            # md[board_temp][1] += 0
            pass
        else:
            if i % 2 == 1:
                md[board_temp][1] -= 1
            else:
                md[board_temp][1] += 1

In [25]:
# def rollout(board,next_player):
#     winner = 'nobody'
#     player = next_player
#     while winner == 'nobody':
#         legal = find_legal(board)
#         if len(legal) == 0:
#             winner = 'tie'
#             return winner
#         move = random.choice(legal)
#         board = update_board(board,player,move)
#         winner = check_for_win(board,move)
        
#         if player == 'plus':
#             player = 'minus'
#         else:
#             player = 'plus'
#     return winner

In [26]:
def rollout(board, next_player):

    winner = 'nobody'
    player = next_player

    while winner == 'nobody':
        # Find all legal moves
        legal = find_legal(board)

        # If no legal moves, it's a tie
        if len(legal) == 0:
            winner = 'tie'
            return winner

        # Select a random legal move
        move = random.choice(legal)
        board = update_board(board, player, move)

        # Check for a winner immediately after the move
        winner = check_for_win(board, move)
        if winner != 'nobody':
            return winner

        # Switch players
        player = 'minus' if player == 'plus' else 'plus'

    return winner


In [27]:
def mcts(board_temp,color0,nsteps):
    # nsteps is a parameter that determines the skill (and slowness) of the player
    # bigger values of nsteps means the player is better, but also slower to figure out a move.
    board = board_temp.copy()
    ##############################################
    winColumn = look_for_win(board,color0) # check to find a winning column
    if winColumn > -0.5:
        return winColumn # if there is one - play that!
    legal0 = find_all_nonlosers(board,color0) # find all moves that won't immediately lead to your opponent winning
    if len(legal0) == 0: # if you can't block your opponent - just find the 'best' losing move
        legal0 = find_legal(board)
    ##############################################
    # the code above, in between the hash rows, is not part of traditional MCTS
    # but it makes it better and faster - so I included it!
    # MCTS occasionally makes stupid mistakes
    # like not dropping the checker on a winning column, or not blocking an obvious opponent win
    # this avoids a little bit of that stupidity!
    # we could also add this logic to the rest of the MCTS and rollout functions - I just haven't done that yet...
    # feel free to experiment!
    mcts_dict = {tuple(board.ravel()):[0,0]}
    for ijk in range(nsteps):
        color = color0
        winner = 'nobody'
        board_mcts = board.copy()
        path = [tuple(board_mcts.ravel())]
        while winner == 'nobody':
            legal = find_legal(board_mcts)
            if len(legal) == 0:
                winner = 'tie'
                back_prop(winner,path,color0,mcts_dict)
                #back_prop(('tie', 'tie', 'e'), path, color0, mcts_dict)
                break
            board_list = []
            for col in legal:
                board_list.append(tuple(update_board(board_mcts,color,col).ravel()))
            for bl in board_list:
                if bl not in mcts_dict.keys():
                    mcts_dict[bl] = [0,0]
            ucb1 = np.zeros(len(legal))
            for i in range(len(legal)):
                num_denom = mcts_dict[board_list[i]]
                if num_denom[0] == 0:
                    ucb1[i] = 10*nsteps
                else:
                    ucb1[i] = num_denom[1]/num_denom[0] + 2*np.sqrt(np.log(mcts_dict[path[-1]][0])/mcts_dict[board_list[i]][0])
            chosen = np.argmax(ucb1)
            
            board_mcts = update_board(board_mcts,color,legal[chosen])
            path.append(tuple(board_mcts.ravel()))
            winner = check_for_win(board_mcts,legal[chosen])
            if winner[2]==color[0]:
                back_prop(winner,path,color0,mcts_dict)
                break
            if color == 'plus':
                color = 'minus'
            else:
                color = 'plus' 
            if mcts_dict[tuple(board_mcts.ravel())][0] == 0:
                winner = rollout(board_mcts,color)
                back_prop(winner,path,color0,mcts_dict)
                break    
    maxval = -np.inf
    best_col = -1
    for col in legal0:
        board_temp = tuple(update_board(board,color0,col).ravel())
        num_denom = mcts_dict[board_temp]
        if num_denom[0] == 0:
            compare = -np.inf
        else:
            compare = num_denom[1] / num_denom[0]   
        if compare > maxval:
            maxval = compare
            best_col = col
    return (best_col)
            

In [28]:
def display_board(board):
    # this function displays the board as ascii using X for +1 and O for -1
    # For the project, this should be a better picture of the board...
    clear_output()
    horizontal_line = '-'*(7*5+8)
    blank_line = '|'+' '*5
    blank_line *= 7
    blank_line += '|'
    print('   0     1     2     3     4     5     6')
    print(horizontal_line)
    for row in range(6):
        print(blank_line)
        this_line = '|'
        for col in range(7):
            if board[row,col] == 0:
                this_line += ' '*5 + '|'
            elif board[row,col] == 1:
                this_line += '  X  |'
            else:
                this_line += '  O  |'
        print(this_line)
        print(blank_line)
        print(horizontal_line)
    print('   0     1     2     3     4     5     6')

            

In [29]:
def display_board_2(board_6x7x2):
    """
    Displays a 6x7x2 board in ASCII format using X for +1 and O for -1.
    board_6x7x2: numpy array of shape (6, 7, 2), where:
        - board_6x7x2[i, j, 0] = 1 indicates a 'plus' piece at (i, j)
        - board_6x7x2[i, j, 1] = 1 indicates a 'minus' piece at (i, j)
        - Both are 0 for an empty spot
    """
    clear_output()
    horizontal_line = '-' * (7 * 5 + 8)
    blank_line = '|' + ' ' * 5
    blank_line *= 7
    blank_line += '|'
    
    print('   0     1     2     3     4     5     6')
    print(horizontal_line)
    
    for row in range(6):
        print(blank_line)
        this_line = '|'
        for col in range(7):
            if board_6x7x2[row, col, 0] == 1:  # 'plus' piece
                this_line += '  X  |'
            elif board_6x7x2[row, col, 1] == 1:  # 'minus' piece
                this_line += '  O  |'
            else:  # Empty spot
                this_line += '     |'
        print(this_line)
        print(blank_line)
        print(horizontal_line)
    
    print('   0     1     2     3     4     5     6')


In [30]:
# convert board to 6x7x2
def convert_board(board):
    board_6x7x2 = np.zeros((6, 7, 2), dtype=int)
    board_6x7x2[:, :, 0] = (board == 1).astype(int)  # channel for 'plus'
    board_6x7x2[:, :, 1] = (board == -1).astype(int)  # channel for 'minus'
    return board_6x7x2

In [31]:
def introduce_randomness(board):

    random_moves = np.random.randint(0, 3)  # randomly decide how many moves will be played
    player = 'plus' if np.random.rand() < 0.5 else 'minus'
    
    for _ in range(random_moves):
        legal_cols = find_legal(board)
        if legal_cols:
            random_col = np.random.choice(legal_cols)
            board = update_board(board, player, random_col)
            player = 'minus' if player == 'plus' else 'plus'

    return board, player

In [32]:
dataset=[]
dataset_length = 0
unique_boards = set()
games = 500

for i in range(games): #loop through number of games
    print('game: ', i)
    board = np.zeros((6,7)) #initialize board
    board, player = introduce_randomness(board)  # Add random moves

    winner = 'nobody' #start winner as nobody

    while winner == 'nobody':
        recommended_col = mcts(board, player, 5000)
        board_save = convert_board(board)

        if recommended_col == -1:
            winner = 'tie'
            break

        # prevent duplicates
        board_hash = tuple(board_save.flatten())

        player_store = '1' if player == 'plus' else '-1'
        
        if board_hash not in unique_boards:
            unique_boards.add(board_hash)
            dataset.append({
                'board': board_save,
                'recommended_column': recommended_col,
                'player': player_store
            })
            dataset_length += 1
            print('len: ', dataset_length)

        board = update_board(board, player, recommended_col)
        winner = check_for_win(board,recommended_col)
        player = 'minus' if player == 'plus' else 'plus'


game:  0
len:  1
len:  2
len:  3
len:  4
len:  5
len:  6
len:  7
len:  8
len:  9
len:  10
len:  11
len:  12
len:  13
len:  14
len:  15
len:  16
len:  17
len:  18
len:  19
len:  20
len:  21
len:  22
len:  23
len:  24
game:  1
len:  25
len:  26
len:  27
len:  28
len:  29
len:  30
len:  31
len:  32
len:  33
len:  34
len:  35
len:  36
len:  37
len:  38
len:  39
len:  40
len:  41
len:  42
len:  43
len:  44
len:  45
len:  46
len:  47
len:  48
len:  49
len:  50
len:  51
len:  52
len:  53
len:  54
len:  55
game:  2
len:  56
len:  57
len:  58
len:  59
len:  60
len:  61
len:  62
len:  63
len:  64
len:  65
len:  66
len:  67
len:  68
len:  69
len:  70
len:  71
len:  72
len:  73
len:  74
len:  75
len:  76
len:  77
len:  78
len:  79
len:  80
len:  81
len:  82
len:  83
len:  84
len:  85
len:  86
len:  87
len:  88
len:  89
game:  3
len:  90
len:  91
len:  92
len:  93
len:  94
len:  95
len:  96
len:  97
len:  98
len:  99
len:  100
len:  101
len:  102
len:  103
len:  104
len:  105
len:  106
len:  107
le

In [33]:
print('Dataset Length: ', dataset_length)

Dataset Length:  11937


In [34]:
import pickle
# Save the dataset to a pickle file
with open('dataset4.pkl', 'wb') as f:
    pickle.dump(dataset, f)

print("Dataset generation complete!")

Dataset generation complete!
