In [155]:
from manim import *
from typing import List
from dataclasses import dataclass
from enum import Enum

import torch

config.media_embed = True

In [156]:
import numpy as np


class Direction(Enum):
    LEFT = 0
    UP = 1
    RIGHT = 2
    DOWN = 3


class Game2048Env:
    def __init__(self):
        self.grid_size = 4
        self.reset()

    def reset(self):
        self.board = np.zeros((self.grid_size, self.grid_size), dtype=int)
        self.spawn_tile()
        self.spawn_tile()
        self.score = 0
        return self.board.copy()

    def spawn_tile(self):
        empty = list(zip(*np.where(self.board == 0)))
        if empty:
            x, y = empty[np.random.randint(len(empty))]
            self.board[x, y] = 2 if np.random.random() < 0.9 else 4

    def step(self, action: Direction):
        moved, reward = self.move(action.value)
        if moved:
            self.spawn_tile()
        else:
            # Stop if invalid move
            return self.board.copy(), reward, True, {}
        done = not self.can_move()
        self.score += reward
        return self.board.copy(), reward, done, {}

    def move(self, direction):
        board = np.copy(self.board)
        reward = 0
        moved = False

        # Rotate board so all moves are left-moves
        for _ in range(direction):
            board = np.rot90(board)

        for i in range(self.grid_size):
            tiles = board[i][board[i] != 0]  # Extract non-zero
            merged = []
            j = 0
            while j < len(tiles):
                if j + 1 < len(tiles) and tiles[j] == tiles[j + 1]:
                    merged_val = tiles[j] * 2
                    reward += 10
                    merged.append(merged_val)
                    j += 2  # Skip next
                    moved = True
                else:
                    merged.append(tiles[j])
                    reward += 1
                    j += 1
            # Pad with zeros to the right
            merged += [0] * (self.grid_size - len(merged))
            # Detect if move or merge happened
            if not np.array_equal(board[i], merged):
                moved = True
            board[i] = merged

        # Restore original orientation
        for _ in range((4 - direction) % 4):
            board = np.rot90(board)

        if moved:
            self.board = board

        return moved, reward

    def can_move(self):
        for direction in range(4):
            temp_board = self.board.copy()
            moved, _ = self.move(direction)
            self.board = temp_board  # Restore original
            if moved:
                return True
        return False

In [157]:
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


# Determine the best available device
def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")


# DEVICE = get_device()
DEVICE = torch.device("cpu")
print(f"Using device: {DEVICE}")


class SimpleNeuralNetwork(nn.Module):
    """Simple feedforward neural network using PyTorch"""

    def __init__(
        self,
        input_size: int = 16,
        hidden_layers: List[int] = [256],
        output_size: int = 4,
        empty: bool = False,
    ):
        super().__init__()

        if empty:
            return

        # Build layers using PyTorch modules
        layers = []
        prev_size = input_size

        # Add hidden layers
        for hidden_size in hidden_layers:
            layers.append(nn.Linear(prev_size, hidden_size))
            # layers.append(nn.Tanh())
            layers.append(nn.ReLU())
            prev_size = hidden_size

        # Add output layer (no activation)
        layers.append(nn.Linear(prev_size, output_size))

        self.network = nn.Sequential(*layers)

        # Initialize weights using He initialization
        self._initialize_weights()

        # Move to device
        self.to(DEVICE)

    def _initialize_weights(self):
        """Initialize weights using He initialization"""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight, nonlinearity="tanh")
                nn.init.zeros_(module.bias)

    def forward(self, x):
        """Forward pass through the network"""
        # Convert numpy array to tensor if needed and move to device
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x).float().to(DEVICE)
        elif isinstance(x, torch.Tensor):
            x = x.to(DEVICE)

        return self.network(x)

    def mutate(self, mutation_rate: float = 0.1, mutation_strength: float = 0.5):
        """Mutate the network's weights and biases"""
        with torch.no_grad():
            for param in self.parameters():
                if torch.rand(1).item() < mutation_rate:
                    mutation = torch.randn_like(param) * mutation_strength
                    param.add_(mutation)

