In [17]:
import pygame
from pygame.locals import K_w, K_UP, K_s, K_DOWN, QUIT, K_ESCAPE, KEYDOWN, KEYUP
import numpy as np
import torchvision.transforms as transforms
import torchvision
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
import shutil
from PIL import Image
from datetime import datetime


In [18]:
class pong_ai_game:
    def __init__(self, width_screen = 80, height_screen = 60):
        self.width_screen = width_screen
        self.height_screen = height_screen
        self.score = 0
        self.rect_ball_distance = [0.0,0.0]
        self.screen_color = (35, 35, 35)
        self.object_color = (251, 248, 243)
        self.left_rect_color = (0,255,0)
        self.right_rect_color = (0,255,0)
        self.game_screen = pygame.display.set_mode((self.width_screen, self.height_screen))
        self.clock = pygame.time.Clock()
        self.rect_speed = 3
        self.render_game = False
        self.reset()

    def reset(self):
        self.score = 0
        self.rect_ball_distance = [0.0,0.0]
        self.left_rect = pygame.Rect(0, 7*self.height_screen//16, 2, self.height_screen//8)
        self.right_rect = pygame.Rect(self.width_screen - 2, 7*self.height_screen//16, 2, self.height_screen//8)
        self.ball_rect = pygame.Rect(self.width_screen//2, self.height_screen//2, 2,2)
        self.ball_speed = [pow(-1, np.random.randint(0,2)), pow(-1, np.random.randint(0,2))]
        self.render_game = False


    def play_step(self, action):
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()

        paddle, move = action

        #Move Up
        if move == 0:
            if paddle == 1 and self.left_rect.top > 0:
                # Move rectangle 1 up by rectangle speed
                self.left_rect = self.left_rect.move(0, -self.rect_speed)
                if self.left_rect.top < 0:
                    self.left_rect.top = 0
            elif paddle == 2 and self.right_rect.top > 0:
                # Move rectangle 2 up by rectangle speed
                self.right_rect = self.right_rect.move(0, -self.rect_speed)
                if self.right_rect.top < 0:
                    self.right_rect.top = 0
        #Move Down
        elif move == 1:
            if paddle == 1 and self.left_rect.bottom < 600:
                # Move rectangle 1 down by rectangle speed
                self.left_rect = self.left_rect.move(0, self.rect_speed)
                if self.left_rect.bottom > self.height_screen:
                    self.left_rect.bottom = self.height_screen
            elif paddle == 2 and self.right_rect.bottom < 600:
                # Move rectangle 2 down by rectangle speed
                self.right_rect = self.right_rect.move(0, self.rect_speed)
                if self.right_rect.bottom > self.height_screen:
                    self.right_rect.bottom = self.height_screen
        
        #Get vertical distance between ball and rect
        if self.ball_speed[0] < 0:
            self.rect_ball_distance[0] = self.ball_rect.centery - self.left_rect.centery
        elif self.ball_speed[0] > 0:
            self.rect_ball_distance[1] = self.ball_rect.centery - self.right_rect.centery

        self.game_screen.fill(self.screen_color)
        self.ball_rect = self.ball_rect.move(self.ball_speed[0],self.ball_speed[1])

        self.left_rectangle = pygame.draw.rect(self.game_screen, self.left_rect_color,self.left_rect)
        self.right_rectangle = pygame.draw.rect(self.game_screen, self.right_rect_color,self.right_rect)
        self.ball_rectangle = pygame.draw.rect(self.game_screen, self.object_color,self.ball_rect)

        if self.render_game:
            pygame.display.flip()
            self.clock.tick(60)

        self.ball_collision()

        if self.illegal_ball():
            return self.rect_ball_distance, True, self.score, self.get_state()

        return self.rect_ball_distance, False, self.score, self.get_state()

    def illegal_ball(self):
        if self.ball_rect.left <= 0 or self.ball_rect.right >= self.width_screen:
            return True
    def ball_collision(self):
        if self.left_rect.right == self.ball_rect.left:
            if (self.left_rect.bottom >= self.ball_rect.centery >= self.left_rect.top):
                self.ball_speed[0] = 1
                self.score += 1
        elif self.right_rect.left == self.ball_rect.right :
            if (self.right_rect.bottom >= self.ball_rect.centery >= self.right_rect.top):
                self.ball_speed[0] = -1
                self.score += 1
        if self.ball_rect.top <= 0 or self.ball_rect.bottom >= self.height_screen:
            self.ball_speed[1] *= -1



    def get_reward(self, agent, action, agent_obj):
        # agent 1 is the left paddle, agent 2 is the right paddle
        paddle_rect = self.left_rect if agent == 1 else self.right_rect
        ball_y = self.ball_rect.centery

        hit_ball_reward = 0
        if agent == 1 and self.ball_rect.left <= self.left_rect.right and \
        self.left_rect.top <= self.ball_rect.centery <= self.left_rect.bottom:
            hit_ball_reward = 2*self.score
        elif agent == 2 and self.ball_rect.right >= self.right_rect.left and \
            self.right_rect.top <= self.ball_rect.centery <= self.right_rect.bottom:
            hit_ball_reward = 2*self.score

        # Penalize the agent for missing the ball (ball out of bounds)
        miss_penalty = -5 if self.illegal_ball() else 0

        # Action-specific reward: reward for appropriate paddle movement
        action_reward = 0
        wrong_move = False  # New flag

        if action == 0:  # Move up
            if ball_y < paddle_rect.centery:
                agent_obj.right_action_count += 1
                reward_multiplier = 1.15 ** agent_obj.right_action_count
                action_reward = 0.2 * reward_multiplier
                agent_obj.wrong_action_count = 0  # Reset on good action
            else:
                wrong_move = True

        elif action == 1:  # Move down
            if ball_y > paddle_rect.centery:
                agent_obj.right_action_count += 1
                reward_multiplier = 1.15 ** agent_obj.right_action_count
                action_reward = 0.2 * reward_multiplier
                agent_obj.wrong_action_count = 0
            else:
                wrong_move = True

        # Apply exponential penalty if wrong action
        if wrong_move:
            agent_obj.wrong_action_count += 1
            penalty_multiplier = 1.2 ** agent_obj.wrong_action_count  # decay factor
            action_reward = -0.1 * penalty_multiplier
            agent_obj.right_action_count = 0

        # Combine all the rewards and penalties
        total_reward = hit_ball_reward + miss_penalty + action_reward
        return total_reward



    def get_state(self):
        screen = pygame.surfarray.array3d(self.game_screen)
        preprocess = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224,224)),
            transforms.ToTensor()])
        screen = preprocess(screen)
        return screen

