In [1]:
import pygame
import torch
import numpy as np
import time
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import matplotlib.pyplot as plt
from IPython.display import clear_output

pygame 2.6.1 (SDL 2.28.4, Python 3.12.8)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
class DQN(nn.Module):
    def __init__(self, input_size, output_size):
        super(DQN, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_size, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, output_size)
        )
    
    def forward(self, x):
        return self.network(x)

In [3]:
class HumanVsAIPong:
    def __init__(self):
        # Initialize game parameters
        self.width = 400  # Match original training width
        self.height = 400  # Match original training height
        self.paddle_width = 10  # Match training paddle width
        self.paddle_height = 60  # Match training paddle height
        self.ball_size = 10
        self.paddle_speed = 5  # Match training speed
        self.ball_speed = 5    # Match training speed
        
         # AI performance parameters
        self.ai_update_frequency = 1  # Update every frame for better performance
        self.prediction_threshold = 0.8  # Confidence threshold for actions
        self.frame_count = 0
        
        # Initialize positions
        self.human_paddle_pos = self.height // 2
        self.ai_paddle_pos = self.height // 2
        self.ball_pos = [self.width // 2, self.height // 2]
        self.ball_direction = [1, 1]
        
        # Initialize scores
        self.human_score = 0
        self.ai_score = 0
        
        # Initialize Pygame
        pygame.init()
        self.screen = pygame.display.set_mode((self.width, self.height))
        pygame.display.set_caption("Human vs AI Pong")
        self.clock = pygame.time.Clock()
        
        # Initialize font
        self.font = pygame.font.Font(None, 74)
        
    def get_ai_state(self):
        # Convert game state to AI input format
        return np.array([
            self.ai_paddle_pos / self.height,
            self.ball_pos[1] / self.height,
            (self.width - self.ball_pos[0]) / self.width,  # Flip x-coordinate for AI
            -self.ball_direction[0],  # Flip x-direction for AI
            self.ball_direction[1]
        ])
    
    def reset_ball(self, direction):
        self.ball_pos = [self.width // 2, self.height // 2]
        self.ball_direction = [direction, random.uniform(-1, 1)]
        self.normalize_ball_direction()
    
    def normalize_ball_direction(self):
        # Normalize ball direction vector
        length = np.sqrt(self.ball_direction[0]**2 + self.ball_direction[1]**2)
        self.ball_direction = [self.ball_direction[0]/length, self.ball_direction[1]/length]
    
    def update(self, ai_model):
        # Handle events and human input
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                return False
            
        keys = pygame.key.get_pressed()
        if keys[pygame.K_w]:
            self.human_paddle_pos = max(self.paddle_height // 2,
                                      self.human_paddle_pos - self.paddle_speed)
        if keys[pygame.K_s]:
            self.human_paddle_pos = min(self.height - self.paddle_height // 2,
                                      self.human_paddle_pos + self.paddle_speed)
        
        # Update AI paddle
        if self.frame_count % self.ai_update_frequency == 0:
            state = torch.FloatTensor(self.get_ai_state()).unsqueeze(0)
            with torch.no_grad():
                action = ai_model(state).max(1)[1].item()
                
            if action == 1:  # Move up
                self.ai_paddle_pos = max(self.paddle_height // 2,
                                       self.ai_paddle_pos - self.paddle_speed)
            elif action == 2:  # Move down
                self.ai_paddle_pos = min(self.height - self.paddle_height // 2,
                                       self.ai_paddle_pos + self.paddle_speed)
        
        # Update ball position
        self.ball_pos[0] += self.ball_speed * self.ball_direction[0]
        self.ball_pos[1] += self.ball_speed * self.ball_direction[1]
        
        # Ball collision with top and bottom
        if self.ball_pos[1] <= 0 or self.ball_pos[1] >= self.height:
            self.ball_direction[1] *= -1
        
        # Ball collision with paddles
        # Human paddle
        if (self.ball_pos[0] <= self.paddle_width + self.ball_size//2 and
            abs(self.ball_pos[1] - self.human_paddle_pos) < self.paddle_height // 2):
            self.ball_direction[0] *= -1
            # Add some randomness to bounce
            self.ball_direction[1] += random.uniform(-0.2, 0.2)
            self.normalize_ball_direction()
        
        # AI paddle
        if (self.ball_pos[0] >= self.width - self.paddle_width - self.ball_size//2 and
            abs(self.ball_pos[1] - self.ai_paddle_pos) < self.paddle_height // 2):
            self.ball_direction[0] *= -1
            # Add some randomness to bounce
            self.ball_direction[1] += random.uniform(-0.2, 0.2)
            self.normalize_ball_direction()
        
        # Score points
        if self.ball_pos[0] <= 0:
            self.ai_score += 1
            self.reset_ball(1)  # Ball moves towards human
        elif self.ball_pos[0] >= self.width:
            self.human_score += 1
            self.reset_ball(-1)  # Ball moves towards AI
        
        self.frame_count += 1
        return True
    
    def render(self):
        # Clear screen
        self.screen.fill((0, 0, 0))
        
        # Draw paddles
        pygame.draw.rect(self.screen, (255, 255, 255),
                        (0, self.human_paddle_pos - self.paddle_height // 2,
                         self.paddle_width, self.paddle_height))
        pygame.draw.rect(self.screen, (255, 255, 255),
                        (self.width - self.paddle_width,
                         self.ai_paddle_pos - self.paddle_height // 2,
                         self.paddle_width, self.paddle_height))
        
        # Draw ball
        pygame.draw.circle(self.screen, (255, 255, 255),
                         (int(self.ball_pos[0]), int(self.ball_pos[1])),
                         self.ball_size // 2)
        
        # Draw center line
        pygame.draw.line(self.screen, (255, 255, 255),
                        (self.width // 2, 0),
                        (self.width // 2, self.height),
                        2)
        
        # Draw scores
        human_text = self.font.render(str(self.human_score), True, (255, 255, 255))
        ai_text = self.font.render(str(self.ai_score), True, (255, 255, 255))
        self.screen.blit(human_text, (self.width // 4, 20))
        self.screen.blit(ai_text, (3 * self.width // 4, 20))
        
        pygame.display.flip()
        self.clock.tick(60)
    
    def close(self):
        pygame.quit()

In [4]:
def retrain_model(base_model_path, episodes=100, learning_rate=0.0001):
    """Retrain an existing model to improve its performance"""
    # Load the existing model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = DQN(5, 3).to(device)
    model.load_state_dict(torch.load(base_model_path))
    
    # Create optimizer with lower learning rate for fine-tuning
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    memory = ReplayBuffer(10000)
    
    # Create environment
    env = PongEnv()
    
    # Training parameters
    gamma = 0.99
    batch_size = 64
    epsilon = 0.1  # Lower epsilon for more exploitation
    
    try:
        scores = []
        for episode in range(episodes):
            state = env.reset()
            score = 0
            done = False
            
            while not done:
                # Select action with lower exploration
                if random.random() > epsilon:
                    with torch.no_grad():
                        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
                        action = model(state_tensor).max(1)[1].item()
                else:
                    action = random.randrange(3)
                
                # Take action
                next_state, reward, done = env.step(action)
                score += reward
                
                # Store transition
                memory.push(state, action, reward, next_state, done)
                state = next_state
                
                # Train if enough samples
                if len(memory) >= batch_size:
                    states, actions, rewards, next_states, dones = memory.sample(batch_size)
                    
                    states = torch.FloatTensor(states).to(device)
                    actions = torch.LongTensor(actions).to(device)
                    rewards = torch.FloatTensor(rewards).to(device)
                    next_states = torch.FloatTensor(next_states).to(device)
                    dones = torch.FloatTensor(dones).to(device)
                    
                    # Compute Q values
                    current_q = model(states).gather(1, actions.unsqueeze(1))
                    next_q = model(next_states).max(1)[0].detach()
                    target_q = rewards + gamma * next_q * (1 - dones)
                    
                    # Compute loss and update
                    loss = nn.MSELoss()(current_q.squeeze(), target_q)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
            
            scores.append(score)
            if episode % 10 == 0:
                print(f"Episode: {episode}, Score: {score}")
    
    except KeyboardInterrupt:
        print("\nTraining interrupted")
    finally:
        env.close()
    
    # Save the retrained model with a new name
    new_model_path = base_model_path.replace('.pth', '_retrained.pth')
    torch.save(model.state_dict(), new_model_path)
    return new_model_path, scores


In [5]:
def play_vs_ai(model_path):
    # Load the AI model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ai_model = DQN(5, 3).to(device)
    ai_model.load_state_dict(torch.load(model_path))
    ai_model.eval()
    
    # Create game environment
    env = HumanVsAIPong()
    
    print("Controls:")
    print("W - Move paddle up")
    print("S - Move paddle down")
    print("Close window to quit")
    
    try:
        running = True
        while running:
            running = env.update(ai_model)
            env.render()
    except (KeyboardInterrupt, SystemExit):
        print("\nGame ended")
    finally:
        env.close()

In [7]:
play_vs_ai('pong_model.pth')

  ai_model.load_state_dict(torch.load(model_path))


Controls:
W - Move paddle up
S - Move paddle down
Close window to quit
