In [55]:
def check_win(board, player):
    """
    Check if the player can win in the next move and return all such moves.
    """
    winning_moves = []
    winning_combinations = [(0, 1, 2), (3, 4, 5), (6, 7, 8),  # rows
                            (0, 3, 6), (1, 4, 7), (2, 5, 8),  # columns
                            (0, 4, 8), (2, 4, 6)]  # diagonals
    for a, b, c in winning_combinations:
        if board[a] == board[b] == player and board[c] == ' ':
            winning_moves.append(c)
        if board[a] == board[c] == player and board[b] == ' ':
            winning_moves.append(b)
        if board[b] == board[c] == player and board[a] == ' ':
            winning_moves.append(a)
    return winning_moves

def check_block(board, player):
    """
    Check if the opponent can win in the next move and return all blocking moves.
    """
    opponent = 'O' if player == 'X' else 'X'
    return check_win(board, opponent)

def is_unblocked_line(board, player, a, b, c):
    """
    Check if a line (defined by indices a, b, c) is unblocked for the given player.
    """
    return ((board[a] == board[b] == player and board[c] == ' ') or
            (board[a] == board[c] == player and board[b] == ' ') or
            (board[b] == board[c] == player and board[a] == ' '))

def check_fork(board, player):
    """
    Check if the player can create a fork and return all such moves.
    """
    fork_moves = []
    winning_combinations = [(0, 1, 2), (3, 4, 5), (6, 7, 8),  # rows
                            (0, 3, 6), (1, 4, 7), (2, 5, 8),  # columns
                            (0, 4, 8), (2, 4, 6)]  # diagonals
    for i in range(9):
        if board[i] == ' ':
            board[i] = player  # Temporarily place the player's marker
            unblocked_lines = 0  # Count of unblocked lines
            for a, b, c in winning_combinations:
                if is_unblocked_line(board, player, a, b, c):
                    unblocked_lines += 1
            if unblocked_lines >= 2:
                fork_moves.append(i)
            board[i] = ' '  # Reset the board
    return fork_moves

def is_two_in_a_row(board, player, a, b, c):
    return ((board[a] == player and board[b] == player and board[c] == ' ') or
            (board[a] == player and board[c] == player and board[b] == ' ') or
            (board[b] == player and board[c] == player and board[a] == ' '))

# Update the check_block_fork function to include the additional logic
def check_block_fork(board, player):
    opponent = 'O' if player == 'X' else 'X'
    forks = check_fork(board, opponent)

    # Check if opponent has two opposite corners
    if board[0] == board[8] == opponent:
        if board[2] == ' ' or board[6] == ' ':
            to_rtn = []
            for i in [1, 3, 5, 7]:
                if board[i] == ' ': to_rtn.append(i)
            return to_rtn
    if board[2] == board[6] == opponent:
        if board[0] == ' ' or board[8] == ' ':
            to_rtn = []
            for i in [1, 3, 5, 7]:
                if board[i] == ' ': to_rtn.append(i)
            return to_rtn

    if len(forks) > 1:
        two_in_a_row_moves = []
        winning_combinations = [(0, 1, 2), (3, 4, 5), (6, 7, 8),  # rows
                                (0, 3, 6), (1, 4, 7), (2, 5, 8),  # columns
                                (0, 4, 8), (2, 4, 6)]  # diagonals
        for i in range(9):
            if board[i] == ' ':
                board[i] = player  # Temporarily place the player's marker
                for a, b, c in winning_combinations:
                    if is_two_in_a_row(board, player, a, b, c):
                        two_in_a_row_moves.append(i)
                        break  # No need to check further combinations for this move
                board[i] = ' '  # Reset the board
        # Return intersection of two_in_a_row_moves and forks
        return list(set(two_in_a_row_moves) & set(forks))
    if len(forks) == 1:
        return forks
    return []

def check_center(board):
    """
    Check if the center is free and return it as a move if it is.
    """
    return [4] if board[4] == ' ' else []

def check_opposite_corner(board, player):
    """
    Check if the opponent is in the corner, and the opposite corner is free, and return all such moves.
    """
    opponent = 'O' if player == 'X' else 'X'
    opposite_corners = [(0, 8), (2, 6), (6, 2), (8, 0)]
    moves = []
    for a, b in opposite_corners:
        if board[a] == opponent and board[b] == ' ':
            moves.append(b)
        if board[b] == opponent and board[a] == ' ':
            moves.append(a)
    return moves

def check_empty_corner(board):
    """
    Check for any empty corners and return all such moves.
    """
    corners = [0, 2, 6, 8]
    return [corner for corner in corners if board[corner] == ' ']

def check_empty_side(board):
    """
    Check for any empty sides and return all such moves.
    """
    sides = [1, 3, 5, 7]
    return [side for side in sides if board[side] == ' ']

# Update Main Function

