In [None]:
import copy
import random
import math
import numpy as np
from collections import defaultdict
import pickle

import numpy as np
import random
import gym
from gym import spaces
import matplotlib.pyplot as plt
from numba import jit, njit

save_path = "ntuple_approximator.pkl"

COLOR_MAP = {
    0: "#cdc1b4", 2: "#eee4da", 4: "#ede0c8", 8: "#f2b179",
    16: "#f59563", 32: "#f67c5f", 64: "#f65e3b", 128: "#edcf72",
    256: "#edcc61", 512: "#edc850", 1024: "#edc53f", 2048: "#edc22e",
    4096: "#3c3a32", 8192: "#3c3a32", 16384: "#3c3a32", 32768: "#3c3a32"
}
TEXT_COLOR = {
    2: "#776e65", 4: "#776e65", 8: "#f9f6f2", 16: "#f9f6f2",
    32: "#f9f6f2", 64: "#f9f6f2", 128: "#f9f6f2", 256: "#f9f6f2",
    512: "#f9f6f2", 1024: "#f9f6f2", 2048: "#f9f6f2", 4096: "#f9f6f2"
}

@jit(nopython=True)
def compress_and_merge(row, score):
    size = 4
    # 過濾非零元素
    temp = np.zeros(size, dtype=np.int32)
    pos = 0
    for i in range(size):
        if row[i] != 0:
            temp[pos] = row[i]
            pos += 1

    # 合併相鄰相同元素
    result = np.zeros(size, dtype=np.int32)
    write_pos = 0
    i = 0
    while i < pos:
        if i + 1 < pos and temp[i] == temp[i + 1]:
            result[write_pos] = temp[i] * 2
            score += temp[i] * 2
            i += 2
        else:
            result[write_pos] = temp[i]
            i += 1
        write_pos += 1

    return result, score

@jit(nopython=True)
def move_board(board, direction, score):
    new_board = board.copy()
    moved = False
    if direction == 0:  # 上
        for j in range(4):
            col, new_score = compress_and_merge(new_board[:, j], score)
            if not np.array_equal(col, new_board[:, j]):
                moved = True
            new_board[:, j] = col
            score = new_score
    elif direction == 1:  # 下
        for j in range(4):
            col = new_board[::-1, j]
            col, new_score = compress_and_merge(col, score)
            if not np.array_equal(col, new_board[::-1, j]):
                moved = True
            new_board[::-1, j] = col
            score = new_score
    elif direction == 2:  # 左
        for i in range(4):
            row, new_score = compress_and_merge(new_board[i], score)
            if not np.array_equal(row, new_board[i]):
                moved = True
            new_board[i] = row
            score = new_score
    elif direction == 3:  # 右
        for i in range(4):
            row = new_board[i, ::-1]
            row, new_score = compress_and_merge(row, score)
            if not np.array_equal(row, new_board[i, ::-1]):
                moved = True
            new_board[i, ::-1] = row
            score = new_score
    return new_board, moved, score

