<a href="https://colab.research.google.com/github/itberrios/CV_projects/blob/main/RAFT/RAFT_deep_dive.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **RAFT Deep Dive**

## NOTE: GPU is required for this tutorial

Get RAFT and RAFT helper script

In [None]:
!git clone https://github.com/princeton-vl/RAFT.git

In [None]:
!wget https://raw.githubusercontent.com/itberrios/CV_projects/main/RAFT/raft_utils.py

In [None]:
import os
import sys
import numpy as np
import cv2
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

from raft_utils import *

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# add RAFT to core path
sys.path.append('RAFT/core')

In [None]:
# download models
%cd RAFT
!./download_models.sh
%cd ..

### Test Run RAFT

In [None]:
demo_path = 'RAFT/demo-frames'
frame1 = cv2.imread(os.path.join(demo_path, 'frame_0020.png'))
frame2 = cv2.imread(os.path.join(demo_path, 'frame_0021.png'))

frame1 = cv2.cvtColor(frame1, cv2.COLOR_BGR2RGB)
frame2 = cv2.cvtColor(frame2, cv2.COLOR_BGR2RGB)

In [None]:
_, ax = plt.subplots(1, 2, figsize=(20, 10))
ax[0].imshow(frame1)
ax[1].imshow(frame2);

In [None]:
# load model
model = load_model("RAFT/models/raft-sintel.pth", args=Args())

# predict Optical Flow
flow_iters = inference(model, frame1, frame2, device='cuda', test_mode=False)

In [None]:
_, (ax0, ax1) = plt.subplots(1,2, figsize=(20,10))

ax0.imshow(get_viz(flow_iters[0]))
ax0.set_title('first flow iteration')
ax1.imshow(get_viz(flow_iters[-1]))
ax1.set_title('final flow iteration');

Get final optical flow iteration and find some pixels of interest i.e. pixels with high flow and pixels with low flow

In [None]:
flow = flow_iters[-1].squeeze(0).cpu().numpy()
abs_flow = np.abs(flow)
flow.shape

In [None]:
# highest abs flow in each direction
hi_flow_1 = np.where(abs_flow == abs_flow[0, :, :].max()) # u - horizontal
hi_flow_2 = np.where(abs_flow == abs_flow[1, :, :].max()) # v - vertical

# lowest abs flow in each direction
lo_flow_1 = np.where(abs_flow == abs_flow[0, :, :].min()) # u - horizontal
lo_flow_2 = np.where(abs_flow == abs_flow[1, :, :].min()) # v - vertical

In [None]:
flow[lo_flow_1], flow[lo_flow_2], flow[hi_flow_1], flow[hi_flow_2]

# **Explore Different Blocks of RAFT**

First preprocess the data

In [None]:
demo_path = 'RAFT/demo-frames'
frame1 = cv2.imread(os.path.join(demo_path, 'frame_0020.png'))
frame2 = cv2.imread(os.path.join(demo_path, 'frame_0021.png'))

frame1 = cv2.cvtColor(frame1, cv2.COLOR_BGR2RGB)
frame2 = cv2.cvtColor(frame2, cv2.COLOR_BGR2RGB)


# neural net processing
frame1 = process_img(frame1, device='cuda')
frame2 = process_img(frame2, device='cuda')

padder = InputPadder(frame1.shape, mode='sintel')
frame1, frame2 = padder.pad(frame1, frame2)

frame1 = 2 * (frame1 / 255.0) - 1.0
frame2 = 2 * (frame2 / 255.0) - 1.0

frame1 = frame1.contiguous()
frame2 = frame2.contiguous()

## **Explore the Correlation Block**

In [None]:
from corr import CorrBlock

# compute feature maps
with torch.autocast(device_type='cuda', enabled=True):
    fmap1, fmap2 = model.module.fnet([frame1, frame2])

fmap1 = fmap1.float()
fmap2 = fmap2.float()

# compute correlation pyramid
corr_fn = CorrBlock(fmap1, fmap2, num_levels=4, radius=4)

In [None]:
# check sizes of correlation pyramid
print(frame1.shape)
cov_mats = []
for i in range(4):
  print(corr_fn.corr_pyramid[i].shape)
  cov_mats.append(np.cov(corr_fn.corr_pyramid[i].detach().cpu().numpy().reshape((7040, -1), order='c'), rowvar=False))

From this print out we can see that each fine pixel (from frame 1) has a corresponding feature map. As we go up the pyramid, we can see that these feature maps get smaller due to the average pooling operation. The average pooling operation should also introduce correlations. For example in the first level (no pooling) each pixel from frame 2 corresponds to a single pixel from frame 1. In the second level, each pixel from frame 2 will correspond to four pixels in frame 1 and so fourth. So we are essentially introducing greater and greater spatial correlations by average pooling the last two dimensions of the correlation volume

