# 카트폴 게임 마스터하기

In [3]:
import gym
import random
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque

### 하이퍼파라미터



In [4]:
# 하이퍼파라미터
EPISODES = 50    # 애피소드 반복횟수
EPS_START = 0.9  # 학습 시작시 에이전트가 무작위로 행동할 확률
EPS_END = 0.05   # 학습 막바지에 에이전트가 무작위로 행동할 확률
EPS_DECAY = 200  # 학습 진행시 에이전트가 무작위로 행동할 확률을 감소시키는 값
GAMMA = 0.8      # 할인계수
LR = 0.001       # 학습률
BATCH_SIZE = 64  # 배치 크기

## DQN 에이전트

In [5]:
class DQNAgent:
    def __init__(self):
        self.model = nn.Sequential(
            nn.Linear(4, 256),
            nn.ReLU(),
            nn.Linear(256, 2)
        )
        self.optimizer = optim.Adam(self.model.parameters(), LR)
        self.steps_done = 0
        self.memory = deque(maxlen=10000)

    def memorize(self, state, action, reward, next_state):
        self.memory.append((state,
                            action,
                            torch.FloatTensor([reward]),
                            torch.FloatTensor([next_state])))
    
    def act(self, state):
        eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * self.steps_done / EPS_DECAY)
        self.steps_done += 1
        if random.random() > eps_threshold:
            return self.model(state).data.max(1)[1].view(1, 1)
        else:
            return torch.LongTensor([[random.randrange(2)]])
    
    def learn(self):
        """Experience Replay"""
        if len(self.memory) < BATCH_SIZE:
            return
        batch = random.sample(self.memory, BATCH_SIZE)
        states, actions, rewards, next_states = zip(*batch)

        states = torch.cat(states)
        actions = torch.cat(actions)
        rewards = torch.cat(rewards)
        next_states = torch.cat(next_states)

        print()
        print("=============")
        print("states:")
        print(states)
        print(states.size())
        
        model_output = self.model(states)
        print("model_output")
        print(model_output)
        print(model_output.size())
        
        print("actions")
        print(actions)
        print(actions.size())
        
        current_q = model_output.gather(1, actions)
        
        print("current_q")
        print(current_q)
        print(current_q.size())
        
        max_next_q = self.model(next_states).detach().max(1)[0]
        expected_q = rewards + (GAMMA * max_next_q)
        
        loss = F.mse_loss(current_q.squeeze(), expected_q)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

## 학습 준비하기

`gym`을 이용하여 `CartPole-v0`환경을 준비하고 앞서 만들어둔 DQNAgent를 agent로 인스턴스화 합니다.

자, 이제 `agent` 객체를 이용하여 `CartPole-v0` 환경과 상호작용을 통해 게임을 배우도록 하겠습니다.

In [6]:
env = gym.make('CartPole-v0')
agent = DQNAgent()
score_history = []

## 학습 시작

In [7]:
for e in range(1, EPISODES+1):
    state = env.reset()
    steps = 0
    while True:
        env.render()
        state = torch.FloatTensor([state])
        action = agent.act(state)
        next_state, reward, done, _ = env.step(action.item())

        # negative reward when attempt ends
        if done:
            reward = -1

        agent.memorize(state, action, reward, next_state)
        agent.learn()

        state = next_state
        steps += 1

        if done:
            print("에피소드:{0} 점수: {1}".format(e, steps))
            score_history.append(steps)
            break

에피소드:1 점수: 16
에피소드:2 점수: 14
에피소드:3 점수: 16
에피소드:4 점수: 12

