In [None]:
import glob
import os
import sys
import time
from datetime import datetime

import matplotlib.pyplot as plt
import torch
import torch.optim as optim
from torch.utils.tensorboard import FileWriter
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm_notebook as tqdm

# Make library available in path
!rm -rf 'fom-openai-gym-rl'
!git clone https://github.com/fom-big-data/fom-openai-gym-rl
lib_path = os.path.join(os.getcwd(), 'fom-openai-gym-rl', 'notebooks', '00-basemodel', 'atari-dqn', 'lib')
if not (lib_path in sys.path):
    sys.path.insert(0, lib_path)
common_lib_path = os.path.join(os.getcwd(), 'fom-openai-gym-rl', 'notebooks', '00-basemodel', 'common', 'lib')
if not (common_lib_path in sys.path):
    sys.path.insert(0, common_lib_path)
common_reward_shaper_path = os.path.join(os.getcwd(), 'fom-openai-gym-rl', 'notebooks', '00-basemodel', 'common', 'reward_shaper')
if not (common_reward_shaper_path in sys.path):
    sys.path.insert(0, common_reward_shaper_path)
    
# Make directory for models    
!mkdir -p ./model

# Import library classes
from action_selector import ActionSelector
from breakout_reward_shaper import BreakoutRewardShaper
from deep_q_network import DeepQNetwork
from environment_builder import EnvironmentBuilder
from environment_builder import EnvironmentWrapper
from environment_enum import Environment
from freeway_reward_shaper import FreewayRewardShaper
from input_extractor import InputExtractor
from model_optimizer import ModelOptimizer
from model_storage import ModelStorage
from performance_logger import PerformanceLogger
from performance_plotter import PerformancePlotter
from pong_reward_shaper import PongRewardShaper
from replay_memory import ReplayMemory
from screen_plotter import ScreenPlotter
from spaceinvaders_reward_shaper import SpaceInvadersRewardShaper

# 0 Setup

In [None]:
# Path to model to be loaded
RUN_TO_LOAD = None
OUTPUT_DIRECTORY = "./model/"

if RUN_TO_LOAD != None:
    # Get latest file from run
    list_of_files = glob.glob(OUTPUT_DIRECTORY + RUN_TO_LOAD + "/*.model")
    MODEL_TO_LOAD = max(list_of_files, key=os.path.getctime)

    RUN_DIRECTORY = RUN_TO_LOAD

    FINISHED_FRAMES, \
    FINISHED_EPISODES, \
    MODEL_STATE_DICT, \
    OPTIMIZER_STATE_DICT, \
    REPLAY_MEMORY, \
    LOSS, \
 \
    ENVIRONMENT, \
    ENVIRONMENT_WRAPPERS, \
    BATCH_SIZE, \
    GAMMA, \
    EPS_START, \
    EPS_END, \
    EPS_DECAY, \
    NUM_ATOMS, \
    VMIN, \
    VMAX, \
    TARGET_UPDATE_RATE, \
    MODEL_SAVE_RATE, \
    REPLAY_MEMORY_SIZE, \
    NUM_FRAMES, \
    REWARD_PONG_PLAYER_RACKET_HITS_BALL, \
    REWARD_PONG_PLAYER_RACKET_COVERS_BALL, \
    REWARD_PONG_PLAYER_RACKET_CLOSE_TO_BALL_LINEAR, \
    REWARD_PONG_PLAYER_RACKET_CLOSE_TO_BALL_QUADRATIC, \
    REWARD_PONG_OPPONENT_RACKET_HITS_BALL, \
    REWARD_PONG_OPPONENT_RACKET_COVERS_BALL, \
    REWARD_PONG_OPPONENT_RACKET_CLOSE_TO_BALL_LINEAR, \
    REWARD_PONG_OPPONENT_RACKET_CLOSE_TO_BALL_QUADRATIC, \
    REWARD_BREAKOUT_PLAYER_RACKET_HITS_BALL, \
    REWARD_BREAKOUT_PLAYER_RACKET_COVERS_BALL, \
    REWARD_BREAKOUT_PLAYER_RACKET_CLOSE_TO_BALL_LINEAR, \
    REWARD_BREAKOUT_PLAYER_RACKET_CLOSE_TO_BALL_QUADRATIC, \
    REWARD_SPACEINVADERS_PLAYER_AVOIDS_LINE_OF_FIRE, \
    REWARD_FREEWAY_CHICKEN_VERTICAL_POSITION \
        = ModelStorage.loadModel(MODEL_TO_LOAD)
