In [None]:
'''
Editted by Silver on 2019/03/06

This templateis for a tensorflow starter who needs a template to train the network.

This file contains three parts

1. prpare inputs
2. Model Setup
3. Training

'''

In [None]:
import numpy as np
import tensorflow as tf
from os import listdir
from os.path import isfile, join
import pickle
import tqdm
import config
conf, _ = config.get_config()

# Loading target model from a given file name
if conf.tar_model == '[filename of your target model]':
    import '''filename of your model''' as model
    print('Load', conf.tar_model)
else:
    sys.exit("Sorry, Wrong Model!")

# Setting visable GPU
# Here I set CUDA to only see one GPU
gpus = [conf.gpu]
os.environ['CUDA_VISIBLE_DEVICES']=','.join([str(i) for i in gpus])

In [None]:
# Putting the dataset in the 'dataset' folder
data_dir = './dataset/'

# Setting a folder for certain model
model_dir = './'+conf.tar_model+'/'

# Creating the 'logs' folder with respect to model folder to save the trained model
logs_dir = model_dir + 'logs/'

In [None]:
#####################
# 2. Prepare inputs #
#####################

# load data from pickle, a file structure for python
with open(data_dir+'/[filename of your dataset].pickle', 'rb') as f:
    raw_data = pickle.load(f)

# Having a glance at raw_data
print(type(raw_data))
print(raw_data.shape)

# You can preprocess the raw data here
processed_data = raw_data

# Having a glance at processed data
print(type(processed_data))
print(processed_data.shape)

In [None]:
# setting a batch generator for getting a batch of dataset
def batch_generator(data, idx, batch_size = 128):
    np.random.shuffle(idx)
    while(True):
        for i in range(int(len(idx)/batch_size)):
            yield data[idx[i*batch_size:((i+1)*batch_size)]]
            # You might have the training labels, and the output will be 
            # yield data[idx[i*batch_size:((i+1)*batch_size)]], yield label[idx[i*batch_size:((i+1)*batch_size)]]
            
# splitting training and validation sets
def get_generators(training_data, tv_ratio = 0.8, shuffle = True, batch_size = 128):
    
    print('Size of all training data', training_data.shape)
    idx = np.arange(0, training_data.shape[0], 1)
    if shuffle:
        print('*Shuffle before training and validation split')
        np.random.shuffle(idx)
    training_idx = idx[:(int(tv_ratio*len(idx)))]
    validation_idx =idx[int(tv_ratio*len(idx)):]
    if (len(training_idx) + len(validation_idx) - len(idx)) > 0:
        print('Some data are duplicated in the training and validation split')
    elif (len(training_idx) + len(validation_idx) - len(idx)) < 0:
        print('Some data are missing in the training and validation split')

    training_gen = batch_generator(training_data, training_idx, batch_size)
    validation_gen = batch_generator(training_data, validation_idx, batch_size)
    return training_gen, len(training_idx), validation_gen, len(validation_idx)

# defining the generators for training/validation/testing
if conf.training:
    get_training_batch, n_training_samples, get_validation_batch, n_validation_samples = get_generators(processed_data, tv_ratio = 0.8, shuffle = True, batch_size = conf.batch_size) # (128, 30, 450)
else:
    testing_idx = np.arange(0, processed_data.shape[0], 1)
    get_testing_batch = batch_generator(processed_data, testing_idx, conf.batch_size)
    
# Havnig a glance at the output of batch_generator
# this example is for autoencoder (only containing one output)
# training_batch = next(get_training_batch)
# print(training_batch.shape)

# Usually having two output, a training batch and the corresponding label batch
# training_batch, label_batch= next(get_training_batch)
# print(training_batch.shape)
# print(label_batch.shape)

