In [None]:
# Imports
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy.fft import fft, ifft, fftfreq
from sklearn.preprocessing import PolynomialFeatures, StandardScaler
from sklearn.linear_model import LinearRegression
from joblib import Parallel, delayed
from tqdm import tqdm
import pandas as pd
from pandas.tseries.offsets import MonthBegin
pd.set_option('display.float_format', lambda x: '%.3f' % x)

## Hard-Coded Variables 
- uses domain knowledge 

In [None]:
#hardcoded information
expected_groups = {
    ('A', 7),
    ('A', 8),
    ('B', 1),
    ('B', 2),
    ('C', 3),
    ('C', 4),
    ('D', 5),
    ('D', 6),
}

group_1 = {1,2,7,8}
group_2 = {3,4,5,6}
group_1_feature_bounds = {
    "wheeldiameter": (29.0, 41),
    "wheelwidth": (5.50, 5.9),
    "flangeheight": (0.90, 1.55),
    "flangethickness": (0.85, 1.75),
    "flangeslope": (0.10, 0.60),
    "treadhollow": (0.01, 5.3),
    "rimthickness": (0.75, 1.85),
    "backtobackgauge": (52.0, 53.0),
}
group_2_feature_bounds = {
    "wheeldiameter": (29.0, 46.0),
    "wheelwidth": (5.50, 5.87),
    "flangeheight": (0.90, 1.55),
    "flangethickness": (0.85, 1.75),
    "flangeslope": (0.10, 0.60),
    "treadhollow": (0.01, 5.1),
    "rimthickness": (0.80, 2.1),
    "backtobackgauge": (52.0, 53.0),
}

group_2_severity_thresholds = {'backtobackgauge': [(52.0, -5), (52.0, -4), (52.0, -3), (52.0, -2), (52.0, -1),
                     (53.0, 1), (53.0, 2), (53.0, 3), (53.0, 4), (53.0, 5)],
 'flangeheight': [(1.009, -5), (1.027, -4), (1.036, -3), (1.041, -2),
                  (1.057, -1), (1.467, 1), (1.487, 2), (1.492, 3), (1.5, 4),
                  (1.508, 5)],
 'flangeslope': [(0.119, -5), (0.122, -4), (0.124, -3), (0.126, -2),
                 (0.132, -1), (0.505, 1), (0.528, 2), (0.535, 3), (0.544, 4),
                 (0.556, 5)],
 'flangethickness': [(0.878, -5), (0.897, -4), (0.911, -3), (0.924, -2),
                     (0.969, -1), (1.544, 1), (1.572, 2), (1.582, 3),
                     (1.597, 4), (1.62, 5)],
 'rimthickness': [(0.852, -5), (0.868, -4), (0.88, -3), (0.889, -2), (0.92, -1),
                  (1.983, 1), (2.012, 2), (2.021, 3), (2.032, 4), (2.05, 5)],
 'treadhollow': [(0.01, -5), (0.01, -4), (0.01, -3), (0.01, -2), (0.01, -1),
                 (4.029, 1), (4.369, 2), (4.473, 3), (4.601, 4), (4.769, 5)],
 'wheeldiameter': [(29.0, -5), (29.134, -4), (30.0, -3), (30.63, -2),
                   (32.756, -1), (43.622, 1), (44.331, 2), (44.488, 3),
                   (44.724, 4), (45.118, 5)],
 'wheelwidth': [(5.516, -5), (5.526, -4), (5.533, -3), (5.539, -2), (5.554, -1),
                (5.812, 1), (5.831, 2), (5.837, 3), (5.843, 4), (5.852, 5)]}

