In [1]:
import pygame
from pygame.locals import K_w, K_UP, K_s, K_DOWN, QUIT, K_ESCAPE, KEYDOWN, KEYUP
import numpy as np
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
import shutil
from PIL import Image

pygame 2.6.0 (SDL 2.28.4, Python 3.11.9)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [None]:
class pong_ai_game:
    def __init__(self, width_screen = 80, height_screen = 60):
        self.width_screen = width_screen
        self.height_screen = height_screen
        self.score = 0
        self.reward = [0.0,0.0]
        self.screen_color = (35, 35, 35)
        self.object_color = (251, 248, 243)
        self.game_screen = pygame.display.set_mode((self.width_screen, self.height_screen))
        self.clock = pygame.time.Clock()
        self.rect_speed = 2
        self.render_game = False
        self.reset()

    def reset(self):
        self.score = 0
        self.reward = [0.0,0.0]
        self.left_rect = pygame.Rect(0, 7*self.height_screen//16, 2, self.height_screen//8)
        self.right_rect = pygame.Rect(self.width_screen - 2, 7*self.height_screen//16, 2, self.height_screen//8)
        self.ball_rect = pygame.Rect(self.width_screen//2, self.height_screen//2, 2,2)
        self.ball_speed = [pow(-1, np.random.randint(0,2)), pow(-1, np.random.randint(0,2))]
        self.render_game = False


    def play_step(self, action):
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()

        paddle, move = action

        #Move Up
        if move == 0:
            if paddle == 1 and self.left_rect.top > 0:
                # Move rectangle 1 up by rectangle speed
                self.left_rect = self.left_rect.move(0, -self.rect_speed)
                if self.left_rect.top < 0:
                    self.left_rect.top = 0
            elif paddle == 2 and self.right_rect.top > 0:
                # Move rectangle 2 up by rectangle speed
                self.right_rect = self.right_rect.move(0, -self.rect_speed)
                if self.right_rect.top < 0:
                    self.right_rect.top = 0
        #Move Down
        elif move == 1:
            if paddle == 1 and self.left_rect.bottom < 600:
                # Move rectangle 1 down by rectangle speed
                self.left_rect = self.left_rect.move(0, self.rect_speed)
                if self.left_rect.bottom > self.height_screen:
                    self.left_rect.bottom = self.height_screen
            elif paddle == 2 and self.right_rect.bottom < 600:
                # Move rectangle 2 down by rectangle speed
                self.right_rect = self.right_rect.move(0, self.rect_speed)
                if self.right_rect.bottom > self.height_screen:
                    self.right_rect.bottom = self.height_screen

        self.game_screen.fill(self.screen_color)
        self.ball_rect = self.ball_rect.move(self.ball_speed[0],self.ball_speed[1])

        self.left_rectangle = pygame.draw.rect(self.game_screen, self.object_color,self.left_rect)
        self.right_rectangle = pygame.draw.rect(self.game_screen, self.object_color,self.right_rect)
        self.ball_rectangle = pygame.draw.rect(self.game_screen, self.object_color,self.ball_rect)

        if self.render_game:
            pygame.display.flip()
            self.clock.tick(60)

        self.ball_collision()

        if self.illegal_ball():
            return self.reward, True, self.score, self.get_state()

        return self.reward, False, self.score, self.get_state()

    def illegal_ball(self):
        if self.ball_rect.left <= 0 or self.ball_rect.right >= self.width_screen:
            if self.ball_rect.left <= 0: self.reward[0] = -10
            else: self.reward[1] = -10
            return True
    def ball_collision(self):

        if self.left_rect.right == self.ball_rect.left:
            if (self.left_rect.bottom >= self.ball_rect.centery >= self.left_rect.top):
                self.ball_speed[0] = 1
                self.score += 1
                self.reward[0] = 10
        elif self.right_rect.left == self.ball_rect.right :
            if (self.right_rect.bottom >= self.ball_rect.centery >= self.right_rect.top):
                self.ball_speed[0] = -1
                self.score += 1
                self.reward[1] = 10
        if self.ball_rect.top <= 0 or self.ball_rect.bottom >= self.height_screen:
            self.ball_speed[1] *= -1

        if (self.left_rect.bottom >= self.ball_rect.centery >= self.left_rect.top) and self.ball_speed[0] < 0:
            self.reward[0] = 1
        elif (self.right_rect.bottom >= self.ball_rect.centery >= self.right_rect.top) and self.ball_speed[0] > 0:
            self.reward[1] = 1
        else:
            if self.ball_speed[0] < 0:
                self.reward[0] = -0.1
            else:
                self.reward[1] = -0.1

    def get_state(self):
        screen = pygame.surfarray.array3d(self.game_screen)
        preprocess = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((40,30)),
            transforms.Grayscale(),
            transforms.ToTensor()])
        screen = preprocess(screen)
        return screen

In [None]:
class Linear_QNet(nn.Module):
    def __init__(self, lr, input_dims_channel, n_actions):
        super(Linear_QNet, self).__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(4,32,3,1)
        self.conv2 = nn.Conv2d(32,64,3,1)
        self.conv3 = nn.Conv2d(64,64,3,1)
        self.maxPool = nn.MaxPool2d(3, 1)

        # Flatten layer to transition from conv to linear layers
        self.flatten = nn.Flatten()

        # Calculate the flattened size after convolutional layers
        flattened_size = self._get_flattened_size(input_dims_channel)

        # Define linear layers
        self.linear1 = nn.Linear(flattened_size, 256)
        self.linear2 = nn.Linear(256, n_actions)

        # Initialize optimizer and loss function
        self.optimizer = optim.Adam(self.parameters(), lr=lr)
        self.criterion = nn.MSELoss()

        # Specify device
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.to(self.device)

    def _get_flattened_size(self, input_dims_channel):
        # Create a dummy input to determine the size after conv layers
        dummy_input = torch.zeros(1, 4, 40, 30)  # Assuming input image size is 32x32
        dummy_output = self.conv1(dummy_input)
        dummy_output = self.maxPool(self.relu(dummy_output))
        dummy_output = self.conv2(dummy_output)
        dummy_output = self.relu(dummy_output)
        dummy_output = self.conv3(dummy_output)
        dummy_output = self.relu(dummy_output)
        dummy_output = self.flatten(dummy_output)
        return dummy_output.numel()

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxPool(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

    def save(self, file_name = 'model.pth'):
        model_folder_path = './model'

        if not os.path.exists(model_folder_path):
            os.makedirs(model_folder_path)

        file_name = os.path.join(model_folder_path, file_name)
        torch.save(self.state_dict(), file_name)

In [None]:
class Agent():
    def __init__(self, lr, gamma, epsilon, batch_size, rectangle, action_space, max_mem=100_000,
                 eps_end=0.01, eps_dec =  0.9995):
        self.lr = lr
        self.gamma = gamma
        self.epsilon = epsilon
        self.batch_size = batch_size
        self.action_space = action_space
        self.max_mem = max_mem
        self.mem_counter = 0
        self.epsilon_end = eps_end
        self.epsilon_decay = eps_dec
        self.short_model = Linear_QNet(self.lr, 4, len(self.action_space))
        self.short_model_weight_counter = 0
        self.long_model = Linear_QNet(self.lr, 4, len(self.action_space))
        self.long_model_weight_counter = 500
        self.state_memory = torch.rand((self.max_mem, 4,40,30), dtype=torch.float32)
        self.action_memory = torch.zeros((self.max_mem), dtype=torch.int64)
        self.reward_memory = torch.zeros((self.max_mem), dtype=torch.float32)
        self.new_state_memory = torch.rand((self.max_mem, 4,40,30), dtype=torch.float32)
        self.terminal_memory = torch.zeros((self.max_mem), dtype=torch.int64)
        self.show = False
        pass


    def store_transition(self, state, action,reward, new_state, done):
        index = self.mem_counter % self.max_mem
        self.state_memory[index] = state
        self.action_memory[index] = action
        self.reward_memory[index] = reward
        self.new_state_memory[index] = new_state
        self.terminal_memory[index] = done
        self.mem_counter += 1
        self.mem_counter %= self.max_mem

    def get_action(self, observation):
        if np.random.random() > self.epsilon:
            state = observation.unsqueeze(0).to(self.short_model.device)
            actions = self.short_model.forward(state).to(self.short_model.device)
            action = torch.argmax(actions).item()
        else:
            action = np.random.choice(self.action_space)
        self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)
        return action

    def get_samples(self):
        batch_count = min(self.max_mem, self.mem_counter)
        batch_indices = np.random.choice(batch_count, self.batch_size, replace=False)

        states = self.state_memory[batch_indices].clone().detach().to(self.short_model.device)
        actions = self.action_memory[batch_indices].clone().detach().to(self.short_model.device)
        rewards = self.reward_memory[batch_indices].clone().detach().to(self.short_model.device)
        next_states = self.new_state_memory[batch_indices].clone().detach().to(self.short_model.device)
        dones = self.terminal_memory[batch_indices].clone().detach().to(self.short_model.device)

        return states, actions, rewards, next_states, dones

    def learn(self):
        # Ensure sufficient memory
        if self.mem_counter < self.batch_size:
            return

        # Get samples from memory
        states, actions, rewards, next_states, dones = self.get_samples()

        # Move tensors to the appropriate device
        states = states.to(self.short_model.device)
        actions = actions.to(self.short_model.device)
        rewards = rewards.to(self.short_model.device)
        next_states = next_states.to(self.short_model.device)
        dones = dones.to(self.short_model.device)

        # Compute current Q values
        q_value = self.short_model(states).gather(1, actions.unsqueeze(0)).squeeze(0)

        # Compute next Q values
        q_next = self.long_model(next_states).max(1)[0].to(self.short_model.device).detach()

        # Compute target Q values
        q_target = rewards + self.gamma * q_next * (1 - dones.int())

        # Compute loss
        loss = self.short_model.criterion(q_value, q_target).to(self.short_model.device)

        # Zero the gradients
        self.short_model.optimizer.zero_grad()

        # Backpropagate the loss
        loss.backward()

        # Update model parameters
        self.short_model.optimizer.step()








In [5]:
from collections import deque

if __name__ == '__main__':
    game = pong_ai_game()
    state_layers = 1
    gamma = 0.9995
    epsilon = 1
    lr = 0.001
    batch_size = 128
    action_space =[0,1,2]
    training = True
    agent1 = Agent(lr, gamma, epsilon, batch_size, 0, action_space)
    agent2 = Agent(lr, gamma, epsilon, batch_size, 1, action_space)
    framestack = deque(maxlen = 4)
    print(agent1.short_model.device)
    if training:

        scores, eps_hist = [],[]
        n_games = 4000

        for i in range(n_games):
            score = 0
            total_reward1, total_reward2 = 0,0
            score1, score2 = 0,0
            done1, done2 = False, False
            game.reset()
            print("run: ",i)
            if i % 250 == 0 and (i != 0 and i != 250):
                agent1.epsilon = 0.5
                agent2.epsilon = 0.5
                agent1.short_model.save('agent1.pth')
                agent2.short_model.save('agent2.pth')
            observation1 = observation2 = game.get_state()
            if i % 20 == 0:
                agent1.long_model.load_state_dict(agent1.short_model.state_dict())
                agent2.long_model.load_state_dict(agent2.short_model.state_dict())

            while not (done1 or done2):
                if len(framestack) == 0:
                      framestack.append(observation1)
                      framestack.append(observation1)
                      framestack.append(observation1)
                      framestack.append(observation1)

                observation1 = torch.stack(list(framestack)).squeeze(1)
                observation2 = torch.stack(list(framestack)).squeeze(1)

                if game.ball_speed[0] < 0:
                    action1 = agent1.get_action(observation1)


                    reward1, done1, score1, new_state1 = game.play_step((1, action1))
                    framestack.append(new_state1)

                    new_state1 = torch.stack(list(framestack)).squeeze(1)

                    score = max(score1, score2)
                    agent1.store_transition(observation1, action1, reward1[0], new_state1, done1)
                    agent1.learn()


                    observation1 = new_state1


                else:
                    action2 = agent2.get_action(observation2)


                    reward2, done2, score2, new_state2 = game.play_step((2, action2))
                    framestack.append(new_state2)
                    new_state2 = torch.stack(list(framestack)).squeeze(1)

                    score = max(score1, score2)
                    agent2.store_transition(observation2, action2, reward2[1], new_state2, done2)
                    agent2.learn()


                    observation2 = new_state2

            # if i%20 == 0:
            #     to_pil = transforms.ToPILImage()
            #     for batch in range(min(agent1.state_memory.shape[0], agent1.mem_counter)):
            #         img_tensor = agent1.state_memory[batch, 1, :, :].squeeze(0)

            #         # Convert to PIL Image
            #         img = to_pil(img_tensor)
                    
            #         # Save image
            #         img.save(os.path.join("images", f'image_{i}-{batch}.png'))

            scores.append(score)
            eps_hist.append(agent1.epsilon)
            avg_scores = np.mean(scores[-100:])
            print(f"avg: {avg_scores:0.2f}, eps, {agent1.epsilon:0.3f}, {agent2.epsilon:0.3f}, {score}")


cuda:0
run:  0
avg: 0.00, eps, 0.980, 1.000, 0
run:  1
avg: 0.00, eps, 0.980, 0.981, 0
run:  2
avg: 0.00, eps, 0.961, 0.981, 0
run:  3
avg: 0.00, eps, 0.961, 0.963, 0
run:  4
avg: 0.00, eps, 0.942, 0.963, 0
run:  5
avg: 0.00, eps, 0.942, 0.945, 0
run:  6
avg: 0.00, eps, 0.942, 0.927, 0
run:  7
avg: 0.00, eps, 0.923, 0.927, 0
run:  8
avg: 0.00, eps, 0.923, 0.909, 0
run:  9
avg: 0.00, eps, 0.923, 0.892, 0
run:  10
avg: 0.00, eps, 0.905, 0.892, 0
run:  11
avg: 0.00, eps, 0.887, 0.892, 0
run:  12
avg: 0.00, eps, 0.887, 0.875, 0
run:  13
avg: 0.00, eps, 0.869, 0.875, 0
run:  14
avg: 0.00, eps, 0.869, 0.859, 0
run:  15
avg: 0.00, eps, 0.852, 0.859, 0
run:  16
avg: 0.00, eps, 0.835, 0.859, 0
run:  17
avg: 0.00, eps, 0.819, 0.859, 0
run:  18
avg: 0.00, eps, 0.802, 0.859, 0
run:  19
avg: 0.00, eps, 0.787, 0.859, 0
run:  20
avg: 0.00, eps, 0.771, 0.859, 0
run:  21
avg: 0.00, eps, 0.756, 0.859, 0
run:  22
avg: 0.00, eps, 0.756, 0.843, 0
run:  23
avg: 0.00, eps, 0.756, 0.827, 0
run:  24
avg: 0.00,

KeyboardInterrupt: 

In [None]:
array = torch.rand((100, 4, 40,30), dtype = torch.float32)
print(array.shape[1])

4


In [None]:
import torch

model = torch.load(r'model/agent1.pth')
print(model)

  model = torch.load(r'model/agent1.pth')


OrderedDict([('conv1.weight', tensor([[[[ 6.6953e-02,  1.6003e-01,  1.0531e-02],
          [ 2.8959e-02,  6.2716e-02,  4.1236e-02],
          [ 8.9511e-02,  1.3526e-01,  2.7370e-02]],

         [[-1.6957e-01,  4.2939e-02, -9.2830e-02],
          [-5.6206e-02, -1.2426e-01,  1.2467e-01],
          [-3.7077e-03,  1.0677e-01, -8.6597e-02]],

         [[ 1.5719e-01,  1.5460e-01,  7.8460e-02],
          [ 5.0477e-02,  9.5723e-02, -2.4567e-02],
          [ 1.4559e-01, -2.1400e-02, -1.4344e-01]],

         [[-4.3955e-02,  1.4602e-01,  9.3654e-03],
          [ 1.1913e-01, -6.1824e-02, -5.0796e-02],
          [ 9.5876e-02, -1.3421e-01, -3.2174e-02]]],


        [[[-1.3485e-01, -8.0284e-02, -1.3901e-01],
          [-1.6968e-01,  6.2023e-02, -2.2870e-02],
          [ 1.8313e-02, -1.2617e-01, -8.2946e-02]],

         [[-1.0370e-01, -6.9788e-02, -1.2310e-01],
          [-4.7219e-03, -1.9452e-02,  4.8692e-02],
          [-1.4473e-01, -5.8029e-02, -9.8315e-02]],

         [[ 1.5708e-01,  3.4339e-02, -

['agent1.pth', 'agent2.pth']