In [None]:
import gymnasium as gym
import gymnasium_2048
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque, defaultdict, Counter
import random

# --- Setup --- 
# create and fully unwrap (removes TimeLimit wrapper)
env = gym.make("gymnasium_2048/TwentyFortyEight-v0")
env = env.unwrapped

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Hyperparameters ---
GAMMA = 0.99
LR = 1e-3
EPSILON_START = 1.0
EPSILON_END = 0.05
EPSILON_DECAY = 0.995
BATCH_SIZE = 64
MEMORY_SIZE = 10000
TARGET_UPDATE_FREQ = 10
ROLLING_WINDOW = 20
CONVERGENCE_SCORE = 2000
MAX_EPISODES = 100000
WARMUP_MEMORY = 1000
DEBUG_EPISODES = 0

ACTION_MAP = {0: "Up", 1: "Right", 2: "Down", 3: "Left"}

# --- DQN Model ---
class DQN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(16, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 4)
        )
    def forward(self, x):
        return self.net(x)

# --- Preprocess ---
def preprocess(state):
    # same as before: log2 scaling into float32
    if state.ndim == 3:
        exp_board = np.argmax(state, axis=2)
        board = exp_board.astype(np.float32)
    elif state.ndim == 2:
        board = np.where(state == 0, 0, np.log2(state)).astype(np.float32)
    else:
        raise ValueError(f"Unexpected state shape: {state.shape}")
    return torch.tensor(board.flatten(), dtype=torch.float32, device=device)

# --- Decode for debug & int‐casting to avoid uint8 wraparound ---
def decode_obs(obs):
    if obs.ndim == 3:
        exp_board = np.argmax(obs, axis=2)
        # cast to int32 so merges >256 don’t wrap
        return np.where(exp_board == 0, 0, 2 ** exp_board).astype(np.int32)
    # obs may be uint8; cast it
    return obs.astype(np.int32)

# --- Slide logic for valid moves ---
def slide_row(row):
    nonzeros = [int(v) for v in row if v]  # ensure Python ints
    merged = []
    i = 0
    while i < len(nonzeros):
        if i + 1 < len(nonzeros) and nonzeros[i] == nonzeros[i+1]:
            merged.append(nonzeros[i] * 2)
            i += 2
        else:
            merged.append(nonzeros[i])
            i += 1
    merged += [0] * (len(row) - len(merged))
    return merged

def slide(board, action):
    # board is assumed int32 numpy array
    b = board.copy().astype(np.int32)
    if action == 0:  # Up
        b = b.T
        for i in range(4):
            b[i] = slide_row(list(b[i]))
        b = b.T
    elif action == 2:  # Down
        b = b.T
        for i in range(4):
            rev = list(reversed(b[i]))
            slid = slide_row(rev)
            b[i] = list(reversed(slid))
        b = b.T
    elif action == 3:  # Left
        for i in range(4):
            b[i] = slide_row(list(b[i]))
    elif action == 1:  # Right
        for i in range(4):
            rev = list(reversed(b[i]))
            slid = slide_row(rev)
            b[i] = list(reversed(slid))
    return b

# --- Valid actions without env.step ---
def valid_actions(obs):
    board = decode_obs(obs)
    # try each action on integer board
    return [a for a in range(4) if not np.array_equal(board, slide(board, a))]

# --- Select action ---
def sample_action(state, obs, explore=True):
    valids = valid_actions(obs)
    if not valids:
        return None
    if explore and random.random() < epsilon:
        return random.choice(valids)
    with torch.no_grad():
        q_vals = policy_net(state).cpu().numpy()
        return max(valids, key=lambda a: q_vals[a])

# --- Networks & memory ---
policy_net = DQN().to(device)
target_net = DQN().to(device)
target_net.load_state_dict(policy_net.state_dict())
optimizer = optim.Adam(policy_net.parameters(), lr=LR)
memory    = deque(maxlen=MEMORY_SIZE)

