In [1]:
import gym
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import collections
import copy
import time
import random

print(torch.cuda.is_available())
device = 'cuda'

True


In [2]:
class ReplayBuffer():
    def __init__(self):
        self.buffer = collections.deque(maxlen = 5000)
        self.minibatch_size = 32

    def append(self, state, action, reward, terminal, next_state):
        self.buffer.append([state, action, reward, terminal, next_state])

    def sample(self):
        mini_batch = random.sample(self.buffer, self.minibatch_size)
        #mini_batch.append(self.buffer[-1])
        s_lst, a_lst, r_lst, terminal, s_prime_lst = map(torch.FloatTensor, zip(*mini_batch))
        return s_lst.to(device), np.array(a_lst), r_lst.to(device), terminal.to(device), s_prime_lst.to(device)
    
    def size(self):
        return len(self.buffer)

In [3]:
def init_weights(l):
    if type(l) == nn.Linear:
        nn.init.xavier_normal_(l.weight, gain=0.25)
        
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(4, 128).to(device)
        self.fc2 = nn.Linear(128, 128).to(device)
        self.fc3 = nn.Linear(128, 2).to(device)
        self.epsilon = 0.1
        self.gamma = 0.9
        self.step_size = 0.1

    def forward(self, x):
        x = F.relu(self.fc1(x.to(device)))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def sample_action(self, state):
        pi = self.forward(torch.from_numpy(state).float().to(device))
        if np.random.random() < self.epsilon:
            action = np.random.randint(2)
            return action
        else : 
            action = torch.argmax(pi).item()
            return action

In [4]:
class Agent():    
    def __init__(self):
        self.discount = 0.98
        self.last_state = None
        self.replay_buffer = ReplayBuffer()
        self.network = Net()
        self.targetNet = Net()
        self.targetNet.load_state_dict(self.network.state_dict())
        self.num_replay = 10
        self.optimizer = optim.Adam(self.network.parameters(), lr = 0.0005)
        self.env = gym.make('CartPole-v1')
        self.total_step = 0
        self.tau = 0.001
        
    def train(self, epi):
        self.last_state = self.env.reset()
        while True:
            self.env.render()
            self.total_step +=1
            action = self.network.sample_action(self.last_state)
            state, reward, done, info = self.env.step(action)
            self.replay_buffer.append(self.last_state, action, reward, done, state) 
            if self.replay_buffer.size() >= self.replay_buffer.minibatch_size:
                for _ in range(self.num_replay):
                    self.optimize_network(self.network, self.targetNet)
            if done:
                break
            self.last_state = state
        if(epi%10 == 0):
            #self.soft_update(self.network, self.targetNet, self.tau)
            self.targetNet.load_state_dict(self.network.state_dict())
            print('#episode : ',epi, 'avg_step : ', self.total_step/10)
            self.total_step = 0
            
            
    def optimize_network(self, network, targetNet):
        states, actions, rewards, terminals, next_states = self.replay_buffer.sample()
        
        v_next_vec = torch.max(targetNet.forward(next_states), dim = -1)[0]*(1-terminals)
        target_vec = rewards + self.discount*v_next_vec
        
        q_mat = network.forward(states)
        batch_indices = np.arange(self.replay_buffer.minibatch_size)
        q_vec = q_mat[batch_indices,actions]
        
        loss = F.smooth_l1_loss(q_vec,target_vec)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

model = Agent()

In [5]:
pre = time.time()
for epi in range(1000):
    model.train(epi)
post = time.time()
print(post-pre)
#cpu 32/50/64 226.5114507675171
#cpu 128/50/5000 812.0476837158203

#episode :  0 avg_step :  1.0
#episode :  10 avg_step :  10.4
#episode :  20 avg_step :  11.2
#episode :  30 avg_step :  11.1
#episode :  40 avg_step :  17.7
#episode :  50 avg_step :  24.3
#episode :  60 avg_step :  183.1
#episode :  70 avg_step :  308.3
#episode :  80 avg_step :  417.6
#episode :  90 avg_step :  379.2
#episode :  100 avg_step :  300.4
#episode :  110 avg_step :  239.0
#episode :  120 avg_step :  280.8
#episode :  130 avg_step :  252.8
#episode :  140 avg_step :  354.7
#episode :  150 avg_step :  412.4
#episode :  160 avg_step :  473.8
#episode :  170 avg_step :  259.8
#episode :  180 avg_step :  269.8
#episode :  190 avg_step :  305.1
#episode :  200 avg_step :  269.6
#episode :  210 avg_step :  212.8
#episode :  220 avg_step :  240.9
#episode :  230 avg_step :  209.1
#episode :  240 avg_step :  254.9
#episode :  250 avg_step :  249.8
#episode :  260 avg_step :  184.4


KeyboardInterrupt: 