In [14]:
import cv2
import numpy as np
import pandas as pd
from pathlib import Path
from scipy.optimize import minimize
from multiprocessing import Pool
from tqdm import tqdm
import matplotlib.pyplot as plt
import warnings

warnings.filterwarnings("ignore")

## K1/K2 Extraction — Edge-Based Optimization

**No SIFT.** Use edge maps instead of raw pixels — barrel moves edges, not flat regions. Cleaner signal.

**Dataset:** corr_path=generated=distorted, dist_path=original=corrected.

**Mac M2 Air:** Default `n_workers=4`.

In [15]:
# ══════════════════════════════════════════════
# EDGE-BASED OPTIMIZATION — Barrel moves edges
# ══════════════════════════════════════════════

def build_camera_matrix(h, w):
    f = float(max(w, h))
    cx, cy = w / 2.0, h / 2.0
    return np.array([[f, 0, cx], [0, f, cy], [0, 0, 1]], dtype=np.float64)


def compute_edge_map(img, blur=3):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    if blur > 0:
        gray = cv2.GaussianBlur(gray, (blur, blur), 0)
    sx = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=3)
    sy = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=3)
    return np.sqrt(sx**2 + sy**2)


def quality_check_edge(k1, k2, improvement, original_diff):
    if k1 is None:
        return "error"
    if k1 > -0.001:
        return "bad_k1_positive"
    if k1 < -0.58:
        return "bad_k1_extreme"
    if abs(k2) > 0.15:
        return "bad_k2_extreme"
    if improvement < -0.5:
        return "validation_failed"
    if improvement / (original_diff + 1e-6) > 0.15:
        return "excellent"
    if improvement > 0:
        return "good"
    return "acceptable"


def extract_k1_k2_edge_based(row_dict, debug=False):
    """
    Edge-based optimization: barrel moves edges, so compare edge maps not raw pixels.
    Dataset: dist_path=original=corrected, corr_path=generated=distorted.
    We undistort the distorted (corr_path) and compare to corrected (dist_path).
    """
    try:
        # Path swap: corr_path = distorted (input), dist_path = corrected (target)
        dist_img = cv2.imread(str(row_dict["corr_path"]))  # distorted
        corr_img = cv2.imread(str(row_dict["dist_path"]))  # corrected
        if dist_img is None or corr_img is None:
            return {**row_dict, "k1": None, "k2": None, "status": "load_error"}

        h, w = dist_img.shape[:2]
        if h > w:
            dist_img = cv2.rotate(dist_img, cv2.ROTATE_90_CLOCKWISE)
            corr_img = cv2.rotate(corr_img, cv2.ROTATE_90_CLOCKWISE)
            h, w = dist_img.shape[:2]

        scale = 0.50
        sw, sh = int(w * scale), int(h * scale)
        dist_s = cv2.resize(dist_img, (sw, sh))
        corr_s = cv2.resize(corr_img, (sw, sh))
        corr_edges = compute_edge_map(corr_s).astype(np.float32)

        f = float(max(sw, sh))
        cam = np.array([[f, 0, sw / 2.0], [0, f, sh / 2.0], [0, 0, 1.0]], dtype=np.float64)

        def edge_error(k1, k2=0.0):
            if k1 > -0.001:
                return 999.0
            dist_coeffs = np.array([k1, k2, 0.0, 0.0, 0.0])
            undist = cv2.undistort(dist_s, cam, dist_coeffs)
            undist_edges = compute_edge_map(undist).astype(np.float32)
            return np.abs(undist_edges - corr_edges).mean()

        # Stage 1: Coarse k1 grid
        best_k1, best_err = -0.15, 999.0
        grid_errors = []
        k1_range = np.linspace(-0.55, -0.005, 80)
        for k1 in k1_range:
            err = edge_error(k1)
            grid_errors.append(err)
            if err < best_err:
                best_err, best_k1 = err, k1

        if debug:
            print(f"  Coarse grid best: k1={best_k1:.4f} err={best_err:.4f}")
            safe_id = "".join(c if c.isalnum() else "_" for c in str(row_dict["image_id"])[:20])
            plt.figure(figsize=(10, 4))
            plt.plot(k1_range, grid_errors)
            plt.xlabel("k1")
            plt.ylabel("edge MAE")
            plt.title(f"Error curve: {row_dict['image_id'][:40]}")
            plt.axvline(x=best_k1, color="red", linestyle="--", label=f"best k1={best_k1:.3f}")
            plt.legend()
            plt.tight_layout()
            plt.savefig(f"error_curve_{safe_id}.png", dpi=100)
            plt.close()

        # Stage 2: Fine k1 refinement
        for k1 in np.linspace(best_k1 - 0.03, best_k1 + 0.03, 40):
            if k1 >= -0.001:
                continue
            err = edge_error(k1)
            if err < best_err:
                best_err, best_k1 = err, k1

        if debug:
            print(f"  Fine k1: k1={best_k1:.4f} err={best_err:.4f}")

        # Stage 3: k2 search (tight range)
        best_k2, best_err2 = 0.0, best_err
        for k2 in np.linspace(-0.05, 0.12, 35):
            err = edge_error(best_k1, k2)
            if err < best_err2:
                best_err2, best_k2 = err, k2

        if debug:
            print(f"  With k2: k1={best_k1:.4f} k2={best_k2:.4f} err={best_err2:.4f}")

        # Stage 4: Joint fine refinement
        final_k1, final_k2, final_err = best_k1, best_k2, best_err2
        for k1 in np.linspace(best_k1 - 0.02, best_k1 + 0.02, 10):
            if k1 >= -0.001:
                continue
            for k2 in np.linspace(best_k2 - 0.02, best_k2 + 0.02, 10):
                if abs(k2) > 0.12:
                    continue
                err = edge_error(k1, k2)
                if err < final_err:
                    final_err, final_k1, final_k2 = err, k1, k2

        if debug:
            print(f"  Joint refined: k1={final_k1:.4f} k2={final_k2:.4f} err={final_err:.4f}")

        # Stage 5: Full-res pixel validation
        cam_full = np.array(
            [[float(max(w, h)), 0, w / 2.0], [0, float(max(w, h)), h / 2.0], [0, 0, 1.0]],
            dtype=np.float64,
        )
        undist_full = cv2.undistort(
            dist_img, cam_full, np.array([final_k1, final_k2, 0.0, 0.0, 0.0])
        )
        undist_resized = cv2.resize(undist_full, (w, h))
        original_diff = cv2.absdiff(dist_img, corr_img).mean()
        after_diff = cv2.absdiff(undist_resized, corr_img).mean()
        improvement = original_diff - after_diff

        if debug:
            print(f"  Pixel validation: orig={original_diff:.3f} after={after_diff:.3f} improv={improvement:.3f}")

        status = quality_check_edge(final_k1, final_k2, improvement, original_diff)
        return {
            "image_id": row_dict["image_id"],
            "dist_path": row_dict["dist_path"],
            "corr_path": row_dict["corr_path"],
            "k1": float(final_k1),
            "k2": float(final_k2),
            "fit_error": float(final_err),
            "orig_diff": float(original_diff),
            "after_diff": float(after_diff),
            "improvement": float(improvement),
            "category": row_dict.get("category", ""),
            "weight": row_dict.get("weight", 1.0),
            "status": status,
        }
    except Exception as e:
        return {**row_dict, "k1": None, "k2": None, "status": f"error:{str(e)}"}


