In [6]:
# === Cell 1: All Definitions & Instantiation ===

# 1) Standard imports
import gymnasium as gym
import gymnasium_2048
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque, defaultdict, Counter
import random

# 2) Hyperparameters & device/env
GAMMA               = 0.99
LR                  = 1e-3
EPSILON_START       = 1.0
EPSILON_END         = 0.05
EPSILON_DECAY       = 0.9995
BATCH_SIZE          = 128
MEMORY_SIZE         = 100_000
TARGET_UPDATE_FREQ  = 10
ROLLING_WINDOW      = 20
CONVERGENCE_SCORE   = 2000
MAX_EPISODES        = 10000
WARMUP_MEMORY       = 10_000
ACTION_MAP          = {0:"Up",1:"Right",2:"Down",3:"Left"}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env    = gym.make("gymnasium_2048/TwentyFortyEight-v0").unwrapped

# 3) DQN network definition
# class DQN(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.net = nn.Sequential(
#             nn.Linear(16,256),
#             nn.ReLU(),
#             nn.Linear(256,256),
#             nn.ReLU(),
#             nn.Linear(256,4)
#         )
#     def forward(self, x):
#         return self.net(x)
class DuelingDQN(nn.Module):
    def __init__(self):
        super().__init__()
        # shared feature extractor
        self.feature = nn.Sequential(
            nn.Linear(16, 256),
            nn.ReLU(),
        )
        # advantage branch
        self.adv = nn.Sequential(
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 4),
        )
        # value branch
        self.val = nn.Sequential(
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
        )

    def forward(self, x):
        f = self.feature(x)
        A = self.adv(f)                   # shape [B,4]
        V = self.val(f).expand_as(A)      # shape [B,4]
        # combine into Q
        return V + (A - A.mean(dim=1, keepdim=True))

# 4) Preprocess & decode helpers
def preprocess(state):
    if state.ndim == 3:
        exp = np.argmax(state,axis=2)
        board = exp.astype(np.float32)
    else:
        board = np.where(state==0,0,np.log2(state)).astype(np.float32)
    return torch.tensor(board.flatten(),dtype=torch.float32,device=device)

def decode_obs(obs):
    if obs.ndim==3:
        exp = np.argmax(obs,axis=2)
        return np.where(exp==0,0,2**exp).astype(np.int32)
    return obs.astype(np.int32)

# 5) Slide & valid‐actions logic
def slide_row(row):
    nz, merged, i = [int(v) for v in row if v], [], 0
    while i < len(nz):
        if i+1<len(nz) and nz[i]==nz[i+1]:
            merged.append(nz[i]*2); i+=2
        else:
            merged.append(nz[i]); i+=1
    merged += [0]*(len(row)-len(merged))
    return merged

def slide(board, action):
    b = board.copy().astype(np.int32)
    if action == 0:            # Up
        b = b.T
        for i in range(4): b[i] = slide_row(b[i])
        b = b.T
    elif action == 2:          # Down
        b = b.T
        for i in range(4):
            rev = list(reversed(b[i]))
            b[i] = list(reversed(slide_row(rev)))
        b = b.T
    elif action == 3:          # Left
        for i in range(4): b[i] = slide_row(b[i])
    else:                       # Right
        for i in range(4):
            rev = list(reversed(b[i]))
            b[i] = list(reversed(slide_row(rev)))
    return b                  # <-- return the new board!

def valid_actions(obs):
    board = decode_obs(obs)
    return [a for a in range(4) if not np.array_equal(board, slide(board, a))]

# 6) Epsilon‐greedy action selector
def sample_action(state, obs, explore=True):
    valids = valid_actions(obs)
    if not valids:
        return None

    if explore and random.random() < epsilon:
        return random.choice(valids)

    with torch.no_grad():
        # make it [1,16] instead of [16]
        state_batched = state.unsqueeze(0)                # shape [1,16]
        q_vals = policy_net(state_batched)                # shape [1,4]
        q_vals = q_vals.cpu().numpy()[0]                  # shape [4]
    return max(valids, key=lambda a: q_vals[a])

# 7) Symmetry replay‐buffer
def transform_action(a: int, k: int, flipped: bool) -> int:
    # first rotate k×90°
    a_rot = (a + k) % 4
    # then if we flipped horizontally, Left<->Right swap
    if flipped:
        if a_rot == 1:   a_rot = 3
        elif a_rot == 3: a_rot = 1
    return a_rot

