In [None]:
# Imports
import torch
import numpy as np
import gymnasium as gym
from collections import deque
import pygame
import random
import torch.nn as nn

In [None]:
# DQN model which takes in the state as an input and outputs predicted q values for every possible action
class DQN(torch.nn.Module):
    def __init__(self, state_space, action_space):
        super().__init__()
        # Add your architecture parameters here
        # You can use nn.Functional
        # Remember that the input is of size batch_size x state_space
        # and the output is of size batch_size x action_space (ulta ho sakta hai dekh lo)
        # TODO: Add code here
        self.fc1 = nn.Linear(state_space, 128)
        self.fc2 = nn.Linear(128, 128)
        self.out = nn.Linear(128, action_space)

    def forward(self, input):
        # TODO: Complete based on your implementation
        x = input
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.out(x)


In [None]:
# While training neural networks, we split the data into batches.
# To improve the training, we need to remove the "correlation" between game states
# The buffer starts storing states and once it reaches maximum capacity, it replaces
# states at random which reduces the correlation.
class ExperienceBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return states, actions, rewards, next_states, dones

    def __len__(self):
        return len(self.buffer)


In [None]:
class SnakeGame(gym.Env):
    metadata = {"render_modes": ["human"], "render_fps": 10}

    def __init__(self, size=10, render_mode=None):
        super().__init__()
        self.size = size
        self.cell_size = 30
        self.screen_size = self.size * self.cell_size
        self.render_mode = render_mode

        self.action_space = gym.spaces.Discrete(4)  # 0: right, 1: up, 2: left, 3: down
        self.observation_space = gym.spaces.Box(0, 2, shape=(self.size, self.size,3), dtype=np.uint8)

        self.screen = None
        self.clock = None
        self.score = 0

        self.snake = deque()
        self.food = None
        self.direction = 0

        self.done = False

        if self.render_mode == "human":
            pygame.init()
            self.screen = pygame.display.set_mode((self.screen_size, self.screen_size))
            self.clock = pygame.time.Clock()

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.snake.clear()
        mid = self.size // 2
        self.snake.appendleft((mid, mid))
        self.direction = 0
        self._place_food()
        self.done = False
        obs = self._get_obs()

        if self.render_mode == "human":
            self._render_init()

        # print(f"Reset: Snake at {tuple(self.snake)}, Food at {self.food}, Direction: {self.direction}")
        return obs, {}

    def step(self, action):
        # TODO: Change reward schema to avoid the following
        # 1) 180 degree turns
        # 2) Wall collisions
        # 3) Being slow at collecting food

        # print(f"--- Step ---\nAction: {action}, Direction: {self.direction}")
        # print(f"Snake before: {list(self.snake)}")

        if self.done:
            return self._get_obs(), 0, True, {}

        # Prevent 180-degree turns by reverting to previous direction if attempted
        if (action + 2) % 4 == self.direction:
          action = self.direction
        self.direction = action

        # [right, up, left, down]:
        move = [(1, 0), (0, -1), (-1, 0), (0, 1)][self.direction]
        head_x, head_y = self.snake[0]
        new_head = (head_x + move[0], head_y + move[1])


        if (not 0 <= new_head[0] < self.size) or (not 0 <= new_head[1] < self.size) or (new_head in self.snake):
            self.done = True
            return self._get_obs(), -10, True, {}

        self.snake.insert(0, new_head)

        if new_head == self.food:
            self.score += 1
            reward = 10
            self._place_food()
            print(f"Food eaten at {self.food}! Current score: {self.score}, reward: {reward}")
        else:
            self.snake.pop()
            reward = -0.1  # small penalty to encourage faster food collection

        return self._get_obs(), reward, False, {}


    def _get_obs(self):
        # TODO: Return an observation state, take inspiration from the observation_space attribute
        grid = np.zeros((self.size, self.size, 3), dtype=np.uint8)
        # print("hi")
        for x, y in self.snake:
            grid[y, x, 0] = 1  # snake body in red channel
        fx, fy = self.food
        grid[fy, fx, 1] = 1  # food in green channel
        return grid

    def _place_food(self):
        positions = set(tuple(p) for p in self.snake)
        empty = [(x, y) for x in range(self.size) for y in range(self.size) if (x, y) not in positions]
        self.food = tuple(random.choice(empty)) if empty else None

    def render(self):
        if self.screen is None:
            self._render_init()

        self.screen.fill((0, 0, 0))
        for x, y in self.snake:
            pygame.draw.rect(
                self.screen, (0, 255, 0),
                pygame.Rect(x * self.cell_size, y * self.cell_size, self.cell_size, self.cell_size)
            )
        if self.food:
            fx, fy = self.food
            pygame.draw.rect(
                self.screen, (255, 0, 0),
                pygame.Rect(fx * self.cell_size, fy * self.cell_size, self.cell_size, self.cell_size)
            )

        pygame.display.flip()
        self.clock.tick(self.metadata["render_fps"])

    def _render_init(self):
        pygame.init()
        self.screen = pygame.display.set_mode((self.size * self.cell_size, self.size * self.cell_size))
        self.clock = pygame.time.Clock()

    def close(self):
        if self.screen:
            pygame.quit()
            self.screen = None