def extract_k1_k2_single(row_dict, debug=False):
    """Wrapper for edge-based extraction; returns same shape as before for pipeline compatibility."""
    r = extract_k1_k2_edge_based(row_dict, debug=debug)
    out = {
        "k1": r.get("k1"),
        "k2": r.get("k2"),
        "fit_error": r.get("fit_error", 0),
        "pixel_validation": r.get("after_diff", 0),
        "original_diff": r.get("orig_diff", 0),
        "improvement": r.get("improvement", 0),
        "status": r.get("status", "error"),
        "method": "edge",
        "n_matches": 0,
    }
    out["image_id"] = r.get("image_id", "")
    out["category"] = r.get("category", "")
    out["weight"] = r.get("weight", 1.0)
    return out

In [16]:
def process_single(row_dict):
    try:
        result = extract_k1_k2_single(row_dict, debug=False)
        return result
    except Exception as e:
        return {
            "image_id": row_dict["image_id"],
            "status": f"exception:{str(e)}",
            "k1": None,
            "k2": None,
        }


def run_full_extraction(clean_csv, n_workers=4, max_images=None):
    df = pd.read_csv(clean_csv)
    df_train = df[df["use_in_train"]].copy()
    if max_images is not None:
        df_train = df_train.head(max_images)
        print(f"Test mode: processing only {max_images} images")

    print(f"Extracting k1/k2 from {len(df_train):,} images")
    print(f"Using {n_workers} parallel workers")
    print(f"Estimated time: ~{len(df_train) * 4 / n_workers / 60:.0f} min (edge-based, ~4s/img)")

    rows = df_train.to_dict("records")
    results = []
    with Pool(processes=n_workers) as pool:
        for result in tqdm(pool.imap(process_single, rows, chunksize=5), total=len(rows)):
            results.append(result)

    df_results = pd.DataFrame(results)
    out_path = "k1k2_extracted_test.csv" if max_images else "k1k2_extracted.csv"
    df_results.to_csv(out_path, index=False)
    print(f"Saved to {out_path}")

    print_extraction_summary(df_results)
    return df_results