group_1_severity_thresholds = {
    'backtobackgauge': [(52.0, -5), (52.0, -4), (52.0, -3), (52.0, -2), (52.0, -1),
                     (53.0, 1), (53.0, 2), (53.0, 3), (53.0, 4), (53.0, 5)],
 'flangeheight': [(1.03, -5), (1.044, -4), (1.051, -3), (1.058, -2),
                  (1.072, -1), (1.5, 1), (1.515, 2), (1.52, 3), (1.525, 4),
                  (1.533, 5)],
 'flangeslope': [(0.124, -5), (0.128, -4), (0.13, -3), (0.132, -2), (0.139, -1),
                 (0.52, 1), (0.543, 2), (0.55, 3), (0.559, 4), (0.572, 5)],
 'flangethickness': [(0.879, -5), (0.895, -4), (0.909, -3), (0.92, -2),
                     (0.96, -1), (1.575, 1), (1.604, 2), (1.615, 3), (1.63, 4),
                     (1.653, 5)],
 'rimthickness': [(0.778, -5), (0.796, -4), (0.81, -3), (0.822, -2),
                  (0.866, -1), (1.694, 1), (1.726, 2), (1.736, 3), (1.749, 4),
                  (1.768, 5)],
 'treadhollow': [(0.01, -5), (0.01, -4), (0.01, -3), (0.01, -2), (0.01, -1),
                 (4.196, 1), (4.53, 2), (4.636, 3), (4.757, 4), (4.938, 5)],
 'wheeldiameter': [(30.394, -5), (30.709, -4), (30.945, -3), (31.102, -2),
                   (31.575, -1), (38.976, 1), (39.528, 2), (39.764, 3),
                   (40.0, 4), (40.236, 5)],
 'wheelwidth': [(5.527, -5), (5.537, -4), (5.543, -3), (5.548, -2), (5.563, -1),
                (5.812, 1), (5.833, 2), (5.839, 3), (5.845, 4), (5.853, 5)]
}

geometry_features = ['flangeheight', 'rimthickness', 'wheeldiameter', 'wheelwidth','flangethickness','flangeslope', 'backtobackgauge','treadhollow', 'flangeangle']  


## Initializations
- loading datasets
- filtering invalid values and axles
- merging failure dfs into one

In [None]:
# Functions
def load_data(path_prefix: str = '../Datasets') -> dict:
    """
    Load all required datasets and fill missing failure reasons.
    Returns a dict of DataFrames.
    """
    files = {
        'failure': 'FailureTable0723.csv',
        'equipment': 'equipment_data_masked.csv',
        'mileage': 'Mileage0723.csv',
        'wpd': 'Wpd0723.csv',
    }
    dfs = {name: pd.read_csv(f"{path_prefix}/{fname}", engine='pyarrow') for name, fname in files.items()}
    dfs['failure']['failurereason'] = dfs['failure']['failurereason'].fillna('not failed')
    dfs['wpd']= dfs['wpd'][dfs['wpd']['traindate'] < '2024-12-01']
    return dfs



In [None]:
# merge with failurereason, and applieddate and mileage data
def fast_filter(df, name,valid_axles, feature_bounds):
    before = df.shape
    valid_trucks = {(truck) for (truck, axle) in expected_groups if axle in valid_axles}
    valid_pairs = {(truck, axle) for (truck, axle) in expected_groups if axle in valid_axles}

    if 'axle' not in df.columns:
        mask = df['truck'].isin(valid_trucks)
    else:
        mask = [pair in valid_pairs for pair in zip(df['truck'], df['axle'])]

    df_filtered = df[mask].copy()
    df_filtered[geometry_features + ['trainspeed']] = df_filtered[geometry_features + ['trainspeed']].replace(0, np.nan)
    for feature, (min_val, max_val) in feature_bounds.items():
        if feature in df_filtered.columns:
            too_low = df_filtered[feature] < (0.8 * min_val)
            too_high = df_filtered[feature] > (1.2 * max_val)

            df_filtered.loc[too_low | too_high, feature] = np.nan

            df_filtered.loc[(df_filtered[feature] >= (0.8 * min_val)) & (df_filtered[feature] < min_val), feature] = min_val
            df_filtered.loc[(df_filtered[feature] <= (1.2 * max_val)) & (df_filtered[feature] > max_val), feature] = max_val

    after = df_filtered.shape
    print(f"{name}: before = {before}, after = {after}")
    return df_filtered.reset_index(drop=True)


