In [None]:
import tensorflow as tf

from model.resnet import ResNet
from model.resnet import data_augmentation

import numpy as np
import random

def load_cifar10():
    (train_data, train_labels), (test_data, test_labels) = tf.keras.datasets.cifar10.load_data()
    train_data, test_data = normalize(train_data, test_data)
    train_labels = train_labels.flatten()
    test_labels = test_labels.flatten()
    
    #train_labels = tf.keras.utils.to_categorical(train_labels, 10)
    #test_labels = tf.keras.utils.to_categorical(test_labels, 10)
    
    # 10000 개만 test 로 설정
    test_data = test_data[:10000]
    test_labels = test_labels[:10000]
    return train_data, train_labels, test_data, test_labels

def normalize(X_train, X_test):
    mean = np.mean(X_train, axis=(0, 1, 2, 3))
    std = np.std(X_train, axis=(0, 1, 2, 3))
    X_train = (X_train - mean) / std
    X_test = (X_test - mean) / std
    return X_train, X_test

epoch = 82
batchSize = 256
resN = 18
lr = 0.1

train_x, train_y, test_x, test_y = load_cifar10()
iteration = len(train_x) // batchSize

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
    resnet = ResNet(resN)
    loss, y_hat = resnet.build_model()
    train_op = tf.train.MomentumOptimizer(resnet.lr, momentum=0.9).minimize(loss)
    accuracy = tf.reduce_mean(tf.cast(tf.equal(y_hat, resnet.test_y), tf.float32))
    
    tf.global_variables_initializer().run()
    
    # loop for epoch
    for epoch in range(epoch):
        if epoch == int(epoch * 0.5) or epoch == int(epoch * 0.75) :
            lr = lr * 0.1
            
        # get batch data
        for idx in range(iteration):
            batch_x = train_x[idx*batchSize:(idx+1)*batchSize]
            batch_y = train_y[idx*batchSize:(idx+1)*batchSize]
            batch_x = data_augmentation(batch_x, 32)
            
            train_feed_dict = {
                resnet.train_x : batch_x,
                resnet.train_y : batch_y,
                resnet.lr : lr
            }
            
            test_feed_dict = {
                resnet.test_x : test_x,
                resnet.test_y : test_y
            }
            
            # update network
            _, _loss = sess.run([train_op, loss], feed_dict=train_feed_dict)
            
            # test
            _accuracy = sess.run(accuracy, feed_dict=test_feed_dict)
            
            # display training status
            print("Epoch: [%2d] [%5d/%5d] , loss: %.2f, acc: %.2f" \
                  % (epoch, idx, iteration, _loss, _accuracy))