else:
    RUN_DIRECTORY = datetime.now().strftime("%Y-%m-%d-%H:%M:%S")

    # Only use defined parameters if there is no previous output being loaded
    FINISHED_FRAMES = 0
    FINISHED_EPISODES = 0

    # Define setup
    ENVIRONMENT_ID = os.getenv('ENVIRONMENT_ID', Environment.BREAKOUT_NO_FRAMESKIP_V0.value)
    ENVIRONMENT = Environment(ENVIRONMENT_ID)
    ENVIRONMENT_WRAPPERS = [
        EnvironmentWrapper.KEEP_ORIGINAL_OBSERVATION,
        EnvironmentWrapper.NOOP_RESET,
        EnvironmentWrapper.MAX_AND_SKIP,
        EnvironmentWrapper.EPISODIC_LIFE,
        EnvironmentWrapper.FIRE_RESET,
        EnvironmentWrapper.WARP_FRAME,
        EnvironmentWrapper.IMAGE_TO_PYTORCH,
    ]
    BATCH_SIZE = 32
    GAMMA = 0.99
    EPS_START = 1.0
    EPS_END = 0.01
    EPS_DECAY = 10_000
    NUM_ATOMS = 51
    VMIN = -10
    VMAX = 10
    TARGET_UPDATE_RATE = 10_000
    MODEL_SAVE_RATE = 100
    REPLAY_MEMORY_SIZE = 100_000
    NUM_FRAMES = 1_000_000

    REWARD_PONG_PLAYER_RACKET_HITS_BALL = 0.0
    REWARD_PONG_PLAYER_RACKET_COVERS_BALL = 0.0
    REWARD_PONG_PLAYER_RACKET_CLOSE_TO_BALL_LINEAR = 0.0
    REWARD_PONG_PLAYER_RACKET_CLOSE_TO_BALL_QUADRATIC = 0.0
    REWARD_PONG_OPPONENT_RACKET_HITS_BALL = 0.0
    REWARD_PONG_OPPONENT_RACKET_COVERS_BALL = 0.0
    REWARD_PONG_OPPONENT_RACKET_CLOSE_TO_BALL_LINEAR = 0.0
    REWARD_PONG_OPPONENT_RACKET_CLOSE_TO_BALL_QUADRATIC = 0.0
    REWARD_BREAKOUT_PLAYER_RACKET_HITS_BALL = 0.0
    REWARD_BREAKOUT_PLAYER_RACKET_COVERS_BALL = 0.0
    REWARD_BREAKOUT_PLAYER_RACKET_CLOSE_TO_BALL_LINEAR = 0.0
    REWARD_BREAKOUT_PLAYER_RACKET_CLOSE_TO_BALL_QUADRATIC = 0.0
    REWARD_SPACEINVADERS_PLAYER_AVOIDS_LINE_OF_FIRE = 0.0
    REWARD_FREEWAY_CHICKEN_VERTICAL_POSITION = 0.0

    # Log parameters
    PerformanceLogger.log_parameters(directory=OUTPUT_DIRECTORY + RUN_DIRECTORY,
                                     batch_size=BATCH_SIZE,
                                     gamma=GAMMA,
                                     eps_start=EPS_START,
                                     eps_end=EPS_END,
                                     eps_decay=EPS_END,
                                     num_atoms=NUM_ATOMS,
                                     vmin=VMIN,
                                     vmax=VMAX,
                                     target_update_rate=TARGET_UPDATE_RATE,
                                     model_save_rate=MODEL_SAVE_RATE,
                                     replay_memory_size=REPLAY_MEMORY_SIZE,
                                     num_frames=NUM_FRAMES,
                                     reward_pong_player_racket_hits_ball=REWARD_PONG_PLAYER_RACKET_HITS_BALL,
                                     reward_pong_player_racket_covers_ball=REWARD_PONG_PLAYER_RACKET_COVERS_BALL,
                                     reward_pong_player_racket_close_to_ball_linear=REWARD_PONG_PLAYER_RACKET_CLOSE_TO_BALL_LINEAR,
                                     reward_pong_player_racket_close_to_ball_quadratic=REWARD_PONG_PLAYER_RACKET_CLOSE_TO_BALL_QUADRATIC,
                                     reward_pong_opponent_racket_hits_ball=REWARD_PONG_OPPONENT_RACKET_HITS_BALL,
                                     reward_pong_opponent_racket_covers_ball=REWARD_PONG_OPPONENT_RACKET_COVERS_BALL,
                                     reward_pong_opponent_racket_close_to_ball_linear=REWARD_PONG_OPPONENT_RACKET_CLOSE_TO_BALL_LINEAR,
                                     reward_pong_opponent_racket_close_to_ball_quadratic=REWARD_PONG_OPPONENT_RACKET_CLOSE_TO_BALL_QUADRATIC,
                                     reward_breakout_player_racket_hits_ball=REWARD_BREAKOUT_PLAYER_RACKET_HITS_BALL,
                                     reward_breakout_player_racket_covers_ball=REWARD_BREAKOUT_PLAYER_RACKET_COVERS_BALL,
                                     reward_breakout_player_racket_close_to_ball_linear=REWARD_BREAKOUT_PLAYER_RACKET_CLOSE_TO_BALL_LINEAR,
                                     reward_breakout_player_racket_close_to_ball_quadratic=REWARD_BREAKOUT_PLAYER_RACKET_CLOSE_TO_BALL_QUADRATIC,
                                     reward_spaceinvaders_player_avoids_line_of_fire=REWARD_SPACEINVADERS_PLAYER_AVOIDS_LINE_OF_FIRE,
                                     reward_freeway_chicken_vertical_position=REWARD_FREEWAY_CHICKEN_VERTICAL_POSITION
                                     )
