In [74]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import torch.optim as optim

In [47]:
import gymnasium as gym
import numpy as np
from collections import deque

class AtariEnvWrapper:
    def __init__(self, env_name="Breakout-v5", render_mode="rgb_array", stack_size=4):
        self.env = gym.make(env_name, render_mode=render_mode)
        self.stack_size = stack_size
        self.frames = deque(maxlen=stack_size)

    def preprocess(self, frame):
        import cv2
        gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        resized = cv2.resize(gray, (84, 110), interpolation=cv2.INTER_AREA)
        cropped = resized[18:102, :]
        return cropped.astype(np.uint8)  # (84, 84)

    def reset(self):
        obs, _ = self.env.reset()
        frame = self.preprocess(obs)
        self.frames = deque([frame] * self.stack_size, maxlen=self.stack_size)
        return np.stack(self.frames, axis=0)  # (4, 84, 84)

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        frame = self.preprocess(obs)
        self.frames.append(frame)
        stacked_obs = np.stack(self.frames, axis=0)  # (4, 84, 84)
        done = terminated or truncated
        return stacked_obs, reward, done, info


In [49]:
env = AtariEnvWrapper("ALE/Breakout-v5")
obs = env.reset()
print(obs.shape)  # → (4, 84, 84)

(4, 84, 84)


In [51]:
action = env.env.action_space.sample()   # env.env to access underlying Gym env
obs, reward, done, info = env.step(action)


In [53]:


env = AtariEnvWrapper("ALE/Breakout-v5")
obs = env.reset()

In [66]:
class DQN(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(4, 16, kernel_size = 8, stride =4)
        self.conv2 = nn.Conv2d(16, 32, kernel_size = 4, stride = 2)
        #pass the output of conv2 to a linear layer
        self.linear1 = nn.Linear(32*9*9, 256)
        self.output = nn.Linear(256, 4)
        
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        #flatten the output
        x = x.view(x.size(0), -1)
        x = self.linear1(x)
        x = F.relu(x)
        x = self.output(x)
        return x
        

In [67]:
class ReplayBuffer:
    
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
        if len(self.buffer)>self.capacity:
            self.buffer.pop(0)
            
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        
        #stack and convert to tensors
        states = torch.stack([torch.tensor(s, dtype=torch.float32) for s in states])
        actions = torch.tensor(actions, dtype=torch.long)
        rewards = torch.tensor(rewards, dtype=torch.float32)
        next_states = torch.stack([torch.tensor(s, dtype=torch.float32) for s in next_states])
        dones = torch.tensor(dones, dtype=torch.float32)
        
        return states, actions, rewards, next_states, dones
    
    def __len__(self):
        return len(self.buffer)
    

In [68]:
replay_buffer = ReplayBuffer(capacity=10000)

In [69]:
model = DQN()


In [70]:
num_episodes = 1000
num_steps = 1000

In [71]:
class DQNAgent:
    def __init__(self, model, optimizer, replay_buffer, num_actions, device, gamma = 0.99, epsilon_start = 1.0, epsilon_end = 0.1, epsilon_decay=1000000):
        self.model = model
        self.optimizer = optimizer
        self.replay_buffer = replay_buffer
        self.num_actions = num_actions
        self.device = device
        
        self.gamma = gamma
        
        self.epsilon = epsilon_start
        self.epsilon_end  = epsilon_end
        self.epsilon_decay = epsilon_decay
        self.step_count  = 0
        
    def select_action(self, state):
        self.step_count +=1
        self.epsilon = max(self.epsilon_end, self.epsilon - (1/self.epsilon_decay))
        
        if random.random() < self.epsilon:
            return random.randint(0, self.num_actions-1)
        else:
            state = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(self.device)
            with torch.no_grad():
                q_values = self.model(state)
            return q_values.argmax(dim=1).item()
        
    def store(self, state, action, reward, next_state, done):
        self.replay_buffer.push(state, action, reward, next_state, done)
        
    def learn(self, batch_size):
        if len(self.replay_buffer)<batch_size:
            return
        
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(batch_size)
        
        states = states.to(self.device)
        actions = actions.to(self.device)
        rewards = rewards.to(self.device)
        next_states = next_states.to(self.device)
        dones = dones.to(self.device)
        
        #q-val
        q_values = self.model(states)
        current_q = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

        with torch.no_grad():
            next_q = self.model(next_states)
            target_q = rewards + self.gamma*next_q*(1-dones)
            
        #loss 
        loss = F.mse_loss(current_q, target_q)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

In [76]:
optimizer = optim.SGD(model.parameters(), lr=0.001)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
agent = DQNAgent(model, optimizer, replay_buffer, num_actions=4, device=device)

In [78]:
# agent = DQNAgent(model, optimizer, replay_buffer, num_actions=4, device=device)

# 
# action = agent.select_action(state)
# next_state, reward, done, _ = env.step(action)
# agent.store(state, action, reward, next_state, done)
# agent.learn(batch_size=32)


In [79]:
# def train_dqn(agent, env, num_episodes=1000, batch_size=32, max_steps=10000, log_every=10):
#     episode_rewards = []

#     for episode in trange(num_episodes):
#         state = env.reset()
#         total_reward = 0

#         for t in range(max_steps):
#             action = agent.select_action(state)
#             next_state, reward, done, _ = env.step(action)

#             agent.store(state, action, reward, next_state, done)
#             agent.learn(batch_size)

#             state = next_state
#             total_reward += reward

#             if done:
#                 break

#         episode_rewards.append(total_reward)

#         if (episode + 1) % log_every == 0:
#             avg_reward = sum(episode_rewards[-log_every:]) / log_every
#             print(f"Episode {episode+1}, Avg Reward (last {log_every}): {avg_reward:.2f}, Epsilon: {agent.epsilon:.3f}")

#     return episode_rewards