In [None]:
import cv2
import numpy as np
import os
import re
from mmeval import EndPointError

import matplotlib.pyplot as plt
from collections import defaultdict

from tqdm.notebook import tqdm

from common.kitti import load_kitti_flow
from common.warp import forward_warp_bilinear
from common.metrics import reconstruction_error

%matplotlib widget

In [None]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

In [None]:
def collect_sources(kitti_path: str, pred_path: str):
    gt_map: dict[int, tuple[np.ndarray, np.ndarray]] = {}
    frame_map: dict[int, tuple[np.ndarray, np.ndarray]] = {}
    pred_map: dict[str, dict[int, str]] = {} # grouped by model name, then by index. Stores filenames

    files = os.listdir(os.path.join(kitti_path, "flow_occ"))
    pattern=re.compile(r'^(\d{6})_10\.png$')
    for filename in files:
        match = pattern.match(filename)
        if match:
            index = int(match.group(1))

            gt_flow, gt_valid = load_kitti_flow(os.path.join(kitti_path, "flow_occ", filename))
            gt_map[index] = (gt_flow, gt_valid)

    frames: dict[int, dict[int, np.ndarray]] = {}

    files = os.listdir(os.path.join(kitti_path, "image_2"))
    pattern=re.compile(r'^(\d{6})_(\d{2})\.png$')
    for filename in files:
        match = pattern.match(filename)
        if match:
            index = int(match.group(1))
            frame_number = int(match.group(2))
            
            frame = cv2.imread(os.path.join(kitti_path, "image_2", filename))
            
            framedict = frames.get(index)
            if not framedict:
                framedict = {}
            framedict[frame_number] = frame
            frames[index] = framedict
            
    for index, framedict in frames.items():
        frame_10 = framedict[10]
        frame_11 = framedict[11]
        
        if frame_10 is None or frame_11 is None:
            print('Error')
            continue
            
        frame_map[index] = (frame_10, frame_11)

    files = os.listdir(pred_path)
    pattern = re.compile(r'^(.*)-(\d{7})\.png$')
    for filename in files:
        match = pattern.match(filename)
        if match:
            index = int(match.group(2))
            model_name = match.group(1)

            if model_name not in pred_map:
                pred_map[model_name] = {}

            pred_map[model_name][index] = os.path.join(pred_path, filename)
            
    return gt_map, frame_map, pred_map

   
def plot_mean_scatter(model_metrics, reference, labels=None):
    means = [(model, np.mean(vals["epe"]), np.mean(vals["recon"])) for model, vals in model_metrics.items()]

    fig, ax = plt.subplots(figsize=(10, 6))

    for model, mean_epe, mean_recon in means:
        ax.scatter(mean_epe, mean_recon, s=30)
        if not labels or model in labels:
            ax.annotate(model, xy=(mean_epe, mean_recon), fontsize=9, xytext=(5,-5),  textcoords="offset points")

    ax.axhline(reference, color='r', linestyle='--', linewidth=1, label='Reference Line')

    ax.set_xlabel("Mean EPE")
    ax.set_ylabel("Mean Reconstruction Error")
    ax.set_title("Mean EPE vs. Mean Reconstruction Error per Model")
    
    ax.legend(loc='upper left')
    
    plt.grid(True)
    plt.tight_layout()
    plt.show()


def plot_per_sample_scatter(model_metrics):
    fig, axs = plt.subplots(len(model_metrics), 1, figsize=(6, 4 * len(model_metrics)))
    
    if len(model_metrics) == 1:
        axs = [axs]

    for ax, (model, vals) in zip(axs, model_metrics.items()):
        ax.scatter(vals["epe"], vals["recon"])
        ax.set_title(f"{model}: EPE vs Recon Error")
        ax.set_xlabel("EPE")
        ax.set_ylabel("Reconstruction Error")
        ax.grid(True)

    plt.tight_layout()
    plt.show()

