# FlowNet Expert – Deep Learning for Optical Flow Workshop

Welcome to FlowNet! In this project, we're going to build a FlowNet algorithm with PyTorch! The idea is simple, given two images, output the optical flow!
<p>

![flownet](https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQ_p_REZwjQ1YqfV51j8vQ1qJodRUDRI8Dd7tPuwbWW-tWUQBhKibGi3Bq1ox6SNp5k2ts&usqp=CAU)

In this project, we're going to:

1.   **Load and Prepare the Dataset** for the Model
2.   Define a **FlowNet Architecture**
3.   **Train the Model** on KITTI
4.   **Run the Model**

Just a note before we begin, this code has been adapted from Clement Pinard who authored FlowNet PyTorch. I have been in contact with clement numerous times and he helped me make this course and this code easy to get. 
<p>

For your information, [here is the original repo link](https://github.com/ClementPinard/FlowNetPytorch)

[Link to the paper](https://arxiv.org/pdf/1504.06852.pdf)

Let's begin with some synchronization and imports!

In [None]:
!wget https://thinkautonomous-flownet.s3.eu-west-3.amazonaws.com/flownet-data.zip && unzip flownet-data.zip && rm flownet-data.zip
!mkdir output
!ls

Imports

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import pickle
from google.colab.patches import cv2_imshow
import glob
from __future__ import division
import os.path
import os
from imageio import imread
import numbers
from pathlib import Path
import shutil
import random
import time
from tqdm import tqdm
import torch.utils.data as data
import torch.utils.data as data
import torch
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.init import kaiming_normal_, constant_
from torch.utils.tensorboard import SummaryWriter

# Part I - Load and Prepare the Dataset for the Model

There are a few Optical Flow Datasets we can use:

*   Flying Chairs
*   Scene Flow (KITTI)
*   Middleburry
*   MPI Sintel
*   Kinetics

For the purpose of this course, we'll use the [KITTI Dataset](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) as it's the closest to autonomous driving.

In [None]:
images_dataset = sorted(glob.glob("dataset/images_2/*.png"))
labels_dataset = sorted(glob.glob("dataset/flow_occ/*.png"))

print(len(images_dataset))
print(len(labels_dataset))

### 1.1 – Understand input/labels

Here's what we want:
*   **Input:** A pair of 2 consecutive images
*   **Labels:** The Flow Map


In [None]:
"""
Make a List of (Img1, Img2, Flow Map)
"""

images = []
for flow_map in labels_dataset:
    root_filename = flow_map[-13:-7]
    img1 = os.path.join("dataset/images_2/", root_filename+'_10.png')
    img2 = os.path.join("dataset/images_2/", root_filename+'_11.png')
    images.append([[img1, img2], flow_map])

In [None]:
# images = [[[RGB_image_t, RGB_image_t+1], flowMap], [[RGB_image2_t, RGB_image2_t+1], flowMap2], ...]
# 400 RGB_images, 200 flowMaps
print(images)

In [None]:
def bgr2rgb(image):
    """
    Convert BGR TO RGB
    """
    return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

In [None]:
cv2_imshow(bgr2rgb(cv2.imread(images[0][1])))

In [None]:
# SIMPLE WORKAROUND

# cv2.IMREAD_UNCHANGED (flag=-1): reads the image as is from the source. If the source image is an RGB, 
# it loads the image into array with Red, Green and Blue channels. If the source image is ARGB, it loads 
# the image with three color components along with the alpha or transparency channel.

#TODO: Read the Images with CV2 IMREAD and -1 as a Flag
yuv = cv2.imread(images[0][1], -1)
# print("yuv: ", yuv)
# yuv_img = yuv[:,:,2:0:-1].astype(np.float32)   # (375, 1242, 2)
# print("yuv_img: ", yuv_img)


#TODO: Convert from YUV to RGB 
rgb_map = cv2.cvtColor(yuv, cv2.COLOR_YUV2RGB)
# print("rgb_map: ", rgb_map)

# make the black background become white background 
rgb_map[np.where((rgb_map==[0,0,0]).all(axis=2))] = [255,255,255]

plt.imshow(rgb_map)
plt.show()

In [None]:
!pip install pypng
from read_kitti import read_png_file, flow_to_image
# Huge Thanks: https://github.com/liruoteng/OpticalFlowToolkit/tree/master/lib

In [None]:
"""
Let's read a triplet of images!
"""
idx = random.randint(0,200)

image_t0 = bgr2rgb(cv2.imread(images[idx][0][0])) #Read the first image at idx in RGB
image_t1 = bgr2rgb(cv2.imread(images[idx][0][1])) #Read the second image at idx in RGB
flo_path = images[idx][1] #Get the Flow Path

flow_label = read_png_file(flo_path)
rgb_map = flow_to_image(flow_label)

"""
Visualize the Data
"""

f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(30,20))
ax1.imshow(image_t0)
ax1.set_title('Image t0', fontsize=30)
ax2.imshow(image_t1)
ax2.set_title('Image t1', fontsize=30)
ax3.imshow(rgb_map)
ax3.set_title("Flow (label)", fontsize=30)

### 1.2 – Split the Data into Train/Test
We're going to split the dataset into training and testing. A good ratio would be 80% training and 20% testing.

In [None]:
def train_test_split(images, default_split=0.8):
    """
    Splits the Dataset Paths into Train/Test
    """
    split_values = np.random.uniform(0,1,len(images)) < default_split # Randomly decides if an image is train or test
    train_samples = [sample for sample, split in zip(images, split_values) if split]
    test_samples = [sample for sample, split in zip(images, split_values) if not split]
    return train_samples, test_samples

In [None]:
#Call Train/Test split function (easy easy)
train_samples, test_samples = train_test_split(images)

In [None]:
print(len(train_samples))
print(len(test_samples))

### 1.3 – Load the Images

So far, we have:
*   *images* – a list of triplet paths
*   *train_samples* and *test_samples* – these paths into two sets

Now, we need to **convert these lists of paths into actual tensors** of images for PyTorch. We'll also need to **perform some transform operations** such as **normalization, random cropping,** etc...
<p>


In [None]:
def load_flow_from_png(png_path):
    '''
    This is used to read flow label images from the KITTI Dataset
    '''
    # The Image is a 16 Bit(uint16) Image. We must read it with OpenCV and 
    # the flag cv2.IMREAD_UNCHANGED (-1)

    # The first channel denotes if the pixel is valid or not (1 if true, 0 otherwise),
    # the second channel contains the v-component and the third channel the u-component.
    flo_file = cv2.imread(png_path, -1)   # (375, 1242, 3)
    flo_img = flo_file[:,:,2:0:-1].astype(np.float32)   # (375, 1242, 2)

    # See the README File in the KITTI DEVKIT AND THE FLOW READER FUNCTIONS
    # To convert the u-/v-flow into floating point values, convert the value 
    # to float, subtract 2^15(32768) and divide the result by 64.0
    invalid = (flo_file[:,:,0] == 0)
    flo_img = flo_img - 32768
    flo_img = flo_img / 64

    # Valid and Small Flow = 1e-10
    flo_img[np.abs(flo_img) < 1e-10] = 1e-10

    # Invalid Flow = 0
    flo_img[invalid, :] = 0
    return flo_img

In [None]:
def KITTI_loader(root, path_imgs, path_flo):
    """
    Returns the Loaded Images in RGB, and the Loaded Optical Flow Labels
    """
    #TODO: Implement the function and return [[img1, img2], flow_image]
    imgs = [os.path.join(root, path) for path in path_imgs]
    flo = os.path.join(root, path_flo)
    # img[:,:,::-1] will do: bgr -> rgb or rgb -> bgr
    return [bgr2rgb(cv2.imread(img)).astype(np.float32) for img in imgs], load_flow_from_png(flo)

In [None]:
import flow_transforms

div_flow = 20 #Factor by which we divide the output (thus >=1). It makes training more stable to deal with low numbers than big ones.

#Normalized for the Flying Chair Dataset (https://github.com/ClementPinard/FlowNetPytorch/issues/101#issuecomment-805222823)
input_transform = transforms.Compose([flow_transforms.ArrayToTensor(), transforms.Normalize(mean=[0,0,0], std=[255,255,255]), transforms.Normalize(mean=[0.45,0.432,0.411], std=[1,1,1])])

target_transform = transforms.Compose([flow_transforms.ArrayToTensor(),transforms.Normalize(mean=[0,0],std=[div_flow,div_flow])])

co_transform = flow_transforms.Compose([flow_transforms.RandomCrop((320,448)), flow_transforms.RandomVerticalFlip(),flow_transforms.RandomHorizontalFlip()])

In [None]:
class ListDataset(data.Dataset):
    def __init__(self, path_list, transform=None, target_transform=None, co_transform=None, loader=KITTI_loader):
        self.root = os.getcwd()
        self.path_list = path_list
        self.transform = transform
        self.target_transform = target_transform
        self.co_transform = co_transform
        self.loader = loader

    def __getitem__(self, index):
      
        """
        In Python, __getitem__ is used to read values from a class. For example; read the transformed input files.
        Instead of calling the function .read(), we use __getitem__ to directly get the value.
        Similarly, __setitem__ can be used to fill values in a class.
        """
        inputs, target = self.path_list[index]
        inputs, target = self.loader(self.root, inputs, target)
        if self.co_transform is not None:
            inputs, target = self.co_transform(inputs, target)
        if self.transform is not None:
            inputs[0] = self.transform(inputs[0])
            inputs[1] = self.transform(inputs[1])
        if self.target_transform is not None:
            target = self.target_transform(target)
        return inputs, target

    def __len__(self):
        return len(self.path_list)

In [None]:
train_dataset = ListDataset(train_samples, input_transform, target_transform, co_transform, loader=KITTI_loader)

test_dataset = ListDataset(test_samples, input_transform, target_transform, flow_transforms.CenterCrop((370,1224)), loader=KITTI_loader)

If you don't understand the ListDataset and how the __ getitem() __ works, a good exercise would be to try and re-code this entire class into something you understand better. Otherwise, let's move on with the Optical Flow Network.

# Part II – Build a FlowNet Architecture

FlowNet has two variations:
*   **FlownetS** or Simple, which is a simple version using 2D Convolutions to get to the optical flow computation
*   **FlownetC** or Correlated, which adds a correlation layer and process images separately

In both, there are two main parts:
*   An **Encoder** Part, learning features
*   A **Refinement** Part, playing the decoder and creating the output Flow Mask.

It looks like we've got some work! In this workshop, we'll build the FlowNet S architecture, as the researchers mentioned it worked best on KITTI! You can find the implementation for the FlowNet C architecture in the course for your information.
<p>
Here's a look at the flownet S model:

![flownets](https://miro.medium.com/max/1400/0*XVygX0wF3enVQJLe.)  

And the refinement part:<p>

![refinement](https://i1.wp.com/syncedreview.com/wp-content/uploads/2017/09/image-14.png?fit=692%2C268&ssl=1)

### 2.1 – Code the necessary operations

Operations we'll need:
* Convolutions in 2D (with or without batchnorm)
* Output Flow Prediction
* Transposed Convolutions
* Crop

In [None]:
# Define a Convolution with LeakyReLU, and with or without batchnorm
def conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1):
    if batchNorm:
        #TODO: Code a Convolution in 2D with Batchnorm and LeakyReLU of 0.1
        return nn.Sequential(
        nn.Conv2d(in_planes, out_planes, kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=False),
        nn.BatchNorm2d(out_planes),
        nn.LeakyReLU(0.1, inplace=True)
        )
    else:
        #TODO: Code a Convolution in 2D with LeakyReLU of 0.1 and without Batchnorm
        return nn.Sequential(
        nn.Conv2d(in_planes, out_planes, kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=False),
        nn.LeakyReLU(0.1, inplace=True)
        )

In [None]:
#Define the last convolution (optical flow map prediction)
def predict_flow(in_planes):
    # TODO: Code a Convolution to predict the output
    # The depth of the flow map is 2, which contains (u, v)
    return nn.Conv2d(in_planes, 2, kernel_size=3, stride=1, padding=1, bias=False)
           # The kernel size in the paper is 5, but we need 3 to make the dimension(width and height) right

In [None]:
#Define a Deconvolution
def deconv(in_planes, out_planes):
    #TODO: Code a Transposed Convolution to upsample the results
    return nn.Sequential(
        nn.ConvTranspose2d(in_planes, out_planes, 4, stride=2, padding=1, bias=False),
        nn.LeakyReLU(0.1, inplace=True)
    )

In [None]:
#Define a Cropping Operation
def crop_like(input, target):
    if input.size()[2:] == target.size()[2:]:
        # input shape = [B, C, H, W]
        # if the width and height of input are the same as target, return input
        return input
    else:
        # otherwise, crop the width and height
        return input[:, :, :target.size(2), :target.size(3)]

In [None]:
"""
test crop_like function
"""
input = torch.from_numpy(np.array([[[[1,2,3], [1,2,3], [1,2,3], [1,2,3], [1,2,3], [1,2,3], [1,2,3]],
                  [[1,2,3], [1,2,3], [1,2,3], [1,2,3], [1,2,3], [1,2,3], [1,2,3]],
                  [[1,2,3], [1,2,3], [1,2,3], [1,2,3], [1,2,3], [1,2,3], [1,2,3]],
                  [[1,2,3], [1,2,3], [1,2,3], [1,2,3], [1,2,3], [1,2,3], [1,2,3]],
                  [[1,2,3], [1,2,3], [1,2,3], [1,2,3], [1,2,3], [1,2,3], [1,2,3]]]]))
target = torch.from_numpy(np.array([[[[9,8,7,6,5,4,3,2,1], [9,8,7,6,5,4,3,2,1], [9,8,7,6,5,4,3,2,1]],
                   [[9,8,7,6,5,4,3,2,1], [9,8,7,6,5,4,3,2,1], [9,8,7,6,5,4,3,2,1]],
                   [[9,8,7,6,5,4,3,2,1], [9,8,7,6,5,4,3,2,1], [9,8,7,6,5,4,3,2,1]]]]))
# crop = crop_like(input, target)
print(target.size(2))
print(target.size(3))
print("target shape: ", target.size())
print("original input shape: ", input.size())
print("After cropping: ", input[:, :, :target.size(2), :target.size(3)].size())

### 2.2 – Create the FlowNet S Model

What PyTorch needs to create a model:

*   An **__init __() function** with a list of all the operations. Weights can be initialized here.
*   A **forward()** function that will take an input and compute the flow. A note: In case of training, we want to return all flows, in case of testing, we only want the output flow.
*   **Weights and Biases**

In [None]:
class FlowNetS(nn.Module):
    expansion = 1

    def __init__(self,batchNorm=True):
        super(FlowNetS,self).__init__()

        #ENCODER PART
        #TODO: Code the Encoder functions
        self.batchNorm = batchNorm
        self.conv1 =   conv(self.batchNorm, 6,   64,   7, 2)   # padding = 3, stride = 2  (padding = (kernel_size-1) // 2)
        self.conv2 =   conv(self.batchNorm, 64, 128,   5, 2)   # padding = 2, stride = 2
        self.conv3 =   conv(self.batchNorm, 128, 256,  5, 2)   # padding = 2, stride = 2
        self.conv3_1 = conv(self.batchNorm, 256, 256,  3)      # padding = 1, stride = 1
        self.conv4 =   conv(self.batchNorm, 256, 512,  3, 2)   # padding = 1, stride = 2
        self.conv4_1 = conv(self.batchNorm, 512, 512,  3)      # padding = 1, stride = 1
        self.conv5 =   conv(self.batchNorm, 512, 512,  3, 2)   # padding = 1, stride = 2
        self.conv5_1 = conv(self.batchNorm, 512, 512,  3)      # padding = 1, stride = 1
        self.conv6 =   conv(self.batchNorm, 512, 1024, 3, 2)   # padding = 1, stride = 2
        self.conv6_1 = conv(self.batchNorm, 1024, 1024)     # NOTE: this one doesn't exist in the paper, but it does in their implementation.

        # NOTE: in real implementation, it does not have 1x1 convlution

        #REFINEMENT PART
        #TODO: Code the Decoder functions and Flow Predictions
        # NOTE: need to use the same name as the pre-trained weights
        self.deconv5 = deconv(1024, 512)
        self.deconv4 = deconv(1026, 256)   # inputDepth = 512 + 512(conv5_1) + 2(flow6) = 1026   NOTE: flow6 dosen't exist in the paper
        self.deconv3 = deconv(770, 128)    # inputDepth = 256 + 512(conv4_1) + 2(flow5) = 1026
        self.deconv2 = deconv(386, 64)     # inputDepth = 128 + 256(conv3_1) + 2(flow4) = 386

        self.predict_flow6 = predict_flow(1024)   # NOTE: this one doesn't exist in the paper, but it does in their implementation
        self.predict_flow5 = predict_flow(1026)   
        self.predict_flow4 = predict_flow(770)
        self.predict_flow3 = predict_flow(386)
        self.predict_flow2 = predict_flow(194)    # 194 = 64(deconv2) + 128(conv2) + 2(flow3)

        self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, stride=2, padding=1, bias=False)
        self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, stride=2, padding=1, bias=False)
        self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, stride=2, padding=1, bias=False)
        self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, stride=2, padding=1, bias=False)
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                # Initialize the Convolutions with "He Initialization" to 0.1 (https://arxiv.org/pdf/1502.01852.pdf)
                kaiming_normal_(m.weight, 0.1)
                if m.bias is not None:
                    # Initialize all bias to 0
                    constant_(m.bias, 0)
            # Initialize the BatchNorm Convolutions with "He Initialization" to 1 (https://arxiv.org/pdf/1502.01852.pdf)
            elif isinstance(m, nn.BatchNorm2d):
                constant_(m.weight, 1)
                constant_(m.bias, 0)

    def forward(self, x):
        #TODO: ENCODER
        # NOTE: the input size is different than the paper
        # x = [batch_size, 6, 320, 448]   
        conv1 = self.conv1(x)           # batch_size x 64 x 160 x 224
        
        conv2 = self.conv2(conv1)       # batch_size x 128 x 80 x 112

        conv3 = self.conv3(conv2)       # batch_size x 256 x 40 x 56
        conv3_1 = self.conv3_1(conv3)   # batch_size x 256 x 40 x 56

        conv4 = self.conv4(conv3_1)     # batch_size x 512 x 20 x 28
        conv4_1 = self.conv4_1(conv4)   # batch_size x 512 x 20 x 28

        conv5 = self.conv5(conv4_1)     # batch_size x 512 x 10 x 14
        conv5_1 = self.conv5_1(conv5)   # batch_size x 512 x 10 x 14

        conv6 = self.conv6(conv5_1)     # batch_size x 1024 x 5 x 7
        conv6_1 = self.conv6_1(conv6)   # batch_size x 1024 x 5 x 7
        
        #TODO: REFINEMENT
        flow6 = self.predict_flow6(conv6_1)                                 # batch_size x 2 x 5 x 7
        flow6_upsampling = self.upsampled_flow6_to_5(flow6)                 # batch_size x 2 x 10 x 14
        flow6_upsampling = crop_like(flow6_upsampling, conv5_1)             # batch_size x 2 x 10 x 14

        deconv5 = self.deconv5(conv6_1)                                     # batch_size x 512 x 10 x 14
        deconv5 = crop_like(deconv5, conv5_1)                               # batch_size x 512 x 10 x 14
        concat5 = torch.cat((conv5_1, deconv5, flow6_upsampling), 1)        # batch_size x 1026 x 10 x 14 
        flow5 = self.predict_flow5(concat5)                                 # batch_size x 2 x 10 x 14      NOTE: the flow map size is different than the paper
        flow5_upsampling = self.upsampled_flow5_to_4(flow5)                 # batch_size x 2 x 20 x 28
        flow5_upsampling = crop_like(flow5_upsampling, conv4_1)             # batch_size x 2 x 20 x 28

        deconv4 = self.deconv4(concat5)                                     # batch_size x 256 x 20 x 28
        deconv4 = crop_like(deconv4, conv4_1)                               # batch_size x 256 x 20 x 28
        concat4 = torch.cat((conv4_1, deconv4, flow5_upsampling), 1)        # batch_size x 770 x 20 x 28
        flow4 = self.predict_flow4(concat4)                                 # batch_size x 2 x 20 x 28
        flow4_upsampling = self.upsampled_flow4_to_3(flow4)                 # batch_size x 2 x 40 x 56
        flow4_upsampling = crop_like(flow4_upsampling, conv3_1)             # batch_size x 2 x 40 x 56

        deconv3 = self.deconv3(concat4)                                     # batch_size x 128 x 40 x 56
        deconv3 = crop_like(deconv3, conv3_1)                               # batch_size x 128 x 40 x 56
        concat3 = torch.cat((conv3_1, deconv3, flow4_upsampling), 1)        # batch_size x 386 x 40 x 56
        flow3 = self.predict_flow3(concat3)                                 # batch_size x 2 x 40 x 56
        flow3_upsampling = self.upsampled_flow3_to_2(flow3)                 # batch_size x 2 x 80 x 112
        flow3_upsampling = crop_like(flow3_upsampling, conv2)               # batch_size x 2 x 80 x 112

        deconv2 = self.deconv2(concat3)                                     # batch_size x 64 x 80 x 112
        deconv2 = crop_like(deconv2, conv2)                                 # batch_size x 64 x 80 x 112
        concat2 = torch.cat((conv2, deconv2, flow3_upsampling), 1)          # batch_size x 194 x 80 x 112
        flow2 = self.predict_flow2(concat2)                                 # batch_size x 2 x 80 x 112

        if self.training:
            return flow2,flow3,flow4,flow5,flow6
        else:
            return flow2

    def weight_parameters(self):
        return [param for name, param in self.named_parameters() if 'weight' in name]

    def bias_parameters(self):
        return [param for name, param in self.named_parameters() if 'bias' in name]


