In [None]:
# Imports
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import random
from math import cos, sin, radians
from collections import deque
import itertools

In [None]:
# Params
N_FRAMES = 2
GAMMA = 0.95
BATCH_SIZE = 64
LEARNING_RATE = 0.00025
MAX_MEMORY = 200_000
MIN_REPLAY_SIZE = 100_000 * N_FRAMES
EPSILON_START = 1.0
EPSILON_END = 0.02
EPSILON_DECAY = 80000 * N_FRAMES
TARGET_UPDATE_FREQ = 10000 * N_FRAMES

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

In [None]:
class Game:
    def __init__(self):
        self.screen = None
        self.player1 = Paddle(self.screen, 5, [-45, -30, -15, -10, 10, 15, 30, 45])
        self.player2 = Paddle(self.screen, 625, [225, 210, 195, 190, 170, 165, 150, 135]) # 135 -> 225 | 180 + (180 - x)
        self.ball = Ball(self.screen)

    def update(self):
        self.player1.update()
        self.player2.update()
        reward, done = self.ball.update(self.player1, self.player2)
        return reward, done

    def getState(self, swap):
        if swap:
            if self.ball.angle > 45:
                swapped_ball = self.player1.angles[self.player2.angles.index(self.ball.angle)]
            else:
                swapped_ball = self.player2.angles[self.player1.angles.index(self.ball.angle)]
            state = [
                round(self.player2.y / (480 - self.player2.height), 2),
                round(self.player1.y / (480 - self.player1.height), 2),
                round(self.ball.y / 480, 2),
                round((640 - self.ball.x) / 640, 2),
                round(swapped_ball / 255, 2),
            ]
            return np.array(state, dtype=np.float)

        state = [
                round(self.player1.y / (480 - self.player1.height), 2),
                round(self.player2.y / (480 - self.player2.height), 2),
                round(self.ball.y / 480, 2),
                round(self.ball.x / 640, 2),
                round(self.ball.angle / 255, 2),
        ]
        return np.array(state, dtype=np.float)

    def run(self):
        reward, done = self.update()
        return reward, done

class Ball:
    def __init__(self, screen):
        self.screen = screen
        self.frame = 0
        self.x = 320
        self.y = 240
        self.angle = random.choice([-45, -30, -15, -10, 10, 15, 30, 45]) + 180 * random.randint(0, 1)
        self.speed = 8
        self.radius = 6

    def update(self, player1, player2):
        reward = 0
        done = False
        self.frame += 1
        
        # Check if ball hits the top or bottom
        if self.y + self.radius > 480 or self.y - self.radius < 0:
            if self.angle <= 45:
                self.angle = -self.angle
            else:
                self.angle = 360 - self.angle

        # left collide
        if self.x - self.radius >= player1.x and self.x - self.radius <= player1.x + player1.width:
            if self.y - player1.y >= -self.radius:
                if self.y - player1.y <= 1/8 * (player1.height + self.radius):
                    self.angle = player1.angles[0]

                elif self.y - player1.y <= 2/8 * (player1.height + self.radius):
                    self.angle = player1.angles[1]

                elif self.y - player1.y <= 3/8 * (player1.height + self.radius):
                    self.angle = player1.angles[2]

                elif self.y - player1.y <= 4/8 * (player1.height + self.radius):
                    self.angle = player1.angles[3]

                elif self.y - player1.y <= 5/8 * (player1.height + self.radius):
                    self.angle = player1.angles[4]

                elif self.y - player1.y <= 6/8 * (player1.height + self.radius):
                    self.angle = player1.angles[5]

                elif self.y - player1.y <= 7/8 * (player1.height + self.radius):
                    self.angle = player1.angles[6]

                elif self.y - player1.y <= 8/8 * (player1.height + self.radius):
                    self.angle = player1.angles[7]
                reward = 2

        # right collide
        elif self.x + self.radius >= player2.x and self.x + self.radius <= player2.x + player2.width:
            if self.y - player2.y >= -self.radius:
                if self.y - player2.y <= 1/8 * (player2.height + self.radius):
                    self.angle = player2.angles[0]

                elif self.y - player2.y <= 2/8 * (player2.height + self.radius):
                    self.angle = player2.angles[1]

                elif self.y - player2.y <= 3/8 * (player2.height + self.radius):
                    self.angle = player2.angles[2]

                elif self.y - player2.y <= 4/8 * (player2.height + self.radius):
                    self.angle = player2.angles[3]

                elif self.y - player2.y <= 5/8 * (player2.height + self.radius):
                    self.angle = player2.angles[4]

                elif self.y - player2.y <= 6/8 * (player2.height + self.radius):
                    self.angle = player2.angles[5]

                elif self.y - player2.y <= 7/8 * (player2.height + self.radius):
                    self.angle = player2.angles[6]

                elif self.y - player2.y <= 8/8 * (player2.height + self.radius):
                    self.angle = player2.angles[7]

        self.y += self.speed*sin(radians(self.angle))
        self.x += self.speed*cos(radians(self.angle))

        # Check if the Ball went right
        if self.x - self.radius >= 670:
            player1.score += 1
            reward = 10
            self.x = player2.x - player2.width * 2 - self.radius
            self.y = 240
            self.angle = random.choice(player2.angles[2:-2])
            self.frame = 0
        
        # Check if the Ball went left
        if self.x + self.radius <= -30 or self.frame > 1000:
            player2.score += 1
            reward = -10
            if player2.score % 5 == 0:
                done = True
            self.x = player1.x + player1.width * 2 + self.radius
            self.y = 240
            self.angle = random.choice(player1.angles[2:-2])
            self.frame = 0

        return reward, done

