In [None]:
import gym
import collections
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(100)

In [None]:
#Hyperparameters
learning_rate = 0.0005
gamma         = 0.98
buffer_limit  = 50000
batch_size    = 32

# Check env

In [None]:
env = gym.make('CartPole-v1')

In [None]:
observation = env.reset()
print(observation)

# Replay Buffer

In [None]:
class ReplayBuffer():
    def __init__(self):
        self.buffer = collections.deque(maxlen=buffer_limit)
    
    def put(self, experience):
        '''
        experience should be [1,5], (st, at, rt, st+1, done)
        done should be 1/0
        '''
        self.buffer.append(experience)
    
    def sample(self, n=.01):
        mini_batch = random.sample(self.buffer, n)
        s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst = [], [], [], [], []
        
        for transition in mini_batch:
            s, a, r, s_prime, done_mask = transition
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r])
            s_prime_lst.append(s_prime)
            done_mask_lst.append([done_mask])

        return torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
               torch.tensor(r_lst), torch.tensor(s_prime_lst, dtype=torch.float), \
               torch.tensor(done_mask_lst)
    
    def size(self):
        return len(self.buffer)

# QNetwork


In [None]:
Q = Qnet()
out = Q(torch.tensor(observation[0]))
out

In [None]:
Q.sample_action(torch.tensor(observation[0]), 0.1)

In [None]:
class Qnet(nn.Module):
    def __init__(self):
        super(Qnet, self).__init__()
        self.score_lst = []
        self.fc1 = nn.Linear(4, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 2)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def sample_action(self, obs, epsilon):
        out = self.forward(obs)
        coin = random.random()
        if coin < epsilon:
            return random.randint(0,1)
        else:
            return out.argmax().item()
            
def train(q, q_target, memory, optimizer):
    for i in range(10):
        s,a,r,s_prime,done_mask = memory.sample(batch_size)

        q_out = q(s)
        q_a = q_out.gather(1,a)
        max_q_prime = q_target(s_prime).max(1)[0].unsqueeze(1)
        target = r + gamma * max_q_prime * done_mask
        loss = F.smooth_l1_loss(q_a, target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [None]:
obs = env.reset()[0]
s = []
s.append(obs)
s.append(env.step(0)[0])
s

# Learning

In [None]:
# STEP 0: Initialize, (1)D, (2)Q
memory = ReplayBuffer()
Q = Qnet()

In [None]:
def dqn(env, q, memory, max_epi=1000, min_memory=2000):

    q_target = Qnet()
    q_target.load_state_dict(q.state_dict())

    print_interval = 20
    optimizer = optim.Adam(q.parameters(), lr=learning_rate)

    # iter for episodes
    for n_epi in range(max_epi):
        score = 0.0
        epsilon = max(0.01, 0.08 - 0.01*(n_epi/200)) #Linear annealing from 8% to 1%
        s, _ = env.reset()
        done = False
        
        # iter for steps, util terminate
        while not done:

            # STEP1: Sample (st,at,rt,st+1)
            a = q.sample_action(torch.FloatTensor(s), epsilon)      
            s_prime, r, done, _, _ = env.step(a)
            done_mask = 0.0 if done else 1.0
            memory.put((s,a,r,s_prime, done_mask))
            s = s_prime

            score += r
            if done:
                break
            
        q.score_lst.append(score)
        
        # STEP 2: update until larger than 2000 samples
        if memory.size()>min_memory:
            train(q, q_target, memory, optimizer)

        if n_epi%print_interval==0 and n_epi!=0:
            q_target.load_state_dict(q.state_dict())
            print("n_episode :{}, score : {:.1f}, n_buffer : {}, eps : {:.1f}%".format(
                                                            n_epi, score, memory.size(), epsilon*100))
            
    return q

In [None]:
dqn(env, Q, memory)

In [None]:
import matplotlib.pyplot as plt
plt.plot(Q.score_lst)