In [1]:
from __future__ import division, print_function, absolute_import

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os 
import random
import cv2
import time
#this is for directory listing


In [2]:
import tensorflow.contrib.slim as slim


In [3]:
dir_path='data/'
dir_log='log/'

In [4]:
#load data and put them into three different lists
def loadData(fileName):
    #get the dir list
    fileList=os.listdir(fileName)
    fileList.sort()
    #iterate and get them into three lists
    img1List=[]
    img2List=[]
    groTruth=[]
    print(len(fileList))
    for i in range(0,int(len(fileList)/3)):
        img1List.append(dir_path+fileList[3*i+1])
        img2List.append(dir_path+fileList[3*i+2])
        groTruth.append(dir_path+fileList[3*i])
    return img1List,img2List,groTruth

In [5]:
#putting all the hyperparameters
initLr=1e-3
epochMax=1          #max number of epocs;1 epoch=all training examples through the NN.
epochLrDecay=5
batchSize=4        #the batch size for every iteration. 1 epoch = 1 batch_size*iterations
numExamples=100
#number of training examples to use.
useGpu=False
W,H=512,384
iterPerEpoch=numExamples//batchSize


In [6]:
#get the data class
"""
The data class will encapusulate:
    1.reading the .flo files
    2.hold together all the 
"""
class Data(object):
    def __init__(self,img1List,img2List,groTruth,bs=batchSize,shuffle=True,minusMean=True):
        self.img1List=img1List
        self.img2List=img2List
        self.groTruth=groTruth
        self.bs=bs
        self.index=0
        self.shuffle=shuffle #wtf is this??????
        self.minusMean=minusMean
        self.range=len(self.img1List)
        self.allIndices=range(self.range)
        
        
    #optical flow .flo type data reading; Courtesy - Univ. of Freiburg website.
    def readFlow(self,name):
        if name.endswith('.pfm') or name.endswith('.PFM'):
            return readPFM(name)[0][:,:,0:2]

        f = open(name, 'rb')

        header = f.read(4)
        if header.decode("utf-8") != 'PIEH':
            raise Exception('Flow file header does not contain PIEH')

        width = np.fromfile(f, np.int32, 1).squeeze()
        height = np.fromfile(f, np.int32, 1).squeeze()

        flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2))

        return flow.astype(np.float32)       

    
    def upBatch(self):
        start=self.index
        #now point the list index to the next batch
        self.index+=self.bs
        #if all the batches complete, then reinitiate the indices with 0/
        if self.index>self.range:
            #if shuffle is set out to be true
            if(self.shuffle):
                random.shuffle(self.allIndices)
                self.index=0
                start=self.index
                self.index+=self.bs
        end=self.index
        img1Batch=[]
        img2Batch=[]
        groTruBatch=[]
        for i in range(start,end):
            img1=cv2.imread(self.img1List[self.allIndices[i]]).astype(np.float32)
            img1Batch.append(img1)
            img2=cv2.imread(self.img2List[self.allIndices[i]]).astype(np.float32)
            img2Batch.append(img2)
            flow=self.readFlow(self.groTruth[self.allIndices[i]])
            groTruBatch.append(flow)
        return np.array(img1Batch), np.array(img2Batch),np.array(groTruBatch)

In [7]:
#this is hugely based on work by Lin Jian at https://github.com/linjian93/tf-flownet
"""
This class is going to encapsulate everything on the architecture of the 
optical flow.
"""

class Net(object):
    def __init__(self,useGpu=True):
        self.img1=tf.placeholder(tf.float32,[batchSize,H,W,3])
        self.img2=tf.constant(1,shape=[batchSize, H, W, 3],dtype=tf.float32)
        self.flow=tf.constant(1,shape=[4,H,W,2],dtype=tf.float32)
        self.learnRate=tf.constant(0.01)
        concat1=tf.concat([self.img1, self.img2],3,name='input')

        #concat the first and second images on the third axis

        #applies 64 5*5 filters
        conv1=slim.conv2d(concat1,8,[5,5],2,scope='conv1')
        conv2=slim.conv2d(conv1,8,[5,5],2,scope='conv2')
        conv3=slim.conv2d(conv2,8,[5,5],2,scope='conv3')
        conv4=slim.conv2d(conv3,8,[5,5],2,scope='conv4')
        conv5=slim.conv2d(conv4,8,[5,5],2,scope='conv5')
        deconv1=slim.conv2d_transpose(conv5,1,[3,3],2,scope='deconv1')
        final = tf.reshape(deconv1, [-1,1,24*32], 'final')
            
        with tf.variable_scope('loss'):
            #take the ground truth flow values and resize them 
            flow=tf.image.resize_images(self.flow,[384,512])
            self.loss=tf.reduce_mean(tf.abs(final-tf.ones([4,24*32,1])))
            self.merged=tf.summary.scalar('lossval',self.loss)
        
        #get the optimizer to run with the described learning rate
        optimizer=tf.train.AdamOptimizer(self.learnRate)
        
        #encapsulating the loss and the optimizer.
        self.trainOp=slim.learning.create_train_op(self.loss,optimizer)
        
        
        #gpu settings
        self.tf_config=tf.ConfigProto()
        self.tf_config.gpu_options.allow_growth=True
        if useGpu==True:
            self.tf_config.gpu_options.visible_device_list='1'
    
        self.init_all=tf.global_variables_initializer()
        

