In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import gym

In [2]:
from JSAnimation.IPython_display import display_animation
from matplotlib import animation
from IPython.display import display

def display_frames_as_gif(frames):
    '''Displays a list of frames as a gif, with controls'''
    plt.figure(figsize=(frames[0].shape[1]/72.0, frames[0].shape[0]/72.0), dpi=72)
    patch = plt.imshow(frames[0])
    plt.axis('off')
    
    def animate(i):
        patch.set_data(frames[i])
        
    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50)
    
    anim.save('movie_cartpole_DQN.mp4')
    display(display_animation(anim, default_mode='loop'))

In [3]:
# using namedtuple
from collections import namedtuple

# ex
# Tr = namedtuple('tr', ('name_a', 'value_b'))
# Tr_object = Tr('nameA', 100)

# print(Tr_object)
# print(Tr_object.value_b)

Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

In [4]:
ENV = 'CartPole-v0'
GAMMA = 0.99 # 시간할인율
MAX_STEPS = 200 # 1에피소드 당 최대 단계 수
NUM_EPISODES = 500 # 최대 에피소드 수

In [5]:
# experience replay
class ReplayMemory:
    def __init__(self, CAPACITY):
        self.capacity = CAPACITY
        self.memory = []
        self.index = 0
        
    def push(self, state, action, state_next, reward):
        '''transition = (state, action, state_next, reward)을 메모리에 저장'''
        if len(self.memory) < self.capacity:
            self.memory.append(None)
            
        self.memory[self.index] = Transition(state, action, state_next, reward)
        self.index = (self.index + 1) % self.capacity
        
    def sample(self, batch_size):
        '''batch_size 개수만큼 무작위로 저장된 transition을 추출'''
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        '''len 함수로 현재 저장된 transition 개수를 반환'''
        return len(self.memory)

In [6]:
TD_ERROR_EPSILON = 0.0001

class TDerrorMemory:
    def __init__(self, CAPCITY):
        self.capacity = CAPACITY
        self.memory = []
        self.index = 0
    
    def push(self, td_error):
        '''save TD error to memory'''
        if len(self.memory) < self.capacity:
            self.memory.append(None)
            
        self.memory[self.index] = td_error
        self.index = (self.index + 1) % self.capacity
        
    def __len__(self):
        return len(self.memory)
    
    def get_prioritized_indexes(self, batch_size):
        '''TD 오차에 따른 확률로 인덱스를 추출'''
        sum_absolute_td_error = np.sum(np.absolute(self.memory))
        sum_absolute_td_error += TD_ERROR_EPSILON * len(self.memory)
        
        rand_list = np.random.uniform(0, sum_absolute_td_error, batch_size)
        rand_list = np.sort(rand_list)
        
        indexes = []
        idx = 0
        tmp_sum_absolute_td_error = 0
        for rand_num in rand_list:
            while tmp_sum_absolute_td_error < rand_num:
                tmp_sum_absolute_td_error += (abs(self.memory[idx]) + TD_ERROR_EPSILON)
                idx += 1
                
            if idx >= len(self.memory):
                idx = len(self.memory) - 1
            indexes.append(idx)
            
        return indexes
            
    def update_td_error(self, updated_td_error):
        self.memory = updated_td_error

In [7]:
# NN
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self, n_in, n_mid, n_out):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(n_in, n_mid)
        self.fc2 = nn.Linear(n_mid, n_mid)
        self.fc3 = nn.Linear(n_mid, n_out)
        
    def forward(self, x):
        h1 = F.relu(self.fc1(x))
        h2 = F.relu(self.fc2(h1))
        output = self.fc3(h2)
        return output

In [8]:
import random
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F

BATCH_SIZE = 32
CAPACITY = 10000

