2048 DQN CNN

In [1]:
import numpy as np  # for matrix and numerical operations
import random  # for epsilon-greedy random action selection
import torch  # PyTorch core library
import torch.nn as nn  # for building neural networks
import torch.optim as optim  # for optimizer (e.g. Adam)
import torch.nn.functional as F  # for activation functions like ReLU
from collections import deque  # not used here, but useful for experience buffers
import gymnasium as gym  # for interacting with RL environments
import gymnasium_2048.envs  # registers the 2048 environment
import csv

  from pkg_resources import resource_stream, resource_exists


In [3]:
class DQN_CNN(nn.Module):
    def __init__(self, input_channels=16, output_dim=4):
        super().__init__()
        # input channels, output channels, kernel_size
        self.conv1 = nn.Conv2d(input_channels, 128, kernel_size=(1,2))
        self.conv2 = nn.Conv2d(input_channels, 128, kernel_size=(2,1))

        self.conv11 = nn.Conv2d(128, 128, kernel_size=(1,2))
        self.conv12 = nn.Conv2d(128, 128, kernel_size=(2,1))
        self.conv21 = nn.Conv2d(128, 128, kernel_size=(1,2))
        self.conv22 = nn.Conv2d(128, 128, kernel_size=(2,1))

        self.fc1 = nn.Linear(128*3*4*2+128*3*3*2+128*4*2*2, 256)
        self.out = nn.Linear(256, output_dim)

    def forward(self, x):
        x1 = F.relu(self.conv1(x))
        x2 = F.relu(self.conv2(x))

        x11 = F.relu(self.conv11(x1))
        x12 = F.relu(self.conv12(x1))
        x21 = F.relu(self.conv21(x2))
        x22 = F.relu(self.conv22(x2))

        x1_flat = x1.flatten(start_dim=1)
        x2_flat = x2.flatten(start_dim=1)
        x11_flat = x11.flatten(start_dim=1)
        x12_flat = x12.flatten(start_dim=1)
        x21_flat = x21.flatten(start_dim=1)
        x22_flat = x22.flatten(start_dim=1)

        x = torch.cat([x1_flat, x2_flat, x11_flat, x12_flat, x21_flat, x22_flat], dim=1)

        x = F.relu(self.fc1(x))
        return self.out(x)

In [None]:
class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha=0.7):
        self.capacity = capacity
        self.buffer = []  # stores experiences (s, a, r, s', done)
        self.pos = 0
        self.priorities = np.zeros((capacity,), dtype=np.float32)
        self.alpha = alpha

    def add(self, transition, td_error):
        max_prio = max(self.priorities.max(), td_error + 1e-6)
        if len(self.buffer) < self.capacity:
            self.buffer.append(transition)
        else:
            self.buffer[self.pos] = transition
        self.priorities[self.pos] = max_prio
        self.pos = (self.pos + 1) % self.capacity

    def sample(self, batch_size, beta=0.4):
        prios = self.priorities if len(self.buffer) == self.capacity else self.priorities[:self.pos]
        probs = prios ** self.alpha
        probs /= probs.sum()

        indices = np.random.choice(len(self.buffer), batch_size, p=probs)
        samples = [self.buffer[i] for i in indices]

        total = len(self.buffer)
        weights = (total * probs[indices]) ** (-beta)
        weights /= weights.max()
        weights = torch.tensor(weights, dtype=torch.float32)

        return samples, weights, indices

    def update_priorities(self, indices, td_errors):
        for idx, err in zip(indices, td_errors):
            self.priorities[idx] = abs(err.item()) + 1e-6


In [None]:
# Utility functions

def find_high_tile(state):
    state = np.argmax(state, axis=2)
    state = 2 ** state * (state > 0)
    return np.max(state)

def preprocess(state): # 4, 4, 16
    state = torch.tensor(state)
    transposed = state.permute(2, 0, 1)   # (16, 4, 4)
    return transposed

def reward_shaping(state):
    state = np.argmax(state, axis=2)
    grid = 2 ** state * (state > 0)
    empty = np.count_nonzero(grid == 0)
    return empty

