# Use this Notebook to watch training as it is happening

#### To use this notebook only change one line as directed below, and then execute all cells.

In [1]:
!python -m pip install tqdm torch tensorboard -e ./PyGame-Learning-Environment/PyGame-Learning-Environment-master -qqq

In [2]:
from tqdm import tqdm
import time
import torch
from torch import Tensor
from torch import nn
import numpy as np
import ple
from IPython.display import clear_output
clear_output()

In [3]:
class PolicyModel(torch.nn.Module):
    """
    Determines the best action to take given a state. 
    """

    def __init__(self,
                 in_feature_len: int, 
                 out_feature_len: int,                 
                 hidden_neuron_len: int, 
                 learning_rate:float,
                 model_path: str = None
                ):
        super().__init__()        
        self.linear1 = torch.nn.Linear(in_feature_len, hidden_neuron_len)
        self.linear2 = torch.nn.Linear(hidden_neuron_len, out_feature_len)

    def forward(self, x: Tensor) -> Tensor:
        x = torch.nn.ReLU()(self.linear1(x))
        x = torch.nn.Softmax(dim=-1)(self.linear2(x))      
        return x

def try_load(func: callable) -> any:
    while True:
        try:
            return func()
        except EOFError as e:
            print(e)
            print('Reloading')
        except RuntimeError as e:
            print(e)
            print('Reloading')
        except FileNotFoundError as e:
            print(e)
            print("Try running PPO_From_Scratch.ipynb")
            print()
            print("Trying again in 10 seconds")
            time.sleep(10)
            clear_output()
            
        time.sleep(1)

def _load_model(
    in_feature_len: int, 
    out_feature_len: int, 
    hidden_neuron_len: int, 
    learning_rate: float
):
    policy = PolicyModel(
        in_feature_len = in_feature_len, 
        out_feature_len = out_feature_len, 
        hidden_neuron_len = hidden_neuron_len, 
        learning_rate = learning_rate,
    )
    policy.load_state_dict(torch.load(MODEL_PATH, weights_only=True))
    policy.double()
    policy.cpu()
    return policy

def _print_summary(
    highest_score: int, 
    score: int, 
    hit: int,
    miss: int,
    frame_i: int,
    reload_frames: int,
    reload_threshold_frames: int,
    max_frames: int
):
    print(f"Highest Score: {highest_score}")    
    print(f"Score: {score}")
    print(f"Hit: {hit}")
    print(f"Miss: {miss}")
    print(f"Frame: {frame_i}")
    print(f"Frames until model reload: {reload_frames} of {reload_threshold_frames}")
    print(f"Frames until forced new game: {frame_i} of {max_frames}")

## Remove comments only for the game of choice
NOTE: Ensure that the game is running, or has already run, in the [PPO_From_Scratch.ipynb](PPO_From_Scratch.ipynb) notebook.

In [4]:
game = ple.games.Catcher(width=64 *8, height=64*8, init_lives=35); MODEL_PATH = 'catcher.model'
#game = ple.games.flappybird.FlappyBird();  MODEL_PATH = 'flappybird.model'

In [5]:
torch.set_default_dtype(torch.float64)

FPS = 30
ACTIONS = [None] + list(game.actions.values())
p = ple.PLE(game, fps=FPS, display_screen=True, force_fps=False)
p.init()

load_config = lambda: torch.load(f"{MODEL_PATH}.config")
config = try_load(load_config)

load_min_max = lambda: torch.load(f"{MODEL_PATH}.min-max", weights_only=False)
min_max = try_load(load_min_max)

load_model = lambda: _load_model(
    in_feature_len = min_max['min'].shape[0], 
    out_feature_len = len(ACTIONS), 
    hidden_neuron_len = int(config['POLICY_HIDDEN_LEN']), 
    learning_rate = float(config['POLICY_LR']),
)
policy = try_load(load_model)

MAX_MINUTES_PER_GAME = 5
RELOAD_THRESHOLD_FRAMES = 1000
SCREEN_DELAY_SECONDS = 3
highest_score = 0
hit = 0
miss = 0
global_frame_i = 0
last_reload = 0

try:
    while True:
        p.reset_game()
        for frame_i in range(MAX_MINUTES_PER_GAME * 60 * FPS):
            clear_output(wait=True)
            if p.score() > highest_score:
                highest_score = p.score()
            
            if p.game_over():
                print("\nGame Over. Resetting.")
                hit = 0
                miss = 0
                p.reset_game()
                break
        
            state = list(p.getGameState().values())
            state = (state - min_max['min'])/(min_max['max'] - min_max['min'])
            one_hot_action = policy(torch.tensor(state)).detach()
            action_index = int(np.argmax(one_hot_action))
            action = ACTIONS[action_index]

            reward = p.act(action)
            if reward > 0:
                hit += 1
            elif reward == -1:
                miss += 1

            _print_summary(
                highest_score = highest_score, 
                score = p.score(), 
                hit = hit,
                miss = miss,
                frame_i = frame_i,
                reload_frames = global_frame_i - last_reload,
                reload_threshold_frames = RELOAD_THRESHOLD_FRAMES,
                max_frames = MAX_MINUTES_PER_GAME * 60 * FPS
            )            
    
            if (global_frame_i - last_reload) >= RELOAD_THRESHOLD_FRAMES:
                print("\nLoading Latest Model")
                min_max = try_load(load_min_max)
                policy = try_load(load_model)
                last_reload = global_frame_i
                time.sleep(SCREEN_DELAY_SECONDS)
            global_frame_i += 1

            if frame_i + 1 == MAX_MINUTES_PER_GAME * 60 * FPS:
                print("\nTime limit reached. Resetting.")
                time.sleep(SCREEN_DELAY_SECONDS)                
            
        hit = 0
        miss = 0
    
        print("\nLoading Latest Model. Starting New Game.")
        min_max = try_load(load_min_max)
        policy = try_load(load_model)
        last_reload = global_frame_i
        time.sleep(SCREEN_DELAY_SECONDS)
        
except KeyboardInterrupt as e:
    clear_output(wait=True)
    _print_summary(
        highest_score = highest_score, 
        score = p.score(), 
        hit = hit,
        miss = miss,
        frame_i = frame_i,
        reload_frames = global_frame_i - last_reload,
        reload_threshold_frames = RELOAD_THRESHOLD_FRAMES,
        max_frames = MAX_MINUTES_PER_GAME * 60 * FPS
    )
    print("\nExecution stopped by user")    

Highest Score: 10.0
Score: 10.0
Hit: 11
Miss: 1
Frame: 423
Frames until model reload: 423 of 1000
Frames until forced new game: 423 of 9000

Execution stopped by user