def get_optimal_moves(board, player):
    """
    Get all optimal moves for the given board and player.
    """
    for check in [check_win, check_block, check_fork, check_block_fork]:
        moves = check(board, player)
        if moves:
            return moves

    for check in [check_center, check_empty_corner, check_empty_side]:
        moves = check(board)
        if moves:
            return moves

    for check in [check_opposite_corner]:
        moves = check(board, player)
        if moves:
            return moves

    return []  # Should never reach this point in a valid game of Tic-Tac-Toe

In [58]:
# Initialize list to store sequences of all finished games
finished_games = []

def is_winner(board, player):
    """
    Check if the player has won on the current board.
    """
    winning_combinations = [(0, 1, 2), (3, 4, 5), (6, 7, 8),  # rows
                            (0, 3, 6), (1, 4, 7), (2, 5, 8),  # columns
                            (0, 4, 8), (2, 4, 6)]  # diagonals
    for a, b, c in winning_combinations:
        if board[a] == board[b] == board[c] == player:
            return True
    return False

def simulate_game(board, move_sequence, next_player):
    """
    Simulate a game of Tic-Tac-Toe recursively.
    
    Parameters:
        board (list): The current game board.
        move_sequence (list): The sequence of moves made so far.
        next_player (str): The player to move next ('X' or 'O').
    """
    
    # Check for game over conditions (win or draw)
    if is_winner(board, 'X') or is_winner(board, 'O'):
        finished_games.append(move_sequence[:])
        return
    if ' ' not in board:
        finished_games.append(move_sequence[:])
        return
    
    # Optimal player's move
    if next_player == 'X':
        optimal_moves = get_optimal_moves(board, next_player)
        for move in optimal_moves:
            board[move] = next_player  # Make the move
            move_sequence.append(move)  # Record the move
            simulate_game(board, move_sequence, 'O')  # Recursive call
            board[move] = ' '  # Undo the move
            move_sequence.pop()  # Remove the last move from the sequence
    
    # All moves for the non-optimal player
    else:
        for move in range(9):
            if board[move] == ' ':
                board[move] = next_player  # Make the move
                move_sequence.append(move)  # Record the move
                simulate_game(board, move_sequence, 'X')  # Recursive call
                board[move] = ' '  # Undo the move
                move_sequence.pop()  # Remove the last move from the sequence

# Initialize board and move_sequence
initial_board = [' ' for _ in range(9)]
initial_move_sequence = []

# Start the simulation with 'X' going first
simulate_game(initial_board, initial_move_sequence, 'X')

# Show some of the finished games to verify correctness
finished_games[:10], len(finished_games)

([[4, 0, 2, 1, 6],
  [4, 0, 2, 3, 6],
  [4, 0, 2, 5, 6],
  [4, 0, 2, 6, 3, 1, 5],
  [4, 0, 2, 6, 3, 5, 8, 1, 7],
  [4, 0, 2, 6, 3, 5, 8, 7, 1],
  [4, 0, 2, 6, 3, 7, 5],
  [4, 0, 2, 6, 3, 8, 5],
  [4, 0, 2, 7, 6],
  [4, 0, 2, 8, 6]],
 488)

In [60]:
import pickle

# Save the list of lists to appropriate file
with open('finished_games.pkl', 'wb') as f:
    pickle.dump(finished_games, f)

# Load the list of lists from appropriate file
with open('finished_games.pkl', 'rb') as f:
    finished_games = pickle.load(f)

In [61]:
finished_games

