In [1]:
#!/usr/bin/env python
from datetime import datetime
from collections import deque
import os
import random
import gym
import torch
from torch.distributions import Categorical
from torch.nn import Module, Linear
import torch.nn.functional as F


class QNetwork(Module):
    def __init__(self):
        super().__init__()
        self.fc = Linear(4, 48)
        self.fcQ1 = Linear(48, 64)
        self.fcQ2 = Linear(64, 2)

    def forward(self, x):
        x = self.fc(x)
        x = F.relu(x)
        x = self.fcQ1(x)
        x = F.relu(x)
        x = self.fcQ2(x)

        return x


# network and optimizer
Q = QNetwork()
optimizer = torch.optim.Adam(Q.parameters(), lr=0.0005)

# target network
Q_target = QNetwork()

history = deque(maxlen=1000000)  # replay buffer
discount = 0.99  # discount factor gamma

def update_Q():
    loss = 0

    for state, action, state_next, reward, done in random.sample(history, min(32, len(history))):
        with torch.no_grad():
            if done:
                target = reward
            else:
                target = reward + discount * torch.max(Q_target(state_next))

        loss = loss + (target - Q(state)[action])**2

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


# gym environment
env = gym.make("CartPole-v0")
max_time_steps = 1000

# for computing average reward over 100 episodes
reward_history = deque(maxlen=100)


# for updating target network
target_interval = 1000
target_counter = 0

# training
for episode in range(1000):
    # sum of accumulated rewards
    rewards = 0

    # get initial observation
    observation = env.reset()
    state = torch.tensor(observation, dtype=torch.float32)

    # loop until an episode ends
    for t in range(1, max_time_steps + 1):
        # display current environment
        #env.render()

        # epsilon greedy policy for current observation
        with torch.no_grad():
            if random.random() < 0.01:
                action = env.action_space.sample()
            else:
                action = torch.argmax(Q(state)).item()

        # get next observation and current reward for the chosen action
        observation_next, reward, done, info = env.step(action)
        state_next = torch.tensor(observation_next, dtype=torch.float32)

        # collect reward
        rewards = rewards + reward

        # collect a transition
        history.append([state, action, state_next, reward, done])

        update_Q()

        # update target network
        target_counter = target_counter + 1
        if target_counter % target_interval == 0:
            Q_target.load_state_dict(Q.state_dict())

        if done:
            break

        # pass observation to the next step
        observation = observation_next
        state = state_next

    # compute average reward
    reward_history.append(rewards)
    avg = sum(reward_history) / len(reward_history)
    print('episode: {}, reward: {:.1f}, avg: {:.1f}'.format(episode, rewards, avg))

env.close()


# TEST     
episode = 0
state = env.reset()     
while episode < 10:  # episode loop
    env.render()
    state = torch.tensor(state, dtype=torch.float32)
    action = torch.argmax(Q(state)).item()
    next_state, reward, done, info = env.step(action)  # take a random action
    state = next_state

    if done:
        episode = episode + 1
        state = env.reset()
env.close()     




episode: 0, reward: 9.0, avg: 9.0
episode: 1, reward: 10.0, avg: 9.5
episode: 2, reward: 8.0, avg: 9.0
episode: 3, reward: 9.0, avg: 9.0
episode: 4, reward: 8.0, avg: 8.8
episode: 5, reward: 10.0, avg: 9.0
episode: 6, reward: 10.0, avg: 9.1
episode: 7, reward: 10.0, avg: 9.2
episode: 8, reward: 10.0, avg: 9.3
episode: 9, reward: 10.0, avg: 9.4
episode: 10, reward: 10.0, avg: 9.5
episode: 11, reward: 10.0, avg: 9.5
episode: 12, reward: 9.0, avg: 9.5
episode: 13, reward: 11.0, avg: 9.6
episode: 14, reward: 9.0, avg: 9.5
episode: 15, reward: 9.0, avg: 9.5
episode: 16, reward: 9.0, avg: 9.5
episode: 17, reward: 9.0, avg: 9.4
episode: 18, reward: 10.0, avg: 9.5
episode: 19, reward: 10.0, avg: 9.5
episode: 20, reward: 10.0, avg: 9.5
episode: 21, reward: 9.0, avg: 9.5
episode: 22, reward: 9.0, avg: 9.5
episode: 23, reward: 10.0, avg: 9.5
episode: 24, reward: 10.0, avg: 9.5
episode: 25, reward: 8.0, avg: 9.5
episode: 26, reward: 9.0, avg: 9.4
episode: 27, reward: 9.0, avg: 9.4
episode: 28, rew

episode: 229, reward: 21.0, avg: 9.6
episode: 230, reward: 10.0, avg: 9.6
episode: 231, reward: 11.0, avg: 9.6
episode: 232, reward: 11.0, avg: 9.6
episode: 233, reward: 9.0, avg: 9.6
episode: 234, reward: 9.0, avg: 9.6
episode: 235, reward: 11.0, avg: 9.6
episode: 236, reward: 10.0, avg: 9.7
episode: 237, reward: 9.0, avg: 9.6
episode: 238, reward: 11.0, avg: 9.7
episode: 239, reward: 9.0, avg: 9.7
episode: 240, reward: 12.0, avg: 9.7
episode: 241, reward: 13.0, avg: 9.7
episode: 242, reward: 12.0, avg: 9.7
episode: 243, reward: 9.0, avg: 9.7
episode: 244, reward: 10.0, avg: 9.7
episode: 245, reward: 11.0, avg: 9.7
episode: 246, reward: 13.0, avg: 9.8
episode: 247, reward: 14.0, avg: 9.8
episode: 248, reward: 10.0, avg: 9.8
episode: 249, reward: 11.0, avg: 9.8
episode: 250, reward: 11.0, avg: 9.8
episode: 251, reward: 10.0, avg: 9.8
episode: 252, reward: 10.0, avg: 9.8
episode: 253, reward: 11.0, avg: 9.8
episode: 254, reward: 11.0, avg: 9.8
episode: 255, reward: 10.0, avg: 9.8
episod

episode: 448, reward: 42.0, avg: 12.7
episode: 449, reward: 13.0, avg: 12.7
episode: 450, reward: 23.0, avg: 12.8
episode: 451, reward: 34.0, avg: 13.1
episode: 452, reward: 16.0, avg: 13.1
episode: 453, reward: 16.0, avg: 13.2
episode: 454, reward: 39.0, avg: 13.5
episode: 455, reward: 16.0, avg: 13.5
episode: 456, reward: 19.0, avg: 13.6
episode: 457, reward: 25.0, avg: 13.8
episode: 458, reward: 11.0, avg: 13.8
episode: 459, reward: 21.0, avg: 13.8
episode: 460, reward: 19.0, avg: 13.9
episode: 461, reward: 27.0, avg: 14.1
episode: 462, reward: 17.0, avg: 14.2
episode: 463, reward: 50.0, avg: 14.6
episode: 464, reward: 23.0, avg: 14.7
episode: 465, reward: 32.0, avg: 14.8
episode: 466, reward: 10.0, avg: 14.7
episode: 467, reward: 21.0, avg: 14.8
episode: 468, reward: 15.0, avg: 14.7
episode: 469, reward: 20.0, avg: 14.8
episode: 470, reward: 15.0, avg: 14.8
episode: 471, reward: 15.0, avg: 14.9
episode: 472, reward: 16.0, avg: 15.0
episode: 473, reward: 17.0, avg: 15.0
episode: 474

KeyboardInterrupt: 