In [None]:
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm

import cv2
import os

from mmeval import EndPointError

from common.kitti import load_kitti_flow

In [None]:
def batch_eval(kitti_path: str, pred_path: str):
    gt_map: dict[int, tuple[np.ndarray, np.ndarray]] = {}
    pred_map: dict[str, dict[int, str]] = {} # grouped by model name, then by index. Stores filenames

    files = os.listdir(os.path.join(kitti_path, "flow_occ"))
    pattern=re.compile(r'^(\d{6})_10\.png$')
    for filename in files:
        match = pattern.match(filename)
        if match:
            index = int(match.group(1))

            gt_flow, gt_valid = load_kitti_flow(os.path.join(kitti_path, "flow_occ", filename))
            gt_map[index] = (gt_flow, gt_valid)
    
    files = os.listdir(pred_path)
    pattern = re.compile(r'^(.*)-(\d{7})\.png$')
    for filename in files:
        match = pattern.match(filename)
        if match:
            index = int(match.group(2))
            model_name = match.group(1)

            if model_name not in pred_map:
                pred_map[model_name] = {}

            pred_map[model_name][index] = os.path.join(pred_path, filename)
            
    results = []
    for model_name, model_pred_map in tqdm(pred_map.items()):
        
        model, _, ckpt = model_name.rpartition('_')
        
        missing = set(gt_map.keys()) - set(model_pred_map.keys())
        if missing:
            print(model_name, 'is missing predictions for:', missing)
        
        epes = []
        f1alls = []
        for index, filename in model_pred_map.items():
            epe = EndPointError()
            pred_flow, pred_valid = load_kitti_flow(filename)
            gt_flow, gt_valid = gt_map[index]
            
            gt_flow = gt_flow.astype(np.float32)
            pred_flow = pred_flow.astype(np.float32)

            diff1 = np.linalg.norm(pred_flow - gt_flow, axis=-1)
            diff2 = np.linalg.norm(pred_flow - gt_flow[:,:,::-1], axis=-1)
                    
            score1 = np.mean(diff1[gt_valid])
            score2 = np.mean(diff2[gt_valid])

            # if score1 <= score2:
            score = score1      
            f1all = 100 * np.count_nonzero(diff1[gt_valid] >= np.clip(0.05 * np.linalg.norm(gt_flow[gt_valid], axis=-1),3, None)) / np.count_nonzero(gt_valid)
            # else:
            #     tqdm.write(f'Flip {model_name} {index}')
            #     score = score2
            #     f1all = 100 * np.count_nonzero(diff2[gt_valid] >= 3) / np.count_nonzero(gt_valid)
            
            epes.append(score)
            f1alls.append(f1all)
        
        results.append({'Model': model, 'Checkpoint': ckpt, 'Mean EPE': np.mean(epes), 'Mean F1-All': np.mean(f1alls)})

    return results


In [None]:
results = batch_eval(r"./data_kitti", r"./results/inference")

In [None]:
df = pd.DataFrame(results)
df.sort_values("Mean EPE", inplace=True)
df['Checkpoint'].replace('mixed','mix', inplace=True)
# df = df[~df['Checkpoint'].str.contains("mix")]
df['Checkpoint'] = df['Checkpoint'].str.capitalize()
# df['Mean F1-All'] = df['Mean F1-All'].map('{:.2f}%'.format)
df.rename(columns={'Mean F1-All':'F1-All','Mean EPE': 'EPE'}, inplace=True)
df

In [None]:
pivot_df = df.pivot(index='Model', columns='Checkpoint')
pivot_df = pivot_df.swaplevel(axis=1).sort_index(axis=1, level=0)
# pivot_df.sort_values(("Things","EPE"), inplace=True)
pivot_df = pivot_df.loc[pivot_df.xs("EPE", level=1, axis=1).mean(axis=1, skipna=True).sort_values().index]

sorted_col0 = pivot_df.xs("EPE", level=1, axis=1).notna().sum().sort_values(ascending=False).index.tolist()
pivot_df = pivot_df.loc[:, sum([[col0 for col0 in pivot_df.columns if col0[0] == k] for k in sorted_col0], [])]


# pivot_df.fillna('-',inplace=True)
print(pivot_df)
pivot_df


In [None]:
latex_table = pivot_df.to_latex(index=True, escape=True, float_format="%.4f",
                            caption="Mean End-Point Error (EPE) for different models, sorted ascendingly.",
                            label="tab:epe_results",
                            multicolumn=True,
                            multicolumn_format='c',
                            na_rep='-',
                            longtable=True)

print(latex_table)

In [None]:
pivot_df.to_excel('results/data/epe_evaluation.xlsx')