[[4, 0, 2, 1, 6],
 [4, 0, 2, 3, 6],
 [4, 0, 2, 5, 6],
 [4, 0, 2, 6, 3, 1, 5],
 [4, 0, 2, 6, 3, 5, 8, 1, 7],
 [4, 0, 2, 6, 3, 5, 8, 7, 1],
 [4, 0, 2, 6, 3, 7, 5],
 [4, 0, 2, 6, 3, 8, 5],
 [4, 0, 2, 7, 6],
 [4, 0, 2, 8, 6],
 [4, 0, 6, 1, 2],
 [4, 0, 6, 2, 1, 3, 7],
 [4, 0, 6, 2, 1, 5, 7],
 [4, 0, 6, 2, 1, 7, 8, 3, 5],
 [4, 0, 6, 2, 1, 7, 8, 5, 3],
 [4, 0, 6, 2, 1, 8, 7],
 [4, 0, 6, 3, 2],
 [4, 0, 6, 5, 2],
 [4, 0, 6, 7, 2],
 [4, 0, 6, 8, 2],
 [4, 0, 8, 1, 2, 3, 5],
 [4, 0, 8, 1, 2, 3, 6],
 [4, 0, 8, 1, 2, 5, 6],
 [4, 0, 8, 1, 2, 6, 5],
 [4, 0, 8, 1, 2, 7, 5],
 [4, 0, 8, 1, 2, 7, 6],
 [4, 0, 8, 2, 1, 3, 7],
 [4, 0, 8, 2, 1, 5, 7],
 [4, 0, 8, 2, 1, 6, 7],
 [4, 0, 8, 2, 1, 7, 6, 3, 5],
 [4, 0, 8, 2, 1, 7, 6, 5, 3],
 [4, 0, 8, 3, 6, 1, 7],
 [4, 0, 8, 3, 6, 1, 2],
 [4, 0, 8, 3, 6, 2, 7],
 [4, 0, 8, 3, 6, 5, 7],
 [4, 0, 8, 3, 6, 5, 2],
 [4, 0, 8, 3, 6, 7, 2],
 [4, 0, 8, 5, 6, 1, 7],
 [4, 0, 8, 5, 6, 1, 2],
 [4, 0, 8, 5, 6, 2, 7],
 [4, 0, 8, 5, 6, 3, 7],
 [4, 0, 8, 5, 6, 3, 2],
 [4, 0, 8, 5, 6,

In [57]:
# Initialize list to store sequences of all finished games
finished_games_O_first = []
o_wins = 0

def simulate_game_O_first(board, move_sequence):
    """
    Simulate a game of Tic-Tac-Toe recursively where 'O' goes first.
    
    Parameters:
        board (list): The current game board.
        move_sequence (list): The sequence of moves made so far.
    """
    
    # Check for game over conditions (win or draw)
    if is_winner(board, 'X') or is_winner(board, 'O'):
        # if is_winner(board, 'X'): print("X won!")
        # if is_winner(board, 'O'): 
        #     print("O won!")
        finished_games_O_first.append(move_sequence[:])
        return
    if ' ' not in board:
        finished_games_O_first.append(move_sequence[:])
        return
    
    # All moves for the non-optimal player
    for move in range(9):
        if board[move] == ' ':
            board[move] = 'O'  # Make the move
            move_sequence.append(move)  # Record the move

            # Optimal player's move
            optimal_moves = get_optimal_moves(board, 'X')
            for x_move in optimal_moves:
                board[x_move] = 'X'  # Make the move
                move_sequence.append(x_move)  # Record the move 
                simulate_game_O_first(board, move_sequence)  # Recursive call
                board[x_move] = ' '  # Undo the move
                move_sequence.pop()  # Remove the last move from the sequence
                
            board[move] = ' '  # Undo the move
            move_sequence.pop()  # Remove the last move from the sequence

# Initialize board and move_sequence
initial_board = [' ' for _ in range(9)]
initial_move_sequence = []

# Start the simulation with 'O' going first
simulate_game_O_first(initial_board, initial_move_sequence)

# Show some of the finished games to verify correctness
finished_games_O_first[:10], len(finished_games_O_first)

([[0, 4, 1, 2, 3, 6],
  [0, 4, 1, 2, 5, 6],
  [0, 4, 1, 2, 6, 3, 7, 5],
  [0, 4, 1, 2, 6, 3, 8, 5],
  [0, 4, 1, 2, 7, 6],
  [0, 4, 1, 2, 8, 6],
  [0, 4, 2, 1, 3, 7],
  [0, 4, 2, 1, 5, 7],
  [0, 4, 2, 1, 6, 7],
  [0, 4, 2, 1, 8, 7]],
 656)

In [34]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
import torch

# Revised TicTacToeDataset class
class TicTacToeDataset(Dataset):
    def __init__(self, finished_games_X_first, finished_games_O_first):
        self.data = []
        
        # For games where X went first
        for game in finished_games_X_first:
            for i in range(1, len(game), 2):  # Only X's moves
                sub_seq = [9] + game[:i]
                self.data.append((sub_seq, game[i]))
        
        # For games where X went second
        for game in finished_games_O_first:
            for i in range(0, len(game), 2):  # Only X's moves
                sub_seq = [9] + game[:i+1]
                self.data.append((sub_seq, game[i+1]))
                
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sub_seq, target = self.data[idx]
        # Pad the sub_seq to have length 9
        padded_sub_seq = sub_seq + [10] * (10 - len(sub_seq))
        return torch.tensor(padded_sub_seq, dtype=torch.long).view(-1), torch.tensor(target, dtype=torch.long).view(-1)

# Create Dataset
dataset = TicTacToeDataset(finished_games, finished_games_O_first)

# Create DataLoader
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)

# Fetch a batch to check
j = 800
for i, (X, y) in enumerate(dataloader):
    if i == j:
        print("X:", X)
        print("y:", y)
        break

X: tensor([[ 9,  7,  4,  1, 10, 10, 10, 10, 10, 10],
        [ 9,  7,  4,  1,  0,  3, 10, 10, 10, 10],
        [ 9,  7, 10, 10, 10, 10, 10, 10, 10, 10],
        [ 9,  7,  4,  1, 10, 10, 10, 10, 10, 10]])
y: tensor([[0],
        [8],
        [4],
        [0]])
