In [None]:
'''Outlier detection'''

# Function to detect outliers based z-score for mean, std, and AUC metrics. 
def get_z_outliers(df, signal_type, threshold=3):
    """
    Detects outliers based on mean, std, and AUC for a given signal type.
    Returns list of outlier sample IDs.
    """
    signal_rows = [idx for idx in df.index if signal_type in idx]
    values = df.loc[signal_rows].astype(float)

    means = values.mean(axis=1)
    stds = values.std(axis=1)
    aucs = values.apply(lambda row: np.trapz(row.values), axis=1)

    stats_df = pd.DataFrame({
        'sample_id': signal_rows,
        'mean': means,
        'std': stds,
        'auc': aucs
    })

    outliers = set()
    for feature in ['mean', 'std', 'auc']:
        modified_z = (stats_df[feature] - stats_df[feature].mean()) / stats_df[feature].std()
        flagged = stats_df[np.abs(modified_z) > threshold]['sample_id']
        outliers.update(flagged)

    return list(outliers)

def plot_outliers_vs_normal(df, signal_type, outlier_ids, class_id):
    import matplotlib.pyplot as plt
    
    signal_rows = [idx for idx in df.index if signal_type in idx]
    outlier_rows = [idx for idx in signal_rows if idx in outlier_ids]
    normal_rows = [idx for idx in signal_rows if idx not in outlier_ids]

    plt.figure(figsize=(12, 5))
    for row in random.sample(normal_rows, min(1, len(normal_rows))):
        plt.plot(df.loc[row].astype(float), alpha=0.4, color='blue', label='Normal' if 'Normal' not in plt.gca().get_legend_handles_labels()[1] else "")
    for row in random.sample(outlier_rows, min(1, len(outlier_rows))):
        plt.plot(df.loc[row].astype(float), alpha=0.6, color='red', linestyle='--', label='Outlier' if 'Outlier' not in plt.gca().get_legend_handles_labels()[1] else "")

    plt.title(f"Class {class_id} — {signal_type}: Outliers vs Normal")
    plt.xlabel("Time Step")
    plt.ylabel("Amplitude")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()



# Detect outliers for each class and signal type
cleaned_dataframes = []
all_outliers = []

for class_idx, df in enumerate(dataframes):

    class_outliers = set()
    for sig_type in ['signal_raw', 'signal_deriv', 'control_raw', 'control_deriv']:
        signal_outliers = get_z_outliers(df, signal_type=sig_type)
        if signal_outliers:
            print(f"Class {class_idx+1}, {sig_type}: found {len(signal_outliers)} outliers")
        else:
            print(f"Class {class_idx+1}, {sig_type}: no outliers found")
        print(f"Class {class_idx+1}, {sig_type} outliers: {signal_outliers}")
        class_outliers.update(signal_outliers)
        plot_outliers_vs_normal(df, sig_type, signal_outliers, class_idx + 1)


    cleaned_df = df.drop(index=class_outliers, errors='ignore')
    print(f"Class {class_idx+1}: removed {len(class_outliers)} outliers, new shape: {cleaned_df.shape}")
    cleaned_dataframes.append(cleaned_df)
    all_outliers.append(class_outliers)


In [None]:
total_z = 0
for i in all_outliers:
    total_z += len(i)
print(f"Total outliers detected using z-score method: {total_z}")

for i, df in enumerate(cleaned_dataframes): 
    print(f"Class {i+1} cleaned shape: {df.shape}")