### 2.3 – Create an Empty FlowNet S Model (with or without pretrained weights)

In [None]:
#Define FlowNet S
def flownets(data=None, batchNorm=False):
    """FlowNetS model architecture from the
    "Learning Optical Flow with Convolutional Networks" paper (https://arxiv.org/abs/1504.06852)
    Args:
        data : pretrained weights of the network. will create a new one if not set
    """
    model = FlowNetS(batchNorm=batchNorm)
    if data is not None:
        model.load_state_dict(data['state_dict'])
    return model

If you'd like to create an empty model, simply call the flownets() function without parameters.

⚠️ However, **the KITTI Dataset only has 200 data points**. It might be very hard to converge.

👉 A good solution is to **load the weights already trained on the Flying Chair dataset** by Clement Pinard, and then **finetune the model on KITTI**. It's called Transfer Learning; and can also work in these cases when the dataset is poor.

In [None]:
model_to_load = "models/flownets_bn_EPE2.459.pth.tar"
#model_to_load = "models/model_best.pth.tar"
checkpoint = torch.load(model_to_load) #if CPU use second parameter: map_location=torch.device("cpu")
model = flownets(data=checkpoint, batchNorm=True)

print(model)#model = flownets()

# Part III - Train the Model on KITTI
To train a Deep Learning Model, we'll need:

