In [None]:
#!/usr/bin/env python
from datetime import datetime
from collections import deque
import os
import random
import gym
import torch
from torch.distributions import Categorical
from torch.nn import Module, Linear
import torch.nn.functional as F


class QNetwork(Module):
    def __init__(self):
        super().__init__()
        self.fc = Linear(4, 48)
        self.fcQ1 = Linear(48, 64)
        self.fcQ2 = Linear(64, 2)

    def forward(self, x):
        x = self.fc(x)
        x = F.relu(x)
        x = self.fcQ1(x)
        x = F.relu(x)
        x = self.fcQ2(x)

        return x


# network and optimizer
Q = QNetwork()
optimizer = torch.optim.Adam(Q.parameters(), lr=0.0005)

# target network
Q_target = QNetwork()

history = deque(maxlen=1000000)  # replay buffer
discount = 0.99  # discount factor gamma

def update_Q():
    loss = 0

    for state, action, state_next, reward, done in random.sample(history, min(32, len(history))):
        with torch.no_grad():
            if done:
                target = reward
            else:
                target = reward + discount * torch.max(Q_target(state_next))

        loss = loss + (target - Q(state)[action])**2

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


# gym environment
env = gym.make("CartPole-v0")
max_time_steps = 1000

# for computing average reward over 100 episodes
reward_history = deque(maxlen=100)


# for updating target network
target_interval = 1000
target_counter = 0

# training
for episode in range(1000):
    # sum of accumulated rewards
    rewards = 0

    # get initial observation
    observation = env.reset()
    state = torch.tensor(observation, dtype=torch.float32)

    # loop until an episode ends
    for t in range(1, max_time_steps + 1):
        # display current environment
        #env.render()

        # epsilon greedy policy for current observation
        with torch.no_grad():
            if random.random() < 0.01:
                action = env.action_space.sample()
            else:
                action = torch.argmax(Q(state)).item()

        # get next observation and current reward for the chosen action
        observation_next, reward, done, info = env.step(action)
        state_next = torch.tensor(observation_next, dtype=torch.float32)

        # collect reward
        rewards = rewards + reward

        # collect a transition
        history.append([state, action, state_next, reward, done])

        update_Q()

        # update target network
        target_counter = target_counter + 1
        if target_counter % target_interval == 0:
            Q_target.load_state_dict(Q.state_dict())

        if done:
            break

        # pass observation to the next step
        observation = observation_next
        state = state_next

    # compute average reward
    reward_history.append(rewards)
    avg = sum(reward_history) / len(reward_history)
    print('episode: {}, reward: {:.1f}, avg: {:.1f}'.format(episode, rewards, avg))

env.close()


# TEST     
episode = 0
state = env.reset()     
while episode < 10:  # episode loop
    env.render()
    state = torch.tensor(state, dtype=torch.float32)
    action = torch.argmax(Q(state)).item()
    next_state, reward, done, info = env.step(action)  # take a random action
    state = next_state

    if done:
        episode = episode + 1
        state = env.reset()
env.close()     




episode: 0, reward: 9.0, avg: 9.0
episode: 1, reward: 10.0, avg: 9.5
episode: 2, reward: 8.0, avg: 9.0
episode: 3, reward: 10.0, avg: 9.2
episode: 4, reward: 9.0, avg: 9.2
episode: 5, reward: 10.0, avg: 9.3
episode: 6, reward: 9.0, avg: 9.3
episode: 7, reward: 10.0, avg: 9.4
episode: 8, reward: 9.0, avg: 9.3
episode: 9, reward: 10.0, avg: 9.4
episode: 10, reward: 9.0, avg: 9.4
episode: 11, reward: 9.0, avg: 9.3
episode: 12, reward: 9.0, avg: 9.3
episode: 13, reward: 11.0, avg: 9.4
episode: 14, reward: 10.0, avg: 9.5
episode: 15, reward: 10.0, avg: 9.5
episode: 16, reward: 10.0, avg: 9.5
episode: 17, reward: 9.0, avg: 9.5
episode: 18, reward: 9.0, avg: 9.5
episode: 19, reward: 10.0, avg: 9.5
episode: 20, reward: 9.0, avg: 9.5
episode: 21, reward: 8.0, avg: 9.4
episode: 22, reward: 10.0, avg: 9.4
episode: 23, reward: 11.0, avg: 9.5
episode: 24, reward: 9.0, avg: 9.5
episode: 25, reward: 10.0, avg: 9.5
episode: 26, reward: 10.0, avg: 9.5
episode: 27, reward: 10.0, avg: 9.5
episode: 28, re

episode: 229, reward: 9.0, avg: 9.4
episode: 230, reward: 10.0, avg: 9.4
episode: 231, reward: 10.0, avg: 9.4
episode: 232, reward: 10.0, avg: 9.4
episode: 233, reward: 9.0, avg: 9.4
episode: 234, reward: 11.0, avg: 9.4
