# $k$ Calibration

According to the methods presented by Lin *et al.*,  output `*_fit_results.csv`.

In [2]:
import torch, math, os, json, itertools
import numpy as np
import pandas as pd
from pathlib import Path
from scipy.stats import invgauss, levy

## 1. Config

In [3]:
OUTPUT_PREFIX = 'RAW'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

MEAN_BINS = 50
DIFF_BINS = 50
COUNT_THRESHOLD = 100
MIN_EVENTS_PER_GROUP = 50

# τ filter
MAX_INTERVAL_US = 50000   # 50 ms


## 2. Read event tensor

In [None]:
def load_dvs_events(pt_path, device='cuda'):
    """Load and process DVS event data.

    Args:
        pt_path (str): Path to the .pt file
        device (str): Device to load the data on ('cuda' or 'cpu')

    Returns:
        tuple: (ts, x, y, p, prev_lum, next_lum, frame_dt)
            - ts: Timestamps (in seconds)
            - x, y: Pixel coordinates
            - p: Polarity
            - prev_lum, next_lum: Luminance from previous and next frames
            - frame_dt: Frame time interval (in seconds)
    """
    data = torch.load(pt_path, map_location=device)
    if data.dim() != 2 or data.size(1) != 7:
        raise ValueError('Expected shape (N, 7), got {}'.format(data.size()))

    timestamp_us, x, y, p, prev_lum, next_lum, frame_dt_us = data.t().cpu().numpy()

    # Convert data types
    timestamp_us = timestamp_us.astype(np.float64)
    frame_dt_us = frame_dt_us.astype(np.float64)
    prev_lum = prev_lum.astype(np.float32)
    next_lum = next_lum.astype(np.float32)

    df = pd.DataFrame(dict(
        timestamp_us=timestamp_us,
        x=x,
        y=y,
        p=p,
        prev_lum=prev_lum,
        next_lum=next_lum,
        frame_dt_us=frame_dt_us
    ))
    
    df_on  = df[df.p == 1]
    df_off = df[df.p == 0]
    
    return df_on, df_off

# Load ON and OFF event DataFrames from file
pt_path = Path('../k_calib_20250509/calib_1/frames_analysis_full/events_with_luminance_raw.pt') 
df_on, df_off = load_dvs_events(pt_path, DEVICE)


## 3. Calculate event interval τ, mean luminance $\bar L$ and luminance difference $\Delta L$

In [5]:
def df_to_arrays(sub):
    """Convert a sub DataFrame into the 6 ndarrays required by compute_event_intervals"""
    return (
        sub.timestamp_us.values,
        sub.x.values,
        sub.y.values,
        sub.prev_lum.values,
        sub.next_lum.values,
        sub.frame_dt_us.values
    )

ts_on,  x_on,  y_on,  prev_on,  next_on,  dtframe_on  = df_to_arrays(df_on)
ts_off, x_off, y_off, prev_off, next_off, dtframe_off = df_to_arrays(df_off)


In [None]:
def compute_event_intervals(ts, x, y, prev_lum, next_lum, frame_dt, max_interval_us=1e10):
    """Compute time intervals and luminance changes between consecutive events.

    Args:
        ts (np.ndarray): Timestamp array (in microseconds)
        x, y (np.ndarray): Pixel coordinate arrays
        prev_lum, next_lum (np.ndarray): Luminance values from previous and next frames
        frame_dt (np.ndarray): Frame time interval array (in microseconds)
        max_interval_us (float): Maximum allowed time interval between events (in µs)

    Returns:
        tuple: (tau, Lbar, dL, dt)
            - tau: Time interval between consecutive events (in µs)
            - Lbar: Mean luminance
            - dL: Luminance difference
            - dt: Corresponding frame interval (in µs)
    """
    # Sort by (x, y, ts)
    order = np.lexsort((ts, y, x))
    ts_sorted = ts[order]
    x_sorted = x[order]
    y_sorted = y[order]
    prev_sorted = prev_lum[order]
    next_sorted = next_lum[order]
    frame_dt_sorted = frame_dt[order]
    
    # Find index of previous event for the same pixel
    N = len(ts_sorted)
    prev_idx = np.full(N, -1, dtype=np.int64)
    mask_same = (x_sorted[1:] == x_sorted[:-1]) & (y_sorted[1:] == y_sorted[:-1])
    prev_idx[1:][mask_same] = np.arange(N - 1)[mask_same]
    
    valid = prev_idx != -1
    # Compute τ
    tau = ts_sorted - ts_sorted[prev_idx]
    tau = tau[valid]
    
    # Filter out abnormal τ values
    valid_tau = (tau > 0) & (tau < max_interval_us)
    tau = tau[valid_tau]
    
    # Compute luminance values
    Lbar = (prev_sorted[prev_idx] + next_sorted[prev_idx]) / 2
    Lbar = Lbar[valid][valid_tau]
    dL = (next_sorted[prev_idx] - prev_sorted[prev_idx])[valid][valid_tau]
    
    # Also extract corresponding frame interval Δt
    dt = frame_dt_sorted[prev_idx][valid][valid_tau]
    
    return tau, Lbar, dL, dt