*   Data
*   A Model
*   Parameters
*   A Loss Function


### 3.1 – Define Hyperparameters and Variables


In [None]:
# TODO: fill the hyperparameters
arch = "flownetsbn"
solver = "adam"
epochs = 200
epoch_size = 0
batch_size = 64
learning_rate = 1e-4
workers = 4   # how many subprocesses to use for data loading.
pretrained = None
bias_decay = 0
weight_decay = 4e-4
momentum = 0.9
milestones= [100, 150, 200] # epoch by which we divide learning rate by 2
             
save_path = '{},{},{}epochs,b{},lr{}'.format(arch, solver, epochs, batch_size, learning_rate)

if not os.path.exists(save_path):
    os.makedirs(save_path)

# When an optimizer is instantiated, parameter group is the as well as a variety of hyperparameters such as the learning rate. 
# Optimizers are also passed other hyperparameters specific to each optimization algorithm. It can be extremely useful to set up 
# groups of these hyperparameters, which can be applied to different parts of the model. This can be achieved by creating a 
# parameter group, essentially a list of dictionaries that can be passed to the optimizer.

# The param variable must either be an iterator over a torch.tensor or a Python dictionary specifying a default value of optimization options. 
# Note that the parameters themselves need to be specified as an ordered collection, such as a list, so that parameters are a consistent sequence
param_groups = [{'params': model.bias_parameters(), 'weight_decay': bias_decay},
                {'params': model.weight_parameters(), 'weight_decay': weight_decay}]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
    model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = True