In [None]:
corr_fn.corr_pyramid[i].detach().cpu().numpy().reshape((7040, -1)).shape

At this point we want to visualize how each the feature maps that correspond to each fine pixel of frame 1 vary with eachother. At each level we will have 7040 samples of each feature map, we want to see how each of these feature maps is related

In [None]:
normed_covs = []
for c in cov_mats:
    normed_covs.append(cv2.normalize(c, dst=None, alpha=0.00001, beta=1.00001, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F))

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(10, 10))
fig.suptitle("Covariances of each feature map")
ax[0, 0].imshow(normed_covs[0])
ax[0, 0].set_title("Pyramid Level 0")
ax[0, 1].imshow(normed_covs[1])
ax[0, 1].set_title("Pyramid Level 1")
ax[1, 0].imshow(normed_covs[2])
ax[1, 0].set_title("Pyramid Level 2")
ax[1, 1].imshow(normed_covs[3])
ax[1, 1].set_title("Pyramid Level 3");

Eventhough we can see the strong cross correlations of the higher feature maps, the displayed covariances of each feature map aren't too helpful for this purpose. Let's look at the middle pixel of each pyramid and see how it relates to the surrounding pixels. To do this, we will just index the middel pixel row of each covariance matrix and plot the columns.

In [None]:
_, ax = plt.subplots(2, 2, figsize=(10, 10), sharey=True)
ax[0, 0].plot(np.log(normed_covs[0][3250, :]))
ax[0, 0].set_title("Pyramid Level 0")
ax[0, 1].plot(np.log(normed_covs[1][864, :]))
ax[0, 1].set_title("Pyramid Level 1")
ax[1, 0].plot(np.log(normed_covs[2][208, :]))
ax[1, 0].set_title("Pyramid Level 2")
ax[1, 1].plot(np.log(normed_covs[3][48, :]))
ax[1, 1].set_title("Pyramid Level 3");

We can see at level 0 (no pooling and no correlations) that there is a single spike (self-correlation) but the remaining pixels are just noise). there is a similar relationship at pyramid level 1 (1 avg pool). However, at levels 2 and 3 we can start to see some rolloff from the center, indicating that the middle pixel has relationships with it's surrounding pixels.

In [None]:
_, ax = plt.subplots(2, 2, figsize=(10, 10), sharey=True)
ax[0, 0].plot(np.log(normed_covs[0][813, :]))
ax[0, 0].set_title("Pyramid Level 0")
ax[0, 1].plot(np.log(normed_covs[1][216, :]))
ax[0, 1].set_title("Pyramid Level 1")
ax[1, 0].plot(np.log(normed_covs[2][52, :]))
ax[1, 0].set_title("Pyramid Level 2")
ax[1, 1].plot(np.log(normed_covs[3][12, :]))
ax[1, 1].set_title("Pyramid Level 3");

Initialize the flow

In [None]:
# returns a mesh grid tensor at 1/8 the sizee of the input frame
coords0, coords1 = model.module.initialize_flow(frame1)

## **Explore Correlation Look Up Operator**
#### Go through code to index correlation volume

In [None]:
r = 4 # radius
coords = coords1.detach().permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
out_pyramid = []

In [None]:
i = 0
_corr = corr_fn.corr_pyramid[i]
dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)

In [None]:
dx, dy, delta.shape

In [None]:
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
coords_lvl = centroid_lvl + delta_lvl

In [None]:
centroid_lvl[1, :, :, :]

In [None]:
centroid_lvl.shape, delta_lvl.shape, coords_lvl.shape

In [None]:
centroid_lvl[10, :, :, :]

#### Set Exploration Index

In [None]:
lo_flow_1

In [None]:
frame1.shape

In [None]:
1024/8

In [None]:
322/8, 373/8

In [None]:
def get_corr_idx(loc):
    u = np.clip(np.round(loc[2]/8), 0, 127)
    v = np.clip(np.round(loc[1]/8), 0, 54)
    return int(u + 128*v)

In [None]:
flow.shape

In [None]:
hi_flow_1

In [None]:
hi_flow_2

In [None]:
439/8, 343/8

In [None]:
40 + 57*128

In [None]:
43 + 128*55

In [None]:
centroid_lvl.shape

In [None]:
get_corr_idx(hi_flow_2)

In [None]:
57 + 128*40

In [None]:
flow.shape

In [None]:
hi_flow_2

In [None]:
439/8, 343/8

In [None]:
get_corr_idx(hi_flow_1)

In [None]:
idx = get_corr_idx(lo_flow_1)

pixel_loc = centroid_lvl[idx, :, :, :].cpu().numpy().squeeze()
pixel_loc

In [None]:
idx

In [None]:
coords_lvl[idx, :, :, 0],coords_lvl[idx, :, :, 1]

