<a href="https://colab.research.google.com/github/mrbenbot/wimblepong/blob/main/ai_training/WimblepongCustomTrainingEnv.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Wimblepong 2124

# Introduction

Welcome, servant of Sovereign. Thank you for your assistance in training bots for the upcoming human vs AI competition. The stakes are incredibly high; while freeing humanity from Sovereign's control is a noble cause, it also puts Earth at great risk. Our primary goal is to save Earth, and your efforts in training these bots are crucial to achieving this balance. In this Colab notebook, you will guide the creation, training, and evaluation of AI models for the high-stakes tournament of WimblePong. Additionally, you will learn how to export the trained model for deployment across various platforms. Let's get started!

### Installing Required Libraries

We first need to install several libraries that are essential for our environment and model training.


In [None]:
# you can skip this if you have installed requirements.txt locally..

%pip install gym
%pip install stable-baselines3[extra]
%pip install tensorflowjs
%pip install onnx2tf onnx==1.15.0 onnxruntime==1.17.1 tensorflow==2.16.1

# onnx2tf deps:
%pip install onnx_graphsurgeon
%pip install sng4onnx
%pip install onnxsim


### Mounting Google Drive

We will mount Google Drive to save and load our models and other necessary files.


In [None]:

DAY = "monday-8-july"

try:
    from google.colab import drive
    drive.mount('/content/drive')

    DRIVE_PATH = "/content/drive/MyDrive/wimblepong"
except:
    import os
    from pathlib import Path

    print(f"Couldn't get a colab drive.. going local")
    DRIVE_PATH = "./data"
    Path(f"{DRIVE_PATH}/{DAY}").mkdir(parents=True, exist_ok=True)


print(f"Working in {DRIVE_PATH}/{DAY}")

### Importing Necessary Libraries

Import all the necessary libraries required for creating the custom environment, training, and evaluation.


In [None]:
# Import necessary libraries
import gymnasium as gym
from gymnasium import spaces
from gymnasium.envs.registration import register

import stable_baselines3
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize, VecCheckNan
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.policies import BasePolicy


from IPython import display

import os
import imageio
import glob
import math
import random
import subprocess
from typing import Tuple


from IPython.display import display, Image, HTML

import tensorflow as tf
import torch as th
import torch.onnx
import numpy as np


import onnx
from onnx2tf import convert
import onnxruntime as ort

import matplotlib.pyplot as plt

from moviepy.editor import ImageSequenceClip
import pygame


# Check versions
print("gym version:", gym.__version__)
print("stable-baselines3 version:", stable_baselines3.__version__)


### Setting Up Environment Parameters

Define the constants that will be used in the Pong environment.


In [15]:
COURT_HEIGHT = 800
COURT_WIDTH = 1200
PADDLE_HEIGHT = 90
PADDLE_WIDTH = 15
BALL_RADIUS = 12
INITIAL_BALL_SPEED = 10
PADDLE_GAP = 10
PADDLE_SPEED_DIVISOR = 15
PADDLE_CONTACT_SPEED_BOOST_DIVISOR = 4
SPEED_INCREMENT = 0.6
SERVING_HEIGHT_MULTIPLIER = 2
PLAYER_COLOURS = {'Player1': 'grey', 'Player2': 'yellow'}
MAX_COMPUTER_PADDLE_SPEED = 30
DELTA_TIME = 2.5

### Creating Helper Classes and Functions

Define helper classes and functions to manage game state, players, and reward system.


In [6]:
class Player:
    Player1 = 'Player1'
    Player2 = 'Player2'

class PlayerPositions:
    Initial = 'Initial'
    Reversed = 'Reversed'

def get_bounce_angle(paddle_y, paddle_height, ball_y):
    relative_intersect_y = (paddle_y + (paddle_height / 2)) - ball_y
    normalized_relative_intersect_y = relative_intersect_y / (paddle_height / 2)
    return normalized_relative_intersect_y * (math.pi / 4)

def bounded_value(value, min_value, max_value):
    return max(min_value, min(max_value, value))