# Compute intervals for ON events
tau_on, Lbar_on, dL_on, dt_on = compute_event_intervals(
    ts_on, x_on, y_on, prev_on, next_on, dtframe_on, MAX_INTERVAL_US
)

# Compute intervals for OFF events
tau_off, Lbar_off, dL_off, dt_off = compute_event_intervals(
    ts_off, x_off, y_off, prev_off, next_off, dtframe_off, MAX_INTERVAL_US
)

print("Number of valid ON event pairs:", len(tau_on))
print("Number of valid OFF event pairs:", len(tau_off))


## 4. Automatically estimate valid luminance range

In [None]:
def estimate_valid_ranges(Lbar, dL, mean_bins=30, diff_bins=30, count_threshold=100):
    """Estimate valid ranges for luminance and luminance difference.

    Args:
        Lbar (np.ndarray): Array of mean luminance
        dL (np.ndarray): Array of luminance differences
        mean_bins (int): Number of bins for mean luminance
        diff_bins (int): Number of bins for luminance difference
        count_threshold (int): Minimum number of events for a valid bin

    Returns:
        tuple: (L_min, L_max, dL_min, dL_max)
            - L_min, L_max: Valid range of mean luminance
            - dL_min, dL_max: Valid range of luminance difference
    """
    H, edges_L, edges_dL = np.histogram2d(Lbar, dL, bins=[mean_bins, diff_bins])
    
    # Find bins where count > count_threshold
    idx_valid = np.where(H > count_threshold)
    if len(idx_valid[0]) == 0:
        L_min, L_max = Lbar.min(), Lbar.max()
        dL_min, dL_max = dL.min(), dL.max()
    else:
        L_min = edges_L[idx_valid[0]].min()
        L_max = edges_L[idx_valid[0] + 1].max()
        dL_min = edges_dL[idx_valid[1]].min()
        dL_max = edges_dL[idx_valid[1] + 1].max()

    return L_min, L_max, dL_min, dL_max

# Estimate valid ranges for ON events
L_min_on, L_max_on, dL_min_on, dL_max_on = estimate_valid_ranges(
    Lbar_on, dL_on, 
    mean_bins=MEAN_BINS, 
    diff_bins=DIFF_BINS, 
    count_threshold=COUNT_THRESHOLD
)

# Estimate valid ranges for OFF events
L_min_off, L_max_off, dL_min_off, dL_max_off = estimate_valid_ranges(
    Lbar_off, dL_off, 
    mean_bins=MEAN_BINS, 
    diff_bins=DIFF_BINS, 
    count_threshold=COUNT_THRESHOLD
)

print('Valid L_on range:', L_min_on, L_max_on)
print('Valid dL_on range:', dL_min_on, dL_max_on)
print('Valid L_off range:', L_min_off, L_max_off)
print('Valid dL_off range:', dL_min_off, dL_max_off)


## 5. Bin and fit inverse Gaussian distribution

In [None]:
def fit_event_distributions(Lbar, dL, tau, L_min, L_max, dL_min, dL_max, p,
                            mean_bins=50, diff_bins=50, min_events_per_group=200):
    """Fit distributions to event data and return the results.

    Args:
        Lbar (np.ndarray): Array of mean luminance
        dL (np.ndarray): Array of luminance differences
        tau (np.ndarray): Array of event time intervals
        L_min, L_max (float): Valid range for mean luminance
        dL_min, dL_max (float): Valid range for luminance difference
        mean_bins (int): Number of bins for mean luminance
        diff_bins (int): Number of bins for luminance difference
        min_events_per_group (int): Minimum number of events per group
        output_prefix (str): Prefix for output file names

    Returns:
        pd.DataFrame: A DataFrame containing the fitted results
    """
    # Create bin grids
    L_bins = np.linspace(L_min, L_max, mean_bins + 1)
    dL_bins = np.linspace(dL_min, dL_max, diff_bins + 1)

    MU_THRESHOLD_KDL = 1e-4  # Drift rate |k_dL| < 1e-4 is treated as zero

    results = []
    for i in range(mean_bins):
        for j in range(diff_bins):
            mask = (Lbar >= L_bins[i]) & (Lbar < L_bins[i + 1]) & \
                   (dL >= dL_bins[j]) & (dL < dL_bins[j + 1])
            if mask.sum() < min_events_per_group:
                continue

            tau_bin = tau[mask]  # Use only the data in the current bin
            shape, loc, scale = invgauss.fit(tau_bin, floc=0)

            mu_hat = shape * scale           # Corresponds to |Θ| / μ in the paper
            lambda_hat = shape / scale       # Corresponds to Θ² / σ² in the paper
            if p == 1:
                mu = -1.0 / mu_hat           # Negate for ON events
            else:                             # For OFF events
                mu = 1.0 / mu_hat
            sigma = 1 / math.sqrt(lambda_hat)

            if mu_hat < MU_THRESHOLD_KDL:
                # ----- Switch to Lévy fitting -----
                loc_lv, scale_lv = levy.fit(tau_bin, floc=0)
                mu = 0.0                     # μ → 0
                mu_hat = 0.0
                lambda_hat = scale_lv        # λ = c  (Θ² / σ²)
                sigma = 1.0 / math.sqrt(lambda_hat)

            results.append({
                "P": p,
                'MeanMin': L_bins[i],
                'MeanMax': L_bins[i + 1],
                'DiffMin': dL_bins[j],
                'DiffMax': dL_bins[j + 1],
                'MuHat': mu_hat,
                'LambdaHat': lambda_hat,
                'Mu': mu,
                'Sigma': sigma,
                'Count': mask.sum()
            })

    df = pd.DataFrame(results)
    return df