In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torch

# Check if GPU is available, otherwise fall back to CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training on {device}")


class MobileAgent:
    def __init__(self, lr, rectangle, action_space):
        self.action_space = action_space
        self.model = torchvision.models.mobilenet_v2(weights=torchvision.models.MobileNet_V2_Weights.DEFAULT)
        self.model.classifier[1] = nn.Linear(1280, self.action_space)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.loss_fn = nn.CrossEntropyLoss()
        self.rectangle = rectangle  # 1 = left, 2 = right
        self.model.to(device)
        self.wrong_action_count = 0
        self.right_action_count = 0


    def act(self, state):
        with torch.no_grad():
            state = state.unsqueeze(0)  # Add batch dimension
            logits = self.model(state)
            action = torch.argmax(logits).item()
        return action
    def train_step(self, state, action, reward):
        self.model.train()
        state = state.unsqueeze(0).to(device)
        action = torch.tensor([action]).to(device)

        logits = self.model(state)
        log_probs = torch.log_softmax(logits, dim=1)
        selected_log_prob = log_probs[0, action.item()]  # select log prob of the taken action

        # Policy gradient loss: maximize reward -> minimize -log_prob * reward
        loss = -selected_log_prob * reward

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

Training on cuda


In [None]:
import pygame
import numpy as np
import torchvision.transforms as transforms

# Initialize game
pygame.init()
game = pong_ai_game(120,90)
left_agent = MobileAgent(lr=1e-4, rectangle=1, action_space=2)
right_agent = MobileAgent(lr=1e-4, rectangle=2, action_space=2)