Using device: cpu


In [158]:
import pickle


def save_network(network: SimpleNeuralNetwork, filename: str):
    torch.save(network.state_dict(), filename)


def load_network(filename: str, hidden_layers: List[int]) -> SimpleNeuralNetwork:
    network = SimpleNeuralNetwork(hidden_layers=hidden_layers)
    network.load_state_dict(torch.load(filename, map_location=DEVICE))
    network.to(DEVICE)
    return network


def save_population(population: List[SimpleNeuralNetwork], filename: str):
    with open(filename, "wb") as f:
        pickle.dump(population, f)


def load_population(filename: str) -> List[SimpleNeuralNetwork]:
    with open(filename, "rb") as f:
        population = pickle.load(f)
    return population

In [159]:
@dataclass
class GameResult:
    score: int
    max_tile: int
    moves: int


class Player:
    def __init__(self, network: SimpleNeuralNetwork):
        self.network = network

    def play(self, env: Game2048Env, max_steps: int = 100) -> GameResult:
        state = env.reset()
        total_reward = 0
        done = False
        steps = 0

        while not done and steps < max_steps:
            action = self.next_move(state)

            state, reward, done, _ = env.step(action)
            total_reward += reward
            steps += 1

        return GameResult(score=total_reward, max_tile=np.max(state), moves=steps)

    def next_move(self, state: np.ndarray) -> Direction:
        self.network.eval()  # Set to evaluation mode
        with torch.no_grad():
            flat_state = state.flatten() / 2048.0  # Normalize input
            q_values = self.network.forward(flat_state)
            # Move back to CPU for numpy conversion
            q_values_cpu = q_values.cpu()
            action = Direction(q_values_cpu.numpy().argmax())
            return action

In [160]:
import pickle


def save_network(network: SimpleNeuralNetwork, filename: str):
    torch.save(network.state_dict(), filename)


def load_network(filename: str, hidden_layers: List[int]) -> SimpleNeuralNetwork:
    network = SimpleNeuralNetwork(hidden_layers=hidden_layers)
    network.load_state_dict(torch.load(filename, map_location=DEVICE))
    network.to(DEVICE)
    return network


def save_population(population: List[SimpleNeuralNetwork], filename: str):
    with open(filename, "wb") as f:
        pickle.dump(population, f)


def load_population(filename: str) -> List[SimpleNeuralNetwork]:
    with open(filename, "rb") as f:
        population = pickle.load(f)
    return population

In [None]:
def find_latest_network(folder: str = "./networks/") -> str:
    import os

    files = [
        f
        for f in os.listdir(folder)
        if f.startswith("population_gen_") and f.endswith(".pkl")
    ]
    if not files:
        return None
    latest_file = max(files, key=lambda x: int(x.split("_")[2].split(".")[0]))
    return os.path.join(folder, latest_file)


hidden_layers = [64, 32]
layers_str = "_".join(map(str, hidden_layers))
latest_checkpoint = find_latest_network(f"./networks/{layers_str}")
# latest_checkpoint = './networks/256_128_64/population_gen_3351.pkl'
print(latest_checkpoint)
population = load_population(latest_checkpoint)

best_block = 0
best_score = 0
best_network = None

games_per_player = 100

for i, net in enumerate(population):
    player = Player(net.to(DEVICE))

    average_score = 0
    for _ in range(games_per_player):
        env = Game2048Env()
        score = 0
        while True:
            board = env.board
            # print_board(board)
            action = player.next_move(board)
            state, reward, done, _ = env.step(action)
            score += reward
            if done:
                average_score += score
                break

        average_score /= games_per_player
        if average_score > best_score:
            best_network = net
            best_score = average_score
            print(f"New Best Network Found! Tile: {np.max(state)} | Score: {average_score}")


player = Player(best_network)

del population

./networks/64_32/population_gen_8411.pkl
New Best Network Found! Tile: 32 | Score: 3.09
New Best Network Found! Tile: 32 | Score: 4.4658090471166165
New Best Network Found! Tile: 32 | Score: 7.936106157811021
New Best Network Found! Tile: 64 | Score: 8.477335257423928
New Best Network Found! Tile: 64 | Score: 9.151445321789092
New Best Network Found! Tile: 64 | Score: 9.997203817968163


