In [None]:
from pathlib import Path

from matplotlib.axes import Axes
from matplotlib.pyplot import close, subplot_mosaic, subplots
from more_itertools import last
from pandas import DataFrame, Series, read_csv
from scipy.fft import rfft, rfftfreq
from scipy.signal import filtfilt, iirnotch

# ! Reduce to "1" to keep all raw data, especially for time-domain sensitive data
WINDOW = 1  # Raw data will be initially averaged by this `WINDOW` width

# Turn off various features in the notebook during experiment runs to reduce compute
START_TIME = 0.0  # (s) Drop data before this time
AVERAGE_PLC = False  # Whether to get PLC averaged data
GET_FFT = True  # Whether to get FFTs for all other data
NOTCH = True  # Whether to apply notch filter

# Index and column names
INDEX = "Time (s)"
VOLTAGE = "Voltage (mV)"
VOLT_TO_MILLIVOLT = 1_000
FREQUENCY = "Frequency (Hz)"

# Data source
LATEST = True  # Use the latest data file
DATA = Path("data")
PROCESSED_DATA = DATA / "processed"
PROCESSED_DATA.mkdir(exist_ok=True)
SPECIFIC = DATA / "data_initial_2025-06-13T123902.csv"  # If not LATEST use this

# Power line frequency parameters for DC offset zeroing
POWER_LINE_FREQUENCY = 60.0  # (Hz) Frequency of the power line noise
NPLC_TO_AVERAGE_DC_OFFSET = 2.0

# Number of power line cycles to average
#   Keithley recommends >=1
#   0.005 is lowest, 8.3us
#   0.006 is 10kHz
NUMBER_OF_POWER_LINE_CYCLES_TO_AVERAGE = 2.0

# Frequencies to notch and the "widths" of the notch at that frequency
FREQUENCIES_TO_NOTCH = {
    POWER_LINE_FREQUENCY: 0.012,
    (2 * POWER_LINE_FREQUENCY): 0.006,
    (3 * POWER_LINE_FREQUENCY): 0.010,
}

# Plot parameters
COLORS = {"AIN0": "red", "AIN1": "green", "AIN2": "blue", "AIN3": "orange"}  # For plot
MAX_FFT_PLOT_FREQUENCY = 600  # Max frequency to show on FFT plots
MAX_ZOOMED_FFT_PLOT_FREQUENCY = 120  # Max frequency to show on zoomed FFT plots
FIGURE_HEIGHT = 16
FIGURE_WIDTH_PER_PLOT = 8

# Output data file prefixes and subplot names
RAW = "Raw Data"
PLC_AVERAGED = "PLC Averaged Data"
NOTCHED = "Notched Data"
FFT_RAW = "FFT of Raw Data"
FFT_OF_PLC_AVERAGED = "FFT of PLC Averaged Data"
FFT_NOTCHED = "FFT of Notched Data"
ZOOMED_FFT_RAW = "FFT of Raw Data (Zoomed)"
ZOOMED_FFT_OF_PLC_AVERAGED = "FFT of PLC Averaged Data (Zoomed)"
ZOOMED_FFT_NOTCHED = "FFT of Notched Data (Zoomed)"


def get_power_line_cycle_window(
    data: DataFrame, nplc: float, power_line_frequency: float
) -> int:
    return int(nplc / power_line_frequency / get_sampling_period(data))


def get_data(
    path: Path,
    nplc_to_average_dc_offset: float,
    power_line_frequency: float,
    window: int = 1,
) -> DataFrame:
    zero_time_nan_row = 1
    data = DataFrame(
        read_csv(path, index_col=INDEX, skiprows=[zero_time_nan_row])
        .dropna(how="all")
        .query(f"index > {START_TIME}")
    )
    initial_window = get_power_line_cycle_window(
        data, nplc_to_average_dc_offset, power_line_frequency
    )
    shifted_data = VOLT_TO_MILLIVOLT * (data - data.iloc[:initial_window, :].mean())
    return (
        shifted_data
        if window == 1
        else (shifted_data.rolling(window=window, step=window).mean().dropna(how="all"))
    )


def get_id(name: str):
    return name.replace("(", "").replace(")", "").replace(" ", "_").casefold()


def save_data(data: DataFrame, original_path: Path, prefix: str):
    data.to_csv(
        PROCESSED_DATA
        / original_path.with_stem(f"{get_id(prefix)}_{original_path.stem}").relative_to(
            DATA
        )
    )


def get_sampling_period(data: DataFrame) -> float:
    return Series(data.index).diff()[1]  # pyright: ignore[reportReturnType]


def apply_power_line_cycle_rolling_average(
    data: DataFrame, number_of_power_line_cycles: float, power_line_frequency: float
) -> DataFrame:
    # TODO: Simulate power line cycling actually reducing sampling rate by integration
    # TODO: Rolling/step and groupby approaches create strange tapering "windows"
    # TODO: "Pinching" in FFT is aliasing caused by rolling average
    window = get_power_line_cycle_window(
        data, number_of_power_line_cycles, power_line_frequency
    )
    return data.rolling(window=window).mean().dropna(how="all")


def apply_notch_filter(
    data: DataFrame, frequencies_to_notch: dict[float, float]
) -> DataFrame:
    name = data.name
    for frequency_to_notch, notch_width in frequencies_to_notch.items():
        b, a = iirnotch(
            w0=frequency_to_notch, Q=1 / notch_width, fs=1 / get_sampling_period(data)
        )
        data = data.apply(  # pyright: ignore[reportCallIssue]
            axis="index", func=lambda ser, a, b: filtfilt(a=a, b=b, x=ser), a=a, b=b
        )
    data.name = name
    return data