if solver == 'adam':
    #TODO: Create a Pytorch ADAM Optimizer
    optimizer = torch.optim.Adam(param_groups, lr=learning_rate)
elif solver == 'sgd':
    #TODO: Create a Pytorch SGD
    optimizer = torch.optim.SGD(param_groups, lr=learning_rate, momentum=momentum)

Writers can be used to plug values or for tensorboard visualization

In [None]:
train_writer = SummaryWriter(os.path.join(save_path,'train'))
test_writer = SummaryWriter(os.path.join(save_path,'test'))
output_writers = []

for i in range(3):
    output_writers.append(SummaryWriter(os.path.join(save_path,'test',str(i))))

In [None]:
# pin_memory: If True, the data loader will copy Tensors into CUDA pinned memory before returning them.
# shuffle: set to True to have the data reshuffled at every epoch
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=workers, pin_memory=True, shuffle=True)
val_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,num_workers=workers, pin_memory=True, shuffle=False)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.5)

In [None]:
for imgs, labels in train_loader:
    print("{} images with t=0 and t=1".format(len(imgs)))       # 2 images with t=0 and t=1
    print("image shape = [B={}, C={}, H={}, W={}]".format(len(imgs[0]), len(imgs[0][0]), len(imgs[0][0][0]), len(imgs[0][0][0][0])))   # image shape = [B=64, C=3, H=320, W=448]
    print("label shape = [B={}, C={}, H={}, W={}]".format(len(labels), len(labels[0]), len(labels[0][0]), len(labels[0][0][0])))       # label shape = [B=64, C=2, H=320, W=448]
    input = torch.cat(imgs, 1).to(device)
    output = model(input)
    print("{} predictions".format(len(output)))     # 5 predictions: flow2, flow3, flow4, flow5, flow6
    print("predicted flow2 shape = [B={}, C={}, H={}, W={}]".format(len(output[0]), len(output[0][0]), len(output[0][0][0]), len(output[0][0][0][0])))   # predicted flow2 shape = [B=64, C=2, H=80, W=112]
    print("predicted flow3 shape = [B={}, C={}, H={}, W={}]".format(len(output[0]), len(output[0][0]), len(output[1][0][0]), len(output[1][0][0][0])))   # predicted flow3 shape = [B=64, C=2, H=40, W=56]
    print("predicted flow4 shape = [B={}, C={}, H={}, W={}]".format(len(output[0]), len(output[0][0]), len(output[2][0][0]), len(output[2][0][0][0])))   # predicted flow4 shape = [B=64, C=2, H=20, W=28]
    print("predicted flow5 shape = [B={}, C={}, H={}, W={}]".format(len(output[0]), len(output[0][0]), len(output[3][0][0]), len(output[3][0][0][0])))   # predicted flow5 shape = [B=64, C=2, H=10, W=14]
    print("predicted flow6 shape = [B={}, C={}, H={}, W={}]".format(len(output[0]), len(output[0][0]), len(output[4][0][0]), len(output[4][0][0][0])))   # predicted flow6 shape = [B=64, C=2, H=5, W=7]
    break

