In [None]:
import cv2
import numpy as np
import os
import re
from mmeval import EndPointError

import pickle

import matplotlib.pyplot as plt
from matplotlib.colors import NoNorm
from collections import defaultdict

from common.kitti import load_kitti_flow
from common.blur import masked_blur
from common.metrics import reconstruction_error

import common.warp


In [None]:
import importlib
importlib.reload(common)

from common.warp import *
from common.blur import *
from common.metrics import *


def collect_sources(kitti_path: str):
    gt_map: dict[int, tuple[np.ndarray, np.ndarray]] = {}
    frame_map: dict[int, tuple[np.ndarray, np.ndarray]] = {}
    annotation_map: dict[int, list[tuple[tuple[int,int],tuple[int,int]]]] = {}

    files = os.listdir(os.path.join(kitti_path, "flow_occ"))
    pattern=re.compile(r'^(\d{6})_10\.png$')
    for filename in files:
        match = pattern.match(filename)
        if match:
            index = int(match.group(1))

            gt_flow, gt_valid = load_kitti_flow(os.path.join(kitti_path, "flow_occ", filename))
            gt_map[index] = (gt_flow, gt_valid)

    frames: dict[int, dict[int, np.ndarray]] = {}

    files = os.listdir(os.path.join(kitti_path, "image_2"))
    pattern=re.compile(r'^(\d{6})_(\d{2})\.png$')
    for filename in files:
        match = pattern.match(filename)
        if match:
            index = int(match.group(1))
            frame_number = int(match.group(2))
            
            frame = cv2.imread(os.path.join(kitti_path, "image_2", filename))
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            
            framedict = frames.get(index)
            if not framedict:
                framedict = {}
            framedict[frame_number] = frame
            frames[index] = framedict
            
    for index, framedict in frames.items():
        frame_10 = framedict[10]
        frame_11 = framedict[11]
        
        if frame_10 is None or frame_11 is None:
            print('Error')
            continue
            
        frame_map[index] = (frame_10, frame_11)
        
    files = os.listdir(os.path.join(kitti_path, "raw_annotations"))
    pattern=re.compile(r'^(\d{6})\.pkl$')
    for filename in files:
        match = pattern.match(filename)
        if match:
            index = int(match.group(1))
            
            with open(os.path.join(kitti_path, "raw_annotations", filename), 'rb') as f:
                annotation = pickle.load(f)
                annotation_map[index] = annotation['convex_hull_annotations']
            
    return gt_map, frame_map, annotation_map

def batch_eval(
    gt_map: dict[int, tuple[np.ndarray, np.ndarray]],
    frame_map: dict[int, tuple[np.ndarray, np.ndarray]],
    annotation_map: dict[int, list[tuple[tuple[int,int],tuple[int,int]]]],
    ):
    model_metrics = defaultdict(lambda: {"epe": [], "recon": []})

    kmap = {}

    recons = []
    
    gt_map = {k:v for k,v in gt_map.items() if k == 34}
    
    for index, (flow_uv, flow_valid) in gt_map.items():
        
        frame_10, frame_11 = frame_map.get(index)
        if frame_10 is None or frame_11 is None:
            print(f"Could not find frame of index {index}")
            continue
        
        isdense = np.count_nonzero(flow_valid) > 100
    
        # frame_10 = cv2.blur(frame_10, (k,k))
        pred_frame_11, pred_valid = forward_warp_bilinear(frame_10, flow_uv, flow_valid)
        # pred_frame_11, pred_valid = forward_displacement_interpolation(frame_10, flow_uv, flow_valid)
        
        for k in range(1,21,2):
            err=reconstruction_error(masked_blur(frame_11, pred_valid, k), masked_blur(pred_frame_11, pred_valid, k), pred_valid)

            l = kmap.get(k,[])
            l.append(err)
            kmap[k] = l
        
                
        
        error = reconstruction_error(frame_11, pred_frame_11, pred_valid) 
        
        if isdense and index in annotation_map:
            p1, p2 = np.array(list(zip(*annotation_map[index])))
            # annotation_error = np.mean(np.linalg.norm(frame_10[p1[:,1],p1[:,0]] - frame_11[p2[:,1],p2[:,0]], axis=1))
            annotation_error = reconstruction_error(frame_10[p1[:,1],p1[:,0]],frame_11[p2[:,1],p2[:,0]])
            print(index, "Dense  error:", error, "Annotation error:",annotation_error)  
            
            recons.append(annotation_error)
        else:
            print(index, "Sparse error:", error)  
            recons.append(error) 
    
    print("Annotation error per blur kernel size:", {k:np.mean(v) for k,v in kmap.items()})
            
    k=3
    
    print(f"Using blur with kernel size {k}")
    frame_11 = cv2.blur(frame_11, (k,k))
    pred_frame_11 = cv2.blur(pred_frame_11, (k,k))
    

    print("Mean image error", reconstruction_error(frame_11, np.zeros_like(frame_11)))
    
    vis_image = np.linalg.norm(frame_11.astype(np.float32) - pred_frame_11.astype(np.float32), axis=2)
    # vis_image = np.abs(frame_11 - pred_frame_11)
    
    plt.title('Actual second frame')
    plt.imshow(frame_11, cmap='gray')
    plt.show()
    
    plt.title('Forward flowed second frame based on ground-truth flow vectors')
    plt.imshow(pred_frame_11, cmap='gray')
    plt.show()
    
    plt.title('Euclidean distance on colors')
    plt.imshow(vis_image, cmap='gray')
    plt.axis('off')
    plt.show()
        
    vis_image = cv2.addWeighted(frame_11,0.5,pred_frame_11,0.5,0)
    plt.title('Actual second frame and forward flowed second frame overlayed')
    plt.imshow(vis_image)
    plt.axis('off')
    plt.show()
        
    print(np.mean(recons))

    return model_metrics
        

if __name__ == "__main__":
    kitti_path = r"./data_kitti"
    
    gt_map, frame_map, annotation_map = collect_sources(kitti_path)
    
    metrics = batch_eval(gt_map, frame_map, annotation_map)