class Brain:
    def __init__(self, num_states, num_actions):
        self.num_actions = num_actions # 행동 가짓수(2)
        
        # transition을 기억하기 위한 메모리 객체 생성
        self.memory = ReplayMemory(CAPACITY)
        
        # nn
        n_in, n_mid, n_out = num_states, 32, num_actions
        self.main_q_network = Net(n_in, n_mid, n_out)
        self.target_q_network = Net(n_in, n_mid, n_out)
        print(self.main_q_network)
        
        self.optimizer = optim.Adam(self.main_q_network.parameters(), lr=0.0001)
        
        # TD 오차를 기억하기 위한 메모리 객체 생성
        self.td_error_memory = TDerrorMemory(CAPACITY)
        
    def replay(self, episode):
        '''Experience Replay로 신경망의 결합 가중치 학습'''

        # 1. 저장된 transition의 수를 확인
        if len(self.memory)  < BATCH_SIZE:
            return
        
        # 2. 미니배치 생성
        self.batch, self.state_batch, self.action_batch, self.reward_batch, self.non_final_next_states = self.make_minibatch(episode)
        
        # 3. 정답신호로 사용할 Q(s_t, a_t)를 계산
        self.expected_state_action_values = self.get_expected_state_action_values()
        
        # 4. 결합 가중치 수정
        self.update_main_q_network()

    def decide_action(self, state, episode):
        '''현재 상태에 따라 행동을 결정한다'''
        # ε-greedy 알고리즘에서 서서히 최적행동의 비중을 늘린다
        epsilon = 0.5 * (1 / (episode + 1))

        if epsilon <= np.random.uniform(0, 1):
            self.main_q_network.eval()  # 신경망을 추론 모드로 전환
            with torch.no_grad():
                action = self.main_q_network(state).max(1)[1].view(1, 1)
            # 신경망 출력의 최댓값에 대한 인덱스 = max(1)[1]
            # .view(1,1)은 [torch.LongTensor of size 1] 을 size 1*1로 변환하는 역할을 한다

        else:
            # 행동을 무작위로 반환(0 혹은 1)
            action = torch.LongTensor(
                [[random.randrange(self.num_actions)]])  # 행동을 무작위로 반환(0 혹은 1)
            # action은 [torch.LongTensor of size 1*1] 형태가 된다

        return action
    
    def make_minibatch(self, episode):
        # 메모리 객체에서 미니배치 추출
        if episode < 30:
            transitions = self.memory.sample(BATCH_SIZE)
        else:
            indexes = self.td_error_memory.get_prioritized_indexes(BATCH_SIZE)
            transitions = [self.memory.memory[n] for n in indexes]
        
        batch = Transition(*zip(*transitions))
        
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
        non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
        
        return batch, state_batch, action_batch, reward_batch, non_final_next_states
    
    def get_expected_state_action_values(self):
        self.main_q_network.eval()
        self.target_q_network.eval()
        
        # 신경망으로 Q(s_t,a_t)를 계산
        # self.model(state_batch)은 왼쪽, 오른쪽에 대한 Q값을 출력하며 [torch.FloatTensor of size BATCH_SIZEx2] 형태
        # 여기서부터 실행한 행동 a_t에 대한 Q값을 계산하므로 action_batch에서 취한 행동 a_t가 왼쪽이냐 오른쪽이냐에 대한 인덱스를 구하고,
        # 이에 대한 Q값을 gather 메서드로 모아온다
        self.state_action_values = self.main_q_network(self.state_batch).gather(1, self.action_batch)
        
        non_final_mask = torch.ByteTensor(tuple(map(lambda s: s is not None, self.batch.next_state)))
        
        next_state_values = torch.zeros(BATCH_SIZE)
        a_m = torch.zeros(BATCH_SIZE).type(torch.LongTensor)
        
        a_m[non_final_mask] = self.main_q_network(self.non_final_next_states).detach().max(1)[1]
        
        a_m_non_final_states = a_m[non_final_mask].view(-1, 1)
        
        next_state_values[non_final_mask] = self.target_q_network(self.non_final_next_states).gather(1, a_m_non_final_states).detach().squeeze()
        
        expected_state_action_values = self.reward_batch + GAMMA * next_state_values
        
        return expected_state_action_values
    
    
    def update_main_q_network(self):
        self.main_q_network.train()
        
        # smooth_l1_loss : Huber function
        loss = F.smooth_l1_loss(self.state_action_values, self.expected_state_action_values.unsqueeze(1))
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
    def update_target_q_network(self):
        '''Target Q-Network을 Main Q-Network와 맞춤'''
        self.target_q_network.load_state_dict(self.main_q_network.state_dict())
        
    def updtae_td_error_memory(self):
        '''TD 오차 메모리에 저장된 TD 오차를 업데이트'''
        self.main_q_network.eval()
        self.target_q_network.eval()
        
        transitions = self.memory.memory
        batch = Transition(*zip(*transitions))
        
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
        non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
        
        state_action_values = self.main_q_network(state_batch).gather(1, action_batch)
        non_final_mask = torch.ByteTensor(tuple(map(lambda s: s is not None, batch.next_state)))
        
        next_state_values = torch.zeros(len(self.memory))
        a_m = torch.zeros(len(self.memory)).type(torch.LongTensor)
        
        # 다음 상태에서 Q값이 최대가 되는 행동 a_m을 Main Q-Network로 계산
        # 마지막에 붙은 [1]로 행동에 해당하는 인덱스를 구함
        a_m[non_final_mask] = self.main_q_network(non_final_next_states).detach().max(1)[1]
        
        a_m_non_final_states = a_m[non_final_mask].view(-1, 1)
        
        next_state_values[non_final_mask] = self.target_q_network(non_final_next_states).gather(1, a_m_non_final_states).detach().squeeze()
        
        # state_action_values는 size[minibatch*1]이므로 squeeze()로 size[minibatch]로 변환
        td_errors = (reward_batch + GAMMA * next_state_values) - state_action_values.squeeze()
        
        # TD 오차 메모리를 업데이트. Tensor를 detach()로 꺼내와 NumPy 변수로 변환하고 다시 파이썬 리스트로 변환
        self.td_error_memory.memory = td_errors.detach().numpy().tolist()