# Assemble reward shapings
REWARD_SHAPINGS = [
    {"method": PongRewardShaper().reward_player_racket_hits_ball,
     "arguments": {"additional_reward": REWARD_PONG_PLAYER_RACKET_HITS_BALL}},
    {"method": PongRewardShaper().reward_player_racket_covers_ball,
     "arguments": {"additional_reward": REWARD_PONG_PLAYER_RACKET_COVERS_BALL}},
    {"method": PongRewardShaper().reward_player_racket_close_to_ball_linear,
     "arguments": {"additional_reward": REWARD_PONG_PLAYER_RACKET_CLOSE_TO_BALL_LINEAR}},
    {"method": PongRewardShaper().reward_player_racket_close_to_ball_quadratic,
     "arguments": {"additional_reward": REWARD_PONG_PLAYER_RACKET_CLOSE_TO_BALL_QUADRATIC}},
    {"method": PongRewardShaper().reward_opponent_racket_hits_ball,
     "arguments": {"additional_reward": REWARD_PONG_OPPONENT_RACKET_HITS_BALL}},
    {"method": PongRewardShaper().reward_opponent_racket_covers_ball,
     "arguments": {"additional_reward": REWARD_PONG_OPPONENT_RACKET_COVERS_BALL}},
    {"method": PongRewardShaper().reward_opponent_racket_close_to_ball_linear,
     "arguments": {"additional_reward": REWARD_PONG_OPPONENT_RACKET_CLOSE_TO_BALL_LINEAR}},
    {"method": PongRewardShaper().reward_opponent_racket_close_to_ball_quadratic,
     "arguments": {"additional_reward": REWARD_PONG_OPPONENT_RACKET_CLOSE_TO_BALL_QUADRATIC}},
    {"method": BreakoutRewardShaper().reward_player_racket_hits_ball,
     "arguments": {"additional_reward": REWARD_BREAKOUT_PLAYER_RACKET_HITS_BALL}},
    {"method": BreakoutRewardShaper().reward_player_racket_covers_ball,
     "arguments": {"additional_reward": REWARD_BREAKOUT_PLAYER_RACKET_COVERS_BALL}},
    {"method": BreakoutRewardShaper().reward_player_racket_close_to_ball_linear,
     "arguments": {"additional_reward": REWARD_BREAKOUT_PLAYER_RACKET_CLOSE_TO_BALL_LINEAR}},
    {"method": BreakoutRewardShaper().reward_player_racket_close_to_ball_quadratic,
     "arguments": {"additional_reward": REWARD_BREAKOUT_PLAYER_RACKET_CLOSE_TO_BALL_QUADRATIC}},
    {"method": SpaceInvadersRewardShaper().reward_player_avoids_line_of_fire,
     "arguments": {"additional_reward": REWARD_SPACEINVADERS_PLAYER_AVOIDS_LINE_OF_FIRE}},
    {"method": FreewayRewardShaper().reward_chicken_vertical_position,
     "arguments": {"additional_reward": REWARD_FREEWAY_CHICKEN_VERTICAL_POSITION}},
]

