In [1]:
from ResNet import *
from Cifar10 import *
import cPickle
import numpy as np
import os

## Configuration

In [2]:
num_units = 5
exp_id = 2
gpu_number = 0
epoch = 400
image_shape = [32, 32, 3]
train_batch_size = 128
test_batch_size = 100
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_number)

log_file_path = os.path.join('/scratch/f1fan/ResNet', 'log_exp{}.txt'.format(exp_id))
log_file = open(log_file_path, 'w+')

## Log

In [3]:
def log(line):
    log_file.write(line)
    log_file.write('\n')
    log_file.flush()
    print line

## Main function

In [None]:
def main(sess):
    dataset = Cifar10(train_batch_size, test_batch_size)
    model = ResNet(num_units, image_shape, train_batch_size, test_batch_size)
    train_op, train_loss, train_accuracy = model.build_train_op()
    test_loss, test_accuracy = model.build_test_op()
    
    global_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='')
    log('Global variables:')
    for i, var in enumerate(global_variables):
        log('{0} {1}'.format(i, var.name))
    
    all_initializer_op = tf.global_variables_initializer()
    sess.run(all_initializer_op)
    
    for i in range(epoch):  
        total_loss = 0.0
        total_accuracy = 0.0
        dataset.shuffle_dataset()
        for j in range(dataset.train_batch_count):
            batch_images, batch_labels = dataset.next_aug_train_batch(j)
            
            sess.run(train_op,
                     feed_dict = {model.train_image_placeholder: batch_images, 
                                  model.train_label_placeholder: batch_labels})
            curr_loss, curr_accuracy = sess.run([train_loss,train_accuracy],
                                                feed_dict = {model.train_image_placeholder: batch_images, 
                                                             model.train_label_placeholder: batch_labels})
            #sess.run(train_step_op)
            total_loss += curr_loss
            total_accuracy += curr_accuracy
        
        total_loss /= dataset.train_batch_count
        total_accuracy /= dataset.train_batch_count
        log('Training epoch {0}, step {1}, learning rate {2}'.
            format(i, sess.run(model.train_step), sess.run(model.learning_rate)))
        log('    train loss {0}, train error {1}'.format(total_loss, 1.0 - total_accuracy))
            

        total_loss = 0.0
        total_accuracy = 0.0
        for k in range(dataset.test_batch_count):
            batch_images, batch_labels = dataset.next_test_batch(k)
                
            curr_loss, curr_accuracy = sess.run([test_loss, test_accuracy],
                                                feed_dict = {model.test_image_placeholder: batch_images,
                                                             model.test_label_placeholder: batch_labels})
            total_loss += curr_loss
            total_accuracy += curr_accuracy
            
        total_loss /= dataset.test_batch_count
        total_accuracy /= dataset.test_batch_count
        log('    test loss {0}, test_error {1}'.format(total_loss, 1.0 - total_accuracy))

In [None]:
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
graph = tf.Graph()
with graph.as_default():
    with tf.Session(config=config) as sess:
        main(sess)
log_file.close()

Global variables:
0 Variable:0
1 ResNet/r0_conv/conv_filter:0
2 ResNet/r0_conv/conv_bias:0
3 ResNet/r0_bn/beta:0
4 ResNet/r0_bn/gamma:0
5 ResNet/r0_bn/moving_mean:0
6 ResNet/r0_bn/moving_variance:0
7 ResNet/res1.0/h1_conv/conv_filter:0
8 ResNet/res1.0/h1_conv/conv_bias:0
9 ResNet/res1.0/h1_bn/beta:0
10 ResNet/res1.0/h1_bn/gamma:0
11 ResNet/res1.0/h1_bn/moving_mean:0
12 ResNet/res1.0/h1_bn/moving_variance:0
13 ResNet/res1.0/h2_conv/conv_filter:0
14 ResNet/res1.0/h2_conv/conv_bias:0
15 ResNet/res1.1/h0_bn/beta:0
16 ResNet/res1.1/h0_bn/gamma:0
17 ResNet/res1.1/h0_bn/moving_mean:0
18 ResNet/res1.1/h0_bn/moving_variance:0
19 ResNet/res1.1/h1_conv/conv_filter:0
20 ResNet/res1.1/h1_conv/conv_bias:0
21 ResNet/res1.1/h1_bn/beta:0
22 ResNet/res1.1/h1_bn/gamma:0
23 ResNet/res1.1/h1_bn/moving_mean:0
24 ResNet/res1.1/h1_bn/moving_variance:0
25 ResNet/res1.1/h2_conv/conv_filter:0
26 ResNet/res1.1/h2_conv/conv_bias:0
27 ResNet/res1.2/h0_bn/beta:0
28 ResNet/res1.2/h0_bn/gamma:0
29 ResNet/res1.2/h0_bn/