In [None]:
%matplotlib inline
import time
import numpy as np
import matplotlib.pyplot as plt
from itertools import count
import torch
import torch.optim as optim
import torch.nn.functional as F

from entities.network import DQN
from entities.environment import EnvironmentManager
from entities.strategy import EpsilonGreedyStrategy
from entities.memory import ReplayMemory
from entities.experience import Experience
from entities.q_values import QValues
from entities.agent import Agent, MODES
from entities.utils import plot, create_torch_device, extract_tensors

from entities.constants import POLICY_NET_FILE

## Example of non-processed screen V.S processed screen

In [None]:
device = create_torch_device()
env_manager = EnvironmentManager(device)

# Render first screen
non_processed = env_manager.render('rgb_array')
processed = env_manager.get_processed_screen()

# Setup the graphs
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(non_processed)
axs[0].set_title('Non-processed screen')

axs[1].imshow(processed.squeeze(0).permute(1, 2, 0), interpolation='none')
axs[1].set_title('Processed screen')
plt.show()

## Example of starting state

In [None]:
screen = env_manager.get_state()

plt.figure()
plt.imshow(screen.squeeze(0).permute(1, 2, 0), interpolation='none')
plt.title('Starting screen')
plt.show()

## Example of non starting states

In [None]:
for i in range(5):
    env_manager.take_action(torch.tensor([1]))
screen = env_manager.get_state()

plt.figure()
plt.imshow(screen.squeeze(0).permute(1, 2, 0), interpolation='none')
plt.title('Starting screen')
plt.show()

## Example of end state

In [None]:
env_manager.done = True
screen = env_manager.get_state()

plt.figure()
plt.imshow(screen.squeeze(0).permute(1, 2, 0), interpolation='none')
plt.title('Starting screen')
plt.show()

Example of the plot method

In [None]:
plot(np.random.rand(300), 100)

# Main training loop

In [None]:
# Define hyper parameters
batch_size = 256
gamma = 0.999
eps_start = 1
eps_end = 0.01
eps_decay_rate = 0.001
target_update = 10
memory_size = 100000
learning_rate = 0.001
num_episodes = 1000

# Define main components
device = create_torch_device()
env_manager = EnvironmentManager(device)
strategy = EpsilonGreedyStrategy(eps_start, eps_end, eps_decay_rate)
agent = Agent(strategy, env_manager.num_actions_available(), device)
memory = ReplayMemory(memory_size)

# Define the neural networks
policy_net = DQN(env_manager.get_screen_width(), env_manager.get_screen_height()).to(device)
target_net = DQN(env_manager.get_screen_width(), env_manager.get_screen_height()).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
optimizer = optim.Adam(params=policy_net.parameters(), lr=learning_rate)

# Store all the durations in this array
episode_durations = []

# Start iterating the episodes
for i_episode in range(num_episodes):
    # Reset the environment and get the first state
    env_manager.reset()
    state = env_manager.get_state()

    # For each episode, play the game
    for timestep in count():
        # Make the agent select an action
        action = agent.select_action(state, policy_net)
        reward = env_manager.take_action(action)
        next_state = env_manager.get_state()

        # Store the experience in the memory
        memory.push(Experience(state, action, next_state, reward))
        state = next_state

        # If we have enough experiences, start optimizing
        if memory.can_sample_memory(batch_size):
            experiences = memory.sample(batch_size)
            states, actions, rewards, next_states = extract_tensors(experiences)

            current_q_values = QValues.get_current(policy_net, states, actions)
            next_q_values = QValues.get_next(target_net, next_states)
            target_q_values = next_q_values * gamma + rewards

            loss = F.mse_loss(current_q_values, target_q_values.unsqueeze(1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        if env_manager.done:
            episode_durations.append(timestep)
            plot(episode_durations, 100)
            break
    
    if i_episode % target_update == 0:
        target_net.load_state_dict(policy_net.state_dict())

env_manager.close() 


In [None]:
# Save the models for inference
data_path = os.path.join(os.getcwd(), 'models')
policy_net_file = os.path.join(data_path, 'policy_network.pth')
target_net_file = os.path.join(data_path, 'target_network.pth')

torch.save(policy_net.state_dict(), policy_net_file)
torch.save(target_net.state_dict(), target_net_file)