In [1]:
import gymnasium as gym
import gymnasium_2048
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
from collections import defaultdict, Counter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env = gym.make("gymnasium_2048/TwentyFortyEight-v0")

# --- Hyperparameters ---
GAMMA = 0.99
LAMBDA = 0.95
CLIP_EPS = 0.2
LR = 3e-4
PPO_EPOCHS = 4
ROLLOUT_STEPS = 2048
TOTAL_ITERATIONS = 200

ACTION_MAP = {0: "Up", 1: "Right", 2: "Down", 3: "Left"}
global_move_counter = defaultdict(int)
bucket_move_counters = []

# --- Preprocessing ---
def preprocess(state):
    flat = np.reshape(state, -1)
    return torch.tensor(flat, dtype=torch.float32).to(device)

# --- Actor-Critic Network ---
class ActorCritic(nn.Module):
    def __init__(self):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU()
        )
        self.policy = nn.Linear(64, 4)
        self.value = nn.Linear(64, 1)

    def forward(self, x):
        x = self.shared(x)
        return self.policy(x), self.value(x)

model = ActorCritic().to(device)
optimizer = optim.Adam(model.parameters(), lr=LR)

# --- Collect Rollout ---
def collect_rollout():
    states, actions, rewards, dones, log_probs, values = [], [], [], [], [], []
    local_move_counter = defaultdict(int)
    obs, _ = env.reset()
    for _ in range(ROLLOUT_STEPS):
        state = preprocess(obs)
        logits, value = model(state)
        dist = Categorical(logits=logits)
        action = dist.sample()
        log_prob = dist.log_prob(action)

        next_obs, reward, terminated, truncated, _ = env.step(action.item())
        done = terminated or truncated

        move_name = ACTION_MAP[action.item()]
        global_move_counter[move_name] += 1
        local_move_counter[move_name] += 1

        states.append(state)
        actions.append(action)
        rewards.append(torch.tensor(reward, dtype=torch.float32, device=device))
        dones.append(torch.tensor(done, dtype=torch.float32, device=device))
        log_probs.append(log_prob)
        values.append(value.squeeze())

        obs = next_obs if not done else env.reset()[0]
    return states, actions, rewards, dones, log_probs, values, local_move_counter

# --- Compute GAE and Returns ---
def compute_advantages(rewards, values, dones):
    advantages = []
    gae = 0
    values = values + [torch.tensor(0.0).to(device)]
    for t in reversed(range(len(rewards))):
        delta = rewards[t] + GAMMA * values[t+1] * (1 - dones[t]) - values[t]
        gae = delta + GAMMA * LAMBDA * (1 - dones[t]) * gae
        advantages.insert(0, gae)
    returns = [adv + val for adv, val in zip(advantages, values[:-1])]
    return advantages, returns