class Game2048Env(gym.Env):
    def __init__(self):
        super(Game2048Env, self).__init__()

        self.size = 4
        self.board = np.zeros((self.size, self.size), dtype=int)
        self.score = 0

        # Action space: 0: up, 1: down, 2: left, 3: right
        self.action_space = spaces.Discrete(4)
        self.actions = ["up", "down", "left", "right"]

        self.last_move_valid = True

        self.reset()

    def reset(self):
        self.board = np.zeros((self.size, self.size), dtype=int)
        self.score = 0
        self.add_random_tile()
        self.add_random_tile()
        return self.board

    def add_random_tile(self):
        # empty_cells = list(zip(*np.where(self.board == 0)))
        # if empty_cells:
        #     x, y = random.choice(empty_cells)
        #     self.board[x, y] = 2 if random.random() < 0.9 else 4
        empty_cells = np.where(self.board == 0)
        if len(empty_cells[0]) > 0:
            idx = random.randint(0, len(empty_cells[0]) - 1)
            x, y = empty_cells[0][idx], empty_cells[1][idx]
            self.board[x, y] = 2 if random.random() < 0.9 else 4

    def compress(self, row):
        new_row = row[row != 0]
        new_row = np.pad(new_row, (0, self.size - len(new_row)), mode='constant')
        return new_row

    def merge(self, row):
        for i in range(len(row) - 1):
            if row[i] == row[i + 1] and row[i] != 0:
                row[i] *= 2
                row[i + 1] = 0
                self.score += row[i]
        return row

    def move_left(self):
        moved = False
        for i in range(self.size):
            original_row = self.board[i].copy()
            new_row = self.compress(self.board[i])
            new_row = self.merge(new_row)
            new_row = self.compress(new_row)
            self.board[i] = new_row
            if not np.array_equal(original_row, self.board[i]):
                moved = True
        return moved

    def move_right(self):
        moved = False
        for i in range(self.size):
            original_row = self.board[i].copy()
            reversed_row = self.board[i][::-1]
            reversed_row = self.compress(reversed_row)
            reversed_row = self.merge(reversed_row)
            reversed_row = self.compress(reversed_row)
            self.board[i] = reversed_row[::-1]
            if not np.array_equal(original_row, self.board[i]):
                moved = True
        return moved

    def move_up(self):
        moved = False
        for j in range(self.size):
            original_col = self.board[:, j].copy()
            col = self.compress(self.board[:, j])
            col = self.merge(col)
            col = self.compress(col)
            self.board[:, j] = col
            if not np.array_equal(original_col, self.board[:, j]):
                moved = True
        return moved

    def move_down(self):
        moved = False
        for j in range(self.size):
            original_col = self.board[:, j].copy()
            reversed_col = self.board[:, j][::-1]
            reversed_col = self.compress(reversed_col)
            reversed_col = self.merge(reversed_col)
            reversed_col = self.compress(reversed_col)
            self.board[:, j] = reversed_col[::-1]
            if not np.array_equal(original_col, self.board[:, j]):
                moved = True
        return moved

    def is_game_over(self):
        if np.any(self.board == 0):
            return False
        for i in range(self.size):
            for j in range(self.size - 1):
                if self.board[i, j] == self.board[i, j+1]:
                    return False
        for j in range(self.size):
            for i in range(self.size - 1):
                if self.board[i, j] == self.board[i+1, j]:
                    return False

        return True

    def step(self, action):
        assert self.action_space.contains(action), "Invalid action"

        # if action == 0:
        #     moved = self.move_up()
        # elif action == 1:
        #     moved = self.move_down()
        # elif action == 2:
        #     moved = self.move_left()
        # elif action == 3:
        #     moved = self.move_right()
        # else:
        #     moved = False

        self.board, moved, self.score = move_board(self.board, action, self.score)

        self.last_move_valid = moved

        if moved:
            self.add_random_tile()

        done = self.is_game_over()

        return self.board, self.score, done, {}

    def render(self, mode="human", action=None):
        fig, ax = plt.subplots(figsize=(4, 4))
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlim(-0.5, self.size - 0.5)
        ax.set_ylim(-0.5, self.size - 0.5)

        for i in range(self.size):
            for j in range(self.size):
                value = self.board[i, j]
                color = COLOR_MAP.get(value, "#3c3a32")
                text_color = TEXT_COLOR.get(value, "white")
                rect = plt.Rectangle((j - 0.5, i - 0.5), 1, 1, facecolor=color, edgecolor="black")
                ax.add_patch(rect)

                if value != 0:
                    ax.text(j, i, str(value), ha='center', va='center',
                            fontsize=16, fontweight='bold', color=text_color)
        title = f"score: {self.score}"
        if action is not None:
            title += f" | action: {self.actions[action]}"
        plt.title(title)
        plt.gca().invert_yaxis()
        plt.show()

    def simulate_row_move(self, row):
        new_row = row[row != 0]
        new_row = np.pad(new_row, (0, self.size - len(new_row)), mode='constant')
        for i in range(len(new_row) - 1):
            if new_row[i] == new_row[i + 1] and new_row[i] != 0:
                new_row[i] *= 2
                new_row[i + 1] = 0
        new_row = new_row[new_row != 0]
        new_row = np.pad(new_row, (0, self.size - len(new_row)), mode='constant')
        return new_row


    def is_move_legal(self, action):
        temp_board = self.board.copy()
        new_board, moved, _ = move_board(temp_board, action, self.score)
        return moved