def transform_action(action):
    button_pressed = action[0] > 0.5
    paddle_direction = max(min(action[1], 1), -1)
    actions = {'button_pressed': button_pressed, 'paddle_direction': paddle_direction * 30}
    return actions

### Computer Player Class

Define a class for the computer player with methods for resetting and getting actions. This class gives us a deterministic computer opponent which our model can train against.


In [7]:
class ComputerPlayer:
    def __init__(self):
        self.reset()

    def reset(self):
        self.serve_delay = random.randint(10, 10)
        self.direction = random.randint(-60, 60)
        self.offset = random.randint(-PADDLE_HEIGHT/2, PADDLE_HEIGHT/2)
        self.serve_delay_counter = 0
        self.max_speed = MAX_COMPUTER_PADDLE_SPEED


    def get_actions(self, player, state):
        is_left = (player == Player.Player1 and not state['positions_reversed']) or (player == Player.Player2 and state['positions_reversed'])
        if state['ball']['score_mode']:
            return {'button_pressed': False, 'paddle_direction': 0}
        paddle = state[player]
        if state['ball']['serve_mode']:
            if paddle['y'] <= 0 or paddle['y'] + paddle['height'] >= COURT_HEIGHT:
                self.direction = -self.direction
            if self.serve_delay_counter > self.serve_delay:
                return {'button_pressed': True, 'paddle_direction': self.direction}
            else:
                self.serve_delay_counter += 1
                return {'button_pressed': False, 'paddle_direction': self.direction}
        if is_left:
            return {
                'button_pressed': False,
                'paddle_direction': bounded_value(
                    paddle['y'] + self.offset - state['ball']['y'] + paddle['height'] / 2,
                    -MAX_COMPUTER_PADDLE_SPEED,
                    MAX_COMPUTER_PADDLE_SPEED
                )
            }
        else:
            return {
                'button_pressed': False,
                'paddle_direction': -bounded_value(
                    paddle['y'] + self.offset - state['ball']['y'] + paddle['height'] / 2  ,
                    -MAX_COMPUTER_PADDLE_SPEED,
                    MAX_COMPUTER_PADDLE_SPEED
                )
            }


class ModelPlayer:
    def __init__(self, model):
        self.model = model

    def reset(self):
        return None

    def get_actions(self, player, state):
        return transform_action(self.model.predict(get_observation( player, state))[0])


### Reward System Class

Define a class to manage the reward system for training the model. This class will calculate and apply rewards based on game events and player actions.


In [8]:
class RewardSystem:
    def __init__(self, rewarded_player):
        self.rewarded_player = rewarded_player
        self.total_reward = 0
        self.step_count = 0

    def reset(self):
        self.total_reward = 0
        self.step_count += 1

    def pre_serve_reward(self, player, game_state):
        if player == self.rewarded_player:
            self.total_reward -= 0.05

    def serve_reward(self, player, game_state):
        if player == self.rewarded_player:
            ball = game_state['ball']
            reward = abs(ball['dy']) * abs(ball['dy']) - 30
            self.total_reward += reward

    def hit_paddle_reward(self, player, game_state):
        if player == self.rewarded_player:
            reward = 50
            # reward = 50 + abs(game_state[player]['dy'])
            self.total_reward += reward

    def conceed_point_reward(self, player, game_state):
        if player == self.rewarded_player:
            punishment = abs(game_state['ball']['y'] - (game_state[player]['y'] + game_state[player]['height'])) / 8
            # punishment = abs(game_state['ball']['y'] - (game_state[player]['y'] + game_state[player]['height'])) / 4
            self.total_reward -= punishment

    def score_point_reward(self, player, game_state):
        if player == self.rewarded_player:
            reward = 200
            self.total_reward += reward

    def paddle_movement_reward(self, player, game_state):
        if player == self.rewarded_player:
            paddle = game_state[player]
            # reward = -0.5 if abs(paddle['dy']) < 0.2 else 0
            reward = 0
            self.total_reward += reward

    def end_episode(self, player, game_state):
        if player == self.rewarded_player:
          if game_state['ball']['serve_mode'] and player == game_state['server']:
            self.total_reward -= 200
          else:
            self.total_reward += 200


