## Playing Pong Using Deep Q-Network

Author's note: This code does not work. Either there's an implementation error, or it just takes an absurdly long time for the agent to learn the game. I will be switching to policy gradients method and see if that works better. I'm posting this code on GitHub anyway in case I want to work on it sometime in the future.

In [1]:
# Import stuff
import time
import gym
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

# Set random seed
np.random.seed(42)
torch.manual_seed(42)

# Create environment, use Pong as default environment for now
env = gym.make('Pong-v0')
env.seed(42)
n_actions = env.action_space.n

# Create directory for tensorlog
# Make sure to use a new directory for every new run
log_dir = 'logs/pong_dqn_1'
writer = SummaryWriter(log_dir)

# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [2]:
# Class for the Q-network. It uses a convolutional neural network to predict the Q-values for every
# action in a given state. Assume input size is (N, 4, 210, 160).

class QNetwork(nn.Module):
    def __init__(self, input_channels=4, input_height=210, input_width=160, output_size=n_actions):
        """Initialize Q-network"""
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=16, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2)
        # No pooling layers?
        flat_size = 32 * self.conv2d_size_out(self.conv2d_size_out(input_height, 8, 4), 4, 2) * \
                    self.conv2d_size_out(self.conv2d_size_out(input_width, 8, 4), 4, 2)
        self.fc1 = nn.Linear(flat_size, 256)
        self.fc2 = nn.Linear(256, output_size)

    def conv2d_size_out(self, size, kernel_size, stride):
        """Utility function to calculate size of dimension after convolution"""
        return (size - (kernel_size - 1) - 1) // stride + 1
        
    def forward(self, x):
        """Make a forward pass"""
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [3]:
# Class for experience replay. This class will store state transitions (state, action, reward, next_state).
# Random memory will be sampled when training the agent. This method supposedly reduces variance.

class ExperienceReplay:
    def __init__(self, memory_cap):
        """Initialize memory"""
        self.memory_cap = memory_cap
        self.size = 0
        self.index = 0
        self.states = []
        self.actions = []
        self.rewards = []
        self.next_states = []
        
    def store(self, state, action, reward, next_state):
        """Store a transition"""
        if self.size < self.memory_cap:
            self.states.append(state)
            self.actions.append(action)
            self.rewards.append(reward)
            self.next_states.append(next_state)
            self.size += 1
        else:
            self.states[self.index] = state
            self.actions[self.index] = action
            self.rewards[self.index] = reward
            self.next_states[self.index] = next_state
            self.index = (self.index + 1) % self.memory_cap
    
    def stack_states(self, i, stack_size):
        """Utility function to stack the last four states ending at the i-th index"""
        # Check validity of all states that are going to be stacked
        for k in range(i - stack_size + 1, i):
            if k < 0 or k >= self.size:
                # Out of bound
                return None
            elif self.next_states[k] is None:
                # Some of the states come from different episodes
                return None
            
        return torch.cat(self.states[i-stack_size+1:i+1]).unsqueeze(0)
    
    def sample(self, batch_size, stack_size=4):
        """Sample memory"""
        states_batch = []
        actions_batch = []
        rewards_batch = []
        next_states_batch = []
        random_indices = np.random.permutation(self.size)
        for random_index in random_indices:
            stacked_states = self.stack_states(random_index, stack_size)
            if stacked_states is not None:
                states_batch.append(stacked_states)
                actions_batch.append(self.actions[random_index])
                rewards_batch.append(self.rewards[random_index])
                if self.next_states[random_index] is None:
                    # This state is a terminal state
                    next_states_batch.append(None)
                else:
                    next_states_batch.append(torch.cat(self.next_states[random_index-stack_size+1:random_index+1]).unsqueeze(0))
                if len(states_batch) == batch_size:
                    break
        if len(states_batch) < batch_size:
            return None  # not enough samples
        return (states_batch, actions_batch, rewards_batch, next_states_batch)
    
    def __len__(self):
        return self.size

In [4]:
# Class for the agent.

