In [1]:
#!/usr/bin/env python
from collections import deque
import os
import random
import gym
import torch
from torch.distributions import Categorical
import torch.nn.functional as F
from IPython.display import clear_output
import numpy as np
from skimage.color import rgb2gray
from skimage.transform import rescale, resize, downscale_local_mean
import matplotlib.pyplot as plt


class QNetwork(torch.nn.Module):
    def __init__(self, num_frames= 1, num_actions=2):
        super(QNetwork, self).__init__()
        self.num_frames = num_frames
        self.num_actions = num_actions
        
        # Layers
        self.conv1 = torch.nn.Conv2d(
            in_channels=num_frames,
            out_channels=16,
            kernel_size=8,
            stride=4,
            padding=2
            )
        self.conv2 = torch.nn.Conv2d(
            in_channels=16,
            out_channels=32,
            kernel_size=4,
            stride=2,
            padding=1
            )
        self.fc1 = torch.nn.Linear(
            in_features=3200,
            out_features=256,
            )
        self.fc2 = torch.nn.Linear(
            in_features=256,
            out_features=num_actions,
            )
        
        # Activation Functions
        self.relu = torch.nn.ReLU()
    
    def flatten(self, x):
        batch_size = x.size()[0]
        x = x.view(batch_size, -1)
        return x
    
    def forward(self, x):
        
        # Forward pass
        x = self.relu(self.conv1(x))  # In: (80, 80, 4)  Out: (20, 20, 16)
        x = self.relu(self.conv2(x))  # In: (20, 20, 16) Out: (10, 10, 32)
        x = self.flatten(x)           # In: (10, 10, 32) Out: (3200,)
        x = self.relu(self.fc1(x))    # In: (3200,)      Out: (256,)
        x = self.fc2(x)               # In: (256,)       Out: (4,)
        
        return x    
    
    
    


def update_Q():
    loss = 0

    for state, action, state_next, reward, done in random.sample(history, min(32, len(history))):
        with torch.no_grad():
            if done:
                target = reward
            else:
                target = reward + discount * torch.max(Q_target(state_next.to(device)))
        loss = loss + (target - Q(state.to(device))[0][action])**2

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


def process(img):

    # Crop and resize the image
    img = resize(img, (80, 80), anti_aliasing=True)

    # Convert the image to greyscale
    img = rgb2gray(img)
  
    img = img[np.newaxis, np.newaxis, :, :]

    return torch.tensor(img, device=device, dtype=torch.float)   





In [2]:
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")    
    
# gym environment
env = gym.envs.make("CartPole-v0")
#env = gym.envs.make("Breakout-v0")

# network and optimizer
n_actions = env.action_space.n
Q = QNetwork().to(device)
optimizer = torch.optim.Adam(Q.parameters(), lr=0.0005)

# target network
Q_target = QNetwork().to(device)
Q_target.load_state_dict(Q.state_dict())




history = deque(maxlen=100000)  # replay buffer
discount = 0.99  # discount factor gamma

In [3]:
max_time_steps = 1000

# for computing average reward over 100 episodes
reward_history = deque(maxlen=100)


# for updating target network
target_interval = 1000
target_counter = 0

# training
for episode in range(300):
    # sum of accumulated rewards
    rewards = 0

    # get initial observation
    state = env.reset()
    screen = env.render(mode='rgb_array')
    screen =process(screen)
   

      # loop until an episode ends
    for t in range(1, max_time_steps + 1):
        # display current environment
        #env.render()
        

        # epsilon greedy policy for current observation
        with torch.no_grad():
            if random.random() < 0.05:
                action = env.action_space.sample()
            else:
                q_values = Q(screen.to(device)).detach()
                action = torch.argmax(q_values)
                action = action.tolist()
        # get next observation and current reward for the chosen action
        state_next, reward, done, info = env.step(action)
        screen_next = env.render(mode='rgb_array')
        screen_next  =process(screen_next)

        # collect reward
        rewards = rewards + reward

        # collect a transition
        history.append([screen, action, screen_next, reward, done])

        update_Q()

        # update target network
        target_counter = target_counter + 1
        if target_counter % target_interval == 0:
            Q_target.load_state_dict(Q.state_dict())

        if done:
            env.close()
            break

        # pass observation to the next step
        state = state_next
        screen = screen_next
    # compute average reward
    reward_history.append(rewards)
    avg = sum(reward_history) / len(reward_history)
    clear_output(wait=True)
    print('episode: {}, reward: {:.1f}, avg: {:.1f}'.format(episode, rewards, avg))

env.close()  








episode: 299, reward: 26.0, avg: 19.4


In [4]:
# TEST     
episode = 0
state = env.reset()  
screen = env.render(mode='rgb_array')
screen =process(screen)

while episode < 10:  # episode loop
    q_values = Q(screen.to(device)).detach()
    action = torch.argmax(q_values)    
    action = action.tolist()
    next_state, reward, done, info = env.step(action)  # take a random action
    screen_next = env.render(mode='rgb_array')
    screen_next  =process(screen_next)
    state = next_state
    screen = screen_next
    if done:
        episode = episode + 1
        state = env.reset()
env.close()     