In [1]:
import os
import numpy as np

import tensorflow as tf
from tensorflow.python.framework import ops

from load_data import *
from models.customlayers import *
from models.activations import *
from training import *

import models.ConvAE as cae
L = tf.layers

import matplotlib.pyplot as plt
% matplotlib inline

In [None]:
data_dir = os.path.expanduser('~/Insight/video-representations/frames') #data/downsampled')
# X_train, y_train, X_test, y_test = get_splits(*load_all_data_stacked(data_dir, every_n=2))

In [None]:
def read_record(filepath_queue):
    reader = tf.RecordReader()
    _, serialized_example = reader.read(filepath_queue)
    
    features = tf.parse_single_example(
        serialized_example,
        features={
            'height': tf.FixedLenFeature([], tf.int64),
            'width': tf.FixedLenFeature([], tf.int64),
            'image': tf.FixedLenFeature([], tf.string)
        }
    )
    
    video = tf.decode_raw(features['image'], tf.uint8)   # feature may be renamed to video in future
    h = tf.cast(features['height'], tf.int32)
    w = tf.cast(features['width'], tf.int32)
    
    video_shape = tf.pack([-1, h, w, 3])
    video = tf.reshape(video, video_shape)
    
    return video
    
def inputs(split_type, batchsize, num_epochs): 
    if not num_epochs:
        num_epochs = None
        
    filepath = os.path.join(data_dir, '{}.tfrecords'.format(split_type))
    
    with tf.name_scope('input'):
        filepath_queue = tf.train.string_input_producer(
            [filepath], num_epochs=num_epochs
        )
    
    video = read_record(filepath_queue)
    
    videos = tf.train.shuffle_batch(
        [videos],
        batchsize=16, capacity=32, num_threads=2, min_after_dequeue=10
    )
    

In [None]:
tf.reset_default_graph()

input_var = tf.placeholder(dtype=tf.float32, shape=(None, 3, 60, 80), name='input')
target_var = tf.placeholder(dtype=tf.float32, shape=(None, 3, 60, 80), name='target')
l2_weight = .01

with tf.variable_scope('encoder'):
    encoded = cae.encoder(input_var)

with tf.variable_scope('decoder'):
    decoded = cae.decoder(encoded)

l2_term = tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'bias' not in v.name])
loss = tf.reduce_mean(tf.pow(decoded - target_var, 2))
train_step = tf.train.AdamOptimizer().minimize(loss + l2_weight*l2_term)
saver = tf.train.Saver()
init = tf.global_variables_initializer()

In [None]:
print('Training data: {} examples, {:.2f} GB'.format(X_train.shape[0], X_train.nbytes/1000000))

In [None]:
os.makedirs('tmp/models', exist_ok=True)
epochs = 30

with tf.Session() as sesh:
    
    sesh.run(init)
        
    train_trace = []
    validation_trace = []
    
    batch_no = 0
    current_loss = 0
    
    for batch in iterate_minibatches(X_test, y_test, batchsize=128):
        X_batch, y_batch = batch
        current_loss += loss.eval({input_var: X_batch, target_var: y_batch})
        batch_no += 1
        
    print('Initial loss:\t{}'.format(current_loss / batch_no))
    saver.save(sesh, 'tmp/models/prototype_ae initial.ckpt')
    
    for epoch in range(epochs):
        batch_no = 0
        current_loss = 0
        
        for batch in iterate_minibatches(X_train, y_train, batchsize=64, shuffle=True):
            X_batch, y_batch = batch
            _, batch_loss = sesh.run([train_step, loss], {input_var: X_batch, target_var: y_batch})
            
            current_loss += batch_loss
            batch_no += 1
            train_trace.append(batch_loss)
            
            if batch_no % 200 == 0:
                print('\t\t', current_loss / batch_no)
            
        print('Epoch {} train loss:\t{}'.format(epoch, current_loss / batch_no))
        
        val_loss = 0
        batch_no = 0
        
        for batch in iterate_minibatches(X_test, y_test, batchsize=128, shuffle=True):
            X_batch, y_batch = batch
            
            batch_loss = loss.eval({input_var: X_batch, target_var: y_batch})
            val_loss += batch_loss
            batch_no += 1
            
        print('\t test loss:\t{}'.format(val_loss / batch_no))
    
    train_outputs = []
    for batch in iterate_minibatches(X_train, y_train, batchsize=128, shuffle=False):
        X_batch, y_batch = batch
        output = sesh.run(decoded, {input_var: X_batch})
        train_outputs.append(output)
    
    test_outputs = []
    for batch in iterate_minibatches(X_test, y_test, batchsize=128, shuffle=False):
        X_batch, y_batch = batch
        output = sesh.run(decoded, {input_var: X_batch})
        test_outputs.append(output)
    
    saver.save(sesh, 'tmp/models/prototype_ae.ckpt')

In [None]:
lim = 30
new_test_frames = np.concatenate(test_outputs[:lim]).transpose((0, 2, 3, 1))
new_train_frames = np.concatenate(train_outputs[:lim]).transpose((0, 2, 3, 1))
X_test_ims = X_test[:lim].transpose((0, 2, 3, 1))
X_train_ims = X_train[:lim].transpose((0, 2, 3, 1))

In [None]:
fig, axes = plt.subplots(5, 4, figsize=(20, 20))
probe = 4

for i in range(5):
    for j in range(4):
        ax = axes[i, j]
        if j == 0:
            ax.imshow(X_train_ims[probe+(i//5)+2*i])
        elif j == 1:
            ax.imshow(new_train_frames[probe+(i//5)+2*i])
        elif j == 2:
            ax.imshow(X_test_ims[probe+(i//5)+2*i])
        else:
            ax.imshow(new_test_frames[probe+(i//5)+2*i])