In [34]:
import pygame
import numpy as np
import random

# Define PongGame class
class PongGame:
    def __init__(self):
        self.WIDTH = 700
        self.HEIGHT = 700
        self.PADDLE_WIDTH = 20
        self.PADDLE_HEIGHT = 100
        self.BALL_RADIUS = 7
        self.WINNING_SCORE = 20
        self.FPS = 60
        self.WHITE = (255, 255, 255)
        self.BLACK = (0, 0, 0)
        self.left_score = 0
        self.right_score = 0
        self.left_vel = 10
        self.right_vel = 20
        self.reward = 0

        pygame.init()
        self.screen = pygame.display.set_mode((self.WIDTH, self.HEIGHT))
        self.clock = pygame.time.Clock()

        self.left_paddle = self.create_paddle(10, self.HEIGHT // 2 - self.PADDLE_HEIGHT // 2, self.left_vel)
        self.right_paddle = self.create_paddle(self.WIDTH-30, self.HEIGHT // 2 - self.PADDLE_HEIGHT // 2, self.right_vel)
        self.ball = self.create_ball(self.WIDTH // 2, self.HEIGHT // 2)

    def create_paddle(self, x, y, vel):
        return {
            "x": x,
            "y": y,
            "width": self.PADDLE_WIDTH,
            "height": self.PADDLE_HEIGHT,
            "vel": vel,
            "direction" : 0,
        }

    def create_ball(self, x, y):
        return {
            "x": x,
            "y": y,
            "radius": self.BALL_RADIUS,
            "x_vel": 8,
            "y_vel": 0,
        }

    def draw(self, render = False):
        self.screen.fill(self.BLACK)

        # Draw paddles
        pygame.draw.rect(self.screen, self.WHITE, 
                         (self.left_paddle["x"], self.left_paddle["y"], 
                          self.left_paddle["width"], self.left_paddle["height"]))
        pygame.draw.rect(self.screen, self.WHITE, 
                         (self.right_paddle["x"], self.right_paddle["y"], 
                          self.right_paddle["width"], self.right_paddle["height"]))

        # Draw ball
        pygame.draw.circle(self.screen, self.WHITE, 
                           (int(self.ball["x"]), int(self.ball["y"])), self.ball["radius"])

        # Draw middle separating line
        for i in range(10, self.HEIGHT, self.HEIGHT // 20):
            if i % 2 == 1:
                continue
            pygame.draw.rect(self.screen, self.WHITE, 
                             (self.WIDTH // 2 - 5, i, 10, self.HEIGHT // 20))

        # Draw scores
        score_font = pygame.font.SysFont("comicsans", 50)
        left_score_text = score_font.render(f"{self.left_score}", 1, self.WHITE)
        right_score_text = score_font.render(f"{self.right_score}", 1, self.WHITE)
        self.screen.blit(left_score_text, (self.WIDTH // 4 - left_score_text.get_width() // 2, 20))
        self.screen.blit(right_score_text, 
                         (self.WIDTH * 3 // 4 - right_score_text.get_width() // 2, 20))
        if render:
            pygame.display.update()  # Refresh the screen

    def move_ball(self):
        self.ball["x"] += self.ball["x_vel"]
        self.ball["y"] += self.ball["y_vel"]

        # Ball collisions with top and bottom walls
        if self.ball["y"] <= 0 or self.ball["y"] >= self.HEIGHT:
            self.ball["y_vel"] *= -1

        # Ball collision with the left paddle
        if (self.ball["y"] >= self.left_paddle["y"] and
            self.ball["y"] <= self.left_paddle["y"] + self.left_paddle["height"]):
            if self.ball["x"] - self.ball["radius"] <= self.left_paddle["x"] + self.left_paddle["width"]:
                self.ball["x"] = self.left_paddle["x"] + self.left_paddle["width"] + self.ball["radius"]
                self.ball["x_vel"] *= -1
                
                if self.ball["y_vel"] < 0:
                    ball_direction = 1
                elif self.ball["y_vel"] >0:
                    ball_direction = -1
                else:
                    ball_direction = 0
                middle_y = self.left_paddle["y"] + self.left_paddle["height"] / 2
                difference_in_y = middle_y - self.ball["y"]
                reduction_factor = (self.left_paddle["height"] / 2) / self.ball["x_vel"]
                red = difference_in_y / reduction_factor
                if self.left_paddle["direction"] != 0:
                    self.ball["y_vel"] += -1 * ball_direction * self.left_paddle["direction"] * red
                else:
                    self.ball["y_vel"] = -1 * red
                

        # Ball collision with the right paddle
        if (self.ball["y"] >= self.right_paddle["y"] and
            self.ball["y"] <= self.right_paddle["y"] + self.right_paddle["height"]):
            if self.ball["x"] + self.ball["radius"] >= self.right_paddle["x"]:
                self.ball["x"] = self.right_paddle["x"] - self.ball["radius"]
                self.ball["x_vel"] *= -1
                
                if self.ball["y_vel"] < 0:
                    ball_direction = 1
                elif self.ball["y_vel"] >0:
                    ball_direction = -1
                else:
                    ball_direction = 0
                middle_y = self.right_paddle["y"] + self.right_paddle["height"] / 2
                difference_in_y = middle_y - self.ball["y"]
                reduction_factor = (self.right_paddle["height"] / 2) / self.ball["x_vel"]
                red = difference_in_y / reduction_factor
                if self.right_paddle["direction"] != 0:
                    self.ball["y_vel"] += -1 * ball_direction * self.right_paddle["direction"] * red
                else:
                    self.ball["y_vel"] = -1 * red
                self.reward = 5
    def track_ball_with_left_paddle(self):
        # If the ball is above the paddle, move up
        if self.ball["y"] < self.left_paddle["y"]:
            self.left_paddle["y"] -= self.left_paddle["vel"]
            self.left_paddle["direction"] = 1
        # If the ball is below the paddle, move down
        elif self.ball["y"] > self.left_paddle["y"] + self.left_paddle["height"]:
            self.left_paddle["y"] += self.left_paddle["vel"]
            self.left_paddle["direction"] = -1
        else:
            self.left_paddle["direction"] = 0
    
    def step(self, action, render = False):
        # Action 0: move right paddle up, 1: move right paddle down
        if action == 0 and self.right_paddle["y"] > 0:
            self.right_paddle["y"] -= self.right_paddle["vel"]
            self.right_paddle["direction"] = 1
        elif action == 1 and self.right_paddle["y"] < self.HEIGHT - self.PADDLE_HEIGHT:
            self.right_paddle["y"] += self.right_paddle["vel"]
            self.right_paddle["direction"] = -1
        else:
            self.right_paddle["direction"] = 0

        self.reward = 0
        self.move_ball()  # Move the ball

        # Check if the ball missed the paddles
        
        done = False
        if self.ball["x"] < 0:  # Left side
            self.right_score += 1
            self.reset_ball()  # Reset the ball
            self.reward = 10  # Positive reward when the ball goes left
            done = self.right_score >= self.WINNING_SCORE
        elif self.ball["x"] > self.WIDTH:  # Right side
            self.left_score += 1
            self.reset_ball()  # Reset the ball
            self.reward = -10
            done = self.left_score >= self.WINNING_SCORE

        self.draw(render)  # Draw the game
        current_frame = pygame.surfarray.array3d(self.screen)  # Get the frame
        current_frame = np.transpose(current_frame, (1, 0, 2))  # Adjust the frame shape
        
        return current_frame, self.reward, done

    def reset(self):
        self.left_paddle = self.create_paddle(10, self.HEIGHT // 2 - self.PADDLE_HEIGHT // 2, self.left_vel)
        self.right_paddle = self.create_paddle(self.WIDTH - 30, self.HEIGHT // 2 - self.PADDLE_HEIGHT // 2, self.right_vel)
        self.reset_ball()
        self.left_score = 0
        self.right_score = 0
        return self.get_frame()  # Return the initial frame

    def reset_ball(self):
        self.ball = self.create_ball(self.WIDTH // 2, self.HEIGHT // 2)
        self.ball["x_vel"] = self.ball["x_vel"] if random.choice([True, False]) else (-1 * self.ball["x_vel"])
        self.ball["y_vel"] = 0
        angle = random.uniform(-45, 45)  # Random launch angle
        self.ball["y_vel"] = self.ball["x_vel"] * np.tan(np.radians(angle))  # Calculate y-velocity

    def get_frame(self):
        self.draw()  # Draw the game
        current_frame = pygame.surfarray.array3d(self.screen)  # Get the frame
        current_frame = np.transpose(current_frame, (1, 0, 2))  # Adjust the frame shape
        return current_frame


In [35]:
import pygame
import numpy as np
import cv2
import collections
import torch as T
import random
import os
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Constants for Pong game
WIDTH, HEIGHT = 700, 700
FPS = 60
PADDLE_WIDTH, PADDLE_HEIGHT = 20, 100
BALL_RADIUS = 7
WHITE = (255, 255, 255)
BLACK = (0, 0, 0)
WINNING_SCORE = 10

# Initialize Pygame
pygame.init()

# Function to preprocess Pygame frame for DQN
def preprocess_frame(frame):
    gray_frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)  # Grayscale
    resized_frame = cv2.resize(gray_frame, (84, 84), interpolation=cv2.INTER_AREA)  # Resize
    normalized_frame = resized_frame / 255.0  # Normalize
    return np.expand_dims(normalized_frame, axis=0)  # For CNN

# Frame stacker to stack consecutive frames
class FrameStacker:
    def __init__(self, stack_size):
        self.stack_size = stack_size
        self.stack = collections.deque(maxlen=stack_size)

    def reset(self, initial_frame):
        initial_frame = initial_frame.squeeze() 
        self.stack = collections.deque([initial_frame] * self.stack_size, maxlen=self.stack_size)
        return np.array(self.stack)

    def add_frame(self, frame):
        frame = frame.squeeze()
        self.stack.append(frame)
        return np.array(self.stack)

# Replay buffer to store transitions
class ReplayBuffer:
    def __init__(self, max_size, input_shape, n_actions):
        self.mem_size = max_size
        self.mem_cntr = 0
        self.state_memory = np.zeros((self.mem_size, *input_shape), dtype=np.float32)
        self.new_state_memory = np.zeros((self.mem_size, *input_shape), dtype=np.float32)
        self.action_memory = np.zeros(self.mem_size, dtype=np.int64)
        self.reward_memory = np.zeros(self.mem_size, dtype=np.float32)
        self.terminal_memory = np.zeros(self.mem_size, dtype=bool)

    def store_transition(self, state, action, reward, state_, done):
        index = self.mem_cntr % self.mem_size
        self.state_memory[index] = state
        self.new_state_memory[index] = state_
        self.action_memory[index] = action
        self.reward_memory[index] = reward
        self.terminal_memory[index] = done
        self.mem_cntr += 1

    def sample_transitions(self, batch_size):
        max_mem = min(self.mem_cntr, self.mem_size)
        indices = np.random.choice(max_mem, batch_size, replace=False)
        return (
            self.state_memory[indices],
            self.action_memory[indices],
            self.reward_memory[indices],
            self.new_state_memory[indices],
            self.terminal_memory[indices],
        )

# Deep Q-Network (DQN) definition
class DeepQNetwork(nn.Module):
    def __init__(self, lr, n_actions, input_dims, name, chkpt_dir):
        super(DeepQNetwork, self).__init__()
        self.checkpoint_dir = chkpt_dir
        self.checkpoint_file = os.path.join(self.checkpoint_dir, name)
        self.conv1 = nn.Conv2d(input_dims[0], 32, 8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, 3, stride=1)
        fc_input_dims = self.calculate_conv_output_dims(input_dims)
        self.fc1 = nn.Linear(fc_input_dims, 512)
        self.fc2 = nn.Linear(512, n_actions)

        self.optimizer = optim.RMSprop(self.parameters(), lr=lr)
        self.loss = nn.MSELoss()
        self.device = T.device("cuda:0" if T.cuda.is_available() else "cpu")
        self.to(self.device)

    def calculate_conv_output_dims(self, input_dims):
        state = T.zeros(1, *input_dims)
        dims = self.conv1(state)
        dims = self.conv2(dims)
        dims = self.conv3(dims)
        return int(np.prod(dims.size()))

    def forward(self, state):
        x = F.relu(self.conv1(state))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)  # Flatten
        x = F.relu(self.fc1(x))
        actions = self.fc2(x)
        return actions
    
    def save_checkpoint(self):
        print('... saving checkpoint ...')
        T.save(self.state_dict(), self.checkpoint_file)

    def load_checkpoint(self):
        print('... loading checkpoint ...')
        self.load_state_dict(T.load(self.checkpoint_file, map_location=T.device(self.device)))

# DQN agent to interact with the game
class DQNAgent:
    def __init__(self, gamma, epsilon, lr, n_actions, input_dims, mem_size, batch_size, eps_min=0.01, eps_dec=5e-7, name = None, chkpt_dir = None):
        self.gamma = gamma
        self.epsilon = epsilon
        self.lr = lr
        self.n_actions = n_actions
        self.input_dims = input_dims
        self.batch_size = batch_size
        self.eps_min = eps_min
        self.eps_dec = eps_dec
        self.learn_step_counter = 0
        self.memory = ReplayBuffer(mem_size, input_dims, n_actions)
        self.q_eval = DeepQNetwork(lr, n_actions, input_dims, name+'_q_eval', chkpt_dir)
        self.q_next = DeepQNetwork(lr, n_actions, input_dims, name+'_q_next', chkpt_dir)

    def choose_action(self, observation):
        if np.random.random() > self.epsilon:
            observation_np = np.array(observation)  # Ensure it's a NumPy array
            state = T.tensor([observation_np], dtype=T.float).to(self.q_eval.device)
            actions = self.q_eval(state)
            action = T.argmax(actions).item()
        else:
            action = np.random.choice([0, 1])  # 0: up, 1: down
        return action

    def store_transition(self, state, action, reward, state_, done):
        self.memory.store_transition(state, action, reward, state_, done)
        
    def save_models(self):
        self.q_eval.save_checkpoint()
        self.q_next.save_checkpoint()

    def load_models(self):
        self.q_eval.load_checkpoint()
        self.q_next.load_checkpoint()

    def learn(self):
        if self.memory.mem_cntr < self.batch_size:
            return

        self.q_eval.optimizer.zero_grad()
        self.replace_target_network()

        states, actions, rewards, states_, dones = self.memory.sample_transitions(self.batch_size)
        states = T.tensor(states, dtype=T.float).to(self.q_eval.device)  # Convert to tensor and move to device
        actions = T.tensor(actions, dtype=T.long).to(self.q_eval.device)  
        rewards = T.tensor(rewards, dtype=T.float).to(self.q_eval.device)
        states_ = T.tensor(states_, dtype=T.float).to(self.q_eval.device)
        dones = T.tensor(dones, dtype=T.bool).to(self.q_eval.device)
        indices = np.arange(self.batch_size)

        q_pred = self.q_eval(states)[indices, actions]
        q_next = self.q_next(states_).max(dim=1)[0]

        q_next[dones] = 0.0  # No future value if the game is over
        q_target = rewards + self.gamma * q_next

        loss = self.q_eval.loss(q_target, q_pred).to(self.q_eval.device)
        loss.backward()
        self.q_eval.optimizer.step()

        self.learn_step_counter += 1
        self.epsilon = max(self.epsilon - self.eps_dec, self.eps_min)

    def replace_target_network(self):
        if self.learn_step_counter % 1000 == 0:
            self.q_next.load_state_dict(self.q_eval.state_dict())




In [36]:
pong_game = PongGame()
agent = DQNAgent(
    gamma=0.99,
    epsilon=1.0,
    lr=0.0001,
    n_actions=2,  # Two actions: move up or down
    input_dims=(4, 84, 84),  # Stacking four frames
    mem_size=20000,
    batch_size=32,
    eps_min=0.1,
    eps_dec=25e-7, 
    chkpt_dir='models/', 
    name='DQNAgent_final'
)

best_score = -np.inf
scores =[]
frame_stacker = FrameStacker(4)


In [94]:

n_episodes = 200
for episode in range(87,200):
    done = False
    ep_reward = 0
    frame = preprocess_frame(pong_game.reset())  # Preprocess the first frame
    stacked_frames = frame_stacker.reset(frame)  # Initialize frame stack

    while not done:
        pong_game.track_ball_with_left_paddle()
        action = agent.choose_action(stacked_frames)  # Get the action
        next_frame, reward, done = pong_game.step(action)  # Get next state and reward
        next_frame = preprocess_frame(next_frame)  # Preprocess the next frame
        stacked_frames_ = frame_stacker.add_frame(next_frame)  # Stack frames

        agent.store_transition(stacked_frames, action, reward, stacked_frames_, done)  # Store transitions
        agent.learn()  # Train the DQN agent on stored transitions

        stacked_frames = stacked_frames_  # Update current state
        ep_reward += reward  # Add to the score
        
    scores.append(pong_game.right_score)
    avg_score = np.mean(scores[-100:])
    print(f"Episode {episode}, Reward: {ep_reward}, Score: {pong_game.right_score}, Average score: {avg_score} , Epsilon: {agent.epsilon}")  # Display the episode score
    if avg_score > best_score:
        agent.save_models()
        best_score = avg_score

pygame.quit()  # Clean up Pygame resources


Episode 87, Reward: 450, Score: 20, Average score: 10.363636363636363 , Epsilon: 0.1
... saving checkpoint ...
... saving checkpoint ...
Episode 88, Reward: 465, Score: 20, Average score: 10.47191011235955 , Epsilon: 0.1
... saving checkpoint ...
... saving checkpoint ...
Episode 89, Reward: 460, Score: 20, Average score: 10.577777777777778 , Epsilon: 0.1
... saving checkpoint ...
... saving checkpoint ...
Episode 90, Reward: 460, Score: 20, Average score: 10.68131868131868 , Epsilon: 0.1
... saving checkpoint ...
... saving checkpoint ...
Episode 91, Reward: 405, Score: 20, Average score: 10.782608695652174 , Epsilon: 0.1
... saving checkpoint ...
... saving checkpoint ...
Episode 92, Reward: 435, Score: 20, Average score: 10.881720430107526 , Epsilon: 0.1
... saving checkpoint ...
... saving checkpoint ...
Episode 93, Reward: 515, Score: 20, Average score: 10.97872340425532 , Epsilon: 0.1
... saving checkpoint ...
... saving checkpoint ...
Episode 94, Reward: 520, Score: 20, Average 

KeyboardInterrupt: 

In [52]:
agent.q_eval

DeepQNetwork(
  (conv1): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
  (conv2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
  (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=3136, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=2, bias=True)
  (loss): MSELoss()
)

In [None]:
pygame.quit()

In [12]:
cv2.destroyAllWindows()

In [95]:
agent.save_models()

... saving checkpoint ...
... saving checkpoint ...


In [40]:
agent2 = DQNAgent(
    gamma=0.99,
    epsilon=0,
    lr=0.0001,
    n_actions=2,  # Two actions: move up or down
    input_dims=(4, 84, 84),  # Stacking four frames
    mem_size=20000,
    batch_size=32,
    eps_min=0.1,
    eps_dec=1e-4, 
    chkpt_dir='models/', 
    name='DQNAgent_final'
)

In [41]:
agent2.load_models()

... loading checkpoint ...
... loading checkpoint ...


In [42]:
def generate_saliency_map(model, input_image, action, device='cpu'):
    model.to(device)
    model.eval()
    input_image = T.tensor(input_image, dtype=T.float32).unsqueeze(0).to(device)
    input_image.requires_grad = True

    # Forward pass
    predicted_q_values = model(input_image)
    
    # Zero-out all the gradients in the input image
    model.zero_grad()

    # Select the Q-value corresponding to the specified action
    q_value = predicted_q_values[0, action]

    # Backward pass
    q_value.backward()

    saliency = input_image.grad.data.abs().squeeze()
    saliency = saliency[-1, :, :]
    saliency = saliency.cpu().numpy()
    if saliency.ndim == 3:  # If still three dimensions, take max across channels
        saliency = np.max(saliency, axis=0)
    
    return saliency

In [96]:
pong_game = PongGame()

for episode in range(1,11):
    done = False
    score = 0
    frame = preprocess_frame(pong_game.reset())  # Preprocess the first frame
    stacked_frames = frame_stacker.reset(frame)  # Initialize frame stack

    while not done:
        pong_game.clock.tick(60)
        pong_game.track_ball_with_left_paddle()
        action = agent2.choose_action(stacked_frames)
        next_frame, reward, done = pong_game.step(action,True)
        score += reward
        cv2.imshow("Original Image", next_frame)
#         cv2.imshow("Overlay Image", overlay_image)
        
        next_frame = preprocess_frame(next_frame)  # Preprocess the next frame
        stacked_frames = frame_stacker.add_frame(next_frame)
         
        saliency_map = generate_saliency_map(agent2.q_eval, stacked_frames, action)
        resized_saliency = cv2.resize(saliency_map, (700, 700), interpolation=cv2.INTER_NEAREST)

        # resized_saliency = cv2.blur(resized_saliency,(5,5))
        # saliency_flat = resized_saliency.flatten()
        # saliency_sorted = np.sort(saliency_flat)
        
        # threshold_index = int(0.9 * len(saliency_sorted))
        # threshold_value = saliency_sorted[threshold_index]
        # # Set pixels below the threshold to zero
        # resized_saliency[resized_saliency < threshold_value] = 0
        # smoothed_saliency = cv2.GaussianBlur(resized_saliency, (1,1),0)
        # smoothed_saliency = cv2.blur(resized_saliency,(5,5))
        

        cv2.imshow("Saliency map2", resized_saliency)
        # cv2.imshow("Saliency map2", smoothed_saliency)
        cv2.waitKey(1)
        

        
    print(f"Episode {episode}, Score {score}")

KeyboardInterrupt: 