In [9]:
import gymnasium as gym
import gymnasium_2048
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
from collections import defaultdict, Counter, deque

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env = gym.make("gymnasium_2048/TwentyFortyEight-v0")

# --- Hyperparameters ---
GAMMA = 0.99                   # Discount factor
LAMBDA = 0.95                  # GAE lambda
CLIP_EPS = 0.2                 # PPO clip epsilon
LR = 3e-4                      # Learning rate
PPO_EPOCHS = 4                 # PPO epochs per update
ROLLOUT_STEPS = 2048           # Steps per rollout
MAX_ITERATIONS = 10000         # Max training iterations
ROLLING_AVG_WINDOW = 100       # Window size for rolling average
ROLLING_AVG_THRESHOLD = 2000   # Threshold to stop early
DEBUG_TRAIN_ITERS = 0          # Number of debug iterations

ACTION_MAP = {0: "Up", 1: "Right", 2: "Down", 3: "Left"}

# --- Preprocessing ---
def preprocess(state):
    flat = np.reshape(state, -1)
    return torch.tensor(flat, dtype=torch.float32).to(device)

# --- Board Decoding ---
def decode_board(obs):
    """Convert observation to 4x4 board of tile values."""
    if obs.ndim == 3:
        idxs = np.argmax(obs, axis=-1)
        mask = (obs.sum(axis=-1) == 1)
        board = (2 ** idxs) * mask
    else:
        board = obs
    return board

# --- Actor-Critic Network ---
class ActorCritic(nn.Module):
    def __init__(self):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU()
        )
        self.policy = nn.Linear(64, 4)
        self.value = nn.Linear(64, 1)

    def forward(self, x):
        x = self.shared(x)
        return self.policy(x), self.value(x)

# Initialize model and optimizer
model = ActorCritic().to(device)
optimizer = optim.Adam(model.parameters(), lr=LR)

# --- Trackers ---
global_move_counter = defaultdict(int)
bucket_move_counters = []
score_history = deque(maxlen=ROLLING_AVG_WINDOW)

# --- Collect Rollout ---
def collect_rollout(debug=False):
    states, actions, rewards, dones, log_probs, values = [], [], [], [], [], []
    local_move_counter = defaultdict(int)
    episode_scores = []
    episode_debug_info = []

    obs, _ = env.reset()
    current_score = 0.0
    current_episode_steps = []

    for _ in range(ROLLOUT_STEPS):
        if debug:
            current_episode_steps.append({
                'state': obs.copy(),
                'score_before': current_score
            })

        # act
        state = preprocess(obs)
        logits, value = model(state)
        dist = Categorical(logits=logits)
        action = dist.sample()
        log_prob = dist.log_prob(action)

        # step
        next_obs, step_reward, terminated, truncated, _ = env.step(action.item())
        done = terminated or truncated
        current_score += step_reward

        move_name = ACTION_MAP[action.item()]
        local_move_counter[move_name] += 1

        # store
        states.append(state)
        actions.append(action)
        # **use step reward here**:
        rewards.append(torch.tensor(step_reward, dtype=torch.float32))
        dones.append(torch.tensor(done, dtype=torch.float32))
        log_probs.append(log_prob)
        values.append(value.squeeze())

        if debug:
            info = current_episode_steps[-1]
            info.update({
                'action': move_name,
                'step_reward': step_reward,
                'state_after': next_obs.copy(),
                'score_after': current_score
            })

        # episode termination
        if done:
            episode_scores.append(current_score)
            if debug and not episode_debug_info:
                episode_debug_info = current_episode_steps.copy()
            obs, _ = env.reset()
            current_score = 0.0
            current_episode_steps = []
        else:
            obs = next_obs

    # if no episode finished in this rollout, still record the partial score
    if not episode_scores:
        episode_scores.append(current_score)

    return states, actions, rewards, dones, log_probs, values, local_move_counter, episode_scores, episode_debug_info