for imgs, labels in val_loader:
    print("{} images with t=0 and t=1".format(len(imgs)))       # 2 images with t=0 and t=1
    # we only has 200 images in total, so the batch size of the val_loader will depend on test_samples, something like 44, 46...
    print("image shape = [B={}, C={}, H={}, W={}]".format(len(imgs[0]), len(imgs[0][0]), len(imgs[0][0][0]), len(imgs[0][0][0][0])))   # image shape = [B=, C=3, H=370, W=1224]
    print("label shape = [B={}, C={}, H={}, W={}]".format(len(labels), len(labels[0]), len(labels[0][0]), len(labels[0][0][0])))       # label shape = [B=, C=2, H=370, W=1224]
    input = torch.cat(imgs, 1).to(device)
    output = model(input)
    print("{} predictions".format(len(output)))     # 5 predictions: flow2, flow3, flow4, flow5, flow6
    print("predicted flow2 shape = [B={}, C={}, H={}, W={}]".format(len(output[0]), len(output[0][0]), len(output[0][0][0]), len(output[0][0][0][0])))   # predicted flow2 shape = [B=, C=2, H=93, W=306]
    print("predicted flow3 shape = [B={}, C={}, H={}, W={}]".format(len(output[0]), len(output[0][0]), len(output[1][0][0]), len(output[1][0][0][0])))   # predicted flow3 shape = [B=, C=2, H=47, W=153]
    print("predicted flow4 shape = [B={}, C={}, H={}, W={}]".format(len(output[0]), len(output[0][0]), len(output[2][0][0]), len(output[2][0][0][0])))   # predicted flow4 shape = [B=, C=2, H=24, W=77]
    print("predicted flow5 shape = [B={}, C={}, H={}, W={}]".format(len(output[0]), len(output[0][0]), len(output[3][0][0]), len(output[3][0][0][0])))   # predicted flow5 shape = [B=, C=2, H=12, W=39]
    print("predicted flow6 shape = [B={}, C={}, H={}, W={}]".format(len(output[0]), len(output[0][0]), len(output[4][0][0]), len(output[4][0][0][0])))   # predicted flow6 shape = [B=, C=2, H=6, W=20]

    break

