In [None]:
import tensorflow as tf
import numpy as np
import C3D_model
import time
import data_processing
import os
import os.path
from os.path import join
import pickle
import tensorflow.contrib.slim as slim

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"
TRAIN_LOG_DIR = os.path.join('Log/train/', time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
# pretrained checkpoint 디렉토리 
TRAIN_CHECK_POINT = 'pretrained_C3D_checkpoint/'
# TRAIN_CHECK_POINT = 'check_point/'
# Non-local block 추가된 C3D 모델 checkpoint 저장 위치
NL_CHECK_POINT = 'check_point_nl/'
BATCH_SIZE = 10
NUM_CLASSES = 11
CROP_SIZE = 112
CHANNEL_NUM = 3
CLIP_LENGTH = 16
EPOCH_NUM = 30
LEARNING_RATE = 1e-5

## 1) Load UCF11(UCF YouTube Action) Dataset Path

In [None]:
DATA_SPLIT_PATH = 'data_split.pkl'
ucf11_dataset = pickle.load(open(DATA_SPLIT_PATH,'rb'))
train_set = ucf11_dataset['train']
test_set = ucf11_dataset['test']

## 2) Get Shuffle Index

In [None]:
train_video_indices = data_processing.get_video_indices(len(train_set))
test_video_indices = data_processing.get_video_indices(len(test_set))

## 3) Set Graph 

In [None]:

# Define graph
    
batch_clips = tf.placeholder(tf.float32, [BATCH_SIZE, CLIP_LENGTH, CROP_SIZE, CROP_SIZE, CHANNEL_NUM], name='X')
batch_labels = tf.placeholder(tf.int32, [BATCH_SIZE, NUM_CLASSES], name='Y')
keep_prob = tf.placeholder(tf.float32)
logits = C3D_model.C3D(batch_clips, NUM_CLASSES, keep_prob, non_local=True)

include_layers = ['C3D/conv1','C3D/conv2','C3D/conv3','C3D/conv4','C3D/conv5','C3D/fc6','C3D/fc7']
variables_to_restore = slim.get_variables_to_restore(include=include_layers)

with tf.name_scope('loss'):
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=batch_labels))
    tf.summary.scalar('entropy_loss', loss)

with tf.name_scope('accuracy'):
    accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), tf.argmax(batch_labels, 1)), np.float32))
    tf.summary.scalar('accuracy', accuracy)

optimizer = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)
saver = tf.train.Saver()
restorer = tf.train.Saver(variables_to_restore)
summary_op = tf.summary.merge_all()
    
    
    

### Load Pretrained Weight of C3D

In [None]:
# Load pretrained weight of C3D
pretrained_path = tf.train.latest_checkpoint(TRAIN_CHECK_POINT)

### Training C3D + Non_local Block

In [None]:
# Set graph
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
    train_summary_writer = tf.summary.FileWriter(TRAIN_LOG_DIR, sess.graph)
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    if pretrained_path != 0:
        restorer.restore(sess, pretrained_path)
        print('Pretrained C3D Model is restored')
        step = 0
    for epoch in range(EPOCH_NUM):
        accuracy_epoch = 0
        loss_epoch = 0
        batch_index = 0
        for i in range(len(train_video_indices) // BATCH_SIZE):
            step += 1
            # Get Batch for Training
            batch_data, batch_index = data_processing.get_batches(train_set, NUM_CLASSES, batch_index,
                                                         train_video_indices, BATCH_SIZE)
                
            # Train
            _, loss_out, accuracy_out, summary = sess.run([optimizer, loss, accuracy, summary_op],
                                                        feed_dict={batch_clips:batch_data['clips'],
                                                        batch_labels:batch_data['labels'],
                                                        keep_prob: 0.5})
            loss_epoch += loss_out
            accuracy_epoch += accuracy_out

            if i % 10 == 0:
                print('Epoch %d, Batch %d: Loss is %.5f; Accuracy is %.5f'%(epoch+1, i, loss_out, accuracy_out))
                train_summary_writer.add_summary(summary, step)

        print('Epoch %d: Average loss is: %.5f; Average accuracy is: %.5f'%(epoch+1, loss_epoch / (len(train_video_indices) // BATCH_SIZE),
                                                                                accuracy_epoch / (len(train_video_indices) // BATCH_SIZE)))
        accuracy_epoch = 0
        loss_epoch = 0
        batch_index = 0
            
        # Get validation results
        for i in range(len(test_video_indices) // BATCH_SIZE):
            batch_data, batch_index = data_processing.get_batches(test_set, NUM_CLASSES, batch_index,
                                                                      test_video_indices, BATCH_SIZE)
            loss_out, accuracy_out = sess.run([loss, accuracy],
                                                  feed_dict={batch_clips:batch_data['clips'],
                                                             batch_labels:batch_data['labels'],
                                                            keep_prob: 1.0})
            loss_epoch += loss_out
            accuracy_epoch += accuracy_out

        print('Test loss is %.5f; Accuracy is %.5f'%(loss_epoch / (len(test_video_indices) // BATCH_SIZE),
                                                               accuracy_epoch /(len(test_video_indices) // BATCH_SIZE)))
        saver.save(sess, NL_CHECK_POINT + 'c3d_nonlocal_ckpt', global_step=epoch)