# --- Compute GAE and Returns ---
def compute_advantages(rewards, values, dones):
    advantages = []
    gae = 0
    values = values + [torch.tensor(0.0).to(device)]
    for t in reversed(range(len(rewards))):
        delta = rewards[t] + GAMMA * values[t+1] * (1 - dones[t]) - values[t]
        gae = delta + GAMMA * LAMBDA * (1 - dones[t]) * gae
        advantages.insert(0, gae)
    returns = [adv + val for adv, val in zip(advantages, values[:-1])]
    return advantages, returns

# --- Training Loop ---
for iteration in range(MAX_ITERATIONS):
    debug = iteration < DEBUG_TRAIN_ITERS
    data = collect_rollout(debug)
    states, actions, rewards, dones, old_log_probs, values, local_move_counter, episode_scores, debug_info = data

    # Update move counters
    for move, cnt in local_move_counter.items():
        global_move_counter[move] += cnt
    if (iteration + 1) % (MAX_ITERATIONS // 10) == 0:
        bucket_move_counters.append(local_move_counter.copy())

    # Compute advantages and returns
    advantages, returns = compute_advantages(rewards, values, dones)
    states = torch.stack(states)
    actions = torch.stack(actions)
    old_log_probs = torch.stack(old_log_probs).detach()
    advantages = torch.stack(advantages).detach()
    returns = torch.stack(returns).detach()
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    # PPO updates
    for _ in range(PPO_EPOCHS):
        logits, value = model(states)
        dist = Categorical(logits=logits)
        new_log_probs = dist.log_prob(actions)
        entropy = dist.entropy().mean()
        ratio = (new_log_probs - old_log_probs).exp()
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1 - CLIP_EPS, 1 + CLIP_EPS) * advantages
        policy_loss = -torch.min(surr1, surr2).mean()
        value_loss = (returns - value.squeeze()).pow(2).mean()
        loss = policy_loss + 0.5 * value_loss - 0.01 * entropy
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    last_score = episode_scores[-1] if episode_scores else 0.0
    score_history.append(last_score)
    print(f"Iter {iteration+1}/{MAX_ITERATIONS}, Last Score: {last_score:.2f}")

    # Early stopping
    if len(score_history) == ROLLING_AVG_WINDOW:
        rolling_avg = sum(score_history) / ROLLING_AVG_WINDOW
        if rolling_avg >= ROLLING_AVG_THRESHOLD:
            print(f"Stopping early at iter {iteration+1}: rolling avg = {rolling_avg:.2f}")
            break

# --- Final Move Summary ---
print("\n=== Final Move Frequency (Total) ===")
for move in ACTION_MAP.values():
    print(f"{move}: {global_move_counter[move]}")