In [None]:
%%manim -v WARNING --progress_bar None -ql --disable_caching Game2048AI

class Game2048AI(Scene):
    def construct(self):
        self.camera.background_color = "#faf8ef"
        
        # Title
        # title = Text("2048 AI Player", font_size=40, color="#776e65", weight=BOLD)
        # title.to_edge(UP, buff=0.3)
        # self.add(title)
        
        # Create the game board
        self.board_size = 4
        self.tile_size = 1.2
        self.board_group = VGroup()
        self.tile_objects = {}
        
        # Initialize board visuals
        self.create_board()
        
        # Score display
        self.score_text = Text("Score: 0", font_size=28, color="#776e65")
        self.score_text.next_to(self.board_group, UP, buff=.5)
        self.add(self.score_text)
        
        # Max tile display
        self.max_tile_text = Text("Max Tile: 0", font_size=28, color="#776e65")
        self.max_tile_text.next_to(self.score_text, RIGHT, buff=.5)
        self.add(self.max_tile_text)
        
        # Play the game
        state = env.reset()
        self.update_board(state)
        self.wait(0.5)
        
        done = False
        steps = 0
        max_steps = 200
        
        while not done and steps < max_steps:
            # Get action from neural network
            action = player.next_move(state)
            
            # Show action arrow
            arrow = self.create_action_arrow(action)
            self.play(FadeIn(arrow), run_time=0.15)
            
            # Store previous state for animation
            prev_state = state.copy()
            
            # Execute action
            state, reward, done, _ = env.step(action)
            
            # Update visualization with sliding animation
            self.play(FadeOut(arrow), run_time=0.2)
            self.update_board_with_slide(prev_state, state)
            self.update_score(env.score)
            self.update_max_tile(np.max(state))
            
            steps += 1
            self.wait(.5)
        
        # Game over
        if done:
            game_over_text = Text("Game Over!", font_size=48, color="#776e65", weight=BOLD)
            game_over_text.move_to(self.board_group.get_center())
            bg_rect = Rectangle(
                width=6, height=1.5,
                fill_color="#faf8ef",
                fill_opacity=0.95,
                stroke_width=0
            )
            bg_rect.move_to(game_over_text.get_center())
            self.play(
                FadeIn(bg_rect),
                Write(game_over_text),
                run_time=0.25
            )
            self.wait(5)
    
    def create_board(self):
        """Create the initial board grid"""
        self.board_group = VGroup()
        
        for i in range(self.board_size):
            for j in range(self.board_size):
                # Background tile
                tile_bg = RoundedRectangle(
                    width=self.tile_size,
                    height=self.tile_size,
                    corner_radius=0.1,
                    fill_color="#cdc1b4",
                    fill_opacity=1,
                    stroke_width=0
                )
                tile_bg.move_to([
                    (j - 1.5) * (self.tile_size + 0.15),
                    (1.5 - i) * (self.tile_size + 0.15),
                    0
                ])
                self.board_group.add(tile_bg)
        
        self.board_group.move_to(ORIGIN)
        self.add(self.board_group)
    
    def get_tile_position(self, i, j):
        """Get the screen position for a tile at grid position (i, j)"""
        return np.array([
            (j - 1.5) * (self.tile_size + 0.15),
            (1.5 - i) * (self.tile_size + 0.15),
            0
        ])
    
    def get_tile_color(self, value):
        """Get color for tile based on value"""
        colors = {
            0: "#cdc1b4",
            2: "#eee4da",
            4: "#ede0c8",
            8: "#f2b179",
            16: "#f59563",
            32: "#f67c5f",
            64: "#f65e3b",
            128: "#edcf72",
            256: "#edcc61",
            512: "#edc850",
            1024: "#edc53f",
            2048: "#edc22e",
        }
        return colors.get(value, "#3c3a32")
    
    def get_text_color(self, value):
        """Get text color based on tile value"""
        return "#776e65" if value <= 4 else "#f9f6f2"
    
    def create_tile(self, value, i, j):
        """Create a tile visual at position (i, j)"""
        tile_rect = RoundedRectangle(
            width=self.tile_size,
            height=self.tile_size,
            corner_radius=0.1,
            fill_color=self.get_tile_color(value),
            fill_opacity=1,
            stroke_width=0
        )
        
        tile_text = Text(
            str(value),
            font_size=32 if value < 1000 else 24,
            color=self.get_text_color(value),
            weight=BOLD
        )
        
        tile_group = VGroup(tile_rect, tile_text)
        tile_group.move_to(self.get_tile_position(i, j))
        
        return tile_group
    
    def update_board_with_slide(self, prev_state, new_state):
        """Update board with sliding animation by tracking actual tile movements"""
        animations = []
        new_tiles = {}
        tiles_to_remove = set()
        
        # Create a mapping of which tiles moved where by simulating the move
        tile_movements = self.compute_tile_movements(prev_state, new_state)
        
        # Handle each type of change
        for (old_pos, new_pos, value, action_type) in tile_movements:
            if action_type == 'stay':
                # Tile stayed in place
                if old_pos in self.tile_objects:
                    new_tiles[new_pos] = self.tile_objects[old_pos]
                
            elif action_type == 'move':
                # Tile moved to new position
                if old_pos in self.tile_objects:
                    tile_obj = self.tile_objects[old_pos]
                    target_pos = self.get_tile_position(new_pos[0], new_pos[1])
                    animations.append(tile_obj.animate.move_to(target_pos))
                    new_tiles[new_pos] = tile_obj
                    
            elif action_type == 'merge':
                # Tile merged - fade out old, scale in new
                if old_pos in self.tile_objects:
                    tile_obj = self.tile_objects[old_pos]
                    tiles_to_remove.add(tile_obj)
                    animations.append(FadeOut(tile_obj))
                
                if new_pos not in new_tiles:
                    new_tile = self.create_tile(value, new_pos[0], new_pos[1])
                    new_tile.scale(0.1)
                    self.add(new_tile)
                    animations.append(new_tile.animate.scale(10))
                    new_tiles[new_pos] = new_tile
                    
            elif action_type == 'spawn':
                # New tile spawned
                new_tile = self.create_tile(value, new_pos[0], new_pos[1])
                new_tile.scale(0.1)
                self.add(new_tile)
                animations.append(new_tile.animate.scale(10))
                new_tiles[new_pos] = new_tile
        
        # Remove any tiles that are no longer on the board
        for pos, tile_obj in self.tile_objects.items():
            if tile_obj not in new_tiles.values() and tile_obj not in tiles_to_remove:
                tiles_to_remove.add(tile_obj)
                animations.append(FadeOut(tile_obj))
        
        # Play all animations
        if animations:
            self.play(*animations, run_time=0.35)
        
        # Clean up removed tiles
        for tile_obj in tiles_to_remove:
            self.remove(tile_obj)
        
        self.tile_objects = new_tiles

    def compute_tile_movements(self, prev_state, new_state):
        """Compute which tiles moved where by analyzing state changes"""
        movements = []
        prev_tiles = {}
        new_tiles = {}
        
        # Build lists of tile positions with values
        for i in range(self.board_size):
            for j in range(self.board_size):
                if prev_state[i][j] > 0:
                    if prev_state[i][j] not in prev_tiles:
                        prev_tiles[prev_state[i][j]] = []
                    prev_tiles[prev_state[i][j]].append((i, j))
                    
                if new_state[i][j] > 0:
                    if new_state[i][j] not in new_tiles:
                        new_tiles[new_state[i][j]] = []
                    new_tiles[new_state[i][j]].append((i, j))
        
        used_old = set()
        used_new = set()
        
        # First pass: tiles that stayed in same position
        for i in range(self.board_size):
            for j in range(self.board_size):
                if (prev_state[i][j] == new_state[i][j] and 
                    prev_state[i][j] > 0):
                    movements.append(((i, j), (i, j), prev_state[i][j], 'stay'))
                    used_old.add((i, j))
                    used_new.add((i, j))
        
        # Second pass: detect merges (doubled values)
        for value in new_tiles:
            if value > 2 and value // 2 in prev_tiles:
                # This could be a merge
                for new_pos in new_tiles[value]:
                    if new_pos in used_new:
                        continue
                        
                    # Check if there were two tiles of half-value that could merge here
                    half_val = value // 2
                    potential_sources = [p for p in prev_tiles[half_val] if p not in used_old]
                    
                    if len(potential_sources) >= 2:
                        # Find the two closest tiles that would merge to this position
                        # Mark both as merged
                        for k in range(min(2, len(potential_sources))):
                            src = potential_sources[k]
                            movements.append((src, new_pos, value, 'merge'))
                            used_old.add(src)
                        used_new.add(new_pos)
                        break
        
        # Third pass: tiles that moved (same value, different position)
        for value in prev_tiles:
            old_positions = [p for p in prev_tiles[value] if p not in used_old]
            if value in new_tiles:
                new_positions = [p for p in new_tiles[value] if p not in used_new]
                
                # Match old to new by proximity (closest first)
                for old_pos in old_positions:
                    if new_positions:
                        # Find closest new position
                        closest = min(new_positions, 
                                    key=lambda np: abs(np[0]-old_pos[0]) + abs(np[1]-old_pos[1]))
                        movements.append((old_pos, closest, value, 'move'))
                        used_old.add(old_pos)
                        used_new.add(closest)
                        new_positions.remove(closest)
        
        # Fourth pass: newly spawned tiles
        for value in new_tiles:
            for new_pos in new_tiles[value]:
                if new_pos not in used_new:
                    movements.append((None, new_pos, value, 'spawn'))
                    used_new.add(new_pos)
        
        return movements

    def update_board(self, state, animate=False):
        """Update the board display (without sliding)"""
        new_tiles = {}
        animations = []
        
        for i in range(self.board_size):
            for j in range(self.board_size):
                value = state[i][j]
                pos = (i, j)
                
                if value > 0:
                    tile_group = self.create_tile(value, i, j)
                    new_tiles[pos] = tile_group
                    
                    if pos not in self.tile_objects:
                        if animate:
                            tile_group.scale(0.1)
                            animations.append(tile_group.animate.scale(10))
                        self.add(tile_group)
        
        # Remove old tiles
        for pos, tile in self.tile_objects.items():
            if pos not in new_tiles:
                if animate:
                    animations.append(FadeOut(tile))
                else:
                    self.remove(tile)
        
        if animate and animations:
            self.play(*animations, run_time=0.5)
        
        self.tile_objects = new_tiles
    
    def update_score(self, score):
        """Update score display"""
        new_text = Text(f"Score: {score}", font_size=28, color="#776e65")
        new_text.move_to(self.score_text.get_center())
        self.remove(self.score_text)
        self.score_text = new_text
        self.add(self.score_text)
    
    def update_max_tile(self, max_tile):
        """Update max tile display"""
        new_text = Text(f"Max Tile: {int(max_tile)}", font_size=28, color="#776e65")
        new_text.move_to(self.max_tile_text.get_center())
        self.remove(self.max_tile_text)
        self.max_tile_text = new_text
        self.add(self.max_tile_text)
    
    def create_action_arrow(self, action: Direction):
        """Create an arrow showing the action"""
        arrow_configs = {
            Direction.UP: (DOWN, UP),
            Direction.DOWN: (UP, DOWN),
            Direction.LEFT: (RIGHT, LEFT),
            Direction.RIGHT: (LEFT, RIGHT),
        }
        
        start_dir, end_dir = arrow_configs[action]
        arrow = Arrow(
            start=self.board_group.get_edge_center(start_dir) + start_dir * 0.5,
            end=self.board_group.get_edge_center(end_dir) + end_dir * 0.3,
            color="#8f7a66",
            stroke_width=8,
            max_tip_length_to_length_ratio=0.15
        )
        return arrow