In [None]:
def train_snake(episodes=300):
    env = SnakeGame(size=10)
    obs_shape = env.observation_space.shape  # (10, 10, 3)
    state_dim = np.prod(obs_shape)           # Flattened input: 300
    action_dim = env.action_space.n          # 4 actions

    model = DQN(state_dim, action_dim)
    target_model = DQN(state_dim, action_dim)
    target_model.load_state_dict(model.state_dict())

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    buffer = ExperienceBuffer(5000)

    batch_size = 32
    gamma = 0.99
    epsilon = 1.0
    epsilon_decay = 0.995
    min_epsilon = 0.05

    for ep in range(episodes):
        obs, _ = env.reset()
        state = obs.flatten().astype(np.float32)
        total_reward = 0
        done = False

        while not done:
            if random.random() < epsilon:
                action = random.randint(0, action_dim - 1)
            else:
                with torch.no_grad():
                    q_vals = model(torch.tensor(state, dtype=torch.float32).unsqueeze(0))
                    action = torch.argmax(q_vals).item()

            next_obs, reward, done, _ = env.step(action)
            next_state = next_obs.flatten().astype(np.float32)

            buffer.push(state.copy(), action, reward, next_state.copy(), done)

            state = next_state
            total_reward += reward

            if len(buffer) >= batch_size:
                states, actions, rewards, next_states, dones = buffer.sample(batch_size)

                states = torch.tensor(states, dtype=torch.float32)
                actions = torch.tensor(actions, dtype=torch.int64).unsqueeze(1)
                rewards = torch.tensor(rewards, dtype=torch.float32).unsqueeze(1)
                next_states = torch.tensor(next_states, dtype=torch.float32)
                dones = torch.tensor(dones, dtype=torch.float32).unsqueeze(1)

                q_values = model(states).gather(1, actions)
                next_q = target_model(next_states).max(1, keepdim=True)[0].detach()
                target = rewards + gamma * next_q * (1 - dones)

                loss = nn.MSELoss()(q_values, target)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        if ep % 10 == 0:
            target_model.load_state_dict(model.state_dict())
            print(f"Episode {ep}: Total Reward = {total_reward}")

        epsilon = max(min_epsilon, epsilon * epsilon_decay)
        # print(epsilon)

    env.close()
    return model


In [None]:
def evaluate_snake_model(model, size=10, episodes=10, render=True):
    env = SnakeGame(size=size, render_mode="human" if render else None)
    model.eval()

    rewards = []

    for episode in range(episodes):
        obs, _ = env.reset()
        total_reward = 0
        done = False

        while not done:
            state = torch.tensor(obs.flatten(), dtype=torch.float32).unsqueeze(0)
            with torch.no_grad():
                q_values = model(state)
                action = torch.argmax(q_values, dim=1).item()

            obs, reward, done, _ = env.step(action)
            total_reward += reward

            if render:
                env.render()

        rewards.append(total_reward)
        print(f"Episode {episode + 1}: Reward = {total_reward}")

    env.close()
    avg_reward = sum(rewards) / episodes

    print(f"Average reward over {episodes} episodes: {avg_reward}")

In [None]:
# TODO: Run evaluation for Snake Game here

torch.cuda.empty_cache()  # if using GPU
model = train_snake()
evaluate_snake_model(model)

Episode 0: Total Reward = -10.7


  states = torch.tensor(states, dtype=torch.float32)


Food eaten at (0, 7)! Current score: 1, reward: 10
Food eaten at (1, 0)! Current score: 2, reward: 10
Episode 10: Total Reward = -10.6
Food eaten at (0, 9)! Current score: 3, reward: 10
Food eaten at (1, 3)! Current score: 4, reward: 10
Episode 20: Total Reward = -11.9
Episode 30: Total Reward = -11.1
Food eaten at (5, 4)! Current score: 5, reward: 10
Food eaten at (4, 7)! Current score: 6, reward: 10
Episode 40: Total Reward = -11.1
Food eaten at (0, 5)! Current score: 7, reward: 10
Episode 50: Total Reward = -10.7
Food eaten at (8, 7)! Current score: 8, reward: 10
Episode 60: Total Reward = -10.7
Food eaten at (7, 4)! Current score: 9, reward: 10
Episode 70: Total Reward = -10.7
Episode 80: Total Reward = -10.5
Episode 90: Total Reward = -15.299999999999997
Food eaten at (3, 0)! Current score: 10, reward: 10
Food eaten at (2, 8)! Current score: 11, reward: 10
Episode 100: Total Reward = -10.5
Food eaten at (3, 0)! Current score: 12, reward: 10
Episode 110: Total Reward = -11.1
Food e