epsilon = EPSILON_START
score_history = []
move_counter_global = defaultdict(int)
episode_bucket_counts = []

# --- Optimize ---
def optimize():
    if len(memory) < BATCH_SIZE:
        return
    batch = random.sample(memory, BATCH_SIZE)
    states, actions, rewards, next_states, dones = zip(*batch)
    states      = torch.stack(states)
    actions     = torch.tensor(actions).unsqueeze(1).to(device)
    rewards     = torch.tensor(rewards).float().to(device)
    next_states = torch.stack(next_states)
    dones       = torch.tensor(dones).float().to(device)

    q      = policy_net(states).gather(1, actions).squeeze()
    q_next = target_net(next_states).max(1)[0].detach()
    target = rewards + GAMMA * q_next * (1 - dones)

    loss = nn.functional.mse_loss(q, target)
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(policy_net.parameters(), 1.0)
    optimizer.step()

# --- Training Loop ---
episode = 0
while episode < MAX_EPISODES:
    obs, _ = env.reset()
    state   = preprocess(obs)
    running_score = 0
    move_counter_episode = defaultdict(int)

    step = 0
    while True:
        action = sample_action(state, obs)
        if action is None:
            break

        move_counter_episode[ACTION_MAP[action]] += 1
        move_counter_global[ACTION_MAP[action]] += 1

        next_obs, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated
        next_state = preprocess(next_obs)

        memory.append((state, action, reward, next_state, done))
        state, obs = next_state, next_obs

        # force Python int accumulation
        running_score = int(running_score) + int(reward)

        if len(memory) > WARMUP_MEMORY:
            optimize()
        if done:
            break
        step += 1

    # Episode summary
    print(f"Episode {episode}, Score:{running_score}, Epsilon:{epsilon:.3f}")

    # Target network update
    if episode % TARGET_UPDATE_FREQ == 0:
        target_net.load_state_dict(policy_net.state_dict())

    epsilon = max(EPSILON_END, epsilon * EPSILON_DECAY)
    score_history.append(running_score)

    # Check convergence
    if (episode >= ROLLING_WINDOW and
        np.mean(score_history[-ROLLING_WINDOW:]) >= CONVERGENCE_SCORE):
        print(f"\n✅ Converged at episode {episode}")
        break

    episode += 1

# --- Final Stats ---
print("\n=== Move Frequencies Overall ===")
for m in ACTION_MAP.values():
    print(f"{m}: {move_counter_global[m]}")

# --- Evaluation (greedy, uncapped) ---
eval_scores, tile_counts, move_counter_eval = [], Counter(), defaultdict(int)
for _ in range(100):
    obs, _ = env.reset()
    state = preprocess(obs)
    total = 0
    while True:
        action = sample_action(state, obs, explore=False)
        if action is None:
            break
        move_counter_eval[ACTION_MAP[action]] += 1

        old_board = decode_obs(obs)
        obs, r, term, trunc, _ = env.step(action)
        state = preprocess(obs)
        total = int(total) + int(r)
        # stop only on true terminal
        if term or trunc or np.array_equal(decode_obs(obs), old_board):
            break

    eval_scores.append(int(total))
    tile_counts[2**int(np.max(decode_obs(obs)))] += 1

print(f"Eval Avg: {np.mean(eval_scores):.2f}, Max: {np.max(eval_scores)}")
print("Eval Tiles:", tile_counts)
print("Eval Moves:", move_counter_eval)