def get_axis(ax: Axes | None = None) -> Axes:
    if ax is None:
        _, ax = subplots()  # pyright: ignore[reportAssignmentType]
    return ax


def plot_data(data: DataFrame, ax: Axes | None = None):
    data.plot(
        ax=get_axis(ax), color=COLORS, title=str(data.name) or "Data", ylabel=VOLTAGE
    )


def get_fft(signals: DataFrame) -> DataFrame:
    fft = DataFrame(
        columns=signals.columns,
        index=Series(
            name=FREQUENCY,
            data=rfftfreq(signals.shape[0], d=get_sampling_period(signals)),
        ),
        data=rfft(signals.values, axis=0),
    ).abs()
    fft.name = f"FFT of {signals.name}"
    return fft


def plot_fft(
    fft: DataFrame, max_frequency: float = 0.0, ax: Axes | None = None, title: str = ""
):
    fft.plot(
        ax=get_axis(ax),
        color=COLORS,
        logy=True,
        title=title or str(fft.name) or "FFT of data",
        xlim=[0, min(max_frequency, fft.index[-1])],
        ylabel="Amplitude",
    )


# Get data path and prepare to plot
path = (
    last(
        sorted(
            [p for p in Path("data").iterdir() if p.is_file()],
            key=lambda p: p.stem[-17:],
        )
    )
    if LATEST
    else SPECIFIC
)
close("all")
axes: dict[str, Axes]
figure, axes = subplot_mosaic([
    [
        get_id(n)
        for n in ([
            RAW,
            *([PLC_AVERAGED] if AVERAGE_PLC else []),
            *([NOTCHED] if NOTCH else []),
        ])
    ],
    *(
        [
            [
                get_id(n)
                for n in [
                    FFT_RAW,
                    *([FFT_OF_PLC_AVERAGED] if AVERAGE_PLC else []),
                    *([FFT_NOTCHED] if NOTCH else []),
                ]
            ],
            [
                get_id(n)
                for n in [
                    ZOOMED_FFT_RAW,
                    *([ZOOMED_FFT_OF_PLC_AVERAGED] if AVERAGE_PLC else []),
                    *([ZOOMED_FFT_NOTCHED] if NOTCH else []),
                ]
            ],
        ]
        if GET_FFT
        else []
    ),
])  # pyright: ignore[reportAssignmentType]
num_plot_columns = 1 + sum([AVERAGE_PLC, NOTCH])
figure.set_figheight(FIGURE_HEIGHT)
figure.set_figwidth(FIGURE_WIDTH_PER_PLOT * num_plot_columns)

# Get raw data
raw_data = get_data(
    path,
    nplc_to_average_dc_offset=NPLC_TO_AVERAGE_DC_OFFSET,
    power_line_frequency=POWER_LINE_FREQUENCY,
)
raw_data.name = RAW
plot_data(data=raw_data, ax=axes[get_id(RAW)])

# Get FFTs
if GET_FFT:
    raw_fft = get_fft(raw_data)
    save_data(data=raw_fft, original_path=path, prefix=FFT_RAW)
    plot_fft(
        ax=axes[get_id(FFT_RAW)], fft=raw_fft, max_frequency=MAX_FFT_PLOT_FREQUENCY
    )
    plot_fft(
        ax=axes[get_id(ZOOMED_FFT_RAW)],
        fft=raw_fft,
        max_frequency=MAX_ZOOMED_FFT_PLOT_FREQUENCY,
        title=ZOOMED_FFT_RAW,
    )

    # Apply power line cycle averaging
    if AVERAGE_PLC:
        plc_averaged_data = apply_power_line_cycle_rolling_average(
            raw_data,
            number_of_power_line_cycles=NUMBER_OF_POWER_LINE_CYCLES_TO_AVERAGE,
            power_line_frequency=POWER_LINE_FREQUENCY,
        )
        plc_averaged_data.name = PLC_AVERAGED
        save_data(data=plc_averaged_data, original_path=path, prefix=PLC_AVERAGED)
        plot_data(data=plc_averaged_data, ax=axes[get_id(PLC_AVERAGED)])
        fft_plc_averaged_data = get_fft(plc_averaged_data)
        save_data(
            data=fft_plc_averaged_data, original_path=path, prefix=FFT_OF_PLC_AVERAGED
        )
        plot_fft(
            ax=axes[get_id(FFT_OF_PLC_AVERAGED)],
            fft=fft_plc_averaged_data,
            max_frequency=MAX_FFT_PLOT_FREQUENCY,
        )
        plot_fft(
            ax=axes[get_id(ZOOMED_FFT_OF_PLC_AVERAGED)],
            fft=fft_plc_averaged_data,
            max_frequency=MAX_ZOOMED_FFT_PLOT_FREQUENCY,
            title=ZOOMED_FFT_OF_PLC_AVERAGED,
        )

    # Notch the data
    if NOTCH:
        notched_data = apply_notch_filter(
            data=raw_data, frequencies_to_notch=FREQUENCIES_TO_NOTCH
        )
        notched_data.name = NOTCHED
        save_data(data=notched_data, original_path=path, prefix=NOTCHED)
        plot_data(ax=axes[get_id(NOTCHED)], data=notched_data)

        # Get FFT of notched data
        notched_fft = get_fft(notched_data)
        save_data(data=notched_fft, original_path=path, prefix=FFT_NOTCHED)
        plot_fft(
            ax=axes[get_id(FFT_NOTCHED)],
            fft=notched_fft,
            max_frequency=MAX_FFT_PLOT_FREQUENCY,
        )
        plot_fft(
            ax=axes[get_id(ZOOMED_FFT_NOTCHED)],
            fft=notched_fft,
            max_frequency=MAX_ZOOMED_FFT_PLOT_FREQUENCY,
            title=ZOOMED_FFT_NOTCHED,
        )