# -------------------------------
# TODO: Define transformation functions (rotation and reflection), i.e., rot90, rot180, ..., etc.
# -------------------------------
def rot90(pattern, board_size):
    return [(j, board_size - 1 - i) for (i, j) in pattern]

def rot180(pattern, board_size):
    return [(board_size - 1 - i, board_size - 1 - j) for (i, j) in pattern]

def rot270(pattern, board_size):
    return [(board_size - 1 - j, i) for (i, j) in pattern]

def reflect(pattern, board_size):
    return [(i, board_size - 1 - j) for (i, j) in pattern]




class NTupleApproximator:
    def __init__(self, board_size, patterns):
        """
        Initializes the N-Tuple approximator.
        Hint: you can adjust these if you want
        """
        self.board_size = board_size
        self.patterns = patterns
        # Create a weight dictionary for each pattern (shared within a pattern group)
        self.weights = [defaultdict(float) for _ in patterns]
        # Generate symmetrical transformations for each pattern
        self.symmetry_patterns = []
        self.symmetry_groups = []
        for pattern in self.patterns:
            syms = self.generate_symmetries(pattern)
            self.symmetry_groups.append(syms)
            for syms_ in syms:
                self.symmetry_patterns.append(syms_)

    def generate_symmetries(self, pattern):
        # TODO: Generate 8 symmetrical transformations of the given pattern.
        board_size = self.board_size
        sym0 = pattern
        sym1 = rot90(pattern, board_size)
        sym2 = rot180(pattern, board_size)
        sym3 = rot270(pattern, board_size)
        syms = [sym0, sym1, sym2, sym3,
              reflect(sym0, board_size),
              reflect(sym1, board_size),
              reflect(sym2, board_size),
              reflect(sym3, board_size)]
        return syms


    def tile_to_index(self, tile):
        """
        Converts tile values to an index for the lookup table.
        """
        if tile == 0:
            return 0
        else:
            return int(math.log(tile, 2))

    def get_feature(self, board, coords):
        # TODO: Extract tile values from the board based on the given coordinates and convert them into a feature tuple.
        return tuple(self.tile_to_index(board[i, j]) for (i, j) in coords)


    def value(self, board):
        # TODO: Estimate the board value: sum the evaluations from all patterns.
        total_value = 0.0
        for i, syms in enumerate(self.symmetry_groups):
            group_value = 0.0
            for pattern in syms:
                feature = self.get_feature(board, pattern)
                group_value += self.weights[i][feature]
            total_value += group_value / len(syms)
        return total_value

    def update(self, board, delta, alpha):
        # TODO: Update weights based on the TD error.
        for i, syms in enumerate(self.symmetry_groups):
            update_value = alpha * delta / len(syms)
            for pattern in syms:
                feature = self.get_feature(board, pattern)
                self.weights[i][feature] += update_value