def batch_eval(
    gt_map: dict[int, tuple[np.ndarray, np.ndarray]],
    frame_map: dict[int, tuple[np.ndarray, np.ndarray]],
    pred_map: dict[str, dict[int, str]],
    warp_output_dir: str,
    save_warps: bool,
    load_warps: bool,
    use_mask: bool
    ):
    
    os.makedirs(warp_output_dir, exist_ok=True)
    
    model_metrics = defaultdict(lambda: defaultdict(lambda: {"epe": [], "recon": []}))

    for index, (flow_uv, flow_valid) in tqdm(gt_map.items()):
        frame_10, frame_11 = frame_map.get(index)
        if frame_10 is None or frame_11 is None:
            print(f"Could not find frame of index {index}")
            continue
        
        for model_name, path_map in pred_map.items():
            pred_path = path_map.get(index)
            if not pred_path:
                print("Error pred path:", pred_path, model_name, index)
                continue
            
            warp_output_path = os.path.join(warp_output_dir, f"warp_{os.path.basename(pred_path)}")
            warp_present = os.path.exists(warp_output_path)
            
            pred_flow_uv, pred_flow_valid = load_kitti_flow(pred_path)
            if load_warps and warp_present:
                pred_frame_11 = cv2.imread(warp_output_path)
            else:
                pred_frame_11, pred_valid = forward_warp_bilinear(frame_10, pred_flow_uv)
            
            if (not load_warps or not warp_present) and save_warps:
                cv2.imwrite(warp_output_path, pred_frame_11)
                
            if frame_10.shape != pred_frame_11.shape:
                print(f"Error")
                continue
            
            for k in range(1,21,2):
                blur_frame_11 = cv2.blur(frame_11, (k,k))
                blur_pred_frame_11 = cv2.blur(pred_frame_11, (k,k))
                
                recon_error = reconstruction_error(blur_frame_11, blur_pred_frame_11, flow_valid if use_mask else None)
                epe = EndPointError()([pred_flow_uv], [flow_uv], [flow_valid])['EPE'][0]
                    
                model_metrics[k][model_name]["epe"].append(epe)
                model_metrics[k][model_name]["recon"].append(recon_error)
            
    return model_metrics


In [None]:
kitti_path = r"./data_kitti"
pred_path = r"./results/inference"

warp_output_path = r"./results/warp"

save_warps = True
load_warps = True
use_mask = False

In [None]:
# Load data
gt_map, frame_map, pred_map = collect_sources(kitti_path, pred_path)

In [None]:
# Evaluate
metrics_list = batch_eval(gt_map, frame_map, pred_map, warp_output_path, save_warps, load_warps, use_mask)

In [None]:
# Set blur kernel size.
k = 3
metrics = metrics_list[k]

In [None]:
# Plot all
plot_mean_scatter(metrics, 10.184816)


In [None]:
# Plot a single model
plot_per_sample_scatter({'sea_raft_m_kitti':metrics['sea_raft_m_kitti']})

In [None]:
# Remove seperation by training checkpoint. Will take the mean of all samples over all checkpoint.
metrics_agg = {}
for model_name, valuesdict in metrics.items():
    model_name = model_name.rpartition('_')[0] # Remove training checkpoint from model name
    newvdict = metrics_agg.get(model_name, {})
    for metric, values in valuesdict.items():
        vs = newvdict.get(metric, [])
        vs += values
        newvdict[metric] = vs
    metrics_agg[model_name] = newvdict

# Only label interesting models
labels = [
    'ms_raft_p',
    'ccmr_p',
    'ccmr',
    'liteflownet3s',
    'sea_raft_l',
    'gma',
    'llaflow',
    'liteflownet',
    'flowformer',
    'raft',
    'sea_raft_m',
    'memflow_t',
    'dpflow',
]

plot_mean_scatter(metrics_agg, 10.184816, labels)


In [None]:
metrics_list

In [None]:
def sanitize(obj):
    if isinstance(obj, dict):
        return {k:sanitize(v) for k,v in obj.items()}
    if isinstance(obj, list):
        return [sanitize(x) for x in obj]
    if isinstance(obj, np.floating):
        return float(obj)
    return obj

raw_data = sanitize(metrics_list)

import pickle
with open('./results/data/fci_evaluation.pkl', 'wb') as f:
    pickle.dump(raw_data, f)
    
import json
with open('./results/data/fci_evaluation.json', 'w') as f:
    json.dump(raw_data, f)

In [None]:
raw_data[1]['ccmr_kitti']['epe'][0].dtype