In [9]:
class Agent:
    def __init__(self, num_states, num_actions):
        self.brain = Brain(num_states, num_actions)
        
    def update_q_function(self, episode):
        self.brain.replay(episode)
    
    def get_action(self, state, episode):
        action = self.brain.decide_action(state, episode)
        return action

    def memorize(self, state, action, state_next, reward):
        self.brain.memory.push(state, action, state_next, reward)
        
    def update_target_q_function(self):
        self.brain.update_target_q_network()
        
    def memorize_td_error(self, td_error):
        self.brain.td_error_memory.push(td_error)
        
    def update_td_error_memory(self):
        sefl.brain.update_td_error_memory()

In [10]:
class Environment:
    def __init__(self):
        self.env = gym.make(ENV)
        num_states = self.env.observation_space.shape[0]
        num_actions = self.env.action_space.n
        self.agent = Agent(num_states, num_actions)
        
    def run(self):
        episode_10_list = np.zeros(10)
        complete_episodes = 0
        episode_final = False
        frames = []
        
        for episode in range(NUM_EPISODES):
            observation = self.env.reset()
            
            state = observation
            state = torch.from_numpy(state).type(torch.FloatTensor)
            state = torch.unsqueeze(state, 0)
            
            for step in range(MAX_STEPS):
#                 if episode_final is True:
#                     frames.append(self.env.render(mode='rgb_array'))
                action = self.agent.get_action(state, episode)
                
                observation_next, _, done, _ = self.env.step(action.item())
                
                if done:
                    state_next = None
                    episode_10_list = np.hstack((episode_10_list[1:], step + 1))
                    
                    if step < 195:
                        reward = torch.FloatTensor([-1.0])
                        complete_episodes = 0
                    else:
                        reward = torch.FloatTensor([1.0])
                        complete_episodes = complete_episodes + 1
                        
                else:
                    reward = torch.FloatTensor([0.0])
                    state_next = observation_next
                    state_next = torch.from_numpy(state_next).type(torch.FloatTensor)
                    state_next = torch.unsqueeze(state_next, 0)
                    
                self.agent.memorize(state, action, state_next, reward)
                
                self.agent.memorize_td_error(0)
                
                self.agent.update_q_function(episode)
                
                state = state_next
                
                if done:
                    print('%d Episode : Finished after %d steps : Current avg of 10 episode\'s step = %.1lf' % (episode, step + 1, episode_10_list.mean()))
                    if (episode % 2 == 0):
                        self.agent.update_target_q_function()
                    break
                
                if episode_final is True:
#                     display_frames_as_gif(frames)
                    break
                
                if complete_episodes >= 10:
                    print("10 episode success")
                    episode_final = True

In [11]:
cartpole_env = Environment()
cartpole_env.run()

Net(
  (fc1): Linear(in_features=4, out_features=32, bias=True)
  (fc2): Linear(in_features=32, out_features=32, bias=True)
  (fc3): Linear(in_features=32, out_features=2, bias=True)
)
0 Episode : Finished after 8 steps : Current avg of 10 episode's step = 0.8
1 Episode : Finished after 10 steps : Current avg of 10 episode's step = 1.8
2 Episode : Finished after 8 steps : Current avg of 10 episode's step = 2.6
3 Episode : Finished after 8 steps : Current avg of 10 episode's step = 3.4
4 Episode : Finished after 9 steps : Current avg of 10 episode's step = 4.3
5 Episode : Finished after 10 steps : Current avg of 10 episode's step = 5.3
6 Episode : Finished after 9 steps : Current avg of 10 episode's step = 6.2
7 Episode : Finished after 13 steps : Current avg of 10 episode's step = 7.5
8 Episode : Finished after 9 steps : Current avg of 10 episode's step = 8.4
9 Episode : Finished after 16 steps : Current avg of 10 episode's step = 10.0
10 Episode : Finished after 15 steps : Current avg

TypeError: 'NoneType' object is not iterable