print("\n=== Move Distribution Per 10% of Training ===")
for i, snap in enumerate(bucket_move_counters):
    start = i * (MAX_ITERATIONS // 10) + 1
    end = (i + 1) * (MAX_ITERATIONS // 10)
    print(f"\n--- Iterations {start} to {end} ---")
    prev = bucket_move_counters[i-1] if i>0 else None
    diff = {move: snap.get(move,0) - (prev.get(move,0) if prev else 0) for move in ACTION_MAP.values()}
    for move, cnt in diff.items():
        print(f"{move}: {cnt}")

# --- Evaluation ---
print("\n🔍 Evaluating over 100 episodes...")
eval_scores, eval_move_counter, tile_counts = [], defaultdict(int), Counter()
for ep in range(100):
    obs, _ = env.reset()
    total_reward = 0
    done = False
    while not done:
        state = preprocess(obs)
        logits, _ = model(state)
        action = Categorical(logits=logits).probs.argmax().item()
        move = ACTION_MAP[action]
        eval_move_counter[move] += 1
        obs, reward, terminated, truncated, _ = env.step(action)
        total_reward += reward
        done = terminated or truncated
    eval_scores.append(total_reward)
    max_tile = decode_board(obs).max()
    tile_counts[max_tile] += 1

print("\n✅ Eval Results:")
print(f"Avg Score: {np.mean(eval_scores):.2f}, Max Score: {np.max(eval_scores):.2f}")
print("\nMax Tile Frequencies:")
for tile, cnt in sorted(tile_counts.items(), reverse=True):
    print(f"{tile}: {cnt} times")
print("\nEval Move Distribution:")
for move in ACTION_MAP.values():
    print(f"{move}: {eval_move_counter[move]}")


Iter 1/10000, Last Score: 1084.00
Iter 2/10000, Last Score: 1168.00
Iter 3/10000, Last Score: 464.00
Iter 4/10000, Last Score: 1436.00
Iter 5/10000, Last Score: 640.00
Iter 6/10000, Last Score: 180.00
Iter 7/10000, Last Score: 300.00
Iter 8/10000, Last Score: 596.00
Iter 9/10000, Last Score: 992.00
Iter 10/10000, Last Score: 1276.00
Iter 11/10000, Last Score: 840.00
Iter 12/10000, Last Score: 1344.00
Iter 13/10000, Last Score: 580.00
Iter 14/10000, Last Score: 920.00
Iter 15/10000, Last Score: 1016.00
Iter 16/10000, Last Score: 1468.00
Iter 17/10000, Last Score: 1340.00
Iter 18/10000, Last Score: 780.00
Iter 19/10000, Last Score: 1140.00
Iter 20/10000, Last Score: 1356.00
Iter 21/10000, Last Score: 824.00
Iter 22/10000, Last Score: 1360.00
Iter 23/10000, Last Score: 744.00
Iter 24/10000, Last Score: 584.00
Iter 25/10000, Last Score: 1308.00
Iter 26/10000, Last Score: 612.00
Iter 27/10000, Last Score: 1440.00
Iter 28/10000, Last Score: 568.00
Iter 29/10000, Last Score: 636.00
Iter 30/10

  score += 2 ** (board[row, col] + 1)


Iter 197/10000, Last Score: 1480.00
Iter 198/10000, Last Score: 1356.00
Iter 199/10000, Last Score: 1084.00
Iter 200/10000, Last Score: 1668.00
Iter 201/10000, Last Score: 732.00
Iter 202/10000, Last Score: 1432.00
Iter 203/10000, Last Score: 1516.00
Iter 204/10000, Last Score: 1748.00
Iter 205/10000, Last Score: 1960.00
Iter 206/10000, Last Score: 2808.00
Iter 207/10000, Last Score: 1028.00
Iter 208/10000, Last Score: 2960.00
Iter 209/10000, Last Score: 2196.00
Iter 210/10000, Last Score: 3024.00
Iter 211/10000, Last Score: 416.00
Iter 212/10000, Last Score: 1240.00
Iter 213/10000, Last Score: 2200.00
Iter 214/10000, Last Score: 4276.00
Iter 215/10000, Last Score: 1156.00
Iter 216/10000, Last Score: 872.00
Iter 217/10000, Last Score: 1652.00
Iter 218/10000, Last Score: 1252.00
Iter 219/10000, Last Score: 2452.00
Iter 220/10000, Last Score: 1068.00
Iter 221/10000, Last Score: 408.00
Iter 222/10000, Last Score: 1332.00
Iter 223/10000, Last Score: 1328.00
Iter 224/10000, Last Score: 1124

KeyboardInterrupt: 

In [10]:
import gymnasium as gym
import gymnasium_2048
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
from collections import defaultdict, Counter, deque

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env = gym.make("gymnasium_2048/TwentyFortyEight-v0")

# --- Hyperparameters ---
GAMMA = 0.99                   # Discount factor
LAMBDA = 0.95                  # GAE lambda
CLIP_EPS = 0.2                 # PPO clip epsilon
LR = 3e-4                      # Learning rate
PPO_EPOCHS = 4                 # PPO epochs per update
ROLLOUT_STEPS = 2048           # Steps per rollout
MAX_ITERATIONS = 10000         # Max training iterations
ROLLING_AVG_WINDOW = 100       # Window size for rolling average
ROLLING_AVG_THRESHOLD = 10000   # Threshold to stop early
DEBUG_TRAIN_ITERS = 0          # Number of debug iterations
EVAL_ITERS = 50

ACTION_MAP = {0: "Up", 1: "Right", 2: "Down", 3: "Left"}

# --- Preprocessing ---
def preprocess(state):
    flat = np.reshape(state, -1)
    return torch.tensor(flat, dtype=torch.float32).to(device)

# --- Board Decoding ---
def decode_board(obs):
    """Convert observation to 4x4 board of tile values."""
    if obs.ndim == 3:
        idxs = np.argmax(obs, axis=-1)
        mask = (obs.sum(axis=-1) == 1)
        board = (2 ** idxs) * mask
    else:
        board = obs
    return board

# --- Actor-Critic Network ---
class ActorCritic(nn.Module):
    def __init__(self):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU()
        )
        self.policy = nn.Linear(64, 4)
        self.value = nn.Linear(64, 1)

    def forward(self, x):
        x = self.shared(x)
        return self.policy(x), self.value(x)

# Initialize model and optimizer
model = ActorCritic().to(device)
optimizer = optim.Adam(model.parameters(), lr=LR)

# --- Trackers ---
global_move_counter = defaultdict(int)
bucket_move_counters = []
score_history = deque(maxlen=ROLLING_AVG_WINDOW)

# --- Collect Rollout ---
def collect_rollout(debug=False):
    states, actions, rewards, dones, log_probs, values = [], [], [], [], [], []
    local_move_counter = defaultdict(int)
    episode_scores = []
    episode_debug_info = []

    obs, _ = env.reset()
    current_score = 0.0
    current_episode_steps = []

    for _ in range(ROLLOUT_STEPS):
        if debug:
            current_episode_steps.append({
                'state': obs.copy(),
                'score_before': current_score
            })

        # act
        state = preprocess(obs)
        logits, value = model(state)
        dist = Categorical(logits=logits)
        action = dist.sample()
        log_prob = dist.log_prob(action)

        # step
        next_obs, step_reward, terminated, truncated, _ = env.step(action.item())
        done = terminated or truncated
        current_score += step_reward

        move_name = ACTION_MAP[action.item()]
        local_move_counter[move_name] += 1

        # store
        states.append(state)
        actions.append(action)
        # **use step reward here**:
        rewards.append(torch.tensor(step_reward, dtype=torch.float32))
        dones.append(torch.tensor(done, dtype=torch.float32))
        log_probs.append(log_prob)
        values.append(value.squeeze())

        if debug:
            info = current_episode_steps[-1]
            info.update({
                'action': move_name,
                'step_reward': step_reward,
                'state_after': next_obs.copy(),
                'score_after': current_score
            })

        # episode termination
        if done:
            episode_scores.append(current_score)
            if debug and not episode_debug_info:
                episode_debug_info = current_episode_steps.copy()
            obs, _ = env.reset()
            current_score = 0.0
            current_episode_steps = []
        else:
            obs = next_obs

    # if no episode finished in this rollout, still record the partial score
    if not episode_scores:
        episode_scores.append(current_score)

    return states, actions, rewards, dones, log_probs, values, local_move_counter, episode_scores, episode_debug_info


# --- Compute GAE and Returns ---
def compute_advantages(rewards, values, dones):
    advantages = []
    gae = 0
    values = values + [torch.tensor(0.0).to(device)]
    for t in reversed(range(len(rewards))):
        delta = rewards[t] + GAMMA * values[t+1] * (1 - dones[t]) - values[t]
        gae = delta + GAMMA * LAMBDA * (1 - dones[t]) * gae
        advantages.insert(0, gae)
    returns = [adv + val for adv, val in zip(advantages, values[:-1])]
    return advantages, returns

# --- Training Loop ---
for iteration in range(MAX_ITERATIONS):
    debug = iteration < DEBUG_TRAIN_ITERS
    data = collect_rollout(debug)
    states, actions, rewards, dones, old_log_probs, values, local_move_counter, episode_scores, debug_info = data

    # Update move counters
    for move, cnt in local_move_counter.items():
        global_move_counter[move] += cnt
    if (iteration + 1) % (MAX_ITERATIONS // 10) == 0:
        bucket_move_counters.append(local_move_counter.copy())

    # Compute advantages and returns
    advantages, returns = compute_advantages(rewards, values, dones)
    states = torch.stack(states)
    actions = torch.stack(actions)
    old_log_probs = torch.stack(old_log_probs).detach()
    advantages = torch.stack(advantages).detach()
    returns = torch.stack(returns).detach()
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    # PPO updates
    for _ in range(PPO_EPOCHS):
        logits, value = model(states)
        dist = Categorical(logits=logits)
        new_log_probs = dist.log_prob(actions)
        entropy = dist.entropy().mean()
        ratio = (new_log_probs - old_log_probs).exp()
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1 - CLIP_EPS, 1 + CLIP_EPS) * advantages
        policy_loss = -torch.min(surr1, surr2).mean()
        value_loss = (returns - value.squeeze()).pow(2).mean()
        loss = policy_loss + 0.5 * value_loss - 0.01 * entropy
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    last_score = episode_scores[-1] if episode_scores else 0.0
    score_history.append(last_score)
    print(f"Iter {iteration+1}/{MAX_ITERATIONS}, Last Score: {last_score:.2f}")

    # Early stopping
    if len(score_history) == ROLLING_AVG_WINDOW:
        rolling_avg = sum(score_history) / ROLLING_AVG_WINDOW
        if rolling_avg >= ROLLING_AVG_THRESHOLD:
            print(f"Stopping early at iter {iteration+1}: rolling avg = {rolling_avg:.2f}")
            break




  self.total_score += self.step_score


Iter 1/10000, Last Score: 936.00
Iter 2/10000, Last Score: 1320.00
Iter 3/10000, Last Score: 924.00
Iter 4/10000, Last Score: 1416.00
Iter 5/10000, Last Score: 1368.00
Iter 6/10000, Last Score: 2640.00
Iter 7/10000, Last Score: 820.00
Iter 8/10000, Last Score: 684.00
Iter 9/10000, Last Score: 904.00
Iter 10/10000, Last Score: 636.00
Iter 11/10000, Last Score: 1040.00
Iter 12/10000, Last Score: 2176.00
Iter 13/10000, Last Score: 1352.00
Iter 14/10000, Last Score: 1004.00
Iter 15/10000, Last Score: 924.00
Iter 16/10000, Last Score: 708.00
Iter 17/10000, Last Score: 1024.00
Iter 18/10000, Last Score: 276.00
Iter 19/10000, Last Score: 1108.00
Iter 20/10000, Last Score: 2672.00
Iter 21/10000, Last Score: 1080.00
Iter 22/10000, Last Score: 1324.00
Iter 23/10000, Last Score: 812.00
Iter 24/10000, Last Score: 532.00
Iter 25/10000, Last Score: 580.00
Iter 26/10000, Last Score: 868.00
Iter 27/10000, Last Score: 652.00
Iter 28/10000, Last Score: 520.00
Iter 29/10000, Last Score: 2120.00
Iter 30/1

  score += 2 ** (board[row, col] + 1)


Iter 2633/10000, Last Score: 4760.00
Iter 2634/10000, Last Score: 756.00
Iter 2635/10000, Last Score: 3396.00
Iter 2636/10000, Last Score: 1664.00
Iter 2637/10000, Last Score: 3212.00
Iter 2638/10000, Last Score: 1760.00
Iter 2639/10000, Last Score: 1420.00
Iter 2640/10000, Last Score: 1692.00
Iter 2641/10000, Last Score: 1340.00
Iter 2642/10000, Last Score: 3672.00
Iter 2643/10000, Last Score: 1492.00
Iter 2644/10000, Last Score: 1748.00
Iter 2645/10000, Last Score: 2480.00
Iter 2646/10000, Last Score: 2252.00
Iter 2647/10000, Last Score: 860.00
Iter 2648/10000, Last Score: 1592.00
Iter 2649/10000, Last Score: 2092.00
Iter 2650/10000, Last Score: 2492.00
Iter 2651/10000, Last Score: 496.00
Iter 2652/10000, Last Score: 2656.00
Iter 2653/10000, Last Score: 3080.00
Iter 2654/10000, Last Score: 2824.00
Iter 2655/10000, Last Score: 1148.00
Iter 2656/10000, Last Score: 2836.00
Iter 2657/10000, Last Score: 2856.00
Iter 2658/10000, Last Score: 2572.00
Iter 2659/10000, Last Score: 612.00
Iter 

KeyboardInterrupt: 

In [None]:
# --- Final Move Summary ---
print("\n=== Final Move Frequency (Total) ===")
for move in ACTION_MAP.values():
    print(f"{move}: {global_move_counter[move]}")

print("\n=== Move Distribution Per 10% of Training ===")
for i, snap in enumerate(bucket_move_counters):
    start = i * (MAX_ITERATIONS // 10) + 1
    end = (i + 1) * (MAX_ITERATIONS // 10)
    print(f"\n--- Iterations {start} to {end} ---")
    prev = bucket_move_counters[i-1] if i>0 else None
    diff = {move: snap.get(move,0) - (prev.get(move,0) if prev else 0) for move in ACTION_MAP.values()}
    for move, cnt in diff.items():
        print(f"{move}: {cnt}")

# --- Evaluation ---
print(f"\n🔍 Evaluating over {EVAL_ITERS} episodes…")
eval_scores = []
eval_move_counter = defaultdict(int)
tile_counts = Counter()

for ep in range(EVAL_ITERS):
    obs, _ = env.reset()
    total_reward = 0
    done = False

    while not done:
        state = preprocess(obs)
        logits, _ = model(state)
        dist = Categorical(logits=logits)

        # get probabilities and sort actions by descending prob
        probs = dist.probs.detach().cpu().numpy()
        candidates = list(np.argsort(probs)[::-1])

        # try each candidate at most once
        step_taken = False
        for action in candidates:
            next_obs, reward, terminated, truncated, info = env.step(action)
            if not info.get('invalid_move', False):
                step_taken = True
                break
            # invalid move → env state unchanged, so safe to try next
        if not step_taken:
            # fallback to one sampled action
            action = int(dist.sample().item())
            next_obs, reward, terminated, truncated, info = env.step(action)

        # record results
        move = ACTION_MAP[action]
        eval_move_counter[move] += 1
        total_reward += reward
        done = terminated or truncated
        obs = next_obs

    eval_scores.append(total_reward)
    final_board = decode_board(obs)
    max_tile = final_board.max()
    tile_counts[max_tile] += 1

print("\n✅ Eval Results:")
print(f"Avg Score: {np.mean(eval_scores):.2f}, Max Score: {np.max(eval_scores):.2f}")
print("\nMax Tile Frequencies:")
for tile, cnt in sorted(tile_counts.items(), reverse=True):
    print(f"{tile}: {cnt} times")   
print("\nEval Move Distribution:")
for move in ACTION_MAP.values():
    print(f"{move}: {eval_move_counter[move]}")


=== Final Move Frequency (Total) ===
Up: 5562560
Right: 9344025
Down: 5459288
Left: 114127

=== Move Distribution Per 10% of Training ===

--- Iterations 1 to 1000 ---
Up: 547
Right: 1001
Down: 494
Left: 6

--- Iterations 1001 to 2000 ---
Up: -148
Right: 197
Down: -45
Left: -4

--- Iterations 2001 to 3000 ---
Up: 141
Right: -40
Down: -100
Left: -1

--- Iterations 3001 to 4000 ---
Up: 71
Right: -219
Down: 149
Left: -1

--- Iterations 4001 to 5000 ---
Up: 176
Right: -146
Down: -30
Left: 0

--- Iterations 5001 to 6000 ---
Up: -99
Right: 37
Down: 62
Left: 0

--- Iterations 6001 to 7000 ---
Up: -30
Right: -29
Down: 59
Left: 0

--- Iterations 7001 to 8000 ---
Up: -227
Right: 57
Down: 169
Left: 1

--- Iterations 8001 to 9000 ---
Up: 478
Right: -304
Down: -182
Left: 8

--- Iterations 9001 to 10000 ---
Up: -231
Right: 316
Down: -84
Left: -1

🔍 Evaluating over 50 episodes…
