In [None]:
# Imports
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import random
from math import cos, sin, radians
from collections import deque
import itertools

In [None]:
# Params
GAMMA = 0.95
BATCH_SIZE = 64
LEARNING_RATE = 0.00025
MAX_MEMORY = 200_000
MIN_REPLAY_SIZE = 100_000
EPSILON_START = 1.0
EPSILON_END = 0.02
EPSILON_DECAY = 80000
TARGET_UPDATE_FREQ = 10000

N_FRAMES = 2
N_FRAMES_HAND = 4

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

In [None]:
# Pygame
class Game:
    def __init__(self, play, train):
        self.play = play
        self.train = train
        if self.play:
            pygame.init()
            pygame.display.set_caption('Pong')
            pygame.display.set_icon(pygame.image.load('assets/icon.png'))
            self.myFont = pygame.font.SysFont('arial', 30)
            self.screen = pygame.display.set_mode((640, 480))
            self.clock = pygame.time.Clock()
        else:
            self.screen = None

        self.player1 = Paddle(self.screen, 5, [-45, -30, -15, -10, 10, 15, 30, 45])
        self.player2 = Paddle(self.screen, 625, [225, 210, 195, 190, 170, 165, 150, 135]) #  [-135, -150, -165, 170, 190, 165, 150, 135]
        self.ball = Ball(self.screen)

    def update(self):
        self.player1.update()
        self.player2.update()
        reward, done = self.ball.update(self.player1, self.player2, self.train)
        return reward, done

    def render(self):
        self.screen.fill((0, 0, 0))
        self.player1.render()
        self.player2.render()
        self.ball.render()
        textSurface = self.myFont.render(f'AI {self.player1.score}:{self.player2.score} YOU', False, (255, 255, 255))
        self.screen.blit(textSurface, (250, 20))
        pygame.display.flip()

    def getState(self):
        state = [
                round(self.player1.y / (480 - self.player1.height), 2),
                round(self.player2.y / (480 - self.player2.height), 2),
                round(self.ball.y / 480, 2),
                round(self.ball.x / 640, 2),
                round(self.ball.angle / 255, 2),
        ]
        return np.array(state, dtype=np.float)

    def run(self):
        reward, done = self.update()
        if self.play:
            self.render()
            # self.clock.tick(60)
        return reward, done

class Ball:
    def __init__(self, screen):
        self.screen = screen
        self.frame = 0
        self.x = 320
        self.y = 240
        self.angle = random.choice([-45, -30, -15, -10, 10, 15, 30, 45]) + 180 * random.randint(0, 1)
        self.speed = 8
        self.radius = 6

    def update(self, player1, player2, train):
        reward = [0, 0]
        done = [False, False]
        if train:
            self.frame += 1
        
        # Check if ball hits the top or bottom
        if self.y + self.radius > 480 or self.y - self.radius < 0:
            if self.angle <= 45:
                self.angle = -self.angle
            else:
                self.angle = 360 - self.angle

        # left collide
        if self.x - self.radius >= player1.x and self.x - self.radius <= player1.x + player1.width:
            if self.y - player1.y >= -self.radius:
                for i in range(len(player1.angles)):
                    if self.y - player1.y <= (i+1)/len(player1.angles) * (player1.height + self.radius):
                        self.angle = player1.angles[i]
                        break
                reward = [2, 0]

        # right collide
        elif self.x + self.radius >= player2.x and self.x + self.radius <= player2.x + player2.width:
            if self.y - player2.y >= -self.radius:
                for i in range(len(player2.angles)):
                    if self.y - player2.y <= (i+1)/len(player2.angles) * (player2.height + self.radius):
                        self.angle = player2.angles[i]
                        break
                reward = [0, 2]

        self.y += self.speed*sin(radians(self.angle))
        self.x += self.speed*cos(radians(self.angle))

        # Check if the Ball went right
        if self.x - self.radius >= 670:
            player1.score += 1
            reward = [10, -10]
            if player1.score % 5 == 0:
                done = [False, True]
            self.x = player2.x - player2.width * 2 - self.radius
            self.y = 240
            self.angle = random.choice(player2.angles[2:-2])
            self.frame = 0
        
        # Check if the Ball went left
        if self.x + self.radius <= -30:
            player2.score += 1
            reward = [-10, 10]
            if player2.score % 5 == 0:
                done = [True, False]
            self.x = player1.x + player1.width * 2 + self.radius
            self.y = 240
            self.angle = random.choice(player1.angles[2:-2])
            self.frame = 0

        if self.frame > 1000:
          reward = [-10, 10]
          if random.randint(0, 1) == 1:
            self.x = player1.x + player1.width * 2 + self.radius
            self.y = 240
            self.angle = random.choice(player1.angles[2:-2])
          else:
            self.x = player2.x - player2.width * 2 - self.radius
            self.y = 240
            self.angle = random.choice(player2.angles[2:-2])
          self.frame = 0

        return reward, done

    def render(self):
        pygame.draw.circle(self.screen, (255, 255, 255), (self.x, self.y), self.radius)