In [None]:
# Main training loop

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    env = gym.make("gymnasium_2048/TwentyFortyEight-v0")

    main_net = DQN_CNN().to(device)
    target_net = DQN_CNN().to(device)
    target_net.load_state_dict(main_net.state_dict())
    optimizer = optim.Adam(main_net.parameters(), lr=1e-4)
    buffer = PrioritizedReplayBuffer(10000)

    batch_size = 128
    gamma = 0.99
    epsilon = 1.0
    epsilon_decay = 0.992
    epsilon_min = 0.0005
    target_update_freq = 20
    num_episodes = 4000
    total_steps = 0

    with open('run.csv', 'w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(["episode", "high_tile", "steps"])

    for episode in range(num_episodes):
        state, info = env.reset()
        total_reward = 0.0
        high_tile = 0
        done = False
        steps = 0

        while not done:
            if random.random() < epsilon:
                actions = [0, 1, 3]
                arr = np.array(actions)
                np.random.shuffle(arr)
                np.append(arr, 2)
                for action in actions:
                    next_state, reward, terminated, truncated, _ = env.step(action)
                    if not np.array_equal(state, next_state):
                        break
            else:
                state_tensor = torch.tensor(np.array([preprocess(state)]), dtype=torch.float32).to(device)
                with torch.no_grad():
                    q_vals = main_net(state_tensor).squeeze(0)

                    actions = torch.argsort(q_vals, descending=True).tolist()
                    for action in actions:
                        if action == 2:
                            continue
                        next_state, reward, terminated, truncated, _ = env.step(action) # 1 is right, 0 is up, 3 is left, 2 is down
                        if not np.array_equal(state, next_state):
                            break
                    else:
                        next_state, reward, terminated, truncated, _ = env.step(2)
                        action = 2 # BUG

            steps += 1
            shaped_reward = reward_shaping(state) # merge reward
            done = terminated or truncated

            with torch.no_grad():
                s_tensor = torch.tensor(np.array([preprocess(state)]), dtype=torch.float32).to(device)
                ns_tensor = torch.tensor(np.array([preprocess(next_state)]), dtype=torch.float32).to(device)

                q_sa = main_net(s_tensor)[0, action]
                next_action = main_net(ns_tensor).argmax().item()
                q_next = target_net(ns_tensor)[0, next_action]

                target = shaped_reward + gamma * q_next * (0.0 if done else 1.0)
                td_error = target

            buffer.add((state, action, shaped_reward, next_state, done), td_error)

            state = next_state
            total_reward += float(reward)

            if len(buffer.buffer) >= batch_size:
                samples, weights, indices = buffer.sample(batch_size)
                batch_states, batch_actions, batch_rewards, batch_next_states, batch_dones = zip(*samples)

                batch_states = torch.tensor(np.array([preprocess(s) for s in batch_states]), dtype=torch.float32).to(device)
                batch_actions = torch.tensor(batch_actions).to(device)
                batch_rewards = torch.tensor(batch_rewards, dtype=torch.float32).to(device)
                batch_next_states = torch.tensor(np.array([preprocess(s) for s in batch_next_states]), dtype=torch.float32).to(device)
                batch_dones = torch.tensor(batch_dones, dtype=torch.float32).to(device)
                weights = weights.to(device)

                q_values = main_net(batch_states).gather(1, batch_actions.unsqueeze(1)).squeeze(1)

                # Double DQN: action from main_net, value from target_net
                next_actions = main_net(batch_next_states).argmax(1, keepdim=True)
                next_q_values = target_net(batch_next_states).gather(1, next_actions).squeeze(1)

                target_q_values = batch_rewards + gamma * next_q_values * (1 - batch_dones)

                loss = ((q_values - target_q_values.detach()) ** 2 * weights).mean()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                td_errors = torch.abs(q_values - target_q_values.detach())
                buffer.update_priorities(indices, td_errors)

            total_steps += 1
            if total_steps % target_update_freq == 0:
                target_net.load_state_dict(main_net.state_dict())


        epsilon = max(epsilon * epsilon_decay, epsilon_min)

        with open('run.csv', 'a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([episode, find_high_tile(state), steps])

        if episode % 50 == 0:
            print(f"Episode {episode:03d} | High tile: {find_high_tile(state)} | Steps: {steps} | Epsilon: {epsilon:.4f}")

        if episode == 2000:
            torch.save(main_net.state_dict(), "dqn_2048_2000.pth")
        if episode == 2500:
            torch.save(main_net.state_dict(), "dqn_2048_2500.pth")
        if episode == 3000:
            torch.save(main_net.state_dict(), "dqn_2048_3000.pth")
        if episode == 3500:
            torch.save(main_net.state_dict(), "dqn_2048_3500.pth")

    env.close()
    torch.save(main_net.state_dict(), "dqn_2048.pth")
    print("Model saved!")

Episode 000 | High tile: 64 | Steps: 98 | Epsilon: 0.9920
Episode 050 | High tile: 128 | Steps: 208 | Epsilon: 0.6639
Episode 100 | High tile: 128 | Steps: 137 | Epsilon: 0.4443
Episode 150 | High tile: 256 | Steps: 358 | Epsilon: 0.2973
Episode 200 | High tile: 256 | Steps: 221 | Epsilon: 0.1990
Episode 250 | High tile: 512 | Steps: 430 | Epsilon: 0.1332
Episode 300 | High tile: 512 | Steps: 472 | Epsilon: 0.0891
Episode 350 | High tile: 512 | Steps: 458 | Epsilon: 0.0596
Episode 400 | High tile: 512 | Steps: 428 | Epsilon: 0.0399
Episode 450 | High tile: 512 | Steps: 598 | Epsilon: 0.0267
Episode 500 | High tile: 256 | Steps: 307 | Epsilon: 0.0179
Episode 550 | High tile: 256 | Steps: 352 | Epsilon: 0.0120


  score += 2 ** (board[row, col] + 1)


Episode 600 | High tile: 256 | Steps: 249 | Epsilon: 0.0080
Episode 650 | High tile: 128 | Steps: 166 | Epsilon: 0.0054
Episode 700 | High tile: 512 | Steps: 611 | Epsilon: 0.0036
Episode 750 | High tile: 128 | Steps: 179 | Epsilon: 0.0024
Episode 800 | High tile: 256 | Steps: 398 | Epsilon: 0.0016
Episode 850 | High tile: 256 | Steps: 380 | Epsilon: 0.0011
Episode 900 | High tile: 512 | Steps: 408 | Epsilon: 0.0007
Episode 950 | High tile: 128 | Steps: 164 | Epsilon: 0.0005
Episode 1000 | High tile: 256 | Steps: 241 | Epsilon: 0.0005
Episode 1050 | High tile: 512 | Steps: 520 | Epsilon: 0.0005
Episode 1100 | High tile: 512 | Steps: 415 | Epsilon: 0.0005
Episode 1150 | High tile: 128 | Steps: 169 | Epsilon: 0.0005
Episode 1200 | High tile: 512 | Steps: 555 | Epsilon: 0.0005
Episode 1250 | High tile: 256 | Steps: 246 | Epsilon: 0.0005
Episode 1300 | High tile: 256 | Steps: 287 | Epsilon: 0.0005
Episode 1350 | High tile: 512 | Steps: 422 | Epsilon: 0.0005
Episode 1400 | High tile: 512 | 