# Use this Notebook to watch training as it is happening

To use this notebook only change one line as directed below, and then execute all cells.

In [1]:
import time
import torch
from torch import Tensor
from torch import nn
import numpy as np
import ple
from ple.games.catcher import Catcher
from IPython.display import clear_output
clear_output()

In [2]:
class PolicyModel(torch.nn.Module):
    """
    Determines the best action to take given a state. 
    """

    def __init__(self,
                 in_feature_len: int, 
                 out_feature_len: int,                 
                 hidden_neuron_len: int, 
                 learning_rate:float,
                 model_path: str = None
                ):
        super().__init__()        
        self.linear1 = torch.nn.Linear(in_feature_len, hidden_neuron_len)
        self.linear2 = torch.nn.Linear(hidden_neuron_len, out_feature_len)

    def forward(self, x: Tensor) -> Tensor:
        x = torch.nn.ReLU()(self.linear1(x))
        x = torch.nn.Softmax(dim=-1)(self.linear2(x))      
        return x

def try_load(func: callable) -> any:
    while True:
        try:
            return func()
        except EOFError as e:
            print(e)
            print('Reloading')
        except RuntimeError as e:
            print(e)
            print('Reloading')
        except FileNotFoundError as e:
            print(e)
            print("Try running PPO_From_Scratch.ipynb")
            print()
            print("Trying again in 10 seconds")
            time.sleep(10)
            clear_output()
            
        time.sleep(1)

def _load_model(
    in_feature_len: int, 
    out_feature_len: int, 
    hidden_neuron_len: int, 
    learning_rate: float
):
    policy = PolicyModel(
        in_feature_len = in_feature_len, 
        out_feature_len = out_feature_len, 
        hidden_neuron_len = hidden_neuron_len, 
        learning_rate = learning_rate,
    )
    policy.load_state_dict(torch.load(MODEL_PATH, weights_only=True))
    policy.double()
    policy.cpu()
    return policy

def _print_summary(
    highest_score: int, 
    score: int, 
    hit: int,
    miss: int,
    frame_i: int,
    reload_frames: int,
    reload_threshold: int
):
    print(f"Highest Score: {highest_score}")    
    print(f"Score: {score}")
    print(f"Hit: {hit}")
    print(f"Miss: {miss}")
    print(f"Frame: {frame_i}")
    print(f"Frames until model reload: {reload_frames} of {reload_threshold}")

## Remove comments only for the game of choice

In [3]:
game = Catcher(width=64 *8, height=64*8, init_lives=1); MODEL_PATH = 'catcher.model'
#game = ple.games.flappybird.FlappyBird();  MODEL_PATH = 'flappybird.model'

In [4]:
torch.set_default_dtype(torch.float64)

ACTIONS = [None] + list(game.actions.values())
p = ple.PLE(game, fps=30, display_screen=True, force_fps=False)
p.init()

load_config = lambda: torch.load(f"{MODEL_PATH}.config")
config = try_load(load_config)

load_min_max = lambda: torch.load(f"{MODEL_PATH}.min-max", weights_only=False)
min_max = try_load(load_min_max)

load_model = lambda: _load_model(
    in_feature_len = min_max['min'].shape[0], 
    out_feature_len = len(ACTIONS), 
    hidden_neuron_len = int(config['POLICY_HIDDEN_LEN']), 
    learning_rate = float(config['POLICY_LR']),
)
policy = try_load(load_model)

RELOAD_THRESHOLD = 1000
highest_score = 0
hit = 0
miss = 0
global_frame_i = 0
last_reload = 0

try:
    while True:
        p.reset_game()
        for frame_i in range(100000000): # Number of Frames
            clear_output(wait=True)
            if p.score() > highest_score:
                highest_score = p.score()
            _print_summary(
                highest_score = highest_score, 
                score = p.score(), 
                hit = hit,
                miss = miss,
                frame_i = frame_i,
                reload_frames = global_frame_i - last_reload,
                reload_threshold = RELOAD_THRESHOLD
            )
            
            if p.game_over():
                print("Game Over. Resetting.")
                hit = 0
                miss = 0
                p.reset_game()
                break
        
            state = list(p.getGameState().values())
            state = (state - min_max['min'])/(min_max['max'] - min_max['min'])
            one_hot_action = policy(torch.tensor(state)).detach()
            action_index = int(np.argmax(one_hot_action))
            action = ACTIONS[action_index]

            reward = p.act(action)
            if reward > 0:
                hit += 1
            elif reward < 0:
                miss += 1
    
            if (global_frame_i - last_reload) >= RELOAD_THRESHOLD:
                print("Reloading")
                min_max = try_load(load_min_max)
                policy = try_load(load_model)
                last_reload = global_frame_i
                time.sleep(2)
            global_frame_i += 1
            
        hit = 0
        miss = 0
    
        
        if (global_frame_i - last_reload) >= RELOAD_THRESHOLD:
            print("Reloading on reset")
            min_max = try_load(load_min_max)
            policy = try_load(load_model)
            last_reload = global_frame_i
            time.sleep(2)
except KeyboardInterrupt as e:
    clear_output(wait=True)
    print("Execution stopped by user")
    _print_summary(
        highest_score = highest_score, 
        score = p.score(), 
        hit = hit,
        miss = miss,
        frame_i = frame_i,
        reload_frames = global_frame_i - last_reload,
        reload_threshold = RELOAD_THRESHOLD        
    )  

Execution stopped by user
Highest Score: 11.0
Score: 11.0
Hit: 11
Miss: 0
Frame: 420
Frames until model reload: 420 of 1000