# left_agent.model.load_state_dict(torch.load(r"model\2025_06_19_00_55_28_model.pth"))
# Optionally share initial weights
right_agent.model.load_state_dict(left_agent.model.state_dict())

n_episodes = 5000
for episode in range(n_episodes):
    game.reset()
    done = False
    training = False
    total_reward = [0, 0]  # For left and right agents
    left_agent_loss = 0.0
    right_agent_loss = 0.0
    if episode%50 == 0:
        game.render_game = True

    while not done:
        state = game.get_state().to(device)
        if game.score >=10 and training:
            print("EVALUATING")
            training = False
            left_agent.model.eval()
            right_agent.model.eval()
        left_action = left_agent.act(state)
        right_action = right_agent.act(state)

        # Send actions to game: (paddle_id, move)
        game.play_step((1, left_action))  # Left paddle
        distances, done, score, state_ = game.play_step((2, right_action))  # Right paddle


        reward_left = game.get_reward(1, left_action, left_agent)
        reward_right = game.get_reward(2, right_action, right_agent)
        total_reward[0] += reward_left
        total_reward[1] += reward_right

        # Train both agents
        if training:
            left_loss = left_agent.train_step(state, left_action, reward_left)
            right_loss = right_agent.train_step(state, right_action, reward_right)

            # Track and accumulate losses
            left_agent_loss += left_loss
            right_agent_loss += right_loss

        if left_agent.wrong_action_count > 0:
            game.left_rect_color = (255,0,0)
        else:
            game.left_rect_color = (0,255,0)
        if right_agent.wrong_action_count > 0:
            game.right_rect_color = (255,0,0)
        else:
            game.right_rect_color = (0,255,0)


    if episode % 50 == 0 and episode != 0:
        rn = datetime.today().strftime('%Y_%m_%d_%H_%M_%S')
        # Use os.path.join to create a valid file path
        file_path = os.path.join("model", f"{rn}_model.pth")
        
        # Save the model's state_dict (weights) at that point
        torch.save(left_agent.model.state_dict(), file_path)
        right_agent.model.load_state_dict(left_agent.model.state_dict())
        print(f"Model saved at {file_path}")
    game.render_game = False
    # Print average loss per episode for both agents
    print(f"Episode {episode}, Score: {score}, Left Agent Loss: {left_agent_loss:.4f}, Right Agent Loss: {right_agent_loss:.4f},  Left Reward: {total_reward[0]:.4f},  Right Reward: {total_reward[1]:.4f}")

  left_agent.model.load_state_dict(torch.load(r"model\2025_06_19_00_55_28_model.pth"))


Episode 0, Score: 9, Left Agent Loss: 0.0000, Right Agent Loss: 0.0000,  Left Reward: 35.5439,  Right Reward: 47.9336
Episode 1, Score: 4, Left Agent Loss: 0.0000, Right Agent Loss: 0.0000,  Left Reward: 12.1478,  Right Reward: -4.2997
Episode 2, Score: 13, Left Agent Loss: 0.0000, Right Agent Loss: 0.0000,  Left Reward: 40.3436,  Right Reward: 72.3049
Episode 3, Score: 9, Left Agent Loss: 0.0000, Right Agent Loss: 0.0000,  Left Reward: -4.4986,  Right Reward: 1.2346
Episode 4, Score: 9, Left Agent Loss: 0.0000, Right Agent Loss: 0.0000,  Left Reward: 7.4033,  Right Reward: 19.7211
Episode 5, Score: 9, Left Agent Loss: 0.0000, Right Agent Loss: 0.0000,  Left Reward: 40.7310,  Right Reward: 47.6513
Episode 6, Score: 2, Left Agent Loss: 0.0000, Right Agent Loss: 0.0000,  Left Reward: 0.9325,  Right Reward: -2.1062
Episode 7, Score: 1, Left Agent Loss: 0.0000, Right Agent Loss: 0.0000,  Left Reward: 1.2561,  Right Reward: -0.1011
Episode 8, Score: 3, Left Agent Loss: 0.0000, Right Agent L

KeyboardInterrupt: 