In [None]:
xi0, xi1 = coords_lvl[idx, 0, 0, 0].cpu().numpy(), coords_lvl[idx, -1, 0, 0].cpu().numpy() + 1
yi0, yi1 = coords_lvl[idx, 0, 0, 1].cpu().numpy(), coords_lvl[idx, 0, -1, 1].cpu().numpy() + 1

xi0, xi1 = int(np.clip(xi0, 0, np.infty)), int(np.clip(xi1, 0, np.infty))
yi0, yi1 = int(np.clip(yi0, 0, np.infty)), int(np.clip(yi1, 0, np.infty))

# xi0, xi1 = int(xi0), int(xi1)
# yi0, yi1 = int(yi0), int(yi1)

xi0, xi1, yi0, yi1

Sample from the correlation response at each fine pixel

In [None]:
from RAFT.core.utils.utils import bilinear_sampler

corr = bilinear_sampler(_corr, coords_lvl)

In [None]:
corr.shape, _corr.shape, coords_lvl.shape

### Plot Results

In [None]:
# get RGB correlation response for display
corr_response = _corr[idx, 0, :, :].detach().cpu().numpy()
# corr_response = cv2.normalize(corr_response, dst=None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1)
# corr_response = np.dstack((corr_response, corr_response, corr_response))

In [None]:
import matplotlib.patches as patches

_, ax = plt.subplots(1, 1, figsize=(15, 5))
ax.imshow(corr_response);
ax.set_title(f"Correlation Response at Pyramid level: {i} and pixel: {pixel_loc}");

# mark I2 pixel under test
rect = patches.Rectangle(pixel_loc - 0.5, 1, 1, linewidth=2, edgecolor='r', facecolor='none')
# ax.add_patch(rect)

plt.show();

In [None]:
corr_response.min(), corr_response.max(), corr_response.mean(), np.median(corr_response), corr_response.sum()

In [None]:
_, (ax0, ax1) = plt.subplots(1,2, figsize=(20,10))

# ax0.imshow(_corr[idx, 0, 19:28, 52:61].detach().cpu().numpy().T)
# ax0.set_title(f"Correlation Input at index {idx}")
# ax1.imshow(corr[idx, 0, :, :].detach().cpu().numpy())
# ax1.set_title(f"Correlation Output at index {idx}");

# ensure orientations are consistentidx
# input_corr = _corr[idx, 0, yi0:yi1, xi0:xi1].detach().cpu().numpy()
# output_corr = corr[idx, 0, :, :].detach().cpu().numpy().T

input_corr = _corr[idx, 0, yi0:yi1, xi0:xi1].detach().cpu().numpy().T
output_corr = corr[idx, 0, :, :].detach().cpu().numpy()

ax0.imshow(input_corr)
ax0.set_title(f"Correlation Input at index {pixel_loc} (zero padding not shown)")
ax1.imshow(output_corr)
ax1.set_title(f"Correlation Output at index {pixel_loc}");

Manually Inspect results

In [None]:
np.round(input_corr, 2), np.round(output_corr, 2)

In [None]:
_corr[idx, 0, :, :].max(), corr[idx, 0, :, :].max()

In [None]:
torch.where(_corr[idx, 0, :, :] == _corr[idx, 0, :, :].max()), torch.where(corr[idx, 0, :, :] == corr[idx, 0, :, :].max())

In [None]:
_corr.shape, coords_lvl.shape

In [None]:
# _corr, coords_lvl

H, W = _corr.shape[-2:]
xgrid, ygrid = coords_lvl.split([1,1], dim=-1)
xgrid = 2*xgrid/(W-1) - 1
ygrid = 2*ygrid/(H-1) - 1

In [None]:
xgrid.shape, ygrid.shape, coords_lvl.shape

In [None]:
xgrid[idx, :, :].squeeze()

In [None]:
ygrid[idx, :, :].squeeze()

## **Explore the Recurrent Update Block**

In [None]:
hdim = model.module.hidden_dim
cdim = model.module.context_dim

with torch.autocast(device_type='cuda', enabled=True):
    cnet = model.module.cnet(frame1)
    net, inp = torch.split(cnet, [hdim, cdim], dim=1)
    net = torch.tanh(net)
    inp = torch.relu(inp)

## Compare Learned Feature maps to the context features

In [None]:
_, (ax0, ax1) = plt.subplots(2, 2, figsize=(20,10))
ax0[0].imshow(fmap1.squeeze(0)[0, :, :].detach().cpu().numpy())
ax0[0].set_title("Feature Net Feature Maps", size=18)

ax0[1].imshow(net.squeeze(0)[0, :, :].detach().cpu().numpy())
ax0[1].set_title("Context Net Hidden Feature Maps", size=18)

ax1[0].imshow(fmap1.squeeze(0)[10, :, :].detach().cpu().numpy())

