In [0]:
import tensorflow as tf
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

Found GPU at: /device:GPU:0


In [0]:
from google.colab import drive
drive.mount('/content/drive') 

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/drive


In [0]:
cd drive/My Drive/Colab Notebooks/NVIDIA-SuperSlowMo

/content/drive/My Drive/Colab Notebooks/NVIDIA-SuperSlowMo


In [0]:
!ls

 32Xslow_original.gif   data   Main.ipynb	      utils
 8Xslow_original.gif    libs  'Train(mtahir).ipynb'


In [0]:
import numpy as np
import sys
sys.path.insert(0, 'utils')
sys.path.insert(0, 'libs')
import cv2
import math
import datetime

from keras.callbacks import ModelCheckpoint, LambdaCallback, LearningRateScheduler
from keras import backend as K

import matplotlib.pyplot as plt
from Network import Network
from DataGenerator import FrameGenerator

TRAIN_FILES = 'data/samples/DiDeMoRelease/train_videos'
TEST_FILES = 'data/samples/DiDeMoRelease/test_videos'
VAL_FILES = 'data/samples/DiDeMoRelease/val_videos'
WEIGHTS_FILES ='data/model_data/weights'
PLOT_FILES = 'data/model_data/plots'
CHECKPOINTS = 'data/model_data/model checkpoints'

Using TensorFlow backend.


In [0]:
class DataFrameGenerator():
    def __init__(self, directory, height, width, batch_size, rescale = 1.0, seed = None):
        self.batch_size = batch_size
        self.frame_generator = FrameGenerator(height, width, directory, seed)
        self.rescale = rescale
        self.fps = [2, 4, 8, 8, 16, 16, 32, 32]
        
    def flow_from_directory(self, fps = 8):
      
      while True:
          if fps is None:
              fps = np.random.choice(self.fps, 1, replace = True)[0]
              
          frames = self.frame_generator.sample(self.batch_size + fps)
          frames = [f * self.rescale for f in frames]
          
          for i in range(8):
            frame = i + 1
            t = (1.0 / fps) * frame

            I0 = []
            I1 = []
            It = []

            for j in range(self.batch_size):
                I0.append(frames[j])
                I1.append(frames[j + fps])
                It.append(frames[j + frame])

            I0 = np.asarray(I0)
            I1 = np.asarray(I1)
            It = np.asarray(It)
            t = np.full((self.batch_size,1,1,1), t)
            yield [t, I0, I1], It

In [0]:
train_datagen = DataFrameGenerator(TRAIN_FILES, 320, 320, 1, rescale = 1.0/255)
train_generator = train_datagen.flow_from_directory()

val_datagen = DataFrameGenerator(VAL_FILES, 320, 320, 1, rescale = 1.0/255)
val_generator = train_datagen.flow_from_directory()

test_datagen = DataFrameGenerator(TEST_FILES, 320, 320, 1, rescale = 1.0/255)
test_generator = test_datagen.flow_from_directory()

>> Found 7189 video files in data/samples/DiDeMoRelease/train_videos
>> Found 985 video files in data/samples/DiDeMoRelease/val_videos
>> Found 884 video files in data/samples/DiDeMoRelease/test_videos


In [0]:
sample, output = next(train_generator)
t, I0, I1 = sample
print(I1.shape)
for i in range(len(I1)):
    _, axes = plt.subplots(1, 3, figsize=(20, 5))
    axes[0].imshow(I0[i,:,:,:])
    axes[1].imshow(I1[i,:,:,:])
    axes[2].imshow(output[i,:,:,:])
    axes[0].set_title('I0')
    axes[1].set_title('I1')
    axes[2].set_title('It')
    plt.show()

In [0]:
lr = 1e-4

In [0]:
def step_decay(epoch):
    initial_lrate = lr
    drop = 0.1
    epochs_drop = 200.0
    lrate = initial_lrate * math.pow(drop,math.floor((1+epoch)/epochs_drop))
    return lrate

def plot_callback(model):

    sample, output = next(test_generator)
    t, I0, I1 = sample
    
    y_pred = model.predict(sample)
    pred_time = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    
    for i in range(len(I1)):
        _, axes = plt.subplots(1, 4, figsize=(20, 5))
        axes[0].imshow(I0[i,:,:,:])
        axes[1].imshow(output[i,:,:,:])
        axes[2].imshow(I1[i,:,:,:])
        axes[3].imshow(y_pred[i,:,:,:])
        axes[0].set_title('I0')
        axes[1].set_title('It')
        axes[2].set_title('I1')
        axes[3].set_title('y_pred')
                
        plt.savefig(r'data/model_data/model_checkpoints/mtahir/img_{}_{}.png'.format(i, pred_time))
        plt.close()

#     print(It[0])
    
lrate = LearningRateScheduler(step_decay)     
checkpoint = ModelCheckpoint(WEIGHTS_FILES + '/mtahir/weights.{epoch:02d}-{loss:.4f}.h5', monitor = 'val_loss', 
                             save_best_only = False, save_weights_only = True)

In [0]:
model = Network(vgg_model = WEIGHTS_FILES + "/main/vgg16.h5", shape = (320, 320, 3), lr = lr)

In [0]:
model.load(WEIGHTS_FILES + '/mtahir/weights.37-32.9055.h5', lr = lr)

In [0]:
model.fit_generator(train_generator, 
                steps_per_epoch = 7000,
                epochs = 500,
                callbacks = [
                             lrate,
                             checkpoint,
                             LambdaCallback(on_epoch_end= lambda epoch, logs: plot_callback(model))
                            ],
                validation_data = val_generator,
                validation_steps = 50,
                validation_freq = 2,
                use_multiprocessing = True,
                initial_epoch = model.current_epoch,
                shuffle = False)

In [0]:
from keras.utils import plot_model
plot_model(model.model, to_file = PLOT_FILES + '/model.png', 
          show_shapes = True, show_layer_names = True, rankdir='LR')

In [0]:
sample, output = next(test_generator)
t, I0, I1 = sample

y_pred = model.predict(sample)
pred_time = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')

for i in range(len(I1)):
  _, axes = plt.subplots(1, 4, figsize=(20, 5))
  axes[0].imshow(I0[i,:,:,:])
  axes[1].imshow(output[i,:,:,:])
  axes[2].imshow(I1[i,:,:,:])
  axes[3].imshow(y_pred[i,:,:,:])
  axes[0].set_title('I0')
  axes[1].set_title('It')
  axes[2].set_title('I1')
  axes[3].set_title('y_pred')