In [2]:
import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from collections import deque

from itertools import count
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class memory():

    def __init__(self, size):
        self.mem = deque(maxlen=size)

    def append_sample(self, state, action, reward, next_state, done):
        self.mem.append((state, action, reward, next_state, done))

    def get_train_data(self,batch_size):
        return random.sample(self.mem, batch_size)

    def __len__(self):
        return len(self.mem)



In [3]:
class NN(nn.Module):

    def __init__(self, state_size, action_size):
        super(NN,self).__init__()
        self.fc1 = nn.Linear(state_size, 32)
        self.fc2 = nn.Linear(32,32)
        self.out = nn.Linear(32, action_size)
        nn.init.xavier_normal_(self.fc1.weight)
        nn.init.xavier_normal_(self.fc2.weight)
        nn.init.xavier_normal_(self.out.weight)



    def forward(self,x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.out(x)
        return x
        

In [4]:
action = 0
env = gym.make('CartPole-v1')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n


WARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.




In [5]:
class Agent():
    def __init__(self, mem_size, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        
        self.epsilon = 1
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.9
        self.batch_size = 64
        self.GAMMA =0.9
        
        self.worker_NN = NN(state_size, action_size).to(device)
        self.target_NN = NN(state_size, action_size).to(device)
        self.target_NN.load_state_dict(self.worker_NN.state_dict())
        
        self.optimizer = optim.Adam(self.worker_NN.parameters(),lr=0.001)
        self.loss_func = nn.MSELoss()
        self.mem = memory(mem_size)
    
    def get_action(self,state):
        if np.random.rand() <= self.epsilon:
            self.epsilon = self.epsilon * self.epsilon_decay
            return env.action_space.sample()
        else:
            state = torch.Tensor(state).to(device)
            a=self.worker_NN(state)
            out = torch.argmax(a)
            return  out.cpu().numpy()
    
    def train_model(self):
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay
        
        mini_batch = self.mem.get_train_data(self.batch_size)
        
        states = np.zeros((self.batch_size, self.state_size))
        next_states = np.zeros((self.batch_size, self.state_size))
        actions, rewards, dones = [],[],[]
        
        for i in range(self.batch_size):
            states[i] = mini_batch[i][0]
            actions.append(mini_batch[i][1])
            rewards.append(mini_batch[i][2])
            next_states[i] = mini_batch[i][3]
            dones.append(mini_batch[i][4])
        
        states = torch.Tensor(states).to(device)
        next_states = torch.Tensor(next_states).to(device)
        rewards = torch.Tensor(rewards).to(device)
        
        q_eval = self.worker_NN(states)
        q_eval = q_eval.max(1)[0]
        q_next = self.target_NN(next_states)
        
        q_target = rewards + (self.GAMMA * q_next.max(1)[0])
        
            
        loss = F.mse_loss(q_eval, q_target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        
    def weight_copy(self):
        self.target_NN.load_state_dict(self.worker_NN.state_dict())
            
    

In [6]:
agent = Agent(mem_size=2000, state_size=state_size, action_size =action_size)

In [None]:
for i_episode in range(1500):
    state = env.reset()
    score = 0
    for t in range(500):
        env.render()
        #action = env.action_space.sample()
        action = agent.get_action(state)
        next_state, reward, done, info = env.step(action)
        next_state = np.reshape(next_state, [1, state_size])

        reward = reward if not done or score > 499 else -100 #안끝났거나 score가 499보다 큰경우가 아니고서야 reward는 -100

        agent.mem.append_sample(state, action, reward, next_state, done)
        state=next_state
        score +=reward
        if len(agent.mem) > 1000:
            agent.train_model()

        if done:
            if t+1 > 100:
                print("{} : Episode finished after {} timesteps and ".format(i_episode, t+1))
            if (i_episode % 10) == 0:
                agent.weight_copy()
            break
env.close()


In [17]:
for i in range(50):
    print(agent.mem.get_train_data(1))

[(array([[-0.11311529, -1.1258027 ,  0.15989465,  1.83356267]]), array(0, dtype=int64), 1.0, array([[-0.13563135, -1.32229222,  0.1965659 ,  2.17134632]]), False)]




[(array([[-0.0370611 , -0.19936407,  0.0007396 ,  0.25806303]]), 1, 1.0, array([[-0.04104838, -0.00425268,  0.00590087, -0.03438653]]), False)]




[(array([ 0.01098016,  0.04884696, -0.02438574,  0.01569977]), array(0, dtype=int64), 1.0, array([[ 0.0119571 , -0.14591693, -0.02407175,  0.30059006]]), False)]




[(array([[-0.06404237, -1.31715866,  0.09880463,  2.07431604]]), array(0, dtype=int64), 1.0, array([[-0.09038555, -1.51313479,  0.14029095,  2.39584783]]), False)]




[(array([[-0.08036308, -0.9693422 ,  0.09075955,  1.52128121]]), array(1, dtype=int64), 1.0, array([[-0.09974993, -0.77542666,  0.12118518,  1.25825177]]), False)]




[(array([[-0.05334093, -0.57768114,  0.04857063,  0.90095015]]), 0, 1.0, array([[-0.06489456, -0.77342636,  0.06658963,  1.20849586]]), False)]




[(array([[-0.12833072, -0.79193004,  0.13571234,  1.30485414]]), 0, 1.0, array([[-0.14416932, -0.98848649,  0.16180943,  1.63675493]]), False)]




[(array([[-0.04211623, -0.60611181, -0.01723127,  0.84315718]]), array(0, dtype=int64), 1.0, array([[-5.42384701e-02, -8.00994401e-01, -3.68127785e-04,
         1.13037193e+00]]), False)]




[(array([[-0.08036308, -0.9693422 ,  0.09075955,  1.52128121]]), array(1, dtype=int64), 1.0, array([[-0.09974993, -0.77542666,  0.12118518,  1.25825177]]), False)]




[(array([[-0.11311529, -1.1258027 ,  0.15989465,  1.83356267]]), array(0, dtype=int64), 1.0, array([[-0.13563135, -1.32229222,  0.1965659 ,  2.17134632]]), False)]




[(array([[-0.13469592, -1.16840001,  0.17807624,  1.92081419]]), array(0, dtype=int64), -100, array([[-0.15806392, -1.36493335,  0.21649252,  2.26303047]]), True)]




[(array([[ 4.67285289e-03, -3.82343143e-01, -4.44854855e-04,
         5.95267921e-01]]), array(0, dtype=int64), 1.0, array([[-0.00297401, -0.57745887,  0.0114605 ,  0.88781069]]), False)]




[(array([[-0.01452319, -0.77273447,  0.02921672,  1.18407418]]), array(1, dtype=int64), 1.0, array([[-0.02997788, -0.5780035 ,  0.0528982 ,  0.90069087]]), False)]




[(array([[-0.10844027, -0.39906412,  0.10305538,  0.65478757]]), array(0, dtype=int64), 1.0, array([[-0.11642155, -0.59545849,  0.11615113,  0.97806084]]), False)]




[(array([[-0.01279376, -0.938636  ,  0.0979811 ,  1.49870919]]), array(0, dtype=int64), 1.0, array([[-0.03156648, -1.13480244,  0.12795529,  1.82030791]]), False)]




[(array([[-0.09974993, -0.77542666,  0.12118518,  1.25825177]]), array(0, dtype=int64), 1.0, array([[-0.11525846, -0.97187283,  0.14635021,  1.58630123]]), False)]




[(array([[-0.06410741, -1.01313957,  0.06784579,  1.43150902]]), array(0, dtype=int64), 1.0, array([[-0.0843702 , -1.20903013,  0.09647597,  1.74460022]]), False)]




[(array([[-0.02916851, -0.39462968, -0.01034011,  0.55398572]]), 1, 1.0, array([[-0.0370611 , -0.19936407,  0.0007396 ,  0.25806303]]), False)]




[(array([[-0.10844027, -0.39906412,  0.10305538,  0.65478757]]), array(0, dtype=int64), 1.0, array([[-0.11642155, -0.59545849,  0.11615113,  0.97806084]]), False)]




[(array([[-0.0370611 , -0.19936407,  0.0007396 ,  0.25806303]]), 1, 1.0, array([[-0.04104838, -0.00425268,  0.00590087, -0.03438653]]), False)]




[(array([[-0.05334093, -0.57768114,  0.04857063,  0.90095015]]), 0, 1.0, array([[-0.06489456, -0.77342636,  0.06658963,  1.20849586]]), False)]




[(array([[-0.01704529, -0.15059256, -0.03252903,  0.2578068 ]]), array(0, dtype=int64), 1.0, array([[-0.02005714, -0.34523537, -0.02737289,  0.54005471]]), False)]




[(array([-0.00064894, -0.00067757, -0.04410509, -0.02751637]), array(0, dtype=int64), 1.0, array([[-0.00066249, -0.19514018, -0.04465542,  0.25093119]]), False)]




[(array([[-0.12833072, -0.79193004,  0.13571234,  1.30485414]]), 0, 1.0, array([[-0.14416932, -0.98848649,  0.16180943,  1.63675493]]), False)]




[(array([[-0.14175393, -1.58305711,  0.12575804,  2.34828353]]), array(0, dtype=int64), 1.0, array([[-0.17341507, -1.77906324,  0.17272371,  2.6768464 ]]), False)]




[(array([[-0.14252298, -0.93337063,  0.15943937,  1.49453295]]), array(0, dtype=int64), 1.0, array([[-0.16119039, -1.13003194,  0.18933003,  1.83246018]]), False)]




[(array([[-0.06410741, -1.01313957,  0.06784579,  1.43150902]]), array(0, dtype=int64), 1.0, array([[-0.0843702 , -1.20903013,  0.09647597,  1.74460022]]), False)]




[(array([[-0.04104838, -0.00425268,  0.00590087, -0.03438653]]), array(0, dtype=int64), 1.0, array([[-0.04113344, -0.19945876,  0.00521313,  0.26015233]]), False)]




[(array([[-0.04920447, -0.3949418 ,  0.01509993,  0.5608266 ]]), array(0, dtype=int64), 1.0, array([[-0.05710331, -0.59027238,  0.02631646,  0.85822828]]), False)]




[(array([[-0.09018059, -1.19150132,  0.0506981 ,  1.72248941]]), array(0, dtype=int64), 1.0, array([[-0.11401062, -1.38716566,  0.08514789,  2.03050766]]), False)]




[(array([[-5.42384701e-02, -8.00994401e-01, -3.68127785e-04,
         1.13037193e+00]]), array(0, dtype=int64), 1.0, array([[-0.07025836, -0.99611153,  0.02223931,  1.42293937]]), False)]




[(array([[-0.12697034, -1.16895123,  0.20592531,  1.93480025]]), array(1, dtype=int64), -100, array([[-0.15034937, -0.97654162,  0.24462132,  1.71239869]]), True)]




[(array([[-0.02309778, -0.92585185,  0.03436635,  1.45934805]]), array(0, dtype=int64), 1.0, array([[-0.04161481, -1.12137799,  0.06355331,  1.76256581]]), False)]




[(array([[-0.0045653 , -0.38959694, -0.03963679,  0.52920123]]), array(0, dtype=int64), 1.0, array([[-0.01235724, -0.58413952, -0.02905277,  0.80913572]]), False)]




[(array([[-0.12833072, -0.79193004,  0.13571234,  1.30485414]]), 0, 1.0, array([[-0.14416932, -0.98848649,  0.16180943,  1.63675493]]), False)]




[(array([[-0.09018059, -1.19150132,  0.0506981 ,  1.72248941]]), array(0, dtype=int64), 1.0, array([[-0.11401062, -1.38716566,  0.08514789,  2.03050766]]), False)]




[(array([[ 4.67285289e-03, -3.82343143e-01, -4.44854855e-04,
         5.95267921e-01]]), array(0, dtype=int64), 1.0, array([[-0.00297401, -0.57745887,  0.0114605 ,  0.88781069]]), False)]




[(array([[-3.77610907e-02, -7.34853454e-01, -9.20245434e-05,
         1.11141376e+00]]), array(0, dtype=int64), 1.0, array([[-0.05245816, -0.9299742 ,  0.02213625,  1.40406782]]), False)]