def print_extraction_summary(df):
    total = len(df)
    good_statuses = ["excellent", "good", "acceptable"]
    good = df[df["status"].isin(good_statuses)]

    print("\n" + "=" * 55)
    print("K1/K2 EXTRACTION SUMMARY")
    print("=" * 55)
    print(f"\nTotal processed:  {total:,}")
    print(f"Successful:       {len(good):,} ({100*len(good)/total:.1f}%)")
    print(f"Failed:           {len(df)-len(good):,}")

    print("\nSuccess breakdown:")
    for s in ["excellent", "good", "acceptable"]:
        n = (df["status"] == s).sum()
        print(f"  {s:15s}: {n:,}")

    fail = df[~df["status"].isin(good_statuses)]
    if len(fail) > 0:
        print("\nFailure breakdown:")
        for s, cnt in fail["status"].value_counts().items():
            print(f"  {s[:35]:35s}: {cnt:,}")

    if len(good) > 0:
        print(f"\nK1 (successful): min={good['k1'].min():.4f} max={good['k1'].max():.4f} mean={good['k1'].mean():.4f}")
        print(f"K2 (successful): min={good['k2'].min():.4f} max={good['k2'].max():.4f} mean={good['k2'].mean():.4f}")
        print(f"Fit error: mean={good['fit_error'].mean():.4f}px median={good['fit_error'].median():.4f}px")
        print(f"Method: edge-based (all)")
        if 'improvement' in good.columns:
            print(f"Validation: mean improved={good['improvement'].mean():.2f}px")

In [17]:
def visualize_extraction_results(csv_path):
    df = pd.read_csv(csv_path)
    good = df[df["status"].isin(["excellent", "good", "acceptable"])]
    if len(good) == 0:
        print("No successful extractions to plot")
        return

    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    fig.suptitle("K1/K2 Extraction Results", fontsize=14)

    axes[0, 0].hist(good["k1"], bins=100, color="blue", alpha=0.7)
    axes[0, 0].set_xlabel("k1 value")
    axes[0, 0].set_title("K1 Distribution (negative = barrel)")
    axes[0, 0].axvline(good["k1"].median(), color="red", linestyle="--", label=f"median={good['k1'].median():.3f}")
    axes[0, 0].legend()

    axes[0, 1].hist(good["k2"], bins=100, color="green", alpha=0.7)
    axes[0, 1].set_xlabel("k2 value")
    axes[0, 1].set_title("K2 Distribution")

    axes[0, 2].scatter(good["k1"], good["k2"], alpha=0.1, s=3, color="purple")
    axes[0, 2].set_xlabel("k1")
    axes[0, 2].set_ylabel("k2")
    axes[0, 2].set_title("K1 vs K2 (should show curve)")

    axes[1, 0].hist(good["fit_error"].clip(0, 3), bins=100, color="orange", alpha=0.7)
    axes[1, 0].set_xlabel("fit error (px)")
    axes[1, 0].set_title("Reprojection Error")
    axes[1, 0].axvline(1.0, color="red", linestyle="--", label="1px")
    axes[1, 0].legend()

    axes[1, 1].scatter(good["k1"], good["fit_error"].clip(0, 3), alpha=0.1, s=3, color="red")
    axes[1, 1].set_xlabel("k1")
    axes[1, 1].set_ylabel("fit error (px)")
    axes[1, 1].set_title("K1 vs Fit Error")

    vals = good["improvement"].clip(0, 20) if "improvement" in good.columns else good["n_matches"].clip(0, 500)
    axes[1, 2].hist(vals, bins=50, color="teal", alpha=0.7)
    axes[1, 2].set_xlabel("improvement (px)" if "improvement" in good.columns else "matches")
    axes[1, 2].set_title("Improvement" if "improvement" in good.columns else "Match Count")

    plt.tight_layout()
    plt.savefig("extraction_results.png", dpi=150)
    plt.show()

### Step 1: Test on 5 pairs

In [19]:
# Check 0002be68 in shift analysis
df_shifts = pd.read_csv("shift_analysis.csv")
prop_check = df_shifts[df_shifts["image_id"].str.contains("0002be68")]
print("0002be68 in shift_analysis:")
print(prop_check[["image_id", "shift_mag", "is_shifted", "overall_diff"]].to_string())

