In [1]:
from manim import *
from typing import List
from dataclasses import dataclass
from enum import Enum

import torch

config.media_embed = True

In [2]:
from Game import Game2048Env, Direction

In [3]:
from SimpleNeuralNetwork import SimpleNeuralNetwork, DEVICE

In [4]:
import pickle


def save_network(network: SimpleNeuralNetwork, filename: str):
    torch.save(network.state_dict(), filename)


def load_network(filename: str, hidden_layers: List[int]) -> SimpleNeuralNetwork:
    network = SimpleNeuralNetwork(hidden_layers=hidden_layers)
    network.load_state_dict(torch.load(filename, map_location=DEVICE))
    network.to(DEVICE)
    return network


def save_population(population: List[SimpleNeuralNetwork], filename: str):
    with open(filename, "wb") as f:
        pickle.dump(population, f)


def load_population(filename: str) -> List[SimpleNeuralNetwork]:
    with open(filename, "rb") as f:
        population = pickle.load(f)
    return population

In [5]:
from Player import Player

In [6]:
import pickle


def save_network(network: SimpleNeuralNetwork, filename: str):
    torch.save(network.state_dict(), filename)


def load_network(filename: str, hidden_layers: List[int]) -> SimpleNeuralNetwork:
    network = SimpleNeuralNetwork(hidden_layers=hidden_layers)
    network.load_state_dict(torch.load(filename, map_location=DEVICE))
    network.to(DEVICE)
    return network


def save_population(population: List[SimpleNeuralNetwork], filename: str):
    with open(filename, "wb") as f:
        pickle.dump(population, f)


def load_population(filename: str) -> List[SimpleNeuralNetwork]:
    with open(filename, "rb") as f:
        population = pickle.load(f)
    return population

In [7]:
def find_latest_network(folder: str = "./networks/") -> str:
    import os

    files = [
        f
        for f in os.listdir(folder)
        if f.startswith("population_gen_") and f.endswith(".pkl")
    ]
    if not files:
        return None
    latest_file = max(files, key=lambda x: int(x.split("_")[2].split(".")[0]))
    return os.path.join(folder, latest_file)


hidden_layers = [128]
layers_str = "_".join(map(str, hidden_layers))
latest_checkpoint = find_latest_network(f"./networks/{layers_str}")
# latest_checkpoint = './networks/128_128/population_gen_631.pkl'
print(latest_checkpoint)
population = load_population(latest_checkpoint)

best_block = 0
best_score = 0
best_network = None

games_per_player = 10

random_seed = 42

for i, net in enumerate(population):
    print(f"Evaluating Network {i + 1}/{len(population)}")
    player = Player(net.to(DEVICE))

    average_score = 0
    for game_index in range(games_per_player):
        game_random_seed = random_seed + game_index
        env = Game2048Env(random_seed=game_random_seed)
        score = 0
        while True:
            board = env.board
            # print_board(board)
            actions = player.next_move(board)
            state, reward, done, _ = env.step(actions)
            score += reward
            if done:
                average_score += score
                break

        print(f"  Game {game_index + 1}/{games_per_player} Score: {score}")

        average_score /= games_per_player
    if average_score > best_score:
        best_network = net
        best_score = average_score
        print(f"New Best Network Found! Tile: {np.max(state)} | Score: {average_score}")


player = Player(best_network)

del population

./networks/128/population_gen_1000.pkl
Evaluating Network 1/200
  Game 1/10 Score: 52
  Game 2/10 Score: 148
  Game 3/10 Score: 128
  Game 4/10 Score: 168
  Game 5/10 Score: 64
  Game 6/10 Score: 80
  Game 7/10 Score: 168
  Game 8/10 Score: 288
  Game 9/10 Score: 356
  Game 10/10 Score: 408
New Best Network Found! Tile: 64 | Score: 44.6656822332
Evaluating Network 2/200
  Game 1/10 Score: 52
  Game 2/10 Score: 148
  Game 3/10 Score: 128
  Game 4/10 Score: 168
  Game 5/10 Score: 64
  Game 6/10 Score: 80
  Game 7/10 Score: 168
  Game 8/10 Score: 288
  Game 9/10 Score: 356
  Game 10/10 Score: 408
Evaluating Network 3/200
  Game 1/10 Score: 52
  Game 2/10 Score: 148
  Game 3/10 Score: 128
  Game 4/10 Score: 168
  Game 5/10 Score: 64
  Game 6/10 Score: 80
  Game 7/10 Score: 168
  Game 8/10 Score: 288
  Game 9/10 Score: 356
  Game 10/10 Score: 408