In [None]:
# Clearing tensorflow graph
tf.reset_default_graph()
graph = tf.Graph()
with graph.as_default() as g:
    
    ##################
    # 2. Model Setup #
    ##################
    
    # Model input/label
    model_input = tf.placeholder(tf.float32, shape=[conf.batch_size, '''Shape of the feature for trainging'''], name="model_input")
    model_GT = tf.placeholder(tf.float32, shape=[conf.batch_size, '''Shape of the feature for trainging'''], name='model_GT')
    # Ratio for the dropout (never used in this example)
    # keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
    # Tag for indicating training phase
    train_phase = tf.placeholder(tf.bool, name='phase_train')

    # Build the model
    code, model_out = model.inference(model_input, keep_probability, train_phase)

    # Check # weights contained in the model
    total_parameters = 0
    for variable in tf.trainable_variables():
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value

        total_parameters += variable_parameters
    print('Total parameters', total_parameters)
    
    # Loss functions (MSE here)
    loss = tf.reduce_mean(tf.losses.mean_squared_error(predictions=model_out, labels=model_GT))
    # save to summary for monitor
    loss_summary = tf.summary.scalar("Tot_loss", loss)
    
    # Defining the optimizer and update algorithm (BP)
    def train(loss_val, var_list):
        optimizer = tf.train.AdamOptimizer(conf.lr)
        grads = optimizer.compute_gradients(loss_val, var_list=var_list)
        return optimizer.apply_gradients(grads)

    trainable_var = tf.trainable_variables()
    
    train_op = train(loss, trainable_var)
    
    # This op is used for updating the parameters in the batch normalization layers
    extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    
    # Preparing training saver for save the model and summary
    # and defining a session for training
    
    print("Setting up summary op...")
    summary_op = tf.summary.merge_all()

    print("Setting up Saver...")
    saver = tf.train.Saver()
    
    # Defining the session
    sess = tf.InteractiveSession(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False), graph=g)
    
    # if testing, load pretrained model from the 'logs' folder
    if (conf.training == False):
        ckpt = tf.train.get_checkpoint_state(logs_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            print('Loading sucessfully')
        else:
            print('No checkpoint file found')
            raise
    # otherwise initial all model weights
    else:
        init = tf.global_variables_initializer()
        sess.run(init)

    # Creating two summary writers to show training loss and validation loss in the same graph
    train_writer = tf.summary.FileWriter(logs_dir+'/train', sess.graph)
    validation_writer = tf.summary.FileWriter(logs_dir+'/validation', sess.graph)
    
    ###############
    # 3. Training #
    ###############
    if conf.training == True:
        print("-------- Start Training --------")
        max_validloss = 99999
        
        for itr in range(int(conf.MAX_ITERATION)):
            # Preparing training input
            batch_xs = next(get_training_batch)
            
            # Training
            sess.run([train_op,extra_update_ops], feed_dict={model_input: batch_xs,
                                                             model_GT: batch_xs,
                                                             # keep_probability: 0.4,
                                                             train_phase:conf.training})
            # have a glance to the training loss for a batch
            if itr % 500 == 0:
                train_loss, summary_str = sess.run([loss, loss_summary], feed_dict={model_input: batch_xs,
                                                                                    model_GT: batch_xs,
                                                                                    # keep_probability: 1,
                                                                                    train_phase:False})
                print("[T] Step: %d, loss:%g" % (itr, np.mean(train_loss)))
                train_writer.add_summary(summary_str, itr)
                
            # Validation
            if itr % 1000 == 0:
                # prepare inputs
                valid_losses = []
                for i in tqdm.trange(int(n_validation_samples/conf.batch_size)):
                    batch_xs_valid = next(get_validation_batch)
                    
                    valid_loss, summary_sva=sess.run([loss, loss_summary], feed_dict={model_input: batch_xs_valid,
                                                                                     model_GT: batch_xs_valid,
                                                                                     # keep_probability: 1,
                                                                                     train_phase:False})
                    valid_losses.append(valid_loss)
                    
                # Saving validation log
                validation_writer.add_summary(summary_sva, itr)
                # Saving the ckpt if reaching better loss
                calc_v_loss = np.mean(valid_losses)

                if calc_v_loss < max_validloss:
                    saver.save(sess, logs_dir + "model.ckpt", itr)
                    print("[V*] Step: %d, loss:%g" % (itr, calc_v_loss))
                    max_validloss = calc_v_loss
                else:
                    print("[V] Step: %d, loss:%g" % (itr, calc_v_loss))
    else:
        print("Start Testing....")

        test_losses = []
        for i in tqdm.trange(int(processed_data.shape[0]/conf.batch_size)):
            batch_xs_test = next(get_testing_batch)

            test_loss, summary_sva=sess.run([loss, loss_summary], feed_dict={model_input: batch_xs_test,
                                                                             model_GT: batch_xs_test,
                                                                             # keep_probability: 1,
                                                                             train_phase:False})
            test_losses.append(test_loss)
            
            ''' You might want save the results, and you can add here'''


        # See the average loss here
        calc_test_loss = np.mean(test_losses)
        print("[Test] Avg. loss:%g" % (calc_test_loss))
        
    # finish
    sess.close()