In [1]:
import numpy as np 
import matplotlib.pyplot as plt 
%matplotlib inline 
import gym

In [3]:
from JSAnimation.IPython_display import display_animation
from matplotlib import animation
from IPython.display import display

def display_frames_as_gif(frames):
    '''
    Displays a list of frames as a gif, with controls
    '''
    plt.figure(figsize=(frames[0].shape[1]/72.0, frames[0].shape[0]/72.0), dpi=72)
    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])

    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval = 50)

    anim.save('movie_cartpole_DQN.gif')
    display(display_animation(anim, default_mode = 'loop'))

In [4]:
from collections import namedtuple

Tr = namedtuple('tr', ('name_a', 'value_b'))
Tr_object = Tr('이름 A', 100)

print(Tr_object)
print(Tr_object.value_b)

tr(name_a='이름 A', value_b=100)
100


In [5]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

In [6]:
ENV = 'CartPole-v0'
GAMMA = 0.99
MAX_STEPS = 200 #max steps per episode
NUM_EPISODES = 500 #max num of episodes

In [7]:
#memory class for saving transition

class ReplayMemory:

    def __init__(self, CAPACITY):
        self.capacity = CAPACITY #max capacity of memory
        self.memory = [] #variable for saving transition
        self.index = 0 #index for saving into memory

    def push(self, state, action, state_next, reward):
        '''save transition = (state, action, state_next, reward) into memory '''

        if len(self.memory) < self.capacity:
            self.memory.append(None)

        self.memory[self.index] = Transition(state, action, state_next, reward)

        self.index = (self.index + 1) % self.capacity #next index for saving

    def sample(self, batch_size):
        '''sample saved trainsitions of batch_size randomly'''
        return random.sample(self.memory, batch_size)

    def __len__(self):
        '''get length of saved transitions'''
        return len(self.memory)

In [8]:
# brain class for DQN execution

import random
import torch
from torch import nn 
from torch import optim 
import torch.nn.functional as F 

BATCH_SIZE = 32 
CAPACITY = 10000

class Brain: 
    def __init__(self, num_states, num_actions):
        self.num_actions = num_actions #num of actions

        self.memory = ReplayMemory(CAPACITY)

        #configure neural network
        self.model = nn.Sequential()
        self.model.add_module('fc1', nn.Linear(num_states, 32))
        self.model.add_module('relu1', nn.ReLU())
        self.model.add_module('fc2', nn.Linear(32, 32))
        self.model.add_module('relu2', nn.ReLU())
        self.model.add_module('fc3', nn.Linear(32, num_actions))

        print(self.model)

        self.optimizer = optim.Adam(self.model.parameters(), lr=0.0001)
    
    def replay(self):
        '''train nn with Experience Replay'''

        #------------------------------------
        # 1. check transition length
        #------------------------------------
        if len(self.memory) < BATCH_SIZE:
            return 

        #------------------------------------
        # 2. create minibatch
        #------------------------------------
        transitions = self.memory.sample(BATCH_SIZE)

        # make transition minibatch 
        # (state, action, state_next, reward) * BATCH_SIZE into 
        # (state * BATCH_SIZE, action * BATCH_SIZE, state_next * BATCH_SIZE, reward * BATCH_SIZE)
        batch = Transition(*zip(*transitions))

        print(batch) #check batch 

        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
        non_final_next_states = torch.cat([s for s in batch.next_state
                                            if s in not None])

        #-------------------------------------
        # 3. calculate Q(s_t, a_t)
        #-------------------------------------

        self.model.eval() # set nn to evaluation mode
        
        # calc Q(s_t, a_t)
        state_action_values = self.model(state_batch).gather(1, action_batch)

        # calc

Sequential(
  (fc1): Linear(in_features=4, out_features=32, bias=True)
  (relu1): ReLU()
  (fc2): Linear(in_features=32, out_features=32, bias=True)
  (relu2): ReLU()
  (fc3): Linear(in_features=32, out_features=100, bias=True)
)