def fast_filter_by_group(df, name):
    dfs = df.copy()
    print(f"{name}: pre-filter shape: {dfs.shape}")
    df_group_1 = fast_filter(dfs, f"{name} (group1)", group_1, group_1_feature_bounds)
    df_group_2 = fast_filter(dfs, f"{name} (group2)", group_2, group_2_feature_bounds)
    df_combined = pd.concat([df_group_1, df_group_2], ignore_index=True)
    print(f"{name}: combined shape = {df_combined.shape}")

    return df_combined



In [None]:
#load data
dfs = load_data()

In [None]:
df_wpd = fast_filter_by_group(dfs['wpd'], 'wpd')
df_failure = dfs['failure']

In [None]:
df_wpd = df_wpd[df_wpd['equipmentnumber'] <= 5]
df_failure = df_failure[df_failure['equipmentnumber'] <= 5]


## Adding applieddates to wpd records

In [None]:
df_wpd['recordmonth'] = df_wpd['traindate'].values.astype('datetime64[M]')
df_wpd['recordmonth_next'] = df_wpd['recordmonth'] + pd.DateOffset(months=1)

# getting rid of duplicates due to vendornumbersuppliercode 
df_failure_temp = df_failure.sort_values('applieddate').drop_duplicates(['equipmentnumber', 'truck', 'axle', 'side', 'recordmonth'])
df_failure_temp['recordmonth'] = pd.to_datetime(df_failure_temp['recordmonth']).values.astype('datetime64[M]')

# merging current month's applieddate
merge_keys = ['equipmentnumber', 'truck', 'axle', 'side', 'recordmonth']
df_wpd = df_wpd.merge(
    df_failure_temp[merge_keys + ['applieddate']],
    on=merge_keys,
    how='left'
)
df_wpd.rename(columns={'applieddate': 'applieddate_initial'}, inplace=True)

# merging next month's applieddate 
merge_keys_next = ['equipmentnumber', 'truck', 'axle', 'side', 'recordmonth_next']
df_wpd = df_wpd.merge(
    df_failure_temp.rename(columns={'recordmonth': 'recordmonth_next', 'applieddate': 'applieddate_next'})[
        merge_keys_next + ['applieddate_next']
    ],
    on=merge_keys_next,
    how='left'
)

# if next months applieddate is before my traindate, I am apart of next month's wheel
df_wpd['applieddate'] = np.where(
    df_wpd['traindate'] >= df_wpd['applieddate_next'],
    df_wpd['applieddate_next'],
    df_wpd['applieddate_initial']
)

df_wpd = df_wpd.drop(columns=['recordmonth_next','applieddate_next','applieddate_initial'])

## Make DF for each wheel
- starts at the applieddate, ends at the day before the next applieddate
- apply FFT, decay-weighted rolling mean, and a combination of the 2 as a polynomial trend to the DF 
- more details in the notes

In [None]:
from numpy.fft import fft, ifft, fftfreq
from scipy.optimize import minimize_scalar
import pywt
from scipy.signal import find_peaks
from scipy.signal import welch
import numpy as np
from numpy.linalg import svd
from PyEMD import CEEMDAN

def eval_peak_metrics(orig, recon, peaks_true=None, height=None, distance=None):
    """
    Compute anomaly-preserving metrics: peak recall & attenuation ratio.
    """
    if peaks_true is None:
        peaks_true, _ = find_peaks(orig, height=height, distance=distance)
    peaks_recon, props = find_peaks(recon, height=height, distance=distance)

    if len(peaks_true) == 0:
        return np.nan, np.nan

    tol = 2  # indices within ±2 accepted
    matched = 0
    attn_ratios = []
    for pt in peaks_true:
        close = np.where(np.abs(peaks_recon - pt) <= tol)[0]
        if len(close) > 0:
            matched += 1
            attn_ratios.append(
                np.mean(recon[peaks_recon[close]]) / (orig[pt] + 1e-12)
            )

    recall = matched / len(peaks_true)
    attenuation = np.mean(attn_ratios) if attn_ratios else np.nan
    return recall, attenuation