# Fit ON events
df_results_on = fit_event_distributions(
    Lbar=Lbar_on,
    dL=dL_on,
    tau=tau_on,
    L_min=L_min_on,
    L_max=L_max_on,
    dL_min=dL_min_on,
    dL_max=dL_max_on,
    p=1,
    mean_bins=MEAN_BINS,
    diff_bins=DIFF_BINS,
    min_events_per_group=MIN_EVENTS_PER_GROUP
)

# Fit OFF events
df_results_off = fit_event_distributions(
    Lbar=Lbar_off,
    dL=dL_off,
    tau=tau_off,
    L_min=L_min_off,
    L_max=L_max_off,
    dL_min=dL_min_off,
    dL_max=dL_max_off,
    p=0,
    mean_bins=MEAN_BINS,
    diff_bins=DIFF_BINS,
    min_events_per_group=MIN_EVENTS_PER_GROUP
)

# Combine and save
df_results = pd.concat([df_results_on, df_results_off], ignore_index=True)
out_csv = f'{OUTPUT_PREFIX}_fit_results.csv'
df_results.to_csv(out_csv, index=False)
print('Saved:', out_csv, 'Number of bins:', len(df_results))


# Process all subfolders

In [None]:
def process_pt(pt_path, mode):
    """Process a single .pt file and return the fitting result.

    Args:
        pt_path (Path): Path to the .pt file
        mode (str): Data mode ('RAW' or 'RGB')
    """
    df_on, df_off= load_dvs_events(pt_path, DEVICE)
    ts_on,  x_on,  y_on,  prev_on,  next_on,  dtframe_on  = df_to_arrays(df_on)
    ts_off, x_off, y_off, prev_off, next_off, dtframe_off = df_to_arrays(df_off)
    tau_on,  Lbar_on,  dL_on,  dt_on  = compute_event_intervals(
        ts_on,  x_on,  y_on,  prev_on,  next_on,  dtframe_on,  MAX_INTERVAL_US
    )
    tau_off, Lbar_off, dL_off, dt_off = compute_event_intervals(
        ts_off, x_off, y_off, prev_off, next_off, dtframe_off, MAX_INTERVAL_US
    )
    L_min_on, L_max_on, dL_min_on, dL_max_on = estimate_valid_ranges(
        Lbar_on, dL_on, 
        mean_bins=MEAN_BINS, 
        diff_bins=DIFF_BINS, 
        count_threshold=COUNT_THRESHOLD
    )
    L_min_off, L_max_off, dL_min_off, dL_max_off = estimate_valid_ranges(
        Lbar_off, dL_off, 
        mean_bins=MEAN_BINS, 
        diff_bins=DIFF_BINS, 
        count_threshold=COUNT_THRESHOLD
    )
    df_results_on = fit_event_distributions(
        Lbar=Lbar_on,
        dL=dL_on,
        tau=tau_on,
        L_min=L_min_on,
        L_max=L_max_on,
        dL_min=dL_min_on,
        dL_max=dL_max_on,
        p=1,
        mean_bins=MEAN_BINS,
        diff_bins=DIFF_BINS,
        min_events_per_group=MIN_EVENTS_PER_GROUP
    )
    df_results_off = fit_event_distributions(
        Lbar=Lbar_off,
        dL=dL_off,
        tau=tau_off,
        L_min=L_min_off,
        L_max=L_max_off,
        dL_min=dL_min_off,
        dL_max=dL_max_off,
        p=0,
        mean_bins=MEAN_BINS,
        diff_bins=DIFF_BINS,
        min_events_per_group=MIN_EVENTS_PER_GROUP
    )
    # 合并 ON+OFF 回传
    return pd.concat([df_results_on, df_results_off], ignore_index=True)


# 顶层实验输出目录
ROOT_DIR = Path('../data')

for mode in ['RAW', 'RGB']:                        # 2 different bit‑depth
    all_dfs = []
    # search for .pt
    for pt_path in ROOT_DIR.rglob(f'events_with_luminance_{mode.lower()}.pt'):
        try:
            df_one = process_pt(pt_path, mode)     # process single file
            all_dfs.append(df_one)
            print('processed', pt_path)
        except Exception as e:
            print('skip', pt_path, e)

    if not all_dfs:
        continue

    # merge
    merged = pd.concat(all_dfs, ignore_index=True)
    merged.to_csv(f'{mode}_fit_results_all.csv', index=False)
    print(f'{mode}_fit_results_all.csv written, buckets =', len(merged))