### 3.2 – Define the Loss Function as the End Point Error (EPE)

Flownet (and most optical flow algorithms) use the end point error (EPE) as a metric for the loss function.
It is simply the euclidean distance between the real value (ground truth) and the predicted one.<p>
EPE = ![](https://latex.codecogs.com/gif.latex?%5Cinline%20%5Cleft%20%5C%7CV_%7Best%7D%20-%20V_%7Bgt%7D%20%5Cright%20%5C%7C)

As the model outputs different flow maps, at different scales, we'll need to create different EPE functions.

In [None]:
import numpy as np
target = np.array([[0,1,2,3,4], [9,0,8,0,7]])
mask = (target[:,0] == 0) & (target[:,1]==0)
print(~mask)

In [None]:
def EPE(input_flow, target_flow, sparse=False, mean=True):
    #TODO: Define the norm between target and prediction
    EPE_map = torch.norm(target_flow - input_flow, p=2, dim=1)

    batch_size = EPE_map.size(0)
    if sparse:
        # invalid flow is defined with both flow coordinates to be exactly 0
        mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0)
        # print("mask: ", mask)
        EPE_map = EPE_map[~mask]
        # print("EPE_map: ", EPE_map)
    if mean:
        return EPE_map.mean()
    else:
        return EPE_map.sum()/batch_size

In [None]:
def realEPE(output, target, sparse=False):
    """
    Since the prediction is not the same size as the ground truth,
    we need to resize it to the same as the ground truth, and then calculate the EPE
    """
    b, _, h, w = target.size()
    upsampled_output = F.interpolate(output, (h,w), mode='bilinear', align_corners=False) # used to resize the output (import torch.nn.functional as F)
    return EPE(upsampled_output, target, sparse, mean=True)

In [None]:
def sparse_max_pool(input, size):
    '''
    Downsample the input by considering 0 values as invalid.
    Unfortunately, no generic interpolation mode can resize a sparse map correctly,
    the strategy here is to use max pooling for positive values and "min pooling"
    for negative values, the two results are then summed.
    This technique allows sparsity to be minized, contrary to nearest interpolation,
    which could potentially lose information for isolated data points.
    '''

    positive = (input > 0).float()
    negative = (input < 0).float()
    output = F.adaptive_max_pool2d(input * positive, size) - F.adaptive_max_pool2d(-input * negative, size)
    return output