# Edge-based extraction test: exclude bad properties, use clean normal/heavy
df = pd.read_csv("final_clean_dataset.csv")
bad_properties = ["0002be68", "0c3bb579"]
bad_pattern = "|".join(bad_properties)

clean_normal = df[
    (df["use_in_train"] == True)
    & (df["category"] == "normal")
    & (~df["image_id"].str.contains(bad_pattern))
    & (df["overall_diff"] > 8)
].head(5)

clean_heavy = df[
    (df["use_in_train"] == True)
    & (df["category"] == "heavy")
    & (~df["image_id"].str.contains(bad_pattern))
].head(5)


def run_test(subset, label, debug=True):
    print(f"\n{'='*100}")
    print(f"{label}")
    print(f"{'='*100}")
    print(f"{'image_id':<45} {'orig':>6} {'k1':>8} {'k2':>7} {'fit':>6} {'after':>7} {'improv':>8} {'status'}")
    print("-" * 100)
    for _, row in subset.iterrows():
        result = extract_k1_k2_edge_based(row.to_dict(), debug=debug)
        imp = result.get("improvement", 0) or 0
        imp_str = f"+{imp:.2f}" if imp > 0 else f"{imp:.2f}"
        print(
            f"{str(result['image_id'])[:44]:<45} "
            f"{result.get('orig_diff',0):>6.2f} "
            f"{result.get('k1') or 0:>8.4f} {result.get('k2') or 0:>7.4f} "
            f"{result.get('fit_error',0):>6.3f} {result.get('after_diff',0):>7.2f} "
            f"{imp_str:>8} {result.get('status','')}"
        )


run_test(clean_normal, "Clean Normal (diff>8, different properties)", debug=True)
run_test(clean_heavy, "Clean Heavy (different properties)", debug=True)

0002be68 in shift_analysis:
Empty DataFrame
Columns: [image_id, shift_mag, is_shifted, overall_diff]
Index: []

Clean Normal (diff>8, different properties)
image_id                                        orig       k1      k2    fit   after   improv status
----------------------------------------------------------------------------------------------------
  Coarse grid best: k1=-0.5293 err=19.8765
  Fine k1: k1=-0.5301 err=19.8737
  With k2: k1=-0.5301 k2=-0.0400 err=19.8519
  Joint refined: k1=-0.5190 k2=-0.0600 err=19.8305
  Pixel validation: orig=13.201 after=17.739 improv=-4.538
0077cdbd-7265-4202-88d1-cd158f88fab5_g0_orig   13.20  -0.5190 -0.0600 19.831   17.74    -4.54 validation_failed
  Coarse grid best: k1=-0.0188 err=18.9862
  Fine k1: k1=-0.0188 err=18.9862
  With k2: k1=-0.0188 k2=0.0500 err=18.9683
  Joint refined: k1=-0.0210 k2=0.0611 err=18.9589
  Pixel validation: orig=10.509 after=10.468 improv=0.041
0077cdbd-7265-4202-88d1-cd158f88fab5_g11_ori   10.51  -0.0210  0.0611

### Step 2: Test on 200 images (verify before full run)

In [None]:
# Test run: 200 images, 4 workers (safe for M2 Air)
df_test200 = run_full_extraction("full_dataset_clean.csv", n_workers=4, max_images=200)

Test mode: processing only 200 images
Extracting k1/k2 from 200 images
Using 4 parallel workers
Estimated time: ~4 min


  0%|          | 0/200 [00:00<?, ?it/s]Process SpawnPoolWorker-4:
Process SpawnPoolWorker-1:
Process SpawnPoolWorker-3:
Process SpawnPoolWorker-2:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/Users/darshanrao/anaconda3/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/Users/darshanrao/anaconda3/lib/python3.11/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/darshanrao/anaconda3/lib/python3.11/multiprocessing/pool.py", line 114, in worker
    task = get()
           ^^^^^
  File "/Users/darshanrao/anaconda3/lib/python3.11/multiprocessing/queues.py", line 367, in get
    return _ForkingPickler.loads(res)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/darshanrao/anaconda3/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/Users/darshanrao/ana

KeyboardInterrupt: 

### Step 3: Full extraction (~22k images, ~4 hours on 4 workers)

In [None]:
# Full run — remove max_images to process all. Run overnight.
# df_full = run_full_extraction("full_dataset_clean.csv", n_workers=4)

### Step 4: Visualize results

In [None]:
# After extraction, run visualization
# visualize_extraction_results("k1k2_extracted.csv")
# For test run:
visualize_extraction_results("k1k2_extracted_test.csv")