class Agent:
    def __init__(self, name='Agent', frame_stack_size=4, memory_cap=10000, learning_rate=0.01,  
                 epsilon_max=1.0, epsilon_min=0.05, epsilon_decay=10000, gamma=0.99):
        """Initialize agent"""
        self.name = name
        self.learning_rate = learning_rate
        self.epsilon_max = epsilon_max
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        self.gamma = gamma
        self.Q = QNetwork(input_channels=frame_stack_size).to(device)
        self.targetQ = QNetwork(input_channels=frame_stack_size).to(device)
        self.targetQ.load_state_dict(self.Q.state_dict())
        self.memory = ExperienceReplay(memory_cap)
        self.loss = nn.MSELoss()
        self.optimizer = optim.RMSprop(self.Q.parameters(), learning_rate)
        
    def get_action(self, state, global_step):
        """Use epsilon-greedy to determine action"""
        prior_device = state.device
        state = state.to(device)
        epsilon = max(self.epsilon_min,
                      self.epsilon_max - (self.epsilon_max - self.epsilon_min) * global_step / self.epsilon_decay)
        if np.random.rand() < epsilon:
            # Take a random action
            action = torch.tensor([np.random.randint(n_actions)])
        else:
            # Take the action with the highest Q-value
            with torch.no_grad():
                Q_values = self.Q(state)
                _, action = Q_values.max(dim=1)
        action = action.to(prior_device)
        state = state.to(prior_device)
        return action
    
    def get_Q_values(self, state):
        """Get Q-values of a given state"""
        prior_device = state.device
        state = state.to(device)
        with torch.no_grad():
            Q_values = self.Q(state)
        state = state.to(prior_device)
        return Q_values
    
    def store(self, state, action, reward, next_state):
        '''Store the transition in memory'''
        self.memory.store(state, action, reward, next_state)
    
    def update_parameters(self, batch_size=32, stack_size=4):
        '''Sample memory and update parameters'''
        # Get samples and convert them to tensors
        samples = self.memory.sample(batch_size, stack_size)
        if samples is None:
            return None  # not enough samples
        else:
            states_batch, actions_batch, rewards_batch, next_states_batch = samples
        states = torch.cat(states_batch)
        actions = torch.cat(actions_batch)
        rewards = torch.cat(rewards_batch)
        non_final_states_mask = [True if next_state is not None else False for next_state in next_states_batch]
        non_final_next_states = torch.cat([next_state for next_state in next_states_batch if next_state is not None])
        
        # Send sample tensors to the specified device
        prior_device = states.device
        states = states.to(device)
        actions = actions.to(device)
        rewards = rewards.to(device)
        non_final_next_states = non_final_next_states.to(device)
        
        # Compute predicted Q-values
        Q_values = self.Q(states).gather(dim=1, index=actions.unsqueeze(1))
        
        # Compute target Q-values
        targets = rewards.unsqueeze(1).clone().detach()
        targets[non_final_states_mask] += self.gamma * self.targetQ(non_final_next_states).max(dim=1, keepdim=True)[0]
        
        # Compute loss
        loss = self.loss(Q_values, targets)
        
        # Optimize
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # Send sample tensors back to original device
        states =states.to(prior_device)
        actions = actions.to(prior_device)
        rewards = rewards.to(prior_device)
        non_final_next_states = non_final_next_states.to(prior_device)
        
        # Return loss
        return loss.item()
    
    def update_target_network(self):
        """Update target network to match current Q-network"""
        self.targetQ.load_state_dict(self.Q.state_dict())
        
    def save_parameters(self, path):
        """Save model parameters in a file"""
        torch.save(self.Q.state_dict(), path)
    
    def load_parameters(self, path):
        """Load model parameters from a file"""
        self.Q.load_state_dict(torch.load(path))
        self.targetQ.load_state_dict(self.Q.state_dict())

In [5]:
# Function to preprocess input images
def preprocess_image(image):
    image_gray = np.dot(image, [0.299, 0.587, 0.114])  # what if we just take the mean?
    return image_gray.astype(np.float32) / 255.0

# Return the current frame as a tensor
def get_image(env):
    image = env.render(mode='rgb_array')
    image = preprocess_image(image)
    return torch.from_numpy(image).unsqueeze(0)