In [8]:
def main(_):
    #load data
    imgList1,imgList2,floTruth=loadData('data/')
    trainDataset=Data(imgList1,imgList2,floTruth,shuffle=True,minusMean=False)
    
    #call the model class
    model=Net(useGpu=True)
    #saver for the graph
    saver=tf.train.Saver(max_to_keep=5)
    with tf.Session() as sess:
        sess.run(model.init_all)
        #write the file to a directory
        writerTrain=tf.summary.FileWriter(dir_log,sess.graph)
        #for every epoch,run the iterations for all batches
        #xrange is being used since it uses lot less space
        for epoch in xrange(epochMax):
            lrDecay=0.1**(epoch/epochLrDecay)
            learnRate=initLr*lrDecay
            for iterations in xrange(iterPerEpoch):
                time_start = time.time()
                globalIter=epoch*iterPerEpoch+iterations
                feedImg1List,feedImg2List,groTruthList=trainDataset.upBatch()
                feedDict={model.img1:feedImg1List,model.img2:feedImg2List,
                        model.flow:groTruthList,model.learnRate:learnRate}
                _,mergedOut,lossOut=sess.run([model.trainOp,model.merged,model.loss],feedDict)
                
                writerTrain.add_summary(mergedOut,globalIter+1)
                hour_per_epoch = iterPerEpoch * ((time.time() - time_start) / 3600)
                print('%.2f h/epoch, epoch %03d/%03d, iter %04d/%04d, lr %.5f, loss: %.5f' %
                      (hour_per_epoch, epoch + 1, epochMax, iterations + 1, iterPerEpoch, learnRate, lossOut))
                
            if not (epoch+1)%1:
                saver.save(sess,("checkpoints/model"),global_step=epoch+1)
                tf.train.write_graph(sess.graph, "checkpoints/",
                     'saved_model.pb', as_text=False)

    

In [9]:

if __name__=='__main__':
    tf.reset_default_graph() 
    tf.app.run()
    

300
0.00 h/epoch, epoch 001/001, iter 0001/0025, lr 0.00100, loss: 2.75190
0.00 h/epoch, epoch 001/001, iter 0002/0025, lr 0.00100, loss: 1.66727
0.00 h/epoch, epoch 001/001, iter 0003/0025, lr 0.00100, loss: 1.12314
0.00 h/epoch, epoch 001/001, iter 0004/0025, lr 0.00100, loss: 0.92529
0.00 h/epoch, epoch 001/001, iter 0005/0025, lr 0.00100, loss: 0.86887
0.00 h/epoch, epoch 001/001, iter 0006/0025, lr 0.00100, loss: 0.87187
0.00 h/epoch, epoch 001/001, iter 0007/0025, lr 0.00100, loss: 0.93431
0.00 h/epoch, epoch 001/001, iter 0008/0025, lr 0.00100, loss: 0.95756
0.00 h/epoch, epoch 001/001, iter 0009/0025, lr 0.00100, loss: 0.93524
0.00 h/epoch, epoch 001/001, iter 0010/0025, lr 0.00100, loss: 0.94743
0.00 h/epoch, epoch 001/001, iter 0011/0025, lr 0.00100, loss: 0.93062
0.00 h/epoch, epoch 001/001, iter 0012/0025, lr 0.00100, loss: 0.91495
0.00 h/epoch, epoch 001/001, iter 0013/0025, lr 0.00100, loss: 0.87749
0.00 h/epoch, epoch 001/001, iter 0014/0025, lr 0.00100, loss: 0.84018
0.

SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
