# Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pdb, pickle, torch, random

import numpy as np
import holoviews as hv
import pandas as pd
import xarray as xr

from holoviews import opts
from holoviews.streams import Pipe, Buffer
from holoviews.operation.timeseries import rolling

from replay_buffer import ReplayBuffer
from maddpg_agent import MaddpgAgent
from trainer import Trainer

hv.extension('bokeh')

In [None]:
from unityagents import UnityEnvironment
env = UnityEnvironment(file_name="./Tennis.app")

# Load environment

In [None]:
# Get the default brain and reset env
brain_name = env.brain_names[0]
brain = env.brains[brain_name]
env_info = env.reset(train_mode=True)[brain_name]

# Number of agents 
num_agents = len(env_info.agents)
print(f"Number of agents: {num_agents}")

# Size of the global state/action space (across all agents)
actions = env_info.previous_vector_actions
states = env_info.vector_observations
global_state_space_size = states.flatten().shape[0]
global_action_space_size = actions.flatten().shape[0]
print(f"Global states: {global_state_space_size}")
print(f"Global actions: {global_action_space_size}")

# Size of the local state/action space (for each agent individually)
action_space_size = brain.vector_action_space_size
state_space_size = brain.num_stacked_vector_observations * brain.vector_observation_space_size
print(f"Local states: {state_space_size}")
print(f"Local actions: {action_space_size}")

# Examine the state space 
print('The state for the first agent looks like:', states[0])

# Create/load replay buffer

In [None]:
# Create the replay buffer
replay_buffer_size_max = int(1e6)
min_samples_required = 10000
replay_buffer = ReplayBuffer(max_size=replay_buffer_size_max, min_samples_required=min_samples_required)

In [None]:
# Save replay buffer
#pickle.dump( replay_buffer, open( "replay_buffer.pickle", "wb" ) )

In [None]:
# Load replay buffer
replay_buffer = pickle.load( open( "replay_buffer.pickle", "rb" ) )
print(f"Loaded replay buffer with {len(replay_buffer)} samples.")

# Train

### Create a new trainer

In [None]:
seed = 0
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)

trainer = Trainer(
    env = env,
    replay_buffer = replay_buffer,
    discount = 0.99,
    tau = 0.01,
    actor_lr = 1e-4,
    critic_lr = 3e-4
)

### Start/resume training sesion

In [None]:
# Boost the agent learn rate becuase we're using batch norm
trainer.agents[0].actor_lr  = 4 * 1e-4
trainer.agents[0].critic_lr = 4 * 3e-4

trainer.train(
    num_episodes = 20000,
    batch_size = 512,
    train_every_steps = 4,
    noise_level = 0.12,
    noise_decay = 0.9999,
    max_episode_length = 250,
    print_episodes = 10
)

# Analyse results

In [None]:
# Display returns
max_returns = trainer.get_max_returns()
raw_returns = hv.Curve(max_returns, 'Episode', 'Return').relabel('Single episode')
smooth_returns = rolling(hv.Curve(
    max_returns, 'Episode', 'Return'), rolling_window=100).relabel('100 episode average')
max_returns_curve = (raw_returns * smooth_returns).relabel('Max episode return')

# Display loss
average_loss = trainer.get_average_loss()
actor_loss = hv.Curve(average_loss[:,0], 'Training iteration', 'Loss').relabel('Actor')
critic_loss = hv.Curve(average_loss[:,1], 'Training iteration', 'Loss').relabel('Critic')
loss_curves = (actor_loss * critic_loss).relabel('Actor/critic loss')

(max_returns_curve + loss_curves).opts(opts.Curve(axiswise=True))

# Save/restore training state

### Pausing/resuming training progress

This is especially useful because the Unity environment handle will be corrupted if you interrupt whilst training. Simply save the trainer, restart the kernel and unity environment, then load your progress to resume.

In [None]:
# Save trainer to disk
pickle.dump( trainer, open( "saved_models/trainer.pickle", "wb" ) )

# Save torch params to file
for i, agent in enumerate(trainer.agents):
    torch.save(agent.actor_optimiser,   f"saved_models/agent_{i}_actor_optimiser.pt")
    torch.save(agent.critic_optimiser,  f"saved_models/agent_{i}_critic_optimiser.pt")
    torch.save(agent.actor,         f"saved_models/agent_{i}_actor_model.pt")
    torch.save(agent.actor_target,  f"saved_models/agent_{i}_actor_target_model.pt")
    torch.save(agent.critic,        f"saved_models/agent_{i}_critic_model.pt")
    torch.save(agent.critic_target, f"saved_models/agent_{i}_critic_target_model.pt")

In [None]:
# Load trainer from disk
trainer = pickle.load( open( "saved_models/trainer.pickle", "rb" ) )

# Load torch params from file (NOT safe across refactors)
for i, agent in enumerate(trainer.agents):
    agent.actor_optimiser  = torch.load(f"saved_models/agent_{i}_actor_optimiser.pt")
    agent.critic_optimiser = torch.load(f"saved_models/agent_{i}_critic_optimiser.pt")
    agent.actor         = torch.load(f"saved_models/agent_{i}_actor_model.pt")
    agent.actor_target  = torch.load(f"saved_models/agent_{i}_actor_target_model.pt")
    agent.critic        = torch.load(f"saved_models/agent_{i}_critic_model.pt")
    agent.critic_target = torch.load(f"saved_models/agent_{i}_critic_target_model.pt")

# Watch agent play

To view random play according to the OU noise process, set the noise level to 1. This is what we use to generate exploratory behaviour initially.

In [None]:
for i in range(1, 15):                                      # play game for 5 episodes
    env_info = env.reset(train_mode=False)[brain_name]     # reset the environment    
    states = env_info.vector_observations                  # get the current state (for each agent)
    scores = np.zeros(num_agents)                          # initialize the score (for each agent)
    t = 0
    while True:
        t += 1

        actions = [agent.act(state, noise_level=0.5) for agent, state in zip(trainer.agents, states)]
        
        env_info = env.step(actions)[brain_name]           # send all actions to tne environment
        next_states = env_info.vector_observations         # get next state (for each agent)
        rewards = env_info.rewards                         # get reward (for each agent)
        dones = env_info.local_done                        # see if episode finished
        scores += env_info.rewards                         # update the score (for each agent)
        states = next_states                               # roll over states to next time step
        if np.any(dones):                                  # exit loop if episode finished
            break
    print(f'Episode: {i}; length: {t}, max score: {np.max(scores)}')