Episode 0, Score:1668, Epsilon:1.000
Episode 1, Score:1016, Epsilon:0.995
Episode 2, Score:576, Epsilon:0.990
Episode 3, Score:1996, Epsilon:0.985
Episode 4, Score:760, Epsilon:0.980
Episode 5, Score:580, Epsilon:0.975
Episode 6, Score:820, Epsilon:0.970
Episode 7, Score:544, Epsilon:0.966
Episode 8, Score:1164, Epsilon:0.961
Episode 9, Score:1652, Epsilon:0.956
Episode 10, Score:364, Epsilon:0.951
Episode 11, Score:1212, Epsilon:0.946
Episode 12, Score:1652, Epsilon:0.942
Episode 13, Score:1484, Epsilon:0.937
Episode 14, Score:2064, Epsilon:0.932
Episode 15, Score:1356, Epsilon:0.928
Episode 16, Score:1472, Epsilon:0.923
Episode 17, Score:728, Epsilon:0.918
Episode 18, Score:1124, Epsilon:0.914
Episode 19, Score:892, Epsilon:0.909
Episode 20, Score:340, Epsilon:0.905
Episode 21, Score:1516, Epsilon:0.900
Episode 22, Score:1380, Epsilon:0.896
Episode 23, Score:660, Epsilon:0.891
Episode 24, Score:624, Epsilon:0.887
Episode 25, Score:1032, Epsilon:0.882
Episode 26, Score:1344, Epsilon:0

  score += 2 ** (board[row, col] + 1)


Episode 3358, Score:1696, Epsilon:0.050
Episode 3359, Score:1904, Epsilon:0.050
Episode 3360, Score:612, Epsilon:0.050
Episode 3361, Score:2448, Epsilon:0.050
Episode 3362, Score:1432, Epsilon:0.050
Episode 3363, Score:1148, Epsilon:0.050
Episode 3364, Score:580, Epsilon:0.050
Episode 3365, Score:588, Epsilon:0.050
Episode 3366, Score:2828, Epsilon:0.050
Episode 3367, Score:752, Epsilon:0.050
Episode 3368, Score:1528, Epsilon:0.050
Episode 3369, Score:576, Epsilon:0.050
Episode 3370, Score:728, Epsilon:0.050
Episode 3371, Score:1944, Epsilon:0.050
Episode 3372, Score:1088, Epsilon:0.050
Episode 3373, Score:556, Epsilon:0.050
Episode 3374, Score:1704, Epsilon:0.050
Episode 3375, Score:1336, Epsilon:0.050
Episode 3376, Score:1804, Epsilon:0.050
Episode 3377, Score:1000, Epsilon:0.050
Episode 3378, Score:744, Epsilon:0.050
Episode 3379, Score:2444, Epsilon:0.050
Episode 3380, Score:1352, Epsilon:0.050
Episode 3381, Score:1096, Epsilon:0.050
Episode 3382, Score:600, Epsilon:0.050
Episode 3

In [20]:
# --- Evaluation (greedy, uncapped) ---
eval_scores, tile_counts, move_counter_eval = [], Counter(), defaultdict(int)
for _ in range(100):
    obs, _ = env.reset()
    state = preprocess(obs)
    total = 0
    while True:
        action = sample_action(state, obs, explore=False)
        if action is None:
            break
        move_counter_eval[ACTION_MAP[action]] += 1

        old_board = decode_obs(obs)
        obs, r, term, trunc, _ = env.step(action)
        state = preprocess(obs)
        total = int(total) + int(r)
        # stop only on true terminal
        if term or trunc or np.array_equal(decode_obs(obs), old_board):
            break

    eval_scores.append(int(total))
    tile_counts[2**int(np.max(decode_obs(obs)))] += 1

print(f"Eval Avg: {np.mean(eval_scores):.2f}, Max: {np.max(eval_scores)}")
print("Eval Tiles:", tile_counts)
print("Eval Moves:", move_counter_eval)

  self.total_score += self.step_score


Eval Avg: 1851.00, Max: 4696
Eval Tiles: Counter({340282366920938463463374607431768211456: 37, 115792089237316195423570985008687907853269984665640564039457584007913129639936: 36, 18446744073709551616: 19, 4294967296: 5, 13407807929942597099574024998205846127479365820592393377723561443721764030073546976801874298166903427690031858186486050853753882811946569946433649006084096: 3})
Eval Moves: defaultdict(<class 'int'>, {'Up': 903, 'Down': 7527, 'Left': 6380, 'Right': 3501})