# --- SumTree helper for PER ---
class SumTree:
    def __init__(self, capacity: int):
        self.capacity = capacity
        self.tree = np.zeros(2*capacity - 1)
        self.data = [None]*capacity
        self.write = 0
        self.n_entries = 0

    def _propagate(self, idx, change):
        parent = (idx - 1)//2
        self.tree[parent] += change
        if parent != 0:
            self._propagate(parent, change)

    def update(self, idx, priority):
        change = priority - self.tree[idx]
        self.tree[idx] = priority
        self._propagate(idx, change)

    def add(self, priority, data):
        idx = self.write + self.capacity - 1
        self.data[self.write] = data
        self.update(idx, priority)
        self.write = (self.write + 1) % self.capacity
        self.n_entries = min(self.n_entries + 1, self.capacity)

    def get(self, s: float):
        idx = 0
        while True:
            left = 2*idx + 1
            right = left + 1
            if left >= len(self.tree):
                leaf = idx
                break
            if s <= self.tree[left]:
                idx = left
            else:
                s -= self.tree[left]
                idx = right
        data_idx = leaf - self.capacity + 1
        return leaf, self.tree[leaf], self.data[data_idx]

    @property
    def total(self):
        return self.tree[0]


# --- Prioritized Replay Buffer ---
class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha=0.4, beta_start=0.4, beta_frames=500_000):
        self.tree = SumTree(capacity)
        self.alpha = alpha
        self.beta_start = beta_start
        self.beta_frames = beta_frames
        self.frame = 1
        self.eps = 1e-6

    def add(self, state, action, reward, next_state, done):
        max_p = np.max(self.tree.tree[-self.tree.capacity:])
        if max_p == 0: max_p = 1.0
        self.tree.add(max_p, (state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch, idxs, priors = [], [], []
        segment = self.tree.total / batch_size
        for i in range(batch_size):
            a, b = segment*i, segment*(i+1)
            s = random.uniform(a, b)
            idx, p, data = self.tree.get(s)
            batch.append(data)
            idxs.append(idx)
            priors.append(p)

        probs = np.array(priors) / self.tree.total
        beta = min(1.0, self.beta_start + (1.0 - self.beta_start)*(self.frame/self.beta_frames))
        self.frame += 1
        weights = (self.tree.n_entries * probs)**(-beta)
        weights /= weights.max()

        S, A, R, S2, D = zip(*batch)
        return (
            torch.stack([preprocess(s) for s in S]),
            torch.tensor(A, dtype=torch.long,   device=device),
            torch.tensor(R, dtype=torch.float,  device=device),
            torch.stack([preprocess(s2) for s2 in S2]),
            torch.tensor(D, dtype=torch.float,  device=device),
            idxs,
            torch.tensor(weights, dtype=torch.float, device=device)
        )

    def update_priorities(self, idxs, td_errors):
        for idx, err in zip(idxs, td_errors):
            p = (abs(err) + self.eps)**self.alpha
            self.tree.update(idx, p)

    def __len__(self):
        return self.tree.n_entries

# --- Now subclass to combine symmetry + PER ---
class SymmetricPERBuffer(PrioritizedReplayBuffer):
    def add(self, state, action, reward, next_state, done):
        # state & next_state are raw board (4×4 int arrays)
        for k in range(4):
            rot_s  = np.rot90(state,  k)
            rot_ns = np.rot90(next_state, k)
            a_rot  = transform_action(action, k, False)
            # append the k-rotation
            super().add(rot_s, a_rot, reward, rot_ns, done)

            # now mirror horizontally
            f_s  = np.fliplr(rot_s)
            f_ns = np.fliplr(rot_ns)
            a_fl = transform_action(action, k, True)
            super().add(f_s, a_fl, reward, f_ns, done)

# --- N-Step Wrapper ---
class NStepWrapper:
    def __init__(self, buffer, n=3, gamma=0.99):
        self.buf, self.n, self.gamma = buffer, n, gamma
        self.nq = deque()

    def add(self, s, a, r, s2, done):
        self.nq.append((s, a, r))
        if len(self.nq) < self.n:
            return
        R = sum(self.nq[i][2] * (self.gamma**i) for i in range(self.n))
        s0, a0, _ = self.nq[0]
        self.buf.add(s0, a0, R, s2, done)
        self.nq.popleft()

    def sample(self, *args): 
        return self.buf.sample(*args)

    def __len__(self):
        return len(self.buf)
    
    def update_priorities(self, idxs, td_errors):
        # delegate priority updates
        self.buf.update_priorities(idxs, td_errors)

# 8) Instantiate nets, optimizer & buffer
# policy_net = DQN().to(device)
# target_net = DQN().to(device)
# target_net.load_state_dict(policy_net.state_dict())
policy_net = DuelingDQN().to(device)
target_net = DuelingDQN().to(device)
target_net.load_state_dict(policy_net.state_dict())
optimizer   = optim.Adam(policy_net.parameters(), lr=LR)
base_buf = SymmetricPERBuffer(capacity=MEMORY_SIZE, alpha=0.6, beta_start=0.3, beta_frames=500_000)
memory   = NStepWrapper(base_buf, n=3, gamma=GAMMA)

# 9) Initialize training state
epsilon             = EPSILON_START
score_history       = []
move_counter_global = defaultdict(int)


In [7]:
# --- 3) Training Loop (with enhanced eval & smoothing) ---
global_max_tile = 0  # track the single largest tile we’ve ever seen (either during training or eval)

for episode in range(MAX_EPISODES):
    obs, _ = env.reset()
    state = preprocess(obs)
    running_score = 0

    # reset per-episode max
    max_tile_in_episode = int(np.max(decode_obs(obs)))

    # rolling‐window train‐score smoothing
    if len(score_history) >= 50:
        ma50 = np.mean(score_history[-50:])
    else:
        ma50 = np.mean(score_history) if score_history else 0

    while True:
        # select action
        action = sample_action(state, obs, explore=True)
        if action is None:
            break

        move_counter_global[ACTION_MAP[action]] += 1

        # step env
        next_obs, reward, term, trunc, info = env.step(action)
        done = term or trunc
        next_state = preprocess(next_obs)

        # let the buffer do its 8‐way augmentation
        memory.add(decode_obs(obs), action, reward, decode_obs(next_obs), done)

        state, obs = next_state, next_obs
        running_score += int(reward)

        # update episode max
        t = int(np.max(decode_obs(next_obs)))
        max_tile_in_episode = max(max_tile_in_episode, t)

        # also update the global max
        global_max_tile = max(global_max_tile, max_tile_in_episode)

        # optimize
        if len(memory) > WARMUP_MEMORY:
            s_b, a_b, r_b, ns_b, d_b, idxs, is_weights = memory.sample(BATCH_SIZE)

            # current Q and next Q
            q = policy_net(s_b).gather(1, a_b.unsqueeze(1)).squeeze()
            
            # 1) select next‐actions via the *policy* net
            next_actions = policy_net(ns_b).argmax(dim=1, keepdim=True)      # shape [B,1]
            # 2) evaluate those actions under the *target* net
            q_next = target_net(ns_b).gather(1, next_actions).squeeze()     # shape [B]
            # 3) standard TD target
            target = r_b + GAMMA * q_next * (1 - d_b)

            # compute TD-errors for priority update
            td_errors = (q - target).detach().abs().cpu().numpy()

            # weighted MSE loss
            loss = (is_weights * (q - target)**2).mean()

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(policy_net.parameters(), 1.0)
            optimizer.step()

            # update priorities in the tree
            memory.update_priorities(idxs, td_errors)

        if done:
            break

    # end‐of‐episode housekeeping
    if episode % TARGET_UPDATE_FREQ == 0:
        target_net.load_state_dict(policy_net.state_dict())
    epsilon = max(EPSILON_END, epsilon * EPSILON_DECAY)
    score_history.append(running_score)

    # 1) Training print + 50‐episode MA
    print(f"Episode {episode:4d} | Score {running_score:5d} | MaxTile {max_tile_in_episode:4d} | MA50 {ma50:6.1f} | ε {epsilon:.3f}")

    # 2) Enhanced periodic evaluation every 100 episodes
    if episode > 0 and episode % 100 == 0:
        eval_scores, tiles, move_counter_eval = [], [], defaultdict(int)
        for _ in range(100):  # run 100 greedy games
            o, _ = env.reset()
            s = preprocess(o)
            total = 0
            while True:
                a = sample_action(s, o, explore=False)
                if a is None: break
                o, r, term, trunc, _ = env.step(a)
                s = preprocess(o)
                total += int(r)
                if term or trunc: break
            eval_scores.append(total)
            tile = int(np.max(decode_obs(o)))  # true max tile
            tiles.append(tile)
            move_counter_eval[ACTION_MAP[a]] += 1

        # compute percentiles
        p25, p50, p75 = np.percentile(tiles, [25, 50, 75])

        print(f"  → Eval @ ep {episode}: Avg {np.mean(eval_scores):.1f}, Max {np.max(eval_scores)}")
        print(f"     Tile p25/p50/p75: {p25}/{p50}/{p75}")
        # optional full histogram
        hist = {t: tiles.count(t) for t in sorted(set(tiles))}
        print(f"     Tiles hist: {hist}")

    # 3) Convergence check
    if (episode >= ROLLING_WINDOW and 
        np.mean(score_history[-ROLLING_WINDOW:]) >= CONVERGENCE_SCORE):
        print(f"✅ Converged at ep {episode}")
        break

# --- 4) Final Evaluation & Logs ---
print("\n=== Move Frequencies Overall ===")
for m in ACTION_MAP.values():
    print(f"{m}: {move_counter_global[m]}")

eval_scores, tile_hist, move_counter_eval = [], Counter(), defaultdict(int)
for _ in range(100):
    obs, _ = env.reset()
    state = preprocess(obs)
    total = 0
    while True:
        a = sample_action(state, obs, explore=False)
        if a is None:
            break
        move_counter_eval[ACTION_MAP[a]] += 1
        old_board = decode_obs(obs)
        obs, r, term, trunc, _ = env.step(a)
        state = preprocess(obs)
        total += int(r)
        if term or trunc or np.array_equal(decode_obs(obs), old_board):
            break
    eval_scores.append(total)
    tile_hist[int(np.max(decode_obs(obs)))] += 1

# you already have tile_hist mapping each game’s max-tile to counts
eval_max_tile = max(tile_hist.keys()) if tile_hist else 0
global_max_tile = max(global_max_tile, eval_max_tile)

print(f"\n=== FINAL ===")
print(f"Highest tile over everything: {global_max_tile}")
print(f"Eval Avg: {np.mean(eval_scores):.2f}  |  Max Score: {np.max(eval_scores)}")
print("Eval Tiles:", tile_hist)
print("Eval Moves:", move_counter_eval)

  self.total_score += self.step_score


Episode    0 | Score   532 | MaxTile   64 | MA50    0.0 | ε 1.000
Episode    1 | Score   404 | MaxTile   32 | MA50  532.0 | ε 0.999
Episode    2 | Score  1052 | MaxTile  128 | MA50  468.0 | ε 0.999
Episode    3 | Score  1312 | MaxTile  128 | MA50  662.7 | ε 0.998
Episode    4 | Score   976 | MaxTile   64 | MA50  825.0 | ε 0.998
Episode    5 | Score  1904 | MaxTile  128 | MA50  855.2 | ε 0.997
Episode    6 | Score  1308 | MaxTile  128 | MA50 1030.0 | ε 0.997
Episode    7 | Score  1144 | MaxTile  128 | MA50 1069.7 | ε 0.996
Episode    8 | Score   536 | MaxTile   64 | MA50 1079.0 | ε 0.996
Episode    9 | Score   652 | MaxTile   64 | MA50 1018.7 | ε 0.995
Episode   10 | Score   644 | MaxTile   64 | MA50  982.0 | ε 0.995


  board = np.where(state==0,0,np.log2(state)).astype(np.float32)


Episode   11 | Score  1416 | MaxTile  128 | MA50  951.3 | ε 0.994
Episode   12 | Score  1460 | MaxTile  128 | MA50  990.0 | ε 0.994
Episode   13 | Score   680 | MaxTile   64 | MA50 1026.2 | ε 0.993
Episode   14 | Score   668 | MaxTile   64 | MA50 1001.4 | ε 0.993
Episode   15 | Score  3240 | MaxTile  256 | MA50  979.2 | ε 0.992
Episode   16 | Score   676 | MaxTile   64 | MA50 1120.5 | ε 0.992
Episode   17 | Score  1280 | MaxTile  128 | MA50 1094.4 | ε 0.991
Episode   18 | Score   880 | MaxTile   64 | MA50 1104.7 | ε 0.991
Episode   19 | Score  1072 | MaxTile  128 | MA50 1092.8 | ε 0.990
Episode   20 | Score  2028 | MaxTile  128 | MA50 1091.8 | ε 0.990
Episode   21 | Score   892 | MaxTile   64 | MA50 1136.4 | ε 0.989
Episode   22 | Score  2040 | MaxTile  256 | MA50 1125.3 | ε 0.989
Episode   23 | Score  1320 | MaxTile  128 | MA50 1165.0 | ε 0.988
Episode   24 | Score   636 | MaxTile   64 | MA50 1171.5 | ε 0.988
Episode   25 | Score   656 | MaxTile   64 | MA50 1150.1 | ε 0.987
Episode   

In [None]:
du_ma = np.convolve(dueling_scores, np.ones(100)/100, mode='valid')
print(du_ma.max())