# RAFT Playground

## RAFT Models

RAFT has several pretrained models:
 - raft-chairs - trained on FlyingChairs
 - raft-things - trained on FlyingChairs + FlyingThings
 - raft-sintel - trained on FlyingChairs + FlyingThings + Sintel + KITTI
 - raft-kitti - raft-sintel finetuned on only KITTI
 - raft-small - trained on FlyingChairs + FlyingThings

## Prepare the Notebook

In [None]:
# clone the repo

!pwd
!cd RAFT
!git pull

In [None]:
# necessary imports

import os
import sys
import numpy as np
import cv2
import pandas as pd

import torch                     # for all things PyTorch
import torch.nn as nn            # for torch.nn.Module, the parent object for PyTorch models
import torch.nn.functional as F  # for the activation function

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# print gpu info

print(torch.__version__)
print(torch.cuda.get_arch_list())
 
if torch.cuda.is_available():
    print("CUDA is available!")
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"Allocated memory: {torch.cuda.memory_allocated(i)/1024**2:.2f} MB")
        print(f"Cached memory: {torch.cuda.memory_reserved(i)/1024**2:.2f} MB")
        print(f"Properties: {torch.cuda.get_device_properties(i)}")
else:
    print("CUDA is not available.")

In [None]:
# Add RAFT core to path

# sys.path.append('RAFT/core')
sys.path.append('/home/max/Dokumente/Vitis-AI/CV_projects/RAFT/RAFT/core')

### Helper functions

In [None]:
from collections import OrderedDict
from raft import RAFT
from utils import flow_viz
from utils.utils import InputPadder


# convert to torch and get correct dimensions
def process_img(img, device):
    print("[process_img] entering")
    return torch.from_numpy(img).permute(2, 0, 1).float()[None].to(device)

def load_model(weights_path, args):
    print("[load_model] entering")

    model = RAFT(args)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("[load_model] current device is: " + str(device))

    try:
        pretrained_weights = torch.load(weights_path, map_location=device)
    except Exception as e:
        raise RuntimeError(f"[load_model] Fehler beim laden der Gewichte: {e}")

    print("[load_model] device_count(): " + str(torch.cuda.device_count()))

    if torch.cuda.device_count() >= 1:
       model = torch.nn.DataParallel(model)
    
    try:
        model.load_state_dict(pretrained_weights)
    except Exception as e:
        raise RuntimeError(f"[load_model] Fehler beim setzen der Gewichte: {e}")
    
    model.to(device)

    return model

# perform inference with every model
def inference(model, frame1, frame2, device, pad_mode='sintel', iters=12, flow_init=None, upsample=True, test_mode=True):
    print("[inference] entering")

    # entering evel mode: specific operstions like batch-norm. and dropout are deactivated
    model.eval()
    
    # do not calc or store gradients: increase performance
    with torch.no_grad():
        # preprocess
        frame1 = process_img(frame1, device)
        frame2 = process_img(frame2, device)

        # important because raft requires every image to be divisible by 8
        padder = InputPadder(frame1.shape, mode=pad_mode)
        frame1, frame2 = padder.pad(frame1, frame2)

        print("[inference] Upsampled = " + str(upsample))

        # predict flow in two different modes
        if test_mode:
            # returns the initial flow (1/8 res) + upsampled flow (upsampled res)
            flow_low, flow_up = model(frame1, frame2, iters=iters, flow_init=flow_init, upsample=upsample, test_mode=test_mode)
            
            return flow_low, flow_up

        else:
            # we get all flow it. for the specified amount of iterations
            flow_iters = model(frame1, frame2, iters=iters, flow_init=flow_init, upsample=upsample, test_mode=test_mode)
            
            return flow_iters

def get_viz(flo):
    print("[get_viz] entering")
    flo = flo[0].permute(1,2,0).cpu().numpy()
    return flow_viz.flow_to_image(flo)

def print_model_info(model):
    print("[print_model_info] Model architecture:")
    print(model)

    print("\n [print_model_info] Model parameters and their shapes:")
    for name, param in model.named_parameters():
        print(f"{name}: {param.shape}")

def inspect_model(model):
    print("[inspect_model] entering")

    # print("Model Architecture:\n")
    # print(model)
    
    # Gesamtanzahl der Parameter berechnen
    total_params = sum(p.numel() for p in model.parameters())
    print( "[inspect_model] Total number of parameters: " + str(total_params) )

    frame1 = cv2.imread("/home/max/Dokumente/CV_projects/RAFT/custom_demo_frames/m_baseFrameGray.jpg")
    frame2 = cv2.imread("/home/max/Dokumente/CV_projects/RAFT/custom_demo_frames/m_nextFrameGray.jpg")

    frame1 = cv2.cvtColor(frame1, cv2.COLOR_BGR2RGB)
    frame2 = cv2.cvtColor(frame2, cv2.COLOR_BGR2RGB)
    
    # entering evel mode: specific operstions like batch-norm. and dropout are deactivated
    model.eval()
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("[inspect_model] current device is: " + str(device))

    # do not calc or store gradients: increase performance
    with torch.no_grad():
        # preprocess
        frame1 = process_img(frame1, device)
        frame2 = process_img(frame2, device)

        # important because raft requires every image to be divisible by 8
        padder = InputPadder(frame1.shape, mode='sintel')

        frame1, frame2 = padder.pad(frame1, frame2)

        # predict flow: returns the initial flow (1/8 res) + upsampled flow (upsampled res)
        flow_low, flow_up = model(frame1, frame2, iters=12, flow_init=None, upsample=True, test_mode=True)

    
    # print("[inspect_model] Output of flow_low: ")
    # print(flow_low)

    # print("[inspect_model] Output of flow_up: ")
    # print(flow_up)

    # Modellzusammenfassung manuell erstellen
    print("[inspect_model] Model Summary: ")
    for name, param in model.named_parameters():
        print(f"{name}: {param.numel()} parameters")
    
    print( "[inspect_model] Total number of parameters: " + str(total_params) )

    return model