class Paddle:
    def __init__(self, screen, x, angles):
        self.angles = angles
        self.screen = screen
        self.x = x
        self.speed = 4
        self.width = 10
        self.height = 80
        self.y = 240 - self.height / 2
        self.score = 0
        self.mode = 0

    def update(self):
        self.y += self.mode * self.speed
        self.y = max(0, min(self.y, 480 - self.height))

In [None]:
# DQN
class Linear_QNet(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.linear1 = nn.Linear(input_size, 128)
        self.linear2 = nn.Linear(128, 256)
        self.linear3 = nn.Linear(256, 1024)
        self.linear4 = nn.Linear(1024, 256)
        self.linear5 = nn.Linear(256, output_size)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = F.relu(self.linear3(x))
        x = F.relu(self.linear4(x))
        x = self.linear5(x)
        return x

    def act(self, state):
        with torch.no_grad():
            state_t = torch.as_tensor(state, dtype=torch.float32, device=DEVICE)
            q_values = self(state_t.unsqueeze(0))
            max_q_index = torch.argmax(q_values, dim=1)[0]
            action = max_q_index.detach().item()
        return action

    def save(self, file_name='model.pth'):
        model_folder_path = './model'
        if not os.path.exists(model_folder_path):
            os.makedirs(model_folder_path)
        file_name = os.path.join(model_folder_path, file_name)
        torch.save(self.state_dict(), file_name)


In [None]:
def bot_action(game):
        if game.ball.y < game.player2.y + game.player2.height/2:
            game.player2.mode = -1
        elif game.ball.y > game.player2.y - game.player2.height/2:
            game.player2.mode = 1

def main(args):
    game = Game()
    online_net = Linear_QNet(5, 2).to(DEVICE)
    if args['load']:
        online_net.load_state_dict(torch.load('model/model.pth', map_location=DEVICE))
    run = True
    target_net = Linear_QNet(5, 2).to(DEVICE)
    target_net.load_state_dict(online_net.state_dict())
    memory = deque(maxlen=MAX_MEMORY)
    optimizer = optim.Adam(online_net.parameters(), lr=LEARNING_RATE)
    record = -50
    score = 0
    state_old = game.getState(False)
    for step in itertools.count():
        if args['bot']:
            bot_action(game)
        if step % N_FRAMES == 0: 
            if args['selfPlay']:
                swapped_state = game.getState(True)
                action = online_net.act(swapped_state)
                if action == 0:
                        game.player2.mode = -1
                else:
                    game.player2.mode = 1
            if step < MIN_REPLAY_SIZE: # Load with random actions at first
                action = random.randint(0, 1)
                if action == 0:
                    game.player1.mode = -1
                else:
                    game.player1.mode = 1
                reward, done = game.run()
                state_new = game.getState(False)
                memory.append((state_old, action, reward, done, state_new))
                state_old = state_new
            else:
                epsilon = np.interp(
                    step-MIN_REPLAY_SIZE, [0, EPSILON_DECAY], [EPSILON_START, EPSILON_END])
                if random.random() <= epsilon:
                    action = random.randint(0, 1)
                else:
                    action = online_net.act(state_old)
                if action == 0:
                    game.player1.mode = -1
                else:
                    game.player1.mode = 1
                reward, done = game.run()
                score += reward
                state_new = game.getState(False)
                memory.append((state_old, action, reward, done, state_new))
                state_old = state_new

                transitions = random.sample(memory, BATCH_SIZE)

                obses = np.asarray([t[0] for t in transitions])
                actions = np.asarray([t[1] for t in transitions])
                rews = np.asarray([t[2] for t in transitions])
                dones = np.asarray([t[3] for t in transitions])
                new_obses = np.asarray([t[4] for t in transitions])

                obses_t = torch.as_tensor(obses, dtype=torch.float32, device=DEVICE)
                actions_t = torch.as_tensor(actions, dtype=torch.int64, device=DEVICE).unsqueeze(-1)
                rews_t = torch.as_tensor(rews, dtype=torch.float32, device=DEVICE).unsqueeze(-1)
                dones_t = torch.as_tensor(dones, dtype=torch.float32, device=DEVICE).unsqueeze(-1)
                new_obses_t = torch.as_tensor(new_obses, dtype=torch.float32, device=DEVICE)

                # Compute Targets
                # targets = r + gamma * target q vals * (1 - dones)
                target_q_values = target_net(new_obses_t)
                max_target_q_values = target_q_values.max(
                    dim=1, keepdim=True)[0]
                targets = rews_t + GAMMA * \
                    (1 - dones_t) * max_target_q_values

                # Compute Loss
                q_values = online_net(obses_t)
                action_q_values = torch.gather(
                    input=q_values, dim=1, index=actions_t)
                loss = nn.functional.smooth_l1_loss(
                    action_q_values, targets)

                # Gradient Descent
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # Update Target Net
                if max(1, step-MIN_REPLAY_SIZE) % TARGET_UPDATE_FREQ == 0:
                    print(step, epsilon)
                    target_net.load_state_dict(online_net.state_dict())

                if done:
                    if score > record:
                        record = score
                        print('Record:', record, 'Step:', step)
                        online_net.save()
                    score = 0

        else:
            reward, done = game.run()

In [None]:
main({'see': False, 'human': False, 'bot': False, 'selfPlay': True, 'load': True, 'train': True})