class Paddle:
    def __init__(self, screen, x, angles):
        self.angles = angles
        self.screen = screen
        self.x = x
        self.speed = 4
        self.width = 10
        self.height = 80
        self.y = 240 - self.height / 2
        self.score = 0
        self.mode = 0

    def update(self):
        self.y += self.mode * self.speed
        self.y = max(0, min(self.y, 480 - self.height))

    def render(self):
        pygame.draw.rect(self.screen, (255, 255, 255), (self.x, self.y, self.width, self.height))

In [None]:
# DQN
class Linear_QNet(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.linear1 = nn.Linear(input_size, 128)
        self.linear2 = nn.Linear(128, 256)
        self.linear3 = nn.Linear(256, 1024)
        self.linear4 = nn.Linear(1024, 256)
        self.linear5 = nn.Linear(256, output_size)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = F.relu(self.linear3(x))
        x = F.relu(self.linear4(x))
        x = self.linear5(x)
        return x

In [None]:
class Agent:
    def __init__(self):
        self.online_net = Linear_QNet(5, 2).to(DEVICE)
        self.target_net = Linear_QNet(5, 2).to(DEVICE)
        self.target_net.load_state_dict(self.online_net.state_dict())
        self.memory = deque(maxlen=MAX_MEMORY)
        self.optimizer = optim.Adam(self.online_net.parameters(), lr=LEARNING_RATE)
        self.record = -50
        self.score = 0

    def get_action(self, state, step):
        action = None
        if isinstance(step, int):
            epsilon = np.interp(step, [MIN_REPLAY_SIZE, EPSILON_DECAY + MIN_REPLAY_SIZE], [EPSILON_START, EPSILON_END])
            if random.random() <= epsilon:
                action = random.randint(0, 1)
        if not action:
            with torch.no_grad():
                state_t = torch.as_tensor(state, dtype=torch.float32, device=DEVICE)
                q_values = self.online_net(state_t.unsqueeze(0))
                max_q_index = torch.argmax(q_values, dim=1)[0]
                action = max_q_index.detach().item()
        return action

    def train(self, experience, step, DDQN):
        self.memory.append(experience)
        if len(self.memory) > MIN_REPLAY_SIZE:
            self.score += experience[2] # reward

            sample_experiences = random.sample(self.memory, BATCH_SIZE)

            state_olds = np.asarray([t[0] for t in sample_experiences])
            actions = np.asarray([t[1] for t in sample_experiences])
            rewards = np.asarray([t[2] for t in sample_experiences])
            dones = np.asarray([t[3] for t in sample_experiences])
            state_news = np.asarray([t[4] for t in sample_experiences])

            state_olds_t = torch.as_tensor(state_olds, dtype=torch.float32, device=DEVICE)
            actions_t = torch.as_tensor(actions, dtype=torch.int64, device=DEVICE).unsqueeze(-1)
            rewards_t = torch.as_tensor(rewards, dtype=torch.float32, device=DEVICE).unsqueeze(-1)
            dones_t = torch.as_tensor(dones, dtype=torch.float32, device=DEVICE).unsqueeze(-1)
            state_news_t = torch.as_tensor(state_news, dtype=torch.float32, device=DEVICE)

            # get Outputs
            q_values = self.online_net(state_olds_t)
            q_values_action = q_values.gather(dim=1, index=actions_t)

            # get Targets
            if DDQN:
                q_values_online = self.online_net(state_news_t)
                q_values_online_max = q_values_online.argmax(dim=1, keepdim=True)
                q_values_target = self.target_net(state_news_t)
                q_values_target_selected = q_values_target.gather(dim=1, index=q_values_online_max)
                targets = rewards_t + GAMMA * (1 - dones_t) * q_values_target_selected
            else:
                q_values_target = self.target_net(state_news_t)
                q_values_target_max = q_values_target.max(dim=1, keepdim=True)[0]
                targets = rewards_t + GAMMA * (1 - dones_t) * q_values_target_max

            # Compute Loss
            loss = nn.functional.smooth_l1_loss(q_values_action, targets)

            # Gradient Descent
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # Update Target Net
            if max(-1, step - MIN_REPLAY_SIZE) % TARGET_UPDATE_FREQ == 0:
                print(step)
                self.target_net.load_state_dict(self.online_net.state_dict())

            # check for record and save model if done
            if experience[3]: # done
                if self.score > self.record:
                    self.record = self.score
                    print('Record:', self.record, 'Step:', step)
                    self.save()
                self.score = 0

    def load(self, name):
        # checkpoint = torch.load(f'model/{name}.pth', map_location=DEVICE)
        # self.online_net.load_state_dict(checkpoint['online_net'])
        # self.target_net.load_state_dict(checkpoint['target_net'])
        # self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.online_net.load_state_dict(torch.load(f'model/{name}.pth', map_location=DEVICE))

    def save(self, file_name='model.pth'):
        model_folder_path = './model'
        if not os.path.exists(model_folder_path):
            os.makedirs(model_folder_path)
        file_name = os.path.join(model_folder_path, file_name)
        # save_state = {
        #     'online_net': self.online_net.state_dict(),
        #     'target_net': self.target_net.state_dict(),
        #     'optimizer': self.optimizer.state_dict(),
        # }
        # torch.save(save_state, file_name)
        torch.save(self.online_net.state_dict(), file_name)

In [None]:
def bot_action(game, player):
    if player == 1:
        if game.ball.y < game.player1.y + game.player1.height/2:
            game.player1.mode = -1
        elif game.ball.y > game.player1.y - game.player1.height/2:
            game.player1.mode = 1
    else:
        if game.ball.y < game.player2.y + game.player2.height/2:
            game.player2.mode = -1
        elif game.ball.y > game.player2.y - game.player2.height/2:
            game.player2.mode = 1

def main(args):
    game = Game(args['see'], args['train'])
    player1 = Agent()
    # player2 = Agent()
    if args['load']:
        player1.load('player1')
        # player2.load('player2')
    run = True
    for frame in itertools.count():
        if not run:
            break
        if args['bot']:
            bot_action(game, 2)
        if frame % N_FRAMES == 0:
            state_old = game.getState()
            action1 = player1.get_action(state_old, frame/N_FRAMES if args['train'] else 'testing')
            if action1 == 0:
                game.player1.mode = -1
            else:
                game.player1.mode = 1
            # action2 = player2.get_action(state_old, frame/N_FRAMES if args['train'] else False)
            # if action2 == 0:
            #     game.player2.mode = -1
            # else:
            #     game.player2.mode = 1
            reward, done = game.run()
            if args['train']:
                state_new = game.getState()
                player1.train((state_old, action1, reward[0], done[0], state_new), frame/N_FRAMES, True)
                # player2.train((state_old, action2, reward[1], done[1], state_new), frame/N_FRAMES, False)

In [None]:
main({'see': False, 'human': False, 'bot': True, 'selfPlay': False, 'load': False, 'train': True})