In [1]:
from ResNet import *
import cPickle
import numpy as np
import os

## Configuration

In [2]:
num_units = 5
exp_id = 0
gpu_number = 0
epoch = 400
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_number)
cifar_dir = '/scratch/f1fan/ResNet/data/cifar-10-batches-py'
train_files = ['data_batch_1',
               'data_batch_2',
               'data_batch_3',
               'data_batch_4',
               'data_batch_5']
test_file = 'test_batch'
log_file_path = os.path.join('/scratch/f1fan/ResNet', 'log_exp{}.txt'.format(exp_id))
log_file = open(log_file_path, 'w+')

## Log

In [3]:
def log(line):
    log_file.write(line)
    log_file.write('\n')
    log_file.flush()
    print line

## Dataset Functions

In [4]:
def load_train_data():
    images = []
    labels = []
    
    # Process data files
    for data_file in train_files:
        full_path = os.path.join(cifar_dir, data_file)
        with open(full_path, 'rb') as f:
            raw = cPickle.load(f)
    
        count = raw['data'].shape[0]
        batch = np.transpose(raw['data'].reshape((count, 3, 32, 32)), (0, 2, 3, 1))
    
        images += (list(batch))
        labels += raw['labels']
        
    return np.array(images).astype(np.float32), np.array(labels)

In [5]:
def load_test_data():
    images = []
    labels = []
    
    # Process data file
    full_path = os.path.join(cifar_dir, test_file)
    with open(full_path, 'rb') as f:
        raw = cPickle.load(f)
    
    count = raw['data'].shape[0]
    batch = np.transpose(raw['data'].reshape((count, 3, 32, 32)), (0, 2, 3, 1))
    
    images += (list(batch))
    labels += raw['labels']  
    
    return np.array(images).astype(np.float32), np.array(labels)

In [6]:
def get_per_pixel_mean(train_images, test_images):
    images = np.concatenate((train_images, test_images), axis=0)
    return np.mean(images, axis=0)

## Main function

In [None]:
def main(sess):
    train_images, train_labels = load_train_data()
    test_images, test_labels = load_test_data()
    pp_mean = get_per_pixel_mean(train_images, test_images)
    train_images = (train_images - pp_mean) / 128.0
    test_images = (test_images - pp_mean) / 128.0
    
    image_shape = [32, 32, 3]
    train_batch_size = 128
    test_batch_size = 100
    model = ResNet(num_units, image_shape, train_batch_size, test_batch_size)
    train_op, train_loss, train_accuracy = model.build_train_op()
    test_loss, test_accuracy = model.build_test_op()
    
    global_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='')
    log('Global variables:')
    for i, var in enumerate(global_variables):
        log('{0} {1}'.format(i, var.name))
    
    all_initializer_op = tf.global_variables_initializer()
    sess.run(all_initializer_op)
    
    iteration_per_epoch = train_images.shape[0] // train_batch_size
    for i in range(epoch):
        shuffle = np.random.permutation(train_images.shape[0])
        
        total_loss = 0.0
        total_accuracy = 0.0
        for j in range(iteration_per_epoch):
            batch_images = train_images[shuffle[j * train_batch_size : (j + 1) * train_batch_size]]
            batch_labels = train_labels[shuffle[j * train_batch_size : (j + 1) * train_batch_size]]
            
            sess.run(train_op,
                     feed_dict = {model.train_image_placeholder: batch_images, 
                                  model.train_label_placeholder: batch_labels})
            curr_loss, curr_accuracy = sess.run([train_loss,train_accuracy],
                                                feed_dict = {model.train_image_placeholder: batch_images, 
                                                             model.train_label_placeholder: batch_labels})
            #sess.run(train_step_op)
            total_loss += curr_loss
            total_accuracy += curr_accuracy
        
        train_step = sess.run(model.train_step)
        total_loss /= iteration_per_epoch
        total_accuracy /= iteration_per_epoch
        log('Training epoch {0}, step {1}, train_loss {2}, train_accuracy {3}'.format
            (i, train_step, total_loss, total_accuracy))
            

        test_batch_count = test_images.shape[0] // test_batch_size
        total_loss = 0.0
        total_accuracy = 0.0
        for k in range(test_batch_count):
            batch_images = test_images[k * test_batch_size : (k + 1) * test_batch_size]
            batch_labels = test_labels[k * test_batch_size : (k + 1) * test_batch_size]
                
            curr_loss, curr_accuracy = sess.run([test_loss, test_accuracy],
                                                feed_dict = {model.test_image_placeholder: batch_images,
                                                             model.test_label_placeholder: batch_labels})
            total_loss += curr_loss
            total_accuracy += curr_accuracy
            
        total_loss /= test_batch_count
        total_accuracy /= test_batch_count
        log('Testing after epoch {0}, loss {1}, accuracy {2}'.format(i, total_loss, total_accuracy))

In [None]:
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
graph = tf.Graph()
with graph.as_default():
    with tf.Session(config=config) as sess:
        main(sess)
log_file.close()