In [None]:
# -*- coding: utf-8 -*-
# https://github.com/openai/gym/wiki/CartPole-v0
import tensorflow as tf
import gym
import numpy as np
import random
import matplotlib.pyplot as plt
from collections import deque

from skimage.transform import resize
from skimage.color import rgb2gray



# 하이퍼파라미터

max_episodes = 10000
discount_factor = 0.9
episode_list = []
train_error_list = []
actions_list = []
HEIGHT = 210
WIDTH = 160

# 테스트 에피소드 주기
TEST_PERIOD = 100

# src model에서 target model로 trainable variable copy 주기
COPY_PERIOD = 10



#네트워크 클래스 구성
class DQN:
    def __init__(self, session, height, width, output_size, name="main"):
        # 네트워크 정보 입력
        self.session = session
        self.height = HEIGHT
        self.width = WIDTH
        self.output_size = output_size
        self.net_name = name
        
        # 네트워크 생성
        self.build_network()

    def build_network(self):
        with tf.variable_scope(self.net_name):
            # Convolutional Neural Network (3 filter 2 Fc layer)
            self.X = tf.placeholder(shape=[None, self.height, self.width, 1], dtype=tf.float32)
            self.Y = tf.placeholder(shape=[None], dtype=tf.float32)

            W_conv1 = tf.Variable(tf.truncated_normal([6, 4, 1, 32], stddev =0.1))
            W_conv2 = tf.Variable(tf.truncated_normal([4, 4, 32, 64], stddev =0.1))
            W_conv3 = tf.Variable(tf.truncated_normal([5, 3, 64, 64], stddev =0.1))
            b_conv1 = tf.Variable(tf.constant(0.1, shape = [32]))
            b_conv2 = tf.Variable(tf.constant(0.1, shape = [64]))
            b_conv3 = tf.Variable(tf.constant(0.1, shape = [64]))
            
            W_fc1 = tf.Variable(tf.truncated_normal([11*9*64, 512], stddev= 0.1))
            b_fc1 = tf.Variable(tf.constant(0.1, shape = [512]))
            W_fc2 = tf.Variable(tf.truncated_normal([512, output_size], stddev =0.1))
            b_fc2 = tf.Variable(tf.constant(0.1, shape = [output_size]))
            
            h_conv1 = tf.nn.relu(tf.nn.conv2d(self.X, W_conv1, strides= [1,4,4,1], padding='VALID') + b_conv1)
            h_conv2 = tf.nn.relu(tf.nn.conv2d(h_conv1, W_conv2, strides = [1,2,2,1], padding ='VALID') + b_conv2)
            h_conv3 = tf.nn.relu(tf.nn.conv2d(h_conv2, W_conv3, strides = [1,2,2,1], padding ='VALID') + b_conv3)
            
            L1 = tf.reshape(h_conv3, [-1, W_fc1.get_shape().as_list()[0]])
            L2 = tf.nn.relu(tf.matmul(L1,W_fc1)+b_fc1)
            self.Qpred = tf.matmul(L2, W_fc2)+b_fc2
            
        # 손실 함수 및 최적화 함수
        self.action = tf.placeholder(shape=[None, self.output_size], dtype=tf.float32)
        Q_action = tf.reduce_sum(tf.multiply(self.Qpred, self.action), reduction_indices=1)
        self.loss = tf.reduce_mean(tf.square(self.Y - Q_action))
        self.optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(self.loss)

    # 예측한 Q값 구하기
    def predict(self, state):
#         x = np.reshape(state, newshape=[-1, 84, 84, 1])

        x = np.reshape(state, newshape=[-1, 210,160, 1])
        return self.session.run(self.Qpred, feed_dict={self.X: x})







def bot_play(DQN, env):
    """
    See our trained network in action
    """
    state = env.reset()
    state = np.sum(state, axis=2)
    reward_sum = 0
    done = False
    while not done:
#         env.render()
        action = np.argmax(DQN.predict(state))
        new_state, reward, done, info = env.step(action)
        reward_sum += reward
        state = np.sum(new_state, axis=2)

    return reward_sum



    
def restoreModel(session, path='./breakout.ckpt'):
    tf.train.Saver().restore(sess=session, save_path=path)
    print("Model restored successfully.")


if __name__ == "__main__":
    env = gym.make('Breakout-v0')
    height = HEIGHT 
    width = WIDTH
    output_size = env.action_space.n                # 6 'NOOP', 'FIRE', 'RIGHT', 'LEFT', 'RIGHTFIRE', 'LEFTFIRE'

    # 미니배치 - 꺼내서 사용할 리플레이 갯수
    BATCH_SIZE = 32

    with tf.Session() as sess:
        # DQN 클래스의 mainDQN 인스턴스 생성
        mainDQN = DQN(sess, height, width, output_size, name='main')
        restoreModel(sess, "./cartpole.ckpt")

        for episode in range(max_episodes):
            bot_play(mainDQN, env)

        env.reset()
        env.close()