def multiscaleEPE(network_output, target_flow, weights=None, sparse=False):
    def one_scale(output, target, sparse):

        b, _, h, w = output.size()
        if sparse:
            target_scaled = sparse_max_pool(target, (h, w))
        else:
            target_scaled = F.interpolate(target, (h, w), mode='area')
        return EPE(output, target_scaled, sparse, mean=False)
    

    if type(network_output) not in [tuple, list]:
        # if the network_output is not a tuple or list, make it a list
        network_output = [network_output]
    if weights is None:
        weights = [0.005, 0.01, 0.02, 0.08, 0.32]  # as in original article
    assert(len(weights) == len(network_output))

    loss = 0
    for output, weight in zip(network_output, weights):
        loss += weight * one_scale(output, target_flow, sparse)
    return loss

### 3.3 Create functions to train and validate

We'll begin by using something quite common with PyTorch called an **AverageMeter()**. It is simply a class that **stores the values** for our losses, and that can do an average, median, or whatever we want. **It's quite useful in our case where we have to average a loss over several pixels and several frames.**

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __repr__(self):
        return '{:.3f} ({:.3f})'.format(self.val, self.avg)

In [None]:
def save_checkpoint(state, is_best, save_path, filename='checkpoint.pth.tar'):
    """
    Save a checkpoint to continue training after a wifi problem 🙃
    """
    torch.save(state, os.path.join(save_path,filename))
    if is_best:
        shutil.copyfile(os.path.join(save_path,filename), os.path.join(save_path,'model_best.pth.tar'))

In [None]:
"""
The Train() function is actually a function to train on ONE EPOCH.
"""

def train(train_loader, model, optimizer, epoch, train_writer):
    global n_iter, div_flow
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    flow2_EPEs = AverageMeter()

    multiscale_weights = [0.005,0.01,0.02,0.08,0.32] # from output_flow to flow6

    epoch_size = len(train_loader)

    # switch to train mode
    model.train()

    end = time.time()

    for i, (input, target) in enumerate(train_loader):
        # Go through the entire data loader
        data_time.update(time.time() - end)

        target = target.to(device)
        input = torch.cat(input,1).to(device)

        # Forward Pass
        output = model(input)  #TODO: Run a Forward Pass

        # Since Target pooling is not very precise when sparse,
        # take the highest resolution prediction and upsample it instead of downsampling target
        h, w = target.size()[-2:]
        output = [F.interpolate(output[0], (h,w)), *output[1:]]

        # Compute Multiscale EPE (for all predict flows)
        loss = multiscaleEPE(output, target, weights=multiscale_weights, sparse=True) #TODO: Run a  Multiscale EPE

        # Compute the Output EPE
        flow2_EPE = div_flow * realEPE(output[0], target, sparse=True) #TODO: Run the Flow Output EPE (div_flow????

        # Record loss and EPE
        losses.update(loss.item(), target.size(0))
        train_writer.add_scalar('train_loss', loss.item(), n_iter)
        flow2_EPEs.update(flow2_EPE.item(), target.size(0))

        # compute gradient and do optimization step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % 1 == 0:
            # Every 2 steps, print the Loss and EPE
            print('Epoch: [{0}][{1}/{2}]\t Time {3}\t Data {4}\t Loss {5}\t EPE {6}'
                  .format(epoch, i, epoch_size, batch_time,
                          data_time, losses, flow2_EPEs))
        n_iter += 1
        if i >= epoch_size:
            break
    #Return the Average Loss and Average EPE on the Training Set
    return losses.avg, flow2_EPEs.avg


In [None]:
def validate(val_loader, model, epoch, output_writers):
    global div_flow
    batch_time = AverageMeter()
    flow2_EPEs = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    for i, (input, target) in enumerate(val_loader):
        #Go through the entire validation loader

        target = target.to(device)
        input = torch.cat(input,1).to(device)

        # Forward Pass
        output = model(input) #TODO: Run a forward pass

        #Compute the EPE
        flow2_EPE = div_flow * realEPE(output, target, sparse=True)  #TODO: Run the output EPE

        # record EPE
        flow2_EPEs.update(flow2_EPE.item(), target.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i < len(output_writers):  # log first output of first batches
            if epoch == 0:
                mean_values = torch.tensor([0.45,0.432,0.411], dtype=input.dtype).view(3,1,1)
                output_writers[i].add_image('GroundTruth', flow2rgb(div_flow * target[0], max_value=10), 0)
                output_writers[i].add_image('Inputs', (input[0,:3].cpu() + mean_values).clamp(0,1), 0)
                output_writers[i].add_image('Inputs', (input[0,3:].cpu() + mean_values).clamp(0,1), 1)
            output_writers[i].add_image('FlowNet Outputs', flow2rgb(div_flow * output[0], max_value=10), epoch)

        if i % 5 == 0:
            print('Test: [{0}/{1}]\t Time {2}\t EPE {3}'
                  .format(i, len(val_loader), batch_time, flow2_EPEs))

    print(' * EPE {:.3f}'.format(flow2_EPEs.avg))
    # Return Average EPE on Validation Set
    return flow2_EPEs.avg

### 3.4 – Train the Model and Visualize the Output

In [None]:
def flow2rgb(flow_map, max_value):
    """
    Used to visualize the output after a forward pass
    https://github.com/ClementPinard/FlowNetPytorch/issues/86
    """
    flow_map_np = flow_map.detach().cpu().numpy()
    _, h, w = flow_map_np.shape
    flow_map_np[:,(flow_map_np[0] == 0) & (flow_map_np[1] == 0)] = float('nan')
    rgb_map = np.ones((3,h,w)).astype(np.float32)
    if max_value is not None:
        normalized_flow_map = flow_map_np / max_value
    else:
        normalized_flow_map = flow_map_np / (np.abs(flow_map_np).max())
    rgb_map[0] += normalized_flow_map[0]
    rgb_map[1] -= 0.5*(normalized_flow_map[0] + normalized_flow_map[1])
    rgb_map[2] += normalized_flow_map[1]
    return rgb_map.clip(0,1)

In [None]:
save_path = '{},{},{}epochs{},b{},lr{}'.format(arch, solver, epochs, ',epochSize'+str(epoch_size) if epoch_size > 0 else '', batch_size, learning_rate)
n_iter = 0
best_EPE = -1

# We'll start from a model pretrained on "Flying Chairs" and finetune it to KITTI

save_path = os.path.join("models",save_path)

print('=> will save everything to {}'.format(save_path))

if not os.path.exists(save_path):
    os.makedirs(save_path)

for epoch in range(0, epochs):
    scheduler.step()

    # Train for one epoch
    train_loss, train_EPE = train(train_loader, model, optimizer, epoch, train_writer)  #TODO: Call the Train function
    train_writer.add_scalar('mean EPE', train_EPE, epoch)

    # Evaluate on validation set
    with torch.no_grad():
        endpointerror = validate(val_loader, model, epoch, output_writers)  #TODO: Call the Validate function
    test_writer.add_scalar('mean EPE', endpointerror, epoch)

    # Store the best EPE
    if best_EPE < 0:
        best_EPE = endpointerror

    is_best = endpointerror < best_EPE
    best_EPE = min(endpointerror, best_EPE)
    save_checkpoint({
        'epoch': epoch + 1,
        'arch': arch,
        'state_dict': model.module.state_dict(),
        'best_EPE': best_EPE,
        'div_flow': div_flow
    }, is_best, save_path)

# Part IV – Run the Model

### 4.1 – On 2 Images

In [None]:
input_transform=transforms.Compose([flow_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0,0,0], std=[255,255,255]),
        transforms.Normalize(mean=[0.411,0.432,0.45], std=[1,1,1])
    ])