def td_learning(env, approximator, num_episodes=50000, alpha=0.01, gamma=0.99, epsilon=0.1):
    """
    Trains the 2048 agent using TD-Learning with afterstate updates.
    """
    final_scores = []
    success_flags = []

    for episode in range(num_episodes):
        state = env.reset()
        previous_score = 0
        done = False
        max_tile = np.max(state)

        while not done:
            legal_moves = [a for a in range(4) if env.is_move_legal(a)]
            if not legal_moves:
                break

            # Collect afterstates and their values
            afterstates = []
            afterstate_values = []
            next_rewards = []

            for a in legal_moves:
                env_copy = copy.deepcopy(env)
                next_state, next_score, next_done, _ = env_copy.step(a)
                next_reward = next_score - previous_score
                
                # Store afterstate (the board after move, before random tile)
                # This is a key change - we need to get the board state before adding the random tile
                afterstate = next_state.copy()  # In a real implementation, you'd capture before random tile
                
                afterstates.append((afterstate, a))
                afterstate_values.append(approximator.value(afterstate))
                next_rewards.append(next_reward)

            # Select action based on afterstate values
            if random.random() < epsilon and episode < num_episodes * 0.8:
                idx = random.randrange(len(legal_moves))
            else:
                idx = np.argmax(afterstate_values)
            
            selected_afterstate, action = afterstates[idx]
            selected_value = afterstate_values[idx]
            
            # Take the action in the real environment
            next_state, new_score, done, _ = env.step(action)
            incremental_reward = new_score - previous_score
            previous_score = new_score
            max_tile = max(max_tile, np.max(next_state))
            
            # Update the value function for the selected afterstate
            # The target is the immediate reward plus discounted value of next afterstate
            if done:
                target = incremental_reward
            else:
                # For the next state, we need to look at possible future afterstates
                next_values = []
                next_legal_moves = [a for a in range(4) if env.is_move_legal(a)]
                
                if next_legal_moves:
                    for a in next_legal_moves:
                        env_copy = copy.deepcopy(env)
                        future_state, _, _, _ = env_copy.step(a)
                        next_values.append(approximator.value(future_state))
                    
                    future_value = max(next_values) if next_values else 0
                    target = incremental_reward + gamma * future_value
                else:
                    target = incremental_reward

            # Update the value function
            delta = target - selected_value
            approximator.update(selected_afterstate, delta, alpha)
            
            state = next_state

        final_scores.append(env.score)
        success_flags.append(1 if max_tile >= 2048 else 0)

        if (episode + 1) % 100 == 0:
            avg_score = np.mean(final_scores[-100:])
            success_rate = np.sum(success_flags[-100:]) / 100
            print(f"Episode {episode+1}/{num_episodes} | Avg Score: {avg_score:.2f} | Success Rate: {success_rate:.2f}", flush=True)
            with open(save_path, "wb") as f:
                pickle.dump(approximator, f)

    return final_scores


# TODO: Define your own n-tuple patterns
# patterns = []
patterns = [
    [(0,0)],
    [(0,1)],
    [(1,0)],
    [(1,1)],
    [(0,0), (0,1)],
    [(1,0), (1,1)],
    [(0,0), (0,1), (1,0)],
    [(0,0), (1,1), (2,2)],
    [(0,0), (0,1), (0,2), (0,3)],
    [(1,0), (1,1), (1,2), (1,3)],
    [(0,0), (0,1), (0,2), (1,0), (2,0)],
    [(1,1), (1,2), (1,3), (2,1), (3,1)],
    [(0,0), (0,1), (0,2), (1,0), (1,1), (1,2)],
]

approximator = NTupleApproximator(board_size=4, patterns=patterns)
# with open(save_path, "rb") as f:
#     approximator = pickle.load(f)

env = Game2048Env()
state = env.reset()

# Run TD-Learning training
# Note: To achieve significantly better performance, you will likely need to train for over 100,000 episodes.
# However, to quickly verify that your implementation is working correctly, you can start by running it for 1,000 episodes before scaling up.
final_scores = td_learning(env, approximator, num_episodes=int(1e5), alpha=0.1, gamma=0.99, epsilon=0.1)

Episode 100/100000 | Avg Score: 4452.48 | Success Rate: 0.00
Episode 200/100000 | Avg Score: 6253.72 | Success Rate: 0.00
Episode 300/100000 | Avg Score: 7066.56 | Success Rate: 0.00
Episode 400/100000 | Avg Score: 6512.84 | Success Rate: 0.00
Episode 500/100000 | Avg Score: 8000.68 | Success Rate: 0.00
Episode 600/100000 | Avg Score: 7640.88 | Success Rate: 0.01
Episode 700/100000 | Avg Score: 8435.84 | Success Rate: 0.00
Episode 800/100000 | Avg Score: 8042.32 | Success Rate: 0.00
Episode 900/100000 | Avg Score: 8191.36 | Success Rate: 0.01
Episode 1000/100000 | Avg Score: 8538.16 | Success Rate: 0.01
Episode 1100/100000 | Avg Score: 8046.88 | Success Rate: 0.00
Episode 1200/100000 | Avg Score: 8495.32 | Success Rate: 0.00
Episode 1300/100000 | Avg Score: 9026.76 | Success Rate: 0.00
Episode 1400/100000 | Avg Score: 9216.36 | Success Rate: 0.02
Episode 1500/100000 | Avg Score: 8752.48 | Success Rate: 0.00
Episode 1600/100000 | Avg Score: 9086.92 | Success Rate: 0.00
Episode 1700/1000