Evaluating Network 4/200
  Game 1/10 Score: 52
  Game 2/10 Score: 148
  Game 3/10 Score: 128
  Game 4/10 Score: 168
  Game 5/10 Score: 64
  Game 6

In [8]:
%%manim -v WARNING --progress_bar None -qm --disable_caching Game2048AI

class Game2048AI(Scene):
    def construct(self):
        # Initialize game and player
        env = Game2048Env()
        player = Player(best_network)
        
        # Color scheme for tiles
        tile_colors = {
            0: "#cdc1b4",
            2: "#eee4da",
            4: "#ede0c8",
            8: "#f2b179",
            16: "#f59563",
            32: "#f67c5f",
            64: "#f65e3b",
            128: "#edcf72",
            256: "#edcc61",
            512: "#edc850",
            1024: "#edc53f",
            2048: "#edc22e",
        }
        
        # Create title
        title = Text("2048 AI Player", font_size=48).to_edge(UP, buff=0.5)
        self.play(Write(title))
        
        # Create score display
        score_text = Text("Score: 0", font_size=36).next_to(title, DOWN, buff=0.3)
        self.add(score_text)
        
        # Create the game board
        def create_board_visual(board):
            board_group = VGroup()
            cell_size = 1.2
            
            for i in range(4):
                for j in range(4):
                    value = int(board[i, j])
                    
                    # Create cell background
                    cell = Square(side_length=cell_size)
                    cell.set_fill(
                        color=tile_colors.get(value, "#3c3a32"),
                        opacity=1
                    )
                    cell.set_stroke(color="#bbada0", width=3)
                    
                    # Position cell
                    x_pos = (j - 1.5) * cell_size
                    y_pos = -(i - 1.5) * cell_size - 1
                    cell.move_to([x_pos, y_pos, 0])
                    
                    board_group.add(cell)
                    
                    # Add number text if not empty
                    if value > 0:
                        text_color = "#776e65" if value <= 4 else "#f9f6f2"
                        num_text = Text(
                            str(int(value)),
                            font_size=36 if value < 128 else 28,
                            color=text_color,
                            weight=BOLD
                        )
                        num_text.move_to(cell.get_center())
                        board_group.add(num_text)
            
            return board_group
        
        # Initial board
        board_visual = create_board_visual(env.board)
        self.play(FadeIn(board_visual))
        
        # Play the game
        total_score = 0
        move_count = 0
        max_moves = 10000  # Limit moves for reasonable video length
        
        while move_count < max_moves:
            board = env.board
            actions = player.next_move(board)  # Get list of actions
            
            # Execute move with list of actions
            state, reward, done, meta = env.step(actions)
            actual_action = meta.get("direction", None)
            total_score += reward
            
            # Show move direction (the actual action that was executed)
            direction_map = {
                Direction.UP: "↑",
                Direction.DOWN: "↓",
                Direction.LEFT: "←",
                Direction.RIGHT: "→"
            }
            
            move_text = Text(
                f"{direction_map.get(actual_action, '?')}",
                font_size=128
            ).to_edge(RIGHT, buff=2)
            
            # Update board visual - fade out old, fade in new
            new_board_visual = create_board_visual(env.board)
            
            # Update score
            new_score_text = Text(
                f"Score: {int(total_score)}",
                font_size=36
            ).next_to(title, DOWN, buff=0.3)
            
            # Animate the transition - remove old board and add new one
            self.play(
                FadeOut(board_visual),
                FadeIn(new_board_visual),
                Transform(score_text, new_score_text),
                FadeIn(move_text),
                run_time=0.5
            )
            board_visual = new_board_visual  # Update reference
            self.play(FadeOut(move_text), run_time=0.1)
            
            move_count += 1
            
            if done:
                break
        
        # Show final result
        max_tile = int(np.max(env.board))
        final_text = VGroup(
            Text("Game Over!", font_size=48, color=RED),
            Text(f"Max Tile: {max_tile}", font_size=36),
            Text(f"Final Score: {int(total_score)}", font_size=36),
            Text(f"Moves: {move_count}", font_size=36)
        ).arrange(DOWN, buff=0.3).move_to(ORIGIN)
        
        self.play(
            board_visual.animate.scale(0.7).to_edge(LEFT),
            FadeIn(final_text)
        )
        self.wait(2)