network_data = torch.load("models/model_best.pth.tar")
#network_data = torch.load("models/flownetsbn,adam,200epochs,b64,lr0.001/checkpoint.pth.tar")
div_flow = network_data['div_flow']

model = flownets(network_data, batchNorm=True).to(device)

model.eval()

cudnn.benchmark = True

In [None]:
idx = random.randint(0,len(train_samples))
img1_file = train_samples[idx][0][0]
img2_file = train_samples[idx][0][1]
flow_target = flow_to_image(read_png_file(train_samples[idx][1]))

with torch.no_grad():
    img1 = input_transform(imread(img1_file))
    img2 = input_transform(imread(img2_file))
    input_var = torch.cat([img1, img2]).unsqueeze(0)
    input_var = input_var.to(device)
    output = model(input_var)

    for suffix, flow_output in zip(['flow', 'inv_flow'], output):
        filename = img1_file[:-4]+"flow"
        rgb_flow = flow2rgb(div_flow * flow_output, max_value=None)
        rgb_flow= (rgb_flow * 255).astype(np.uint8).transpose(1,2,0)

f, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(30,20))
ax0.imshow(cv2.imread(img1_file)[:,:,::-1])
ax0.set_title("Original Image", fontsize=30)
ax1.imshow(rgb_flow)
ax1.set_title('Prediction', fontsize=30)
ax2.imshow(flow_target)
ax2.set_title('Ground Truth', fontsize=30)


![](https://miro.medium.com/max/592/0*tRzHPmhbfDOfH6qw.jpg)

### 4.2 – On a Video

In [None]:
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [None]:
video_idx = 2
video_images = sorted(glob.glob("videos/video"+str(video_idx)+"/*.png"))
result_video = []

for idx_run, img in enumerate(video_images):
    if idx_run==0:
        im1 = imread(img)
        idx_run+=1
    else:
        im2 = imread(img)
        with torch.no_grad():
            img1 = input_transform(im1)
            img2 = input_transform(im2)
            input_var = torch.cat([img1, img2]).unsqueeze(0)
            input_var = input_var.to(device)

            output = model(input_var)

            for suffix, flow_output in zip(['flow', 'inv_flow'], output):
                rgb_flow = flow2rgb(div_flow * flow_output, max_value=None)
                rgb_flow = (rgb_flow * 255).astype(np.uint8).transpose(1,2,0)
                result_video.append(cv2.cvtColor(rgb_flow, cv2.COLOR_RGB2BGR))
        

In [None]:
out = cv2.VideoWriter("output/out-"+str(video_idx)+".mp4",cv2.VideoWriter_fourcc(*'MP4V'), 15.0, (311 ,94))

for i in range(len(result_video)):
    out.write(result_video[i])
out.release()

In [None]:
from IPython.display import HTML
from base64 import b64encode
mp4 = open("output/out-"+str(video_idx)+".mp4",'rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML("""
<video width=800 controls>
      <source src="%s" type="video/mp4">
</video>
""" % data_url)

In [None]:
video_idx = 3
video_images = sorted(glob.glob("video"+str(video_idx)+"/*.png"))
vid = []
for idx_run, img in enumerate(video_images):
    vid.append(cv2.imread(img).astype(np.uint8))

out = cv2.VideoWriter("output/out-"+str(video_idx)+".mp4",cv2.VideoWriter_fourcc(*'MP4V'), 15.0, (1242 ,375))

for i in range(len(vid)):
    out.write(vid[i])
out.release()

from IPython.display import HTML
from base64 import b64encode
mp4 = open("output/out-"+str(video_idx)+".mp4",'rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML("""
<video width=800 controls>
      <source src="%s" type="video/mp4">
</video>
""" % data_url)