# sketchy class to pass to RAFT
class Args():
  def __init__(self, model='', path='', small=False, mixed_precision=True, alternate_corr=False):
    self.model = model
    self.path = path
    self.small = small
    self.mixed_precision = mixed_precision
    self.alternate_corr = alternate_corr

# Sketchy hack to pretend to iterate through the class objects
  def __iter__(self):
    return self

  def __next__(self):
    raise StopIteration

## Download models

In [None]:
%cd RAFT

!chmod +x download_models.sh
!./download_models.sh
# !python demo.py --model=models/raft-things.pth --path=demo-frames

%cd ..

## Load Model

In [None]:
# load the ".pth" model

model = load_model("RAFT/models/raft-sintel.pth", args=Args())

## Use the Model

In [None]:
# Untersuchen der Modellstruktur

model = load_model("RAFT/models/raft-sintel.pth", args=Args())

print_model_info(model)
print(model)

In [None]:
# Beispielaufruf der Funktion
model_path = "/home/max/Dokumente/CV_projects/RAFT/RAFT/models/raft-sintel.pth"
args=Args()

model = load_model(model_path, args)
model = inspect_model(model)

## Define a own model and inspect it

In [None]:
# ------------------------------------------------------------------------------------
# Modell definieren und speichern

import torch
import torch.nn as nn
import torch.optim as optim

# Define model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(10, 50)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(50, 1)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Instantiate the model and optimizer
model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Generate dummy input
dummy_input = torch.randn(10)

# Save the model and optimizer state
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict()
}, '/home/max/Dokumente/CV_projects/RAFT/simple_model.pth')


# ------------------------------------------------------------------------------------
# Modell laden und untersuchen

# Load the saved file
checkpoint = torch.load('/home/max/Dokumente/CV_projects/RAFT/simple_model.pth')

# Create a model object and load the state
model = SimpleModel()
model.load_state_dict(checkpoint['model_state_dict'])

# Examine the model structure
print("Model structure:")
print(model)

# Display detailed information about each layer
print("\nLayer details:")
for name, layer in model.named_children():
    print(f"Layer {name}: {layer}")

# Generate dummy input data
dummy_input = torch.randn(1, 10)

# Perform a prediction with the model
output = model(dummy_input)

# Display input and output shapes
print(f"\nInput shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}")

# Optional detailed output for each layer
print("\nModule details:")
for name, module in model.named_modules():
    print(f"Module {name}: {module}")

# Optional output of parameters
print("\nParameter details:")
for name, param in model.named_parameters():
    print(f"Parameter {name}: {param.size()}")

# Display parameter details with sizes
print("\nParameter details with sizes:")
for name, param in model.named_parameters():
    print(f"Parameter {name}: {param.size()}, requires_grad: {param.requires_grad}")
    
# Display detailed information for each layer with parameters
print("\nLayer and parameter details:")
for name, layer in model.named_children():
    print(f"Layer {name}: {layer}")
    for param_name, param in layer.named_parameters(recurse=False):
        print(f"  Param {param_name}: {param.size()}, requires_grad: {param.requires_grad}")

## Flow estimation on a custom pair of frames

In [None]:
# estimation on a custom pair of frames

demo_path = '/home/max/Dokumente/CV_projectsFork/RAFT/custom_demo_frames'
frame1 = cv2.imread(os.path.join(demo_path, 'm_baseFrameGray.jpg'))
frame2 = cv2.imread(os.path.join(demo_path, 'm_nextFrameGray.jpg'))

frame1 = cv2.cvtColor(frame1, cv2.COLOR_BGR2RGB)
frame2 = cv2.cvtColor(frame2, cv2.COLOR_BGR2RGB)

_, ax = plt.subplots(1, 2, figsize=(15, 8))
ax[0].imshow(frame1)
ax[1].imshow(frame2);

## OPTIONAL (use KITTI only model)
# del model
# model = load_model("RAFT/models/raft-kitti.pth", args=Args())

# flow_iters = inference(model, frame1, frame2, device='cuda', pad_mode='kitti', iters=20, test_mode=False)
# flow_iters = inference(model, frame1, frame2, device='cuda', pad_mode='kitti', iters=10, test_mode=False)
# flow_iters = inference(model, frame1, frame2, device='cuda', pad_mode='kitti', iters=5, test_mode=False)

# flow_iters = inference(model, frame1, frame2, device='cuda', pad_mode='kitti', iters=10, test_mode=False)
flow_iters = inference(model, frame1, frame2, device='cuda', pad_mode='kitti', iters=10, flow_init=None, upsample=True, test_mode=False)
# flow_iters = inference(model, frame1, frame2, device='cuda', pad_mode='kitti', iters=10, flow_init=None, upsample=False, test_mode=False)

f, (ax0, ax1) = plt.subplots(1,2, figsize=(15,10))

ax0.imshow(get_viz(flow_iters[0]))
ax0.set_title('first flow iteration')
ax1.imshow(get_viz(flow_iters[-1]))
ax1.set_title('final flow iteration');