# --- Training Loop ---
for iteration in range(TOTAL_ITERATIONS):
    states, actions, rewards, dones, old_log_probs, values, local_move_counter = collect_rollout()
    advantages, returns = compute_advantages(rewards, values, dones)

    states = torch.stack(states)
    actions = torch.stack(actions)
    old_log_probs = torch.stack(old_log_probs).detach()
    advantages = torch.stack(advantages).detach()
    returns = torch.stack(returns).detach()

    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    for _ in range(PPO_EPOCHS):
        logits, value = model(states)
        dist = Categorical(logits=logits)
        new_log_probs = dist.log_prob(actions)
        entropy = dist.entropy().mean()

        ratio = (new_log_probs - old_log_probs).exp()
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1 - CLIP_EPS, 1 + CLIP_EPS) * advantages
        policy_loss = -torch.min(surr1, surr2).mean()
        value_loss = (returns - value.squeeze()).pow(2).mean()

        loss = policy_loss + 0.5 * value_loss - 0.01 * entropy

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    avg_reward = torch.sum(torch.stack(rewards)) / len(rewards)
    print(f"Iter {iteration}, Avg Reward: {avg_reward.item():.2f}")

    if (iteration + 1) % (TOTAL_ITERATIONS // 10) == 0:
        bucket_move_counters.append(dict(global_move_counter))

# --- Final Move Summary ---
print("\n=== Final Move Frequency (Total) ===")
for move in ACTION_MAP.values():
    print(f"{move}: {global_move_counter[move]}")

print("\n=== Move Distribution Per 10% of Training ===")
for i, snapshot in enumerate(bucket_move_counters):
    print(f"\n--- Iterations {i * (TOTAL_ITERATIONS // 10) + 1} to {(i + 1) * (TOTAL_ITERATIONS // 10)} ---")
    if i == 0:
        diff = snapshot
    else:
        prev = bucket_move_counters[i - 1]
        diff = {move: snapshot[move] - prev.get(move, 0) for move in ACTION_MAP.values()}
    for move in ACTION_MAP.values():
        print(f"{move}: {diff[move]}")

# --- Evaluation ---
print("\n🔍 Evaluating agent over 100 episodes...")
eval_scores = []
eval_move_counter = defaultdict(int)
tile_counts = Counter()

for _ in range(100):
    obs, _ = env.reset()
    total_reward = 0
    done = False
    while not done:
        state = preprocess(obs)
        logits, _ = model(state)
        dist = Categorical(logits=logits)
        action = dist.probs.argmax().item()

        move_name = ACTION_MAP[action]
        eval_move_counter[move_name] += 1

        obs, reward, terminated, truncated, _ = env.step(action)
        total_reward += reward
        done = terminated or truncated

    eval_scores.append(total_reward)
    final_board = np.argmax(obs, axis=2)
    tile_counts[2 ** np.max(final_board)] += 1

print("\n✅ Evaluation Results:")
print(f"Average Score: {np.mean(eval_scores):.2f}")
print(f"Max Score: {np.max(eval_scores):.2f}")
print("\nMax Tile Frequencies:")
for tile, count in sorted(tile_counts.items(), reverse=True):
    print(f"{tile}: {count} times")

print("\nEvaluation Move Distribution:")
for move in ACTION_MAP.values():
    print(f"{move}: {eval_move_counter[move]}")


  self.total_score += self.step_score


Iter 0, Avg Reward: 7.16
Iter 1, Avg Reward: 7.64
Iter 2, Avg Reward: 7.55
Iter 3, Avg Reward: 7.62
Iter 4, Avg Reward: 7.98
Iter 5, Avg Reward: 7.25
Iter 6, Avg Reward: 7.88
Iter 7, Avg Reward: 8.13
Iter 8, Avg Reward: 7.45
Iter 9, Avg Reward: 7.48
Iter 10, Avg Reward: 7.62
Iter 11, Avg Reward: 8.09
Iter 12, Avg Reward: 7.45
Iter 13, Avg Reward: 7.89
Iter 14, Avg Reward: 7.37
Iter 15, Avg Reward: 7.79
Iter 16, Avg Reward: 7.52
Iter 17, Avg Reward: 8.11
Iter 18, Avg Reward: 7.70
Iter 19, Avg Reward: 7.90
Iter 20, Avg Reward: 7.58
Iter 21, Avg Reward: 7.14
Iter 22, Avg Reward: 7.80
Iter 23, Avg Reward: 7.84
Iter 24, Avg Reward: 7.63
Iter 25, Avg Reward: 7.54
Iter 26, Avg Reward: 7.28
Iter 27, Avg Reward: 6.71
Iter 28, Avg Reward: 7.11
Iter 29, Avg Reward: 7.28
Iter 30, Avg Reward: 7.50
Iter 31, Avg Reward: 7.64
Iter 32, Avg Reward: 7.34
Iter 33, Avg Reward: 6.91
Iter 34, Avg Reward: 7.32
Iter 35, Avg Reward: 7.07
Iter 36, Avg Reward: 7.83
Iter 37, Avg Reward: 7.45
Iter 38, Avg Reward: 7

KeyboardInterrupt: 