## 0.1 Configure device

In [None]:
# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 0.2 Set up matplotlib

In [None]:
# Enable interactive mode of matplotlib
plt.ion()

## 0.3 Set up TensorBoard

In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

## 0.4 Set up environment

In [None]:
# Initialize environment
env = EnvironmentBuilder.make_environment_with_wrappers(ENVIRONMENT.value, ENVIRONMENT_WRAPPERS)
# Reset environment
env.reset()
# Plot initial screen
# InputExtractor.plot_screen(InputExtractor.get_sharp_screen(env=env, device=device), 'Example extracted screen')

# 1 Set up nets

# 1.1 Define nets

In [None]:
# Get screen size so that we can initialize layers correctly based on shape
# returned from AI gym. Typical dimensions at this point are close to 3x40x90
# which is the result of a clamped and down-scaled render buffer in get_screen()
init_screen = InputExtractor.get_screen(env=env, device=device)
_, _, screen_height, screen_width = init_screen.shape

# Get number of actions from gym action space
n_actions = env.action_space.n

# Only use defined parameters if there is no previous model being loaded
if RUN_TO_LOAD != None:
    # Initialize and loade policy net and target net
    policy_net = DeepQNetwork(screen_height, screen_width, n_actions).to(device)
    policy_net.load_state_dict(MODEL_STATE_DICT)

    target_net = DeepQNetwork(screen_height, screen_width, n_actions).to(device)
    target_net.load_state_dict(MODEL_STATE_DICT)
else:
    # Initialize policy net and target net
    policy_net = DeepQNetwork(screen_height, screen_width, n_actions).to(device)

    target_net = DeepQNetwork(screen_height, screen_width, n_actions).to(device)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()

# 1.2 Define optimizer and replay memory

In [None]:
# Only use defined parameters if there is no previous model being loaded
if RUN_TO_LOAD != None:
    # Initialize and load optimizer
    optimizer = optim.RMSprop(policy_net.parameters())
    optimizer.load_state_dict(OPTIMIZER_STATE_DICT)

    # Load memory
    memory = REPLAY_MEMORY
else:
    # Initialize optimizer
    optimizer = optim.RMSprop(policy_net.parameters())
    # Initialize replay memory
    memory = ReplayMemory(REPLAY_MEMORY_SIZE)

# 2 Training

In [None]:
# Initialize total variables
total_frames = 0
total_episodes = FINISHED_EPISODES
total_original_rewards = []
total_shaped_rewards = []
total_losses = []
total_start_time = time.time()

# Initialize episode variables
episode_frames = 0
episode_original_reward = 0
episode_shaped_reward = 0
episode_start_time = time.time()

# Initialize the environment and state
env.reset()
last_screen = InputExtractor.get_screen(env=env, device=device)
current_screen = InputExtractor.get_screen(env=env, device=device)
state = current_screen - last_screen

# 2.1 Display TensorBoard

In [None]:
# Initialize writer
tensorboard_summary_writer = SummaryWriter()
tensorboard_file_writer = FileWriter("images")
%tensorboard --logdir=runs

# 2.2 Training loop