states:
tensor([[ 3.8397e-02, -2.8788e-02,  1.3879e-02,  1.1811e-02],
        [-4.4595e-02, -6.1357e-01,  1.3354e-02,  8.4660e-01],
        [-7.1810e-02, -6.0002e-01,  7.2500e-02,  8.7961e-01],
        [-2.3952e-02, -6.1374e-01, -1.4719e-02,  8.5047e-01],
        [ 1.6865e-02, -3.7205e-01,  3.1165e-02,  6.2646e-01],
        [-6.6591e-02,  6.7299e-03,  1.7847e-01,  2.9823e-01],
        [ 3.7821e-02,  1.6613e-01,  1.4115e-02, -2.7646e-01],
        [-7.0230e-02, -1.9558e-01,  2.0551e-01,  7.5988e-01],
        [ 1.6443e-02, -8.0449e-01,  1.9730e-03,  1.1100e+00],
        [ 9.7426e-03, -6.1492e-01,  5.2397e-02,  9.0760e-01],
        [-1.7194e-02, -5.6901e-01,  8.6967e-02,  9.6234e-01],
        [-3.6226e-02, -4.1842e-01,  2.2901e-03,  5.5319e-01],
        [-6.0691e-02, -8.0812e-01,  1.0840e-01,  1.1957e+00],
        [-5.5921e-02, -7.9447e-01,  4.9372e-02,  1.1564e+00],
        [ 4.6793e-02, -2.6456e-02, -3.5080e-02, -6.6964e-03],
     

tensor([[-8.3810e-02, -4.0596e-01,  9.0092e-02,  6.1057e-01],
        [ 4.1334e-02, -2.2014e-01, -3.0296e-02,  2.5430e-01],
        [-7.0265e-02,  1.7096e-03,  1.9727e-01,  4.1206e-01],
        [-5.5309e-02, -3.7842e-01,  1.5230e-01,  7.7485e-01],
        [-1.3985e-01, -8.0189e-01,  1.7760e-01,  1.3374e+00],
        [-4.4595e-02, -6.1357e-01,  1.3354e-02,  8.4660e-01],
        [-7.6854e-02, -1.0045e+00,  1.3231e-01,  1.5203e+00],
        [-9.6943e-02, -1.2009e+00,  1.6272e-01,  1.8511e+00],
        [-1.1085e-02, -2.2433e-01, -3.1713e-02,  2.8359e-01],
        [ 4.8364e-02,  1.6581e-01, -2.7069e-03, -2.6928e-01],
        [-3.5872e-02, -4.0351e-01,  2.1208e-02,  5.5446e-01],
        [ 2.8635e-02, -6.0957e-01, -1.4465e-02,  8.2191e-01],
        [-6.6591e-02,  6.7299e-03,  1.7847e-01,  2.9823e-01],
        [-1.2096e-01, -1.3974e+00,  1.9974e-01,  2.1896e+00],
        [-3.2332e-02, -8.0606e-01,  6.7954e-02,  1.1465e+00],
        [-6.7542e-02, -8.1598e-01,  1.7102e-01,  1.3472e+00],
        

tensor([[0.8932],
        [1.5801],
        [1.4630],
        [1.1028],
        [1.3740],
        [1.5888],
        [0.9677],
        [0.8950],
        [0.7752],
        [0.9752],
        [0.9477],
        [0.9457],
        [1.0875],
        [1.3674],
        [1.5881],
        [0.9617],
        [0.8964],
        [1.8905],
        [1.1605],
        [0.9607],
        [1.1714],
        [2.2082],
        [1.7068],
        [1.3461],
        [1.8695],
        [1.0634],
        [1.6549],
        [0.8910],
        [0.7529],
        [1.3508],
        [1.0963],
        [1.4737],
        [0.9887],
        [1.2987],
        [1.0037],
        [1.1116],
        [1.5777],
        [0.7556],
        [1.7706],
        [0.9369],
        [1.3191],
        [1.3091],
        [1.3244],
        [0.9727],
        [1.3010],
        [1.3499],
        [1.4249],
        [1.2763],
        [1.3528],
        [1.3506],
        [0.8846],
        [1.1321],
        [1.8382],
        [2.0285],
        [0.8895],
        [1


states:
tensor([[ 4.1843e-02, -2.5452e-02, -2.9719e-02, -2.8863e-02],
        [ 4.1143e-02,  3.6105e-01,  8.5862e-03, -5.6466e-01],
        [ 3.8224e-02, -6.1439e-01,  1.0670e-02,  8.9517e-01],
        [-8.3810e-02, -4.0596e-01,  9.0092e-02,  6.1057e-01],
        [-1.1835e-02, -4.1461e-01,  4.0532e-02,  5.3296e-01],
        [ 1.4526e-02, -2.2259e-01, -4.2436e-02,  2.8622e-01],
        [-3.6226e-02, -4.1842e-01,  2.2901e-03,  5.5319e-01],
        [-4.4595e-02, -6.1357e-01,  1.3354e-02,  8.4660e-01],
        [-2.5475e-02, -1.9412e-01,  3.1507e-02,  3.7158e-01],
        [ 2.0397e-02, -1.7659e-01,  2.4642e-02,  3.2611e-01],
        [ 1.7322e-03, -2.2147e-01, -2.5407e-02,  2.6120e-01],
        [-1.0397e-01, -7.9856e-01,  1.2091e-01,  1.2532e+00],
        [-8.4898e-02, -8.0845e-01,  9.0515e-02,  1.1784e+00],
        [-6.8444e-02, -5.8846e-01,  1.0907e-01,  1.0525e+00],
        [-3.5971e-02, -4.1583e-01,  2.2956e-02,  5.3719e-01],
        [ 2.5936e-02, -8.0966e-01,  2.8574e-02,  1.1912e+00],

tensor([[-4.7358e-02, -1.0092e+00,  1.3915e-01,  1.5935e+00],
        [ 2.8635e-02, -6.0957e-01, -1.4465e-02,  8.2191e-01],
        [-6.2877e-02, -1.8568e-01,  1.6780e-01,  5.3370e-01],
        [-1.2116e-01, -1.2009e+00,  1.4405e-01,  1.8241e+00],
        [-4.4288e-02, -6.1127e-01,  3.3700e-02,  8.3702e-01],
        [-1.6130e-01, -1.0091e+00,  1.8322e-01,  1.5714e+00],
        [-1.0397e-01, -7.9856e-01,  1.2091e-01,  1.2532e+00],
        [ 4.1334e-02, -2.2014e-01, -3.0296e-02,  2.5430e-01],
        [-2.9492e-02,  2.4044e-03,  3.3836e-02,  4.7914e-02],
        [ 3.8397e-02, -2.8788e-02,  1.3879e-02,  1.1811e-02],
        [ 5.1681e-02, -2.9275e-02, -8.0926e-03,  2.2545e-02],
        [-9.6943e-02, -1.2009e+00,  1.6272e-01,  1.8511e+00],
        [-6.0691e-02, -8.0812e-01,  1.0840e-01,  1.1957e+00],
        [-8.0214e-02, -7.8485e-01,  1.3011e-01,  1.3773e+00],
        [-3.1453e-02, -1.2799e-02,  1.6897e-02, -4.1216e-02],
        [-7.0230e-02, -1.9558e-01,  2.0551e-01,  7.5988e-01],
        


states:
tensor([[ 4.6264e-02, -2.2106e-01, -3.5214e-02,  2.7472e-01],
        [-1.5714e-02, -6.2009e-01,  1.2338e-01,  1.0539e+00],
        [-3.5971e-02, -4.1583e-01,  2.2956e-02,  5.3719e-01],
        [-9.3131e-02, -1.2001e+00,  8.2064e-02,  1.7543e+00],
        [-2.8300e-02, -1.6708e-01,  1.7764e-02,  3.0602e-01],
        [ 3.6931e-02, -4.1481e-01, -2.5211e-02,  5.3727e-01],
        [-6.6591e-02,  6.7299e-03,  1.7847e-01,  2.9823e-01],
        [-1.0491e-02, -2.9681e-02, -3.1735e-02,  1.0831e-03],
        [-3.1709e-02, -2.0816e-01,  1.6073e-02,  2.5675e-01],
        [ 4.1454e-02, -2.6366e-02,  3.6392e-02, -1.5212e-02],
        [ 9.7426e-03, -6.1492e-01,  5.2397e-02,  9.0760e-01],
        [-1.6130e-01, -1.0091e+00,  1.8322e-01,  1.5714e+00],
        [-1.1085e-02, -2.2433e-01, -3.1713e-02,  2.8359e-01],
        [-4.8454e-02, -6.1189e-01,  9.0883e-02,  8.7584e-01],
        [ 4.0927e-02,  1.6822e-01,  3.6088e-02, -2.9619e-01],
        [-1.2116e-01, -1.2009e+00,  1.4405e-01,  1.8241e+00],

tensor([[ 4.1334e-02, -2.2014e-01, -3.0296e-02,  2.5430e-01],
        [-1.2825e-01, -8.1035e-01,  1.8242e-01,  1.4571e+00],
        [ 2.1324e-02, -2.3685e-01, -2.5680e-02,  2.4697e-01],
        [-2.8866e-02,  2.8295e-02,  1.7608e-02,  7.8310e-03],
        [-7.2410e-03, -4.2364e-01,  1.0880e-01,  7.2905e-01],
        [-2.9492e-02,  2.4044e-03,  3.3836e-02,  4.7914e-02],
        [-5.2795e-02, -7.8245e-01,  8.2702e-02,  1.3182e+00],
        [-4.1064e-02, -5.8655e-01,  6.2571e-02,  1.0065e+00],
        [-4.8454e-02, -6.1189e-01,  9.0883e-02,  8.7584e-01],
        [-9.4800e-02, -1.1469e+00,  1.2921e-01,  1.8745e+00],
        [-2.0127e-02, -6.1028e-01,  5.1191e-02,  8.3813e-01],
        [-4.9509e-02, -5.5897e-01,  5.4813e-02,  9.2870e-01],
        [-4.7358e-02, -1.0092e+00,  1.3915e-01,  1.5935e+00],
        [ 2.2166e-02, -4.2092e-02, -2.4925e-02, -3.7745e-02],
        [ 2.0019e-02,  1.8874e-02,  2.4124e-02,  2.5916e-02],
        [ 4.0927e-02,  1.6822e-01,  3.6088e-02, -2.9619e-01],
        

tensor([[ 4.1454e-02, -2.6366e-02,  3.6392e-02, -1.5212e-02],
        [-7.8123e-02, -1.2137e+00,  1.0733e-01,  1.7440e+00],
        [-9.5272e-02, -1.0013e+00,  9.9721e-03,  1.2118e+00],
        [-4.5728e-03, -8.2140e-01,  6.2385e-03,  1.1070e+00],
        [-2.8574e-02, -7.6518e-01,  1.0621e-01,  1.2810e+00],
        [-1.1085e-02, -2.2433e-01, -3.1713e-02,  2.8359e-01],
        [ 4.6793e-02, -2.6456e-02, -3.5080e-02, -6.6964e-03],
        [-1.0325e-01, -2.2000e-01,  1.3531e-01,  4.5229e-01],
        [ 2.2166e-02, -4.2092e-02, -2.4925e-02, -3.7745e-02],
        [-2.8300e-02, -1.6708e-01,  1.7764e-02,  3.0602e-01],
        [ 3.6931e-02, -4.1481e-01, -2.5211e-02,  5.3727e-01],
        [-9.5911e-02, -9.8134e-01,  1.5766e-01,  1.7077e+00],
        [-7.9147e-02, -8.0628e-01, -8.4642e-03,  9.2182e-01],
        [-6.4712e-02, -1.2098e+00,  2.0645e-01,  2.0577e+00],
        [-1.5438e-02, -4.1581e-01, -4.3320e-03,  5.3654e-01],
        [-2.2296e-02, -3.4795e-02, -5.6015e-02, -5.0613e-02],
        

tensor([[ 1.6230e-02, -4.2855e-01,  5.1071e-02,  5.9323e-01],
        [ 4.1143e-02,  3.6105e-01,  8.5862e-03, -5.6466e-01],
        [ 1.3536e-02, -4.2138e-01,  7.5407e-02,  6.7708e-01],
        [-6.8444e-02, -5.8846e-01,  1.0907e-01,  1.0525e+00],
        [-4.1367e-03, -1.6216e-01,  3.9948e-02,  3.1622e-01],
        [-3.3249e-02, -3.9079e-01,  4.8593e-02,  6.9894e-01],
        [-1.5588e-01, -6.0939e-01,  2.0434e-01,  1.1051e+00],
        [-8.3810e-02, -4.0596e-01,  9.0092e-02,  6.1057e-01],
        [-7.5784e-02, -9.5077e-01,  9.8149e-02,  1.5528e+00],
        [-4.1064e-02, -5.8655e-01,  6.2571e-02,  1.0065e+00],
        [-2.5558e-03, -8.1071e-01,  7.0549e-02,  1.2163e+00],
        [ 3.8224e-02, -6.1439e-01,  1.0670e-02,  8.9517e-01],
        [-1.3059e-01, -1.2164e+00,  1.8357e-01,  1.8225e+00],
        [-4.3877e-02, -5.7156e-01,  1.3183e-01,  1.0234e+00],
        [-1.5702e-02, -3.5735e-01,  2.3585e-02,  6.3693e-01],
        [-4.5728e-03, -8.2140e-01,  6.2385e-03,  1.1070e+00],
        

tensor([[10.6915,  8.9379],
        [ 8.5971,  7.1886],
        [ 6.1280,  5.2283],
        [ 8.4101,  7.0040],
        [12.2576, 10.2902],
        [ 9.2044,  7.6871],
        [ 6.4995,  5.4453],
        [ 5.9690,  5.1022],
        [ 7.2625,  6.0335],
        [ 8.1635,  6.8016],
        [11.8004,  9.8760],
        [ 9.3059,  7.7675],
        [ 6.0657,  5.2288],
        [ 6.4106,  5.3610],
        [ 9.7606,  8.1477],
        [ 7.3049,  6.0620],
        [ 5.8343,  5.1590],
        [ 7.5298,  6.2772],
        [ 9.0471,  7.5338],
        [ 6.4792,  5.4249],
        [ 7.3384,  6.1112],
        [10.0443,  8.3981],
        [ 7.3301,  6.0922],
        [ 7.3622,  6.1162],
        [ 6.8861,  5.7511],
        [ 6.0138,  5.1755],
        [ 8.0687,  6.7097],
        [ 5.8648,  5.0628],
        [ 7.6398,  6.3714],
        [ 8.7043,  7.2656],
        [ 7.1135,  5.9544],
        [ 6.4716,  5.4019],
        [ 7.6902,  6.4179],
        [ 6.5395,  5.4857],
        [ 9.2789,  7.7656],
        [ 7.1750,  5

tensor([[-1.5438e-02, -4.1581e-01, -4.3320e-03,  5.3654e-01],
        [-4.5464e-02,  1.7780e-01,  4.6511e-02, -3.0395e-01],
        [-2.5339e-02, -5.5763e-01,  6.3445e-02,  9.5897e-01],
        [ 7.9555e-03, -6.2642e-01, -1.0112e-02,  8.1752e-01],
        [ 2.1324e-02, -2.3685e-01, -2.5680e-02,  2.4697e-01],
        [-6.8444e-02, -5.8846e-01,  1.0907e-01,  1.0525e+00],
        [-1.8770e-02, -6.1657e-01,  9.4875e-02,  9.4651e-01],
        [-5.5921e-02, -7.9447e-01,  4.9372e-02,  1.1564e+00],
        [-1.9771e-02, -3.3184e-02, -6.4235e-02, -8.6259e-02],
        [-2.9346e-02, -1.9511e-01,  4.0718e-02,  3.9370e-01],
        [-4.8454e-02, -6.1189e-01,  9.0883e-02,  8.7584e-01],
        [-1.7804e-02, -3.5929e-01,  6.5566e-02,  6.5407e-01],
        [-1.8103e-02, -3.6184e-01,  5.0428e-02,  6.5085e-01],
        [ 4.1143e-02,  3.6105e-01,  8.5862e-03, -5.6466e-01],
        [ 9.4238e-03, -5.6759e-01,  4.3694e-02,  9.2880e-01],
        [-2.9357e-02,  5.4387e-04,  3.8939e-02,  8.8994e-02],
        

model_output
tensor([[4.2965, 4.5779],
        [4.7338, 4.7008],
        [5.3524, 5.7138],
        [4.9440, 4.8463],
        [4.3173, 4.4756],
        [5.4692, 5.8432],
        [4.6946, 4.6986],
        [4.9732, 4.8876],
        [5.1737, 5.3233],
        [5.0204, 4.9503],
        [5.0923, 5.0210],
        [4.2698, 4.5020],
        [4.5818, 4.6087],
        [5.4331, 5.8042],
        [4.3143, 4.4355],
        [4.9952, 4.8949],
        [4.9803, 4.9408],
        [5.2256, 5.3532],
        [4.4024, 4.5197],
        [4.8554, 4.7820],
        [4.9270, 4.8266],
        [4.3261, 4.6977],
        [5.1393, 5.2723],
        [4.6931, 4.7163],
        [4.2487, 4.4339],
        [4.8222, 4.7537],
        [4.4576, 4.5177],
        [4.7890, 4.7537],
        [4.3237, 4.4710],
        [5.0872, 5.0486],
        [4.2684, 4.6817],
        [4.6739, 4.6935],
        [5.0990, 5.0396],
        [4.3988, 4.6498],
        [4.9735, 4.8844],
        [5.0595, 4.9841],
        [4.9256, 4.8670],
        [4.8230, 4.7826],

tensor([[ 3.9284e-02, -4.1848e-01,  3.6511e-02,  6.1178e-01],
        [ 1.5947e-03, -1.9246e-01, -4.0182e-02,  2.6086e-01],
        [ 1.1715e-02, -8.0953e-01,  6.4320e-02,  1.2447e+00],
        [ 4.6264e-02, -2.2106e-01, -3.5214e-02,  2.7472e-01],
        [ 9.4238e-03, -5.6759e-01,  4.3694e-02,  9.2880e-01],
        [-1.5438e-02, -4.1581e-01, -4.3320e-03,  5.3654e-01],
        [-2.2296e-02, -3.4795e-02, -5.6015e-02, -5.0613e-02],
        [-1.9102e-01, -1.3939e+00,  1.3135e-01,  1.8611e+00],
        [ 1.2163e-02, -5.9909e-01, -2.4583e-02,  7.7748e-01],
        [ 2.0253e-02, -4.0447e-01, -3.4500e-02,  4.9586e-01],
        [-1.2462e-02, -1.6199e-01,  1.6805e-02,  3.3899e-01],
        [-2.7668e-02, -4.0367e-01,  2.7550e-02,  4.7801e-01],
        [ 4.8364e-02,  1.6581e-01, -2.7069e-03, -2.6928e-01],
        [-2.4700e-02, -7.5408e-03, -6.3755e-03,  3.9610e-02],
        [-5.5608e-02, -7.6233e-01,  1.6760e-01,  1.5387e+00],
        [-2.5339e-02, -5.5763e-01,  6.3445e-02,  9.5897e-01],
        

KeyboardInterrupt: 

In [8]:
print(score_history)

[21, 28, 21, 51, 12, 20, 8, 9, 10, 12, 14, 11, 9, 9, 10, 26, 11, 9, 11, 25, 12, 19, 12, 27, 30, 66, 24, 45, 47, 35, 35, 40, 44, 34, 57, 52, 70, 124, 118, 33, 128, 55, 178, 88, 103, 101, 120, 140, 113, 85]


In [None]:
import matplotlib