### Pong Game Class

Define the main class for the Pong game, managing the game state and updates. This class handles the game logic, including ball movement, collisions, scoring, and paddle controls.


In [13]:
class PongGame:
    def __init__(self, server, positions_reversed, player, opponent):
        self.game_state = {
        'server': server,
        'positions_reversed': positions_reversed,
        'player': player,
        'opponent': opponent,
        Player.Player1: {'x': PADDLE_GAP, 'y': COURT_HEIGHT // 2 - PADDLE_HEIGHT // 2, 'dy': 0, 'width': PADDLE_WIDTH, 'height': PADDLE_HEIGHT, 'colour': PLAYER_COLOURS['Player1']},
        Player.Player2: {'x': COURT_WIDTH - PADDLE_WIDTH - PADDLE_GAP, 'y': COURT_HEIGHT // 2 - PADDLE_HEIGHT // 2, 'dy': 0, 'width': PADDLE_WIDTH, 'height': PADDLE_HEIGHT, 'colour': PLAYER_COLOURS['Player2']},
        'ball': {'x': COURT_WIDTH // 2, 'y': COURT_HEIGHT // 2, 'dx': INITIAL_BALL_SPEED, 'dy': INITIAL_BALL_SPEED, 'radius': BALL_RADIUS, 'speed': INITIAL_BALL_SPEED, 'serve_mode': True, 'score_mode': False, 'score_mode_timeout': 0},
        'stats': {'rally_length': 0, 'serve_speed': INITIAL_BALL_SPEED, 'server': server}
        }
        self.apply_meta_game_state()

    def apply_meta_game_state(self):
        game_state = self.game_state
        serving_player = game_state['server']
        positions_reversed = game_state['positions_reversed']
        if serving_player == Player.Player1:
            self.game_state[Player.Player1]['height'] = PADDLE_HEIGHT * SERVING_HEIGHT_MULTIPLIER
            self.game_state[Player.Player2]['height'] = PADDLE_HEIGHT
        else:
            self.game_state[Player.Player1]['height'] = PADDLE_HEIGHT
            self.game_state[Player.Player2]['height'] = PADDLE_HEIGHT * SERVING_HEIGHT_MULTIPLIER
        if positions_reversed:
            self.game_state[Player.Player1]['x'] = COURT_WIDTH - PADDLE_WIDTH - PADDLE_GAP
            self.game_state[Player.Player2]['x'] = PADDLE_GAP
            self.game_state[Player.Player1]['y'] = random.randint(0, COURT_HEIGHT)
            self.game_state[Player.Player2]['y'] = random.randint(0, COURT_HEIGHT)
        else:
            self.game_state[Player.Player1]['x'] = PADDLE_GAP
            self.game_state[Player.Player2]['x'] = COURT_WIDTH - PADDLE_WIDTH - PADDLE_GAP
        ball = self.game_state['ball']
        server_is_left = (serving_player == Player.Player1 and not positions_reversed) or (serving_player == Player.Player2 and positions_reversed)
        ball['y'] = self.game_state[serving_player]['height'] / 2 + self.game_state[serving_player]['y']
        ball['x'] = self.game_state[serving_player]['width'] + ball['radius'] + PADDLE_GAP if server_is_left else COURT_WIDTH - self.game_state[serving_player]['width'] - ball['radius'] - PADDLE_GAP
        ball['speed'] = INITIAL_BALL_SPEED
        ball['serve_mode'] = True
        ball['score_mode'] = False
        ball['score_mode_timeout'] = 0
        self.game_state['stats']['rally_length'] = 0

    def update_game_state(self, actions, delta_time, reward_system: RewardSystem):
        game_state = self.game_state
        ball = game_state['ball']
        stats = game_state['stats']
        server = game_state['server']
        paddle_left, paddle_right = (game_state[Player.Player2], game_state[Player.Player1]) if game_state['positions_reversed'] else (game_state[Player.Player1], game_state[Player.Player2])
        player_is_left = (game_state['player'] == Player.Player1 and not game_state['positions_reversed']) or (game_state['player'] == Player.Player2 and game_state['positions_reversed'])
        if ball['score_mode']:
            return True
        elif ball['serve_mode']:
            serving_from_left = (server == Player.Player1 and not game_state['positions_reversed']) or (server == Player.Player2 and game_state['positions_reversed'])

            reward_system.pre_serve_reward(server, game_state)

            if actions[server]['button_pressed']:
                ball['speed'] = INITIAL_BALL_SPEED
                ball['dx'] = INITIAL_BALL_SPEED if serving_from_left else -INITIAL_BALL_SPEED
                ball['serve_mode'] = False
                stats['rally_length'] += 1
                stats['serve_speed'] = abs(ball['dy']) + abs(ball['dx'])
                stats['server'] = server

                reward_system.serve_reward(server, game_state)

            ball['dy'] = (game_state[server]['y'] + game_state[server]['height'] / 2 - ball['y']) / PADDLE_SPEED_DIVISOR
            ball['y'] += ball['dy'] * delta_time
        else:
            ball['x'] += ball['dx'] * delta_time
            ball['y'] += ball['dy'] * delta_time
            if ball['y'] - ball['radius'] < 0:
                ball['dy'] = -ball['dy']
                ball['y'] = ball['radius']
            elif ball['y'] + ball['radius'] > COURT_HEIGHT:
                ball['dy'] = -ball['dy']
                ball['y'] = COURT_HEIGHT - ball['radius']
            if ball['x'] - ball['radius'] < paddle_left['x'] + paddle_left['width'] and ball['y'] + ball['radius'] > paddle_left['y'] and ball['y'] - ball['radius'] < paddle_left['y'] + paddle_left['height']:
                bounce_angle = get_bounce_angle(paddle_left['y'], paddle_left['height'], ball['y'])
                ball['dx'] = (ball['speed'] + abs(paddle_left['dy']) / PADDLE_CONTACT_SPEED_BOOST_DIVISOR) * math.cos(bounce_angle)
                ball['dy'] = (ball['speed'] + abs(paddle_left['dy']) / PADDLE_CONTACT_SPEED_BOOST_DIVISOR) * -math.sin(bounce_angle)
                ball['x'] = paddle_left['x'] + paddle_left['width'] + ball['radius']
                ball['speed'] += SPEED_INCREMENT
                stats['rally_length'] += 1

                if paddle_left == game_state['player']:
                    reward_system.hit_paddle_reward(self.game_state['player'], game_state)

            elif ball['x'] + ball['radius'] > paddle_right['x'] and ball['y'] + ball['radius'] > paddle_right['y'] and ball['y'] - ball['radius'] < paddle_right['y'] + paddle_right['height']:
                bounce_angle = get_bounce_angle(paddle_right['y'], paddle_right['height'], ball['y'])
                ball['dx'] = -(ball['speed'] + abs(paddle_right['dy']) / PADDLE_CONTACT_SPEED_BOOST_DIVISOR) * math.cos(bounce_angle)
                ball['dy'] = (ball['speed'] + abs(paddle_right['dy']) / PADDLE_CONTACT_SPEED_BOOST_DIVISOR) * -math.sin(bounce_angle)
                ball['x'] = paddle_right['x'] - ball['radius']
                ball['speed'] += SPEED_INCREMENT
                stats['rally_length'] += 1

                if paddle_right == game_state['player']:
                    reward_system.hit_paddle_reward(self.game_state['player'], game_state)

            if ball['x'] - ball['radius'] < 0:
                ball['score_mode'] = True

                if player_is_left:
                    reward_system.conceed_point_reward(self.game_state['player'], game_state)
                else:
                    reward_system.score_point_reward(self.game_state['player'], game_state)

            elif ball['x'] + ball['radius'] > COURT_WIDTH:
                ball['score_mode'] = True

                if not player_is_left:
                    reward_system.conceed_point_reward(self.game_state['player'], game_state)
                else:
                    reward_system.score_point_reward(self.game_state['player'], game_state)

        if game_state['positions_reversed']:
            game_state[Player.Player1]['dy'] = actions[Player.Player1]['paddle_direction']
            game_state[Player.Player2]['dy'] = -actions[Player.Player2]['paddle_direction']
        else:
            game_state[Player.Player1]['dy'] = -actions[Player.Player1]['paddle_direction']
            game_state[Player.Player2]['dy'] = actions[Player.Player2]['paddle_direction']

        game_state[Player.Player1]['y'] += game_state[Player.Player1]['dy'] * delta_time
        game_state[Player.Player2]['y'] += game_state[Player.Player2]['dy'] * delta_time

        if paddle_left['y'] < 0:
            paddle_left['y'] = 0
        if paddle_left['y'] + paddle_left['height'] > COURT_HEIGHT:
            paddle_left['y'] = COURT_HEIGHT - paddle_left['height']
        if paddle_right['y'] < 0:
            paddle_right['y'] = 0
        if paddle_right['y'] + paddle_right['height'] > COURT_HEIGHT:
            paddle_right['y'] = COURT_HEIGHT - paddle_right['height']

        reward_system.paddle_movement_reward(self.game_state['player'], game_state)
        return False

# Set up observation space

The following code defines the observation space for the model we are training as well as a function to tranform the game state into an observation the model can make predictions from.

## Normalisation

Normalisation is a crucial step in preparing data for machine learning models, especially when the data has features with different ranges. Read more about the specific normalisation happening in `get_observation` [in this document](https://github.com/mrbenbot/wimblepong/blob/main/docs/normalisation.md).

## Customising the observation space

If you decide to change the observation space or the way it is normalised, you can upload a `.js` file along with your `model.json` and `weights.bin` to ensure your model is receiving the same input when playing real WimblePong as when being trained.

[read more here](https://github.com/mrbenbot/wimblepong/blob/main/docs/custom_observations.md)



In [10]:
observation_space = spaces.Box(
            low=np.array([0, 0, -1, -1, 0, 0, 0, 0], dtype=np.float32),
            high=np.array([1, 1, 1, 1, 1, 1, 1, 1], dtype=np.float32)
        )

def get_observation(player, game_state):
    is_server = 1 if game_state['server'] == player else 0
    paddle = game_state[player]
    return np.array([
        float(game_state['ball']['x'] / COURT_WIDTH),
        float(game_state['ball']['y'] / COURT_HEIGHT),
        float(game_state['ball']['dx'] / 40),
        float(game_state['ball']['dy'] / 40),
        float(0 if paddle['x'] < COURT_WIDTH / 2 else 1),
        float(paddle['y'] / COURT_HEIGHT),
        float(int(game_state['ball']['serve_mode'])),
        float(int(is_server)),
    ], dtype=np.float32)

### Custom Pong Environment Class

Define a custom gym environment for the Pong game. This environment will interface with the stable-baselines3 library for training the reinforcement learning model.


In [11]:
class CustomPongEnv(gym.Env):
    def __init__(self, computer_player):
        super(CustomPongEnv, self).__init__()

        self.action_space = spaces.Box(low=np.array([0, -1]), high=np.array([1, 1]), dtype=np.float32)
        self.observation_space = observation_space
        self.starting_states = [
           {'server': Player.Player1, 'positions_reversed': False, 'opponent': Player.Player1, 'player': Player.Player2},
           {'server': Player.Player2, 'positions_reversed': False, 'opponent': Player.Player1, 'player': Player.Player2},
           {'server': Player.Player2, 'positions_reversed': True, 'opponent': Player.Player1, 'player': Player.Player2},
           {'server': Player.Player1, 'positions_reversed': True, 'opponent': Player.Player1, 'player': Player.Player2},
        ]
        self.starting_state_index = 0

        self.computer_player = computer_player
        self.screen = None
        self.frame_count = 0
        self.last_event = None
        self.reset(seed=0)

    def seed(self, seed=None):
        self.np_random, seed = gym.utils.seeding.np_random(seed)
        return [seed]

    def reset(self, seed=None):
        super().reset(seed=seed)
        if seed is not None:
            self.seed(seed)
        self.starting_state_index = (self.starting_state_index + 1) % len(self.starting_states)
        starting_state = self.starting_states[self.starting_state_index]

        server = starting_state['server']
        positions_reversed = starting_state['positions_reversed']
        player = starting_state['player']
        opponent = starting_state['opponent']

        self.computer_player.reset()
        self.game = PongGame(server=server, positions_reversed=positions_reversed, opponent=opponent, player=player)
        self.reward_system = RewardSystem(rewarded_player=player)
        self.step_count = 0

        return self._get_obs(), {}

    def step(self, action):
        self.step_count += 1
        self.reward_system.reset()

        model_player_actions = transform_action(action)
        computer_player_actions = self.computer_player.get_actions(self.game.game_state['opponent'], self.game.game_state)
        actions = {self.game.game_state['opponent']: computer_player_actions, self.game.game_state['player']: model_player_actions}
        terminated = self.game.update_game_state(actions, DELTA_TIME, self.reward_system)
        obs = self._get_obs()
        info = {}
        truncated = False
        if self.step_count > 1000:
            self.reward_system.end_episode(self.game.game_state['player'], self.game.game_state)
            terminated = True

        reward = self.reward_system.total_reward
        return obs, reward, terminated, truncated, info

    def _get_obs(self):
        state = self.game.game_state
        player = state['player']
        return get_observation(player, state)

    def render(self, mode='human', close=False):
        if close:
            if pygame.get_init():
                pygame.quit()
            return

        if self.screen is None:
            pygame.init()
            self.screen = pygame.display.set_mode((COURT_WIDTH, COURT_HEIGHT))
        if not os.path.exists('./frames'):
            os.makedirs("./frames")

        # Clear screen
        self.screen.fill((255, 255, 255))  # Fill with white background
        state = self.game.game_state
        # Render paddles
        paddle1 = state[Player.Player1]
        paddle2 = state[Player.Player2]
        pygame.draw.rect(self.screen, paddle1['colour'], (paddle1['x'], paddle1['y'], paddle1['width'], paddle1['height']))
        pygame.draw.rect(self.screen, paddle2['colour'], (paddle2['x'], paddle2['y'], paddle2['width'], paddle2['height']))

        # Render ball
        ball = state['ball']
        pygame.draw.circle(self.screen, (0, 0, 0), (ball['x'], ball['y']), ball['radius'])

        # Update the display
        pygame.display.flip()

        # Save frame as image
        frame_path = f'./frames/frame_{self.frame_count:04d}.png'
        pygame.image.save(self.screen, frame_path)
        self.frame_count += 1


    def close(self):
        if not os.path.exists('./frames'):
            print("No frames directory found, skipping video creation.")
            return
        image_files = [f"./frames/frame_{i:04d}.png" for i in range(self.frame_count)]

        # Create a video clip from the image sequence
        clip = ImageSequenceClip(image_files, fps=24)  # 24 frames per second

        # Write the video file
        clip.write_videofile("./game_video.mp4", codec="libx264")
        pygame.quit()
        frames_dir = "./frames"
        if os.path.exists(frames_dir):
            for filename in os.listdir(frames_dir):
                file_path = os.path.join(frames_dir, filename)
                if os.path.isfile(file_path):
                    os.unlink(file_path)
            os.rmdir(frames_dir)


### Testing the Environment

Create and test a single instance of the custom Pong environment. This will help ensure that the environment behaves as expected before moving on to training.


In [None]:
# Create and test single environment
env = Monitor(CustomPongEnv(computer_player=ComputerPlayer()))

obs = env.reset()
print("Initial observation:", obs)
i = 0
while True:
    i+=1
    action = env.action_space.sample()  # Sample random action
    obs, reward, done, info, _ = env.step(action)
    print("Action taken:", action)
    print("Observation:", obs)
    print("Reward:", reward)
    print('iteration:', i)
    print("Done:", done)
    env.render()
    if done:
        obs = env.reset()
        print("Environment reset")
        break

env.close()

### Creating and Testing Vectorized Environment

Create and test a vectorized environment to allow for more efficient training by running multiple instances of the game in parallel.


In [None]:
# Create and test vectorised environment
# Create a vectorized environment
env = DummyVecEnv([lambda: CustomPongEnv(computer_player=ComputerPlayer()) for _ in range(1)])  # Adjust number of instances as needed
env = VecNormalize(env, norm_obs=False, norm_reward=True)  # Normalize observations and rewards

obs = env.reset()
print("Initial observation:", obs)
i = 0
while True:
    i+=1
    action = env.action_space.sample()  # Sample random action
    print("Action taken:", action)
    obs, reward, done, info = env.step([action for _ in range(1)])
    print("Observation:", obs)
    print("Reward:", reward)
    print('iteration:', i)
    print("Done:", done)
    if np.any(done):
        obs = env.reset()
        print("Environment reset")
        break

env.close()

### Custom Evaluation Callback

Define a custom evaluation callback to periodically evaluate and log the performance of the trained model during training.


In [18]:
class CustomEvalCallback(EvalCallback):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _on_step(self) -> bool:
        result = super()._on_step()
        if self.n_calls % self.eval_freq == 0:
            print(f"Evaluation at step {self.n_calls}: mean reward {self.last_mean_reward:.2f}")
        return result

### Setting Up Hyperparameters and Training the Model

Define hyperparameters for the PPO algorithm and set up the training process. This section includes creating the training and evaluation environments, initializing the model, and starting the training process.


In [None]:
# Configuration
create_new = True  # Set to False to retrain
train_against_model = False  # Set to True to partially train against a previously trained model player (if there is one)
total_timesteps = 600000
eval_freq = 1000

# Hyperparameters for PPO
hyperparams = {
    'n_steps': 2048,
    'batch_size': 64,
    'n_epochs': 4,
    'gamma': 0.99,
    'gae_lambda': 0.95,
    'clip_range': 0.2,
    'clip_range_vf': 0.2,
    'ent_coef': 0.001,
    'vf_coef': 0.5,
    'max_grad_norm': 0.5,
    'target_kl': 0.01,
    'tensorboard_log': f'{DRIVE_PATH}/{DAY}/logs/tensorboard_logs'
}

# Load the previous model if training against it
model2 = PPO.load(f"{DRIVE_PATH}/{DAY}/logs/best_model/best_model") if train_against_model else None

# Create training environment
def make_env(index):
    return Monitor(CustomPongEnv(computer_player=ModelPlayer(model2) if model2 and index < 2 else ComputerPlayer()))

train_env = DummyVecEnv([lambda: make_env(i) for i in range(4)])
train_env = VecNormalize(train_env, norm_obs=False, norm_reward=True) if create_new else VecNormalize.load(f"{DRIVE_PATH}/{DAY}/pong_bot_1_normalize", train_env)
train_env.training = True

# Create evaluation environment
eval_env = DummyVecEnv([lambda: Monitor(CustomPongEnv(computer_player=ModelPlayer(model2) if model2 else ComputerPlayer()))])
eval_env = VecNormalize(eval_env, norm_obs=False, norm_reward=False)
eval_env.training = False

# Create the CustomEvalCallback
eval_callback = CustomEvalCallback(eval_env, best_model_save_path=f'{DRIVE_PATH}/{DAY}/logs/best_model',
                                   log_path=f'{DRIVE_PATH}/{DAY}/logs/results', eval_freq=eval_freq,
                                   deterministic=True, render=False)

# Create or load the model
model = PPO('MlpPolicy', env=train_env, **hyperparams) if create_new else PPO.load(f"{DRIVE_PATH}/{DAY}/pong_bot_1", env=train_env, **hyperparams, force_reset=True)

# Train the model
model.learn(total_timesteps=total_timesteps, callback=eval_callback)

# Save the model and the normalization statistics
model.save(f"{DRIVE_PATH}/{DAY}/pong_bot_1")
train_env.save(f"{DRIVE_PATH}/{DAY}/pong_bot_1_normalize")

print("Training completed and logs are saved.")

### Evaluating the Trained Model

Load the trained model and evaluate its performance. This section demonstrates how to use the trained model to play the game and render the results.


In [None]:
# Load the trained model and evaluate
model = PPO.load(f"{DRIVE_PATH}/{DAY}/pong_bot_1")
# model2 = PPO.load(f"{DRIVE_PATH}/{DAY}/logs/best_model/best_model")

# Create a new environment for rendering
env = CustomPongEnv(computer_player=ComputerPlayer())
# env = CustomPongEnv(computer_player=ModelPlayer(model2))


obs, _ = env.reset()
count = 0


while count < 4:
    action = model.predict(obs)  # Get action from the trained model
    obs, reward, done, info, _ = env.step(action[0])
    env.render()
    if done:
      count += 1
      print("reset")
      env.reset()

env.close()

### Exporting the Model to ONNX

Export the trained model to ONNX format.

In [None]:
class OnnxableSB3Policy(th.nn.Module):
    def __init__(self, policy: BasePolicy):
        super().__init__()
        self.policy = policy

    def forward(self, observation: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
        # Run the policy in deterministic mode
        actions, values, log_prob = self.policy(observation, deterministic=False)
        return actions


# Load the trained PyTorch model
# model_path = f"{DRIVE_PATH}/{DAY}/logs/best_model/best_model"
model_path = f"{DRIVE_PATH}/{DAY}/pong_bot_1"

model = PPO.load(model_path, device="cpu")

onnx_policy = OnnxableSB3Policy(model.policy)
onnx_file_path = f"{DRIVE_PATH}/{DAY}/pong_bot_1_dynamo.onnx"

# Define dummy input based on the observation space shape
observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)

# Export the model to ONNX
th.onnx.export(
    onnx_policy,
    dummy_input,
    onnx_file_path,
    opset_version=11,
    input_names=["input"],
    output_names=["actions"]
)

print(f"ONNX model saved at: {onnx_file_path}")

### Exporting the ONNX Model to TensorFlow

Convert the ONNX format model to TensorFlow SavedModel format. This allows for compatibility with various deployment platforms.


In [None]:
tf_model_path = f"{DRIVE_PATH}/{DAY}/pong_bot_1_tf"

# Convert ONNX to TensorFlow SavedModel
convert(
    input_onnx_file_path=onnx_file_path,
    output_folder_path=tf_model_path,
    output_signaturedefs=True,
)

print(f"TensorFlow SavedModel saved at: {tf_model_path}")


### Converting TensorFlow Model to TensorFlow.js

Convert the TensorFlow model to TensorFlow.js format for deployment in a web browser. A .json file and a .bin file will be created to represent your model. Go to [wimblepong.netlify.app/upload](https://wimblepong.netlify.app/upload) to load it into local storage and make it available to play the browser.

Be careful - if the model or the code handling the model has security vulnerabilities, it can be exploited by malicious actors. This could include executing arbitrary code, data leakage, or other malicious activities. A model trained by yourself should not be cause for alarm. See the [source code](https://https://github.com/mrbenbot/wimblepong/blob/main/src/libs/tensorFlowPlayer.ts) to make sure you are happy with the way the model is being run in the browser.


In [None]:
tfjs_model_path = f"{DRIVE_PATH}/{DAY}/pong_bot_1_tfjs"

# Convert the TensorFlow model to TensorFlow.js
subprocess.run([
    'tensorflowjs_converter',
    '--input_format', 'tf_saved_model',
    '--output_format', 'tfjs_graph_model',
    "--signature_name", "serving_default",
    tf_model_path,
    tfjs_model_path
])

print(f"TensorFlow.js model saved at: {tfjs_model_path}")


### Visualizing Training with TensorBoard

Set up TensorBoard to visualize the training process, including metrics like rewards and losses over time.


In [None]:
%load_ext tensorboard
# change path to logs as needed
LOGDIR=f'{DRIVE_PATH}/{DAY}/logs/tensorboard_logs'
%tensorboard --logdir $LOGDIR