ax1[1].imshow(net.squeeze(0)[10, :, :].detach().cpu().numpy());

In [None]:
_, (ax0, ax1) = plt.subplots(2, 2, figsize=(20,10))
ax0[0].imshow(inp.squeeze(0)[20, :, :].detach().cpu().numpy())
ax0[0].set_title("Context Net Context Feature Maps", size=18)

ax0[1].imshow(net.squeeze(0)[20, :, :].detach().cpu().numpy())
ax0[1].set_title("Context Net Hidden Feature Maps", size=18)

ax1[0].imshow(inp.squeeze(0)[120, :, :].detach().cpu().numpy())

ax1[1].imshow(net.squeeze(0)[120, :, :].detach().cpu().numpy());

In [None]:
inp.max(), net.max()

Since the network uses the hidden maps as the optical flow output, it makes sense for the context network to decode hidden feature maps that emphasize different aspects of the input image.

We can see that the feature extraction network learns something super abstract, while

In [None]:
# initialize flow
coords0, coords1 = model.module.initialize_flow(frame1)

In [None]:
coords1 = coords1.detach()
corr = corr_fn(coords1) # index correlation volume

flow = coords1 - coords0
with torch.autocast(device_type='cuda', enabled=True):
    net, up_mask, delta_flow = model.module.update_block(net, inp, corr, flow)



In [None]:
net.shape, delta_flow.shape

In [None]:
_, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(20,10))
ax0.imshow(net.squeeze(0)[0, :, :].detach().cpu().numpy())
ax1.imshow(delta_flow.squeeze(0)[0, :, :].detach().cpu().numpy())
ax2.imshow(delta_flow.squeeze(0)[1, :, :].detach().cpu().numpy())

In [None]:
# F(t+1) = F(t) + \Delta(t)
coords1 = coords1 + delta_flow

## **Re-Explore the Correlation using the new flow estimate**

In [None]:
r = 4 # radius
coords = coords1.detach().permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
out_pyramid = []

In [None]:
i = 2
_corr = corr_fn.corr_pyramid[i]
dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)

In [None]:
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
coords_lvl = centroid_lvl + delta_lvl

In [None]:
delta_lvl.squeeze(0)[:, :, 0]

In [None]:

idx = 1000

pixel_loc = centroid_lvl[idx, :, :, :].cpu().numpy().squeeze()
pixel_loc

In [None]:
coords_lvl[idx, :, :, 0]

In [None]:
xi0, xi1 = coords_lvl[idx, 0, 0, 0].cpu().numpy(), coords_lvl[idx, -1, 0, 0].cpu().numpy() + 1
yi0, yi1 = coords_lvl[idx, 0, 0, 1].cpu().numpy(), coords_lvl[idx, 0, -1, 1].cpu().numpy() + 1

xi0, xi1 = int(np.clip(xi0, 0, np.infty)), int(np.clip(xi1, 0, np.infty))
yi0, yi1 = int(np.clip(yi0, 0, np.infty)), int(np.clip(yi1, 0, np.infty))

# xi0, xi1 = int(xi0), int(xi1)
# yi0, yi1 = int(yi0), int(yi1)

xi0, xi1, yi0, yi1

In [None]:
corr = bilinear_sampler(_corr, coords_lvl)

In [None]:
corr_response = _corr[idx, 0, :, :].detach().cpu().numpy()

In [None]:
import matplotlib.patches as patches
_, ax = plt.subplots(1, 1, figsize=(15, 5))
ax.imshow(corr_response);
ax.set_title(f"Correlation Response at Pyramid level: {i} and pixel: {pixel_loc}");

# mark I2 pixel under test
rect = patches.Rectangle(pixel_loc - 0.5, 1, 1, linewidth=2, edgecolor='r', facecolor='none')
ax.add_patch(rect)

plt.show();

In [None]:
_, (ax0, ax1) = plt.subplots(1,2, figsize=(20,10))

# ax0.imshow(_corr[idx, 0, 19:28, 52:61].detach().cpu().numpy().T)
# ax0.set_title(f"Correlation Input at index {idx}")
# ax1.imshow(corr[idx, 0, :, :].detach().cpu().numpy())
# ax1.set_title(f"Correlation Output at index {idx}");

# ensure orientations are consistentidx
# input_corr = _corr[idx, 0, yi0:yi1, xi0:xi1].detach().cpu().numpy()
# output_corr = corr[idx, 0, :, :].detach().cpu().numpy().T

input_corr = _corr[idx, 0, yi0:yi1, xi0:xi1].detach().cpu().numpy().T
output_corr = corr[idx, 0, :, :].detach().cpu().numpy()

ax0.imshow(input_corr)
ax0.set_title(f"Correlation Input at index {pixel_loc} (zero padding not shown)")
ax1.imshow(output_corr)
ax1.set_title(f"Correlation Output at index {pixel_loc}");