def spectral_energy_retention(orig, recon, fs=1.0, band=(1/14, 1/1)):
    """
    Fraction of spectral energy retained in anomaly-relevant band.
    band is in cycles/day (e.g. 1/14 to 1/1 ~ 1-14 day periods).
    """
    f_orig, Pxx_orig = welch(orig, fs=fs, nperseg=min(256, len(orig)))
    f_recon, Pxx_recon = welch(recon, fs=fs, nperseg=min(256, len(recon)))

    band_mask_orig = (f_orig >= band[0]) & (f_orig <= band[1])
    band_mask_recon = (f_recon >= band[0]) & (f_recon <= band[1])

    energy_orig = np.trapz(Pxx_orig[band_mask_orig], f_orig[band_mask_orig])
    energy_recon = np.trapz(Pxx_recon[band_mask_recon], f_recon[band_mask_recon])

    return energy_recon / (energy_orig + 1e-12)


def false_peak_rate(orig, recon, height=None, distance=None):
    """
    Fraction of peaks in recon that are spurious (not in orig).
    """
    peaks_orig, _ = find_peaks(orig, height=height, distance=distance)
    peaks_recon, _ = find_peaks(recon, height=height, distance=distance)

    tol = 2
    false_peaks = 0
    for pr in peaks_recon:
        if len(peaks_orig) == 0 or np.min(np.abs(peaks_orig - pr)) > tol:
            false_peaks += 1

    return false_peaks / (len(recon) + 1e-12)


In [None]:
df_wpd['traindate'] = pd.to_datetime(df_wpd['traindate'])
df_wpd['applieddate'] = pd.to_datetime(df_wpd['applieddate'])
df_wpd = df_wpd.sort_values(by=['equipmentnumber', 'truck', 'axle', 'side', 'applieddate', 'traindate'])

def interpolate_daily_wheels(df, n_jobs=-1):
    group_cols = ['equipmentnumber', 'truck', 'axle', 'side']
    grouped = df.groupby(group_cols)

    tasks = []

    # Prepare tasks: each task = (sub_df, start_date, end_date)
    for group_keys, group_df in grouped:
        subgroups = group_df.groupby('applieddate')
        sub_keys_sorted = sorted(subgroups.groups.keys())

        for i, adate in enumerate(sub_keys_sorted):
            sub_df = subgroups.get_group(adate).copy()
            start_date = sub_df['applieddate'].min()
            if start_date < pd.to_datetime('2020-01-01'):
                start_date = pd.to_datetime('2020-01-01')
            if i + 1 < len(sub_keys_sorted):
                next_applieddate = sub_keys_sorted[i + 1]
                end_date = pd.to_datetime(next_applieddate) - pd.Timedelta(days=1)
            else:
                end_date = pd.NaT

            tasks.append((sub_df, start_date, end_date))

    # Wrap the processing function
    def process_task(sub_df, start_date, end_date):
        return signal_process_wheel(sub_df, start_date, end_date)

    results = Parallel(n_jobs=n_jobs)(
        delayed(process_task)(sub_df, start, end)
        for sub_df, start, end in tqdm(tasks, desc="Parallel wheel processing")
    )

    def chunked_concat(dfs, chunk_size=1000):
        chunks = [pd.concat(dfs[i:i+chunk_size], ignore_index=True)
                  for i in range(0, len(dfs), chunk_size)]
        return pd.concat(chunks, ignore_index=True)

    return chunked_concat(results)