In [None]:
# Iterate over frames
progress_bar = tqdm(iterable=range(NUM_FRAMES), unit='frames', initial=FINISHED_FRAMES)
for total_frames in progress_bar:
    total_frames += FINISHED_FRAMES

    # Select and perform an action
    action = ActionSelector.select_action(state=state,
                                          n_actions=n_actions,
                                          total_frames=total_frames,
                                          policy_net=policy_net,
                                          epsilon_end=EPS_END,
                                          epsilon_start=EPS_START,
                                          epsilon_decay=EPS_DECAY,
                                          device=device)

    # Do step
    observation, reward, done, info = env.step(action.item())

    # Shape reward
    original_reward = reward
    shaped_reward = reward
    
    # Retrieve current screen
    screen = observation

    # Iterate over all reward shaping mechanisms
    for reward_shaping in REWARD_SHAPINGS:
        if reward_shaping["arguments"]["additional_reward"] != 0:
            shaped_reward += reward_shaping["method"](environment=ENVIRONMENT,
                                                      screen=screen,
                                                      reward=reward,
                                                      done=done,
                                                      info=info,
                                                      **reward_shaping["arguments"])

    # # Plot intermediate screen
    # if total_frames % 50 == 0:
    #     InputExtractor.plot_screen(InputExtractor.get_sharp_screen(env=env, device=device), "Frame " + str(
    #         total_frames) + " / shaped reward " + str(round(shaped_reward, 4)))
        
    # Use shaped reward for further processing
    reward = shaped_reward

    # Add reward to episode reward
    episode_original_reward += original_reward
    episode_shaped_reward += shaped_reward

    # Transform reward into a tensor
    reward = torch.tensor([reward], device=device)

    # Observe new state
    last_screen = current_screen
    current_screen = InputExtractor.get_screen(env=env, device=device)

    # Update next state
    next_state = current_screen - last_screen

    # Store the transition in memory
    memory.push(state, action, next_state, reward)

    # Move to the next state
    state = next_state

    # Perform one step of the optimization (on the target network)
    loss = ModelOptimizer.optimize_model(policy_net=policy_net,
                                         target_net=target_net,
                                         optimizer=optimizer,
                                         memory=memory,
                                         batch_size=BATCH_SIZE,
                                         gamma=GAMMA,
                                         device=device)
    
    # Write values to TensorBoard
    tensorboard_summary_writer.add_scalar("Loss", loss, total_frames)

    if done:    
        # Track episode time
        episode_end_time = time.time()
        episode_duration = episode_end_time - episode_start_time
        total_duration = episode_end_time - total_start_time

        # Add rewards to total reward
        total_original_rewards.append(episode_original_reward)
        total_shaped_rewards.append(episode_shaped_reward)
        
        # Write values to TensorBoard
        tensorboard_summary_writer.add_scalar("Reward (original)", episode_original_reward, total_episodes)
        tensorboard_summary_writer.add_scalar("Reward (shaped)", episode_shaped_reward, total_episodes)
        tensorboard_summary_writer.add_scalar("Episode duration", episode_duration, total_episodes)

        if loss is not None:
            PerformanceLogger.log_episode(directory=OUTPUT_DIRECTORY + RUN_DIRECTORY,
                                            total_episodes=total_episodes + 1,
                                            total_frames=total_frames,
                                            total_duration=total_duration,
                                            total_original_rewards=total_original_rewards,
                                            total_shaped_rewards=total_shaped_rewards,
                                            episode_frames=episode_frames + 1,
                                            episode_original_reward=episode_original_reward,
                                            episode_shaped_reward=episode_shaped_reward,
                                            episode_loss=loss.item(),
                                            episode_duration=episode_duration)

      # Update the target network, copying all weights and biases from policy net into target net
        if total_episodes % TARGET_UPDATE_RATE == 0:
            target_net.load_state_dict(policy_net.state_dict())

        if total_episodes % MODEL_SAVE_RATE == 0:
            # Save output
            ModelStorage.saveModel(directory=OUTPUT_DIRECTORY + RUN_DIRECTORY,
                                   total_frames=total_frames,
                                   total_episodes=total_episodes,
                                   net=target_net,
                                   optimizer=optimizer,
                                   memory=memory,
                                   loss=loss,
                                   environment=ENVIRONMENT,
                                   environment_wrappers=ENVIRONMENT_WRAPPERS,
                                   batch_size=BATCH_SIZE,
                                   gamma=GAMMA,
                                   eps_start=EPS_START,
                                   eps_end=EPS_END,
                                   eps_decay=EPS_DECAY,
                                   num_atoms=NUM_ATOMS,
                                   vmin=VMIN,
                                   vmax=VMAX,
                                   target_update_rate=TARGET_UPDATE_RATE,
                                   model_save_rate=MODEL_SAVE_RATE,
                                   replay_memory_size=REPLAY_MEMORY_SIZE,
                                   num_frames=NUM_FRAMES,
                                   reward_pong_player_racket_hits_ball=REWARD_PONG_PLAYER_RACKET_HITS_BALL,
                                   reward_pong_player_racket_covers_ball=REWARD_PONG_PLAYER_RACKET_COVERS_BALL,
                                   reward_pong_player_racket_close_to_ball_linear=REWARD_PONG_PLAYER_RACKET_CLOSE_TO_BALL_LINEAR,
                                   reward_pong_player_racket_close_to_ball_quadratic=REWARD_PONG_PLAYER_RACKET_CLOSE_TO_BALL_QUADRATIC,
                                   reward_pong_opponent_racket_hits_ball=REWARD_PONG_OPPONENT_RACKET_HITS_BALL,
                                   reward_pong_opponent_racket_covers_ball=REWARD_PONG_OPPONENT_RACKET_COVERS_BALL,
                                   reward_pong_opponent_racket_close_to_ball_linear=REWARD_PONG_OPPONENT_RACKET_CLOSE_TO_BALL_LINEAR,
                                   reward_pong_opponent_racket_close_to_ball_quadratic=REWARD_PONG_OPPONENT_RACKET_CLOSE_TO_BALL_QUADRATIC,
                                   reward_breakout_player_racket_hits_ball=REWARD_BREAKOUT_PLAYER_RACKET_HITS_BALL,
                                   reward_breakout_player_racket_covers_ball=REWARD_BREAKOUT_PLAYER_RACKET_COVERS_BALL,
                                   reward_breakout_player_racket_close_to_ball_linear=REWARD_BREAKOUT_PLAYER_RACKET_CLOSE_TO_BALL_LINEAR,
                                   reward_breakout_player_racket_close_to_ball_quadratic=REWARD_BREAKOUT_PLAYER_RACKET_CLOSE_TO_BALL_QUADRATIC,
                                   reward_spaceinvaders_player_avoids_line_of_fire=REWARD_SPACEINVADERS_PLAYER_AVOIDS_LINE_OF_FIRE,
                                   reward_freeway_chicken_vertical_position=REWARD_FREEWAY_CHICKEN_VERTICAL_POSITION
                                   )
        
            PerformancePlotter.save_values_plot(directory=OUTPUT_DIRECTORY + RUN_DIRECTORY,
                                        total_frames=total_frames,
                                        values=total_original_rewards,
                                        title="original rewards",
                                        xlabel="reward",
                                        ylabel="episode")

            PerformancePlotter.save_values_plot(directory=OUTPUT_DIRECTORY + RUN_DIRECTORY,
                                                total_frames=total_frames,
                                                values=total_shaped_rewards,
                                                title="shaped rewards",
                                                xlabel="reward",
                                                ylabel="episode")

            PerformancePlotter.save_values_plot(directory=OUTPUT_DIRECTORY + RUN_DIRECTORY,
                                                total_frames=total_frames,
                                                values=total_losses,
                                                title="losses",
                                                xlabel="loss",
                                                ylabel="frame")

            ScreenPlotter.save_screen_plot(directory=OUTPUT_DIRECTORY + RUN_DIRECTORY,
                                           total_frames=total_frames,
                                           env=env,
                                           title="screenshot",
                                           device=device)

        # Reset episode variables
        episode_frames = 0
        episode_original_reward = 0
        episode_shaped_reward = 0
        episode_start_time = time.time()
        
        # Reset the environment and state
        env.reset()
        last_screen = InputExtractor.get_screen(env=env, device=device)
        current_screen = InputExtractor.get_screen(env=env, device=device)
        state = current_screen - last_screen

        # Increment counter
        total_episodes += 1

    # Increment counter
    episode_frames += 1

print('Complete')
env.render()
env.close()
plt.ioff()
plt.show()