In [6]:
def simulate_episode(env, agent, render=False, fps=30, detailed=False):
    env.reset()
    state = get_image(env)
    states_stack = []
    for _ in range(states_stack_size):
        states_stack.append(get_image(env))
        env.step(0)
    done = False
    score = 0
        
    while not done:
        if render:
            env.render()
            time.sleep(1/fps)
        
        states_tensor = torch.cat(states_stack).unsqueeze(0)
        action = agent.get_action(states_tensor, global_step)
        
        if detailed:
            Qs = agent.get_Q_values(states_tensor)
            print(Qs.detach().cpu().numpy(), end='\r')
        
        _, reward, done, _ = env.step(action.item())
        
        score += reward
        
        next_state = get_image(env)
        next_states_stack = states_stack[1:]
        next_states_stack.append(next_state)
        
        state = next_state
        states_stack = next_states_stack
        
    env.close()
    
    return score

In [None]:
# Agent hyperparameters
states_stack_size = 4
memory_cap = 50000
learning_rate = 0.01
epsilon_max = 1.0
epsilon_min = 0.1
epsilon_decay = 1000000
gamma = 0.99

# Initialize agent
agent = Agent('PongPong_v0', states_stack_size, memory_cap, learning_rate, 
              epsilon_max, epsilon_min, epsilon_decay, gamma)

# Training hyperparameters
global_step = 0
frame_skip_count = 3
training_length = 20000001
update_interval = 10
batch_size = 64
target_update_interval = 10
load_parameters_before_training = True
save_parameters = True
save_interval = 10000
save_path = 'pongpong_v0.pth'
simulate = True
simulate_interval = 10000

# Load model parameters if enabled
if load_parameters_before_training:
    agent.load_parameters(save_path)

# Initialize environment
env.seed(42)
env.reset()

# Initialize states stack
state = get_image(env)
states_stack = []
for _ in range(states_stack_size):
    states_stack.append(get_image(env))
    env.step(0)
initial_states_tensor = torch.cat(states_stack).unsqueeze(0)
    
# Track expected values and loss over time
values = []
losses = []
scores = []

# Main training loop
for i in range(training_length):
    # Choose action
    states_tensor = torch.cat(states_stack).unsqueeze(0)
    action = agent.get_action(states_tensor, global_step)
    
    # Perform the action for a few frames, this should increase the number of samples
    for _ in range(frame_skip_count):
        _, reward, done, _ = env.step(action.item())
        
        # Store transition
        if done:
            next_state = None
        else:
            next_state = get_image(env)
        agent.store(state, action, torch.tensor([reward]), next_state)
        
        if done:
            break
        else:
            state = next_state
            next_states_stack = states_stack[1:]
            next_states_stack.append(next_state)
            states_stack = next_states_stack
    
    # Update model parameters
    if i % update_interval == 0:
        loss = agent.update_parameters(batch_size)
        if loss is not None:
            values.append(agent.get_Q_values(initial_states_tensor).max().item())
            losses.append(loss)
            global_step += 1
            if global_step % target_update_interval == 0:
                agent.update_target_network()
                writer.add_scalar('initial_value', values[-1], global_step)
                writer.add_scalar('loss', losses[-1], global_step)
                # print(i, values[-1], losses[-1])
    
    # Save model parameters
    if save_parameters:
        if i % save_interval == 0:
            agent.save_parameters(save_path)
            
    # Show agent every once in a while
    if simulate:
        if i % simulate_interval == 0:
            score = simulate_episode(env, agent)
            scores.append(score)
            
    if done:
        env.reset()
        state = get_image(env)
        states_stack = []
        for _ in range(states_stack_size):
            states_stack.append(get_image(env))
            env.step(0)

In [None]:
plt.plot(values)
plt.show()

In [None]:
plt.plot(losses)
plt.show()

In [None]:
axes = plt.gca()
axes.set_ylim([-25,25])
plt.plot(scores)
plt.show()

In [None]:
simulate_episode(env, agent, render=True, detailed=True)