def signal_process_wheel(df, start_date, end_date):
    # column definition
    static_cols = ['equipmentnumber', 'truck', 'axle', 'side', 'applieddate']
    monthly_cols = ['siteid', 'direction', 'trainspeed']

    # if no end date, that means it never updates again so end is the end of march
    if pd.isna(end_date):
        end_date = pd.to_datetime('2024-11-30')

    # making the df with static cols
    date_range = pd.date_range(start=start_date, end=end_date, freq='D')
    static_values = {col: df.iloc[0][col] for col in static_cols}
    base_df = pd.DataFrame(date_range, columns=['traindate'])
    for col, val in static_values.items():
        base_df[col] = val

    # merging the monthly rows from the main set
    df_subset = df[['traindate'] + monthly_cols + geometry_features].copy()
    df_subset = df_subset.sort_values('traindate', ascending=False)  # keep latest per day
    df_subset['traindate'] = pd.to_datetime(df_subset['traindate']).dt.normalize()
    df_subset = df_subset.drop_duplicates(subset='traindate')

    merged_df = pd.merge(base_df, df_subset, on='traindate', how='left')

    # forward-fill back fill monthly rows
    #merged_df[monthly_cols] = merged_df[monthly_cols].ffill().bfill()

    # go through each feature for FFT + rolling + poly analysis
    for feat in geometry_features:
        df_feat = merged_df[['traindate', feat]].copy()
        df_feat[feat] = df_feat[feat].replace(0, np.nan)
        df_feat = df_feat.dropna().sort_values('traindate')

        # --- handle trivial cases ---
        if len(df_feat) < 3:
            merged_df[f'{feat}_original'] = np.nan
            for suffix in ["lowpass", "loess", "ema","gpr","kalman","l1trend"]:
                merged_df[f'{feat}_{suffix}'] = np.nan
                merged_df[f'{feat}_{suffix}_param1'] = np.nan
                merged_df[f'{feat}_{suffix}_score'] = np.nan
            for suffix in ["ssa"]:
                merged_df[f'{feat}_{suffix}'] = np.nan
                merged_df[f'{feat}_{suffix}_param1'] = np.nan
                merged_df[f'{feat}_{suffix}_param2'] = np.nan
                merged_df[f'{feat}_{suffix}_score'] = np.nan
            continue

        # --- data prep ---
        x = (df_feat['traindate'] - start_date).dt.days.values
        y = df_feat[feat].values
        max_days = (end_date - start_date).days

        x_uniform = np.linspace(x.min(), max_days, len(x))
        y_interp = np.interp(x_uniform, x, y)

        x_days_full = (date_range - start_date).days.values
        original_daily = np.interp(x_days_full, x, y)

        x_min, x_max = x.min(), x.max()

        # =======================================================
        # FFT HIGHPASS (short-term residual reconstruction)
        # =======================================================

        # --- objective function (maximize anomaly fidelity) ---
        def objective_fft_highpass(cutoff):
            n = len(x_uniform)
            d = x_uniform[1] - x_uniform[0]
            freqs = fftfreq(n, d=d)

            Y = fft(y_interp)
            Y_high = Y.copy()
            Y_high[np.abs(freqs) < cutoff] = 0   # remove low freqs, keep high
            highpass = np.real(ifft(Y_high))

            highpass_daily = np.interp(x_days_full, x_uniform, highpass)

            # mask outside observed range
            highpass_daily[x_days_full < x_min] = np.nan
            highpass_daily[x_days_full > x_max] = np.nan
            orig_masked = original_daily.copy()
            orig_masked[x_days_full < x_min] = np.nan
            orig_masked[x_days_full > x_max] = np.nan

            # --- evaluation metrics (only fair ones) ---
            peak_recall, attn_ratio = eval_peak_metrics(orig_masked, highpass_daily)
            spec_reten = spectral_energy_retention(orig_masked, highpass_daily, fs=1.0, band=(1/14,1/1))
            fpr = false_peak_rate(orig_masked, highpass_daily)

            # --- combine into a single objective ---
            # want: HIGH recall, HIGH attenuation, HIGH spectral, LOW false peaks
            score = (
                + 2.0 * (peak_recall if not np.isnan(peak_recall) else 0.0)
                + 2.0 * (attn_ratio if not np.isnan(attn_ratio) else 0.0)
                + 0.5 * (spec_reten if not np.isnan(spec_reten) else 0.0)
                - 1.0 * fpr
            )

            return -score  # negative because minimize_scalar minimizes


        # --- optimize cutoff ---
        res_fft = minimize_scalar(objective_fft_highpass, bounds=(0.0001, 0.1), method="bounded")
        best_cutoff_fft = res_fft.x
        best_score_fft = -res_fft.fun

        # --- reconstruct highpass with best cutoff ---
        n = len(x_uniform)
        d = x_uniform[1] - x_uniform[0]
        freqs = fftfreq(n, d=d)
        Y = fft(y_interp)
        Y[np.abs(freqs) < best_cutoff_fft] = 0
        highpass = np.real(ifft(Y))
        best_series_fft = np.interp(x_days_full, x_uniform, highpass)

        # mask outside observed range
        best_series_fft[x_days_full < x_min] = np.nan
        best_series_fft[x_days_full > x_max] = np.nan
        original_daily[x_days_full < x_min] = np.nan
        original_daily[x_days_full > x_max] = np.nan

        # --- save FFT HIGHPASS result ---
        merged_df[f'{feat}_original'] = original_daily
        merged_df[f'{feat}_highpass'] = best_series_fft
        merged_df[f'{feat}_highpass_param1'] = best_cutoff_fft
        merged_df[f'{feat}_highpass_score'] = best_score_fft

        # =======================================================
        # WAVELET DENOISING / HIGHPASS RECONSTRUCTION
        # =======================================================

        # --- objective function (maximize anomaly fidelity) ---
        def objective_wavelet_thresh(thresh):
            # decompose with discrete wavelet transform
            coeffs = pywt.wavedec(y_interp, wavelet="db4", level=None)
            
            # threshold detail coefficients (remove small ones, keep larger = high freq)
            coeffs_thresh = [coeffs[0]]  # keep approximation untouched
            for detail in coeffs[1:]:
                coeffs_thresh.append(pywt.threshold(detail, thresh, mode="hard"))
            
            # reconstruct
            y_wave = pywt.waverec(coeffs_thresh, wavelet="db4")
            
            # align to daily grid
            highpass_daily = np.interp(x_days_full, x_uniform, y_wave[:len(x_uniform)])
            
            # mask outside observed range
            highpass_daily[x_days_full < x_min] = np.nan
            highpass_daily[x_days_full > x_max] = np.nan
            orig_masked = original_daily.copy()
            orig_masked[x_days_full < x_min] = np.nan
            orig_masked[x_days_full > x_max] = np.nan
            
            # --- evaluation metrics ---
            peak_recall, attn_ratio = eval_peak_metrics(orig_masked, highpass_daily)
            spec_reten = spectral_energy_retention(orig_masked, highpass_daily, fs=1.0, band=(1/14,1/1))
            fpr = false_peak_rate(orig_masked, highpass_daily)
            
            # combine into single score
            score = (
                + 2.0 * (peak_recall if not np.isnan(peak_recall) else 0.0)
                + 2.0 * (attn_ratio if not np.isnan(attn_ratio) else 0.0)
                + 0.5 * (spec_reten if not np.isnan(spec_reten) else 0.0)
                - 1.0 * fpr
            )
            
            return -score  # negative because minimize_scalar minimizes


        # --- optimize threshold ---
        res_wav = minimize_scalar(objective_wavelet_thresh, bounds=(0.01, 5.0), method="bounded")
        best_thresh = res_wav.x
        best_score_wav = -res_wav.fun

        # --- reconstruct with best threshold ---
        coeffs = pywt.wavedec(y_interp, wavelet="db4", level=None)
        coeffs_best = [coeffs[0]]
        for detail in coeffs[1:]:
            coeffs_best.append(pywt.threshold(detail, best_thresh, mode="hard"))
        y_wave = pywt.waverec(coeffs_best, wavelet="db4")

        best_series_wav = np.interp(x_days_full, x_uniform, y_wave[:len(x_uniform)])

        # mask outside observed range
        best_series_wav[x_days_full < x_min] = np.nan
        best_series_wav[x_days_full > x_max] = np.nan
        original_daily[x_days_full < x_min] = np.nan
        original_daily[x_days_full > x_max] = np.nan

        merged_df[f'{feat}_wavelet'] = best_series_wav
        merged_df[f'{feat}_wavelet_param1'] = best_thresh
        merged_df[f'{feat}_wavelet_score'] = best_score_wav

        # =======================================================
        # SSA (Singular Spectrum Analysis) HIGHPASS RECONSTRUCTION
        # =======================================================

        def _ssa_hankelize(y, L):
            """
            Build the trajectory (Hankel) matrix of shape (L, K) from 1D series y.
            """
            N = len(y)
            K = N - L + 1
            return np.column_stack([y[i:i+L] for i in range(K)])  # (L, K)

        def _ssa_diagonal_averaging(X):
            """
            Anti-diagonal (Hankel) averaging to go from matrix back to 1D series.
            """
            L, K = X.shape
            N = L + K - 1
            y = np.zeros(N, dtype=float)
            w = np.zeros(N, dtype=float)
            for i in range(L):
                for j in range(K):
                    y[i + j] += X[i, j]
                    w[i + j] += 1.0
            return y / np.maximum(w, 1e-12)

        def _ssa_lowrank_recon(y, L, r):
            """
            Reconstruct the LOW-frequency (trend/seasonal) part using the first r SSA components.
            """
            X = _ssa_hankelize(y, L)                      # (L, K)
            U, s, Vt = np.linalg.svd(X, full_matrices=False)
            r = max(1, min(r, len(s)))                    # safety
            # Low-rank approximation using first r components
            Xr = np.zeros_like(X)
            for i in range(r):
                Xr += s[i] * np.outer(U[:, i], Vt[i, :])
            # Back to 1D
            y_low = _ssa_diagonal_averaging(Xr)
            # Trim/Pad to original length
            return y_low[:len(y)]

        # Choose an SSA window length (L). Good starting point: a bit larger than your longest "trend" cycle.
        # You can set this explicitly if you like; here we pick something reasonable from the data length.
        N = len(y_interp)
        L_ssa = int(np.clip(round(min(max(30, N // 6), N // 2 - 1)), 2, max(2, N-1)))

        # --- objective function (maximize anomaly fidelity) ---
        def objective_ssa_rank(rank_float):
            # rank is the number of leading components to REMOVE (i.e., model as low-freq).
            # We'll reconstruct low-freq with first r comps, then take residual = high-pass.
            r = int(np.clip(round(rank_float), 1, min(L_ssa, N - L_ssa + 1)))  # 1..min(L,K)

            # low-frequency reconstruction and high-pass residual
            y_low = _ssa_lowrank_recon(y_interp, L=L_ssa, r=r)
            y_hp = y_interp - y_low

            # align to daily grid
            highpass_daily = np.interp(x_days_full, x_uniform, y_hp[:len(x_uniform)])

            # mask outside observed range
            highpass_daily[x_days_full < x_min] = np.nan
            highpass_daily[x_days_full > x_max] = np.nan
            orig_masked = original_daily.copy()
            orig_masked[x_days_full < x_min] = np.nan
            orig_masked[x_days_full > x_max] = np.nan

            # --- evaluation metrics ---
            peak_recall, attn_ratio = eval_peak_metrics(orig_masked, highpass_daily)
            spec_reten = spectral_energy_retention(orig_masked, highpass_daily, fs=1.0, band=(1/14, 1/1))
            fpr = false_peak_rate(orig_masked, highpass_daily)

            # combine into single score (same weights as your latest wavelet block)
            score = (
                + 2.0 * (0.0 if np.isnan(peak_recall) else peak_recall)
                + 1.0 * (0.0 if np.isnan(attn_ratio) else attn_ratio)
                + 0.5 * (0.0 if np.isnan(spec_reten) else spec_reten)
                - 1.0 * (0.0 if np.isnan(fpr) else fpr)
            )
            return -score  # minimize -> maximize score

        # --- optimize rank (number of low-freq comps to remove) ---
        # bounds are floats; we round inside the objective
        max_rank = min(L_ssa, N - L_ssa + 1)  # r ≤ number of singular values available
        res_ssa = minimize_scalar(objective_ssa_rank, bounds=(1, max_rank), method="bounded")
        best_rank = int(np.clip(round(res_ssa.x), 1, max_rank))
        best_score_ssa = -res_ssa.fun

        # --- reconstruct with best rank ---
        y_low_best = _ssa_lowrank_recon(y_interp, L=L_ssa, r=best_rank)
        y_hp_best = y_interp - y_low_best
        best_series_ssa = np.interp(x_days_full, x_uniform, y_hp_best[:len(x_uniform)])

        # mask outside observed range
        best_series_ssa[x_days_full < x_min] = np.nan
        best_series_ssa[x_days_full > x_max] = np.nan
        original_daily[x_days_full < x_min] = np.nan
        original_daily[x_days_full > x_max] = np.nan

        # --- save SSA result ---
        merged_df[f'{feat}_ssa'] = best_series_ssa
        merged_df[f'{feat}_ssa_param1'] = best_rank         # number of leading SSA components removed
        merged_df[f'{feat}_ssa_param2'] = L_ssa           # window length used (for reproducibility)
        merged_df[f'{feat}_ssa_score'] = best_score_ssa


    # dropping raw cols as we now have _original cols
    #merged_df = merged_df.drop(columns=geometry_features)
    merged_df['days'] = (pd.to_datetime(merged_df['traindate']) - pd.to_datetime(start_date)).dt.days
    return merged_df

In [None]:
full_wpd = interpolate_daily_wheels(df_wpd, -1)
#full_wpd.to_feather('wpd_trend_null.feather')

In [None]:
def compute_signal_metrics(full_wpd, geometry_features, type, param1, use_param2=False, anomaly_labels=None):
    results = []

    for feat in geometry_features:
        orig = full_wpd[f"{feat}_original"].values
        recon = full_wpd[f"{feat}_{type}"].values

        mask = ~np.isnan(orig) & ~np.isnan(recon)
        if mask.sum() < 3:
            metrics = dict(
                peak_recall=np.nan,
                attenuation_ratio=np.nan,
                spectral_retention=np.nan,
                false_peak_rate=np.nan
            )
        else:
            # Peak recall & attenuation
            peak_recall, attn_ratio = eval_peak_metrics(orig[mask], recon[mask])

            # Spectral energy retention (example band: 1–14 days)
            spec_reten = spectral_energy_retention(orig[mask], recon[mask], fs=1.0, band=(1/14,1/1))

            # False peak rate
            fpr = false_peak_rate(orig[mask], recon[mask])

            metrics = dict(
                peak_recall=peak_recall,
                attenuation_ratio=attn_ratio,
                spectral_retention=spec_reten,
                false_peak_rate=fpr
            )

        # param1 always present
        param1_val = full_wpd[f"{feat}_{type}_param1"].dropna().mean()
        record = {"feature": feat, "param1": param1_val}

        if use_param2:
            param2_val = full_wpd[f"{feat}_{type}_param2"].dropna().mean()
            record["param2"] = param2_val

        record.update(metrics)
        results.append(record)

    return pd.DataFrame(results).set_index("feature")


for type in ["highpass","wavelet"]:
    summary = compute_signal_metrics(full_wpd, geometry_features, type, "param1")
    display(summary)   # pretty df display

summary = compute_signal_metrics(full_wpd, geometry_features, "ssa", "param1", use_param2=True)
display(summary)


In [None]:
import matplotlib.pyplot as plt
temp = full_wpd[(full_wpd['equipmentnumber'] == 1) & (full_wpd['axle'] == 1) & (full_wpd['side'] == 'L')]['applieddate'].unique()
temp = full_wpd[(full_wpd['equipmentnumber'] == 1) & (full_wpd['axle'] == 1) & (full_wpd['side'] == 'L') & (full_wpd['applieddate'] == temp[1])]
# make sure we have days (if not already in full_wpd)
if "days" not in temp.columns:
    temp["days"] = (temp["traindate"] - temp["traindate"].min()).dt.days

signal = "kalman"
# find all features that have a lowpass
features = [c.replace(f"_{signal}", "") for c in temp.columns if c.endswith(f"_{signal}")]

for feat in features:
    raw_col   = f"{feat}"
    signal_col    = f"{feat}_{signal}"

    if raw_col not in temp or signal_col not in temp:
        continue

    plt.figure(figsize=(10, 5))
    plt.scatter(temp["days"], temp[raw_col], label="Raw", alpha=0.3)
    plt.plot(temp["days"], temp[signal_col], color='orange', label=signal.capitalize())

    plt.title(f"{feat} - Raw vs {signal.capitalize()}")
    plt.xlabel("Days")
    plt.ylabel("Value")
    plt.legend()
    plt.grid(alpha=0.3)
    plt.show()
    break
