# Step 1: Time Gridding and Optional Filtering Demo

This notebook demonstrates the Step 1 processing workflow for mooring data:
- Loading multiple instrument datasets
- Optional time-domain filtering (applied BEFORE interpolation)
- Interpolating onto a common time grid
- Combining into a unified mooring dataset

**Key Point**: Filtering is applied to individual instrument records on their native time grids BEFORE interpolation to preserve data integrity.

Version: 1.0  
Date: 2025-09-07

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import yaml

# Import the time gridding module
from oceanarray.time_gridding import (
    TimeGriddingProcessor,
    time_gridding_mooring,
    process_multiple_moorings_time_gridding
)

# Set up plotting
plt.style.use('default')
sns.set_palette("husl")


## 1. Data Setup and Configuration

First, let's set up our data paths and examine the mooring configuration.

In [None]:
# Set your data paths here
basedir = '/Users/eddifying/Dropbox/data/ifmro_mixsed/ds_data_eleanor/'
mooring_name = 'dsE_1_2018'

# Construct paths
proc_dir = Path(basedir) / 'moor' / 'proc' / mooring_name
config_file = proc_dir / f"{mooring_name}.mooring.yaml"

print(f"Processing directory: {proc_dir}")
print(f"Configuration file: {config_file}")
print(f"Config exists: {config_file.exists()}")

In [None]:
# Load and examine the mooring configuration
if config_file.exists():
    with open(config_file, 'r') as f:
        config = yaml.safe_load(f)

    print("Mooring Configuration:")
    print(f"Name: {config['name']}")
    print(f"Water depth: {config.get('waterdepth', 'unknown')} m")
    print(f"Location: {config.get('latitude', 'unknown')}°N, {config.get('longitude', 'unknown')}°E")
    print(f"\nInstruments ({len(config.get('instruments', []))}):")

    for i, inst in enumerate(config.get('instruments', [])):
        print(f"  {i+1}. {inst.get('instrument', 'unknown')} "
              f"(serial: {inst.get('serial num.', 'unknown')}) at {inst.get('depth', 'unknown')} m")
else:
    print("Configuration file not found!")
    print("Please check your data path and mooring name.")

## 2. Examine Individual Instrument Files

Let's look at the individual instrument files before processing to understand the different sampling rates and data characteristics.

In [None]:
# Find and examine individual instrument files
file_suffix = "_use"
instrument_files = []
instrument_datasets = []
rows = []

if config_file.exists():
    for inst_config in config.get("instruments", []):
        instrument_type = inst_config.get("instrument", "unknown")
        serial = inst_config.get("serial", 0)
        depth = inst_config.get("depth", 0)

        # Look for the file
        filename = f"{mooring_name}_{serial}{file_suffix}.nc"
        filepath = proc_dir / instrument_type / filename

        if filepath.exists():
            ds = xr.open_dataset(filepath)
            instrument_files.append(filepath)
            instrument_datasets.append(ds)

            # Time coverage
            t0, t1 = ds.time.values[0], ds.time.values[-1]
            npoints = len(ds.time)

            # Median sampling interval
            time_diff = np.diff(ds.time.values) / np.timedelta64(1, "m")  # in minutes
            median_interval = np.nanmedian(time_diff)
            if median_interval > 1:
                sampling = f"{median_interval:.1f} min"
            else:
                sampling = f"{median_interval*60:.1f} sec"

            # Collect a row for the table
            rows.append(
                {
                    "Instrument": instrument_type,
                    "Serial": serial,
                    "Depth [m]": depth,
                    "File": filepath.name,
                    "Start": str(t0)[:19],
                    "End": str(t1)[:19],
                    "Points": npoints,
                    "Sampling": sampling,
                    "Variables": ", ".join(list(ds.data_vars)),
                }
            )
        else:
            rows.append(
                {
                    "Instrument": instrument_type,
                    "Serial": serial,
                    "Depth [m]": depth,
                    "File": "MISSING",
                    "Start": "",
                    "End": "",
                    "Points": 0,
                    "Sampling": "",
                    "Variables": "",
                }
            )

    # Make a DataFrame summary
    summary = pd.DataFrame(rows)
    pd.set_option("display.max_colwidth", 80)  # allow long var lists
    print(summary.to_markdown(index=False))

    print(f"\nFound {len(instrument_datasets)} instrument datasets")


## 4. Process with Time Gridding (No Filtering)

First, let's process the mooring without any filtering to see the basic time gridding functionality.

In [None]:
# Process without filtering
print("Processing mooring with time gridding only (no filtering)...")
print("="*60)

result = time_gridding_mooring(mooring_name, basedir, file_suffix='_use')

print(f"\nProcessing result: {'SUCCESS' if result else 'FAILED'}")

In [None]:
# Load and examine the combined dataset
output_file = proc_dir / f"{mooring_name}_mooring_use.nc"

if output_file.exists():
    print(f"Output file exists: {output_file}")

    # Load the combined dataset
    combined_ds = xr.open_dataset(output_file)
else:
    print("Output file not found - processing may have failed")

## 5. Visualize Combined Dataset

Let's plot the combined dataset to see how the different instruments look on the common time grid.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def plot_combined_timeseries(
    combined_ds,
    variables=("temperature", "salinity", "pressure"),
    cmap_name="viridis",
    line_alpha=0.8,
    line_width=1.2,
    percentile_limits=(1, 99),
):
    """
    Plot selected variables from a combined mooring dataset as stacked time series.

    Parameters
    ----------
    combined_ds : xarray.Dataset
        Must have dims: time, N_LEVELS. Optional coords: nominal_depth, serial_number.
    variables : iterable[str]
        Variable names to try to plot (if present in dataset).
    cmap_name : str
        Matplotlib colormap name for coloring by instrument level.
    line_alpha : float
        Line transparency.
    line_width : float
        Line width.
    percentile_limits : (low, high)
        Percentiles to use for automatic y-limits (e.g., (1, 99)).
    """
    if combined_ds is None:
        print("Combined dataset not available.")
        return None, None
    n_levels = combined_ds.sizes.get("N_LEVELS")
    if n_levels is None:
        raise ValueError("Dataset must contain dimension 'N_LEVELS'.")

    available = [v for v in variables if v in combined_ds.data_vars]
    if not available:
        print("No requested variables found to plot.")
        return None, None

    # Colors by level
    cmap = plt.get_cmap(cmap_name)
    colors = cmap(np.linspace(0, 1, n_levels))

    fig, axes = plt.subplots(
        len(available), 1, figsize=(14, 3.6 * len(available)), sharex=True, constrained_layout=True
    )
    if len(available) == 1:
        axes = [axes]

    depth_arr = combined_ds.get("nominal_depth")
    serial_arr = combined_ds.get("serial_number")

    first_axis = True
    for ax, var in zip(axes, available):
        values_for_limits = []
        for level in range(n_levels):
            depth = None if depth_arr is None else depth_arr.values[level]
            serial = None if serial_arr is None else serial_arr.values[level]
            label = None
            if first_axis:
                if depth is not None and np.isfinite(depth):
                    label = f"Serial {serial} ({int(depth)} m)" if serial is not None else f"({int(depth)} m)"
                elif serial is not None:
                    label = f"Serial {serial}"

            da = combined_ds[var].isel(N_LEVELS=level)
            da = da.where(np.isfinite(da), drop=True)
            if da.size == 0:
                continue

            values_for_limits.append(da.values)

            ax.plot(
                da["time"].values,
                da.values,
                color=colors[level],
                alpha=line_alpha,
                linewidth=line_width,
                label=label,
            )

        # Set labels and grid
        ax.set_ylabel(var.replace("_", " ").title())
        ax.grid(True, alpha=0.3)
        ax.set_title(f"{var.replace('_', ' ').title()} — Combined Time Grid")

        # Legend only once
        if first_axis:
            ax.legend(ncol=3, fontsize=8, loc="upper right", frameon=False)
            first_axis = False

        # Auto y-limits based on percentiles
        if values_for_limits:
            flat = np.concatenate(values_for_limits)
            low, high = np.nanpercentile(flat, percentile_limits)
            ax.set_ylim(low, high)

    axes[-1].set_xlabel("Time")
    return fig, axes

# Usage:
if 'combined_ds' in locals():
    plot_combined_timeseries(combined_ds)


## 6. Process with Low-pass Filtering (RAPID-style)

Now let's apply the RAPID-style 2-day low-pass filter to remove tidal and inertial variability. Remember: **filtering is applied to individual instruments BEFORE interpolation**.

In [None]:
# Process with RAPID-style low-pass filtering
print("Processing mooring with 2-day low-pass filtering (RAPID-style)...")
print("="*60)
print("IMPORTANT: Filtering is applied to each instrument on its native time grid")
print("BEFORE interpolation to preserve data integrity.")
print()

filter_params = {
    'cutoff_days': 2.0,  # 2-day cutoff
    'order': 6           # 6th order Butterworth
}

result_filtered = time_gridding_mooring(
    mooring_name, basedir,
    file_suffix='_use',
    filter_type='lowpass',
    filter_params=filter_params
)

print(f"\nFiltered processing result: {'SUCCESS' if result_filtered else 'FAILED'}")

In [None]:
# Load the filtered dataset
filtered_output_file = proc_dir / f"{mooring_name}_mooring_use_lowpass.nc"

if filtered_output_file.exists():
    print(f"Filtered output file created: {filtered_output_file}")

    # Load the filtered dataset
    filtered_ds = xr.open_dataset(filtered_output_file)

    print("\nFiltered Dataset Attributes:")
    filter_attrs = {k: v for k, v in filtered_ds.attrs.items()
                   if 'filter' in k.lower()}
    for key, value in filter_attrs.items():
        print(f"  {key}: {value}")

    print(f"\nDataset shape: {dict(filtered_ds.dims)}")
else:
    print("Filtered output file not found")
    filtered_ds = None

## 7. Compare Filtered vs Unfiltered Data

Let's compare the original and filtered data to see the effect of the low-pass filter.

In [None]:
if 'combined_ds' in locals() and filtered_ds is not None:
    # Compare filtered vs unfiltered for a subset of data
    # Select a 10-day window for detailed comparison
    start_time = combined_ds.time.values[len(combined_ds.time)//4]  # Start 1/4 through
    end_time = start_time + np.timedelta64(10, 'D')  # 10-day window

    # Select subset
    subset_orig = combined_ds.sel(time=slice(start_time, end_time))
    subset_filt = filtered_ds.sel(time=slice(start_time, end_time))

    # Plot comparison for temperature
    if 'temperature' in subset_orig.data_vars:
        fig, axes = plt.subplots(2, 1, figsize=(14, 8), sharex=True)

        # Choose a representative level (first one with data)
        level = 0
        depth = subset_orig.nominal_depth.values[level]
        serial = subset_orig.serial_number.values[level]

        # Original data
        orig_temp = subset_orig.temperature.isel(N_LEVELS=level)
        axes[0].plot(orig_temp.time, orig_temp, 'b-', alpha=0.7, linewidth=1,
                    label='Original')
        axes[0].set_ylabel('Temperature (°C)')
        axes[0].set_title(f'Original Data - Serial {serial} at {depth}m')
        axes[0].grid(True, alpha=0.3)
        axes[0].legend()

        # Filtered data
        filt_temp = subset_filt.temperature.isel(N_LEVELS=level)
        axes[1].plot(filt_temp.time, filt_temp, 'r-', alpha=0.7, linewidth=1.5,
                    label='2-day Low-pass Filtered')
        axes[1].set_ylabel('Temperature (°C)')
        axes[1].set_xlabel('Time')
        axes[1].set_title(f'Filtered Data - Serial {serial} at {depth}m')
        axes[1].grid(True, alpha=0.3)
        axes[1].legend()

        plt.tight_layout()
        plt.show()

        # Overlay comparison
        fig, ax = plt.subplots(1, 1, figsize=(14, 6))
        ax.plot(orig_temp.time, orig_temp, 'b-', alpha=0.5, linewidth=0.8,
                label='Original')
        ax.plot(filt_temp.time, filt_temp, 'r-', alpha=0.8, linewidth=2,
                label='2-day Low-pass Filtered')
        ax.set_ylabel('Temperature (°C)')
        ax.set_xlabel('Time')
        ax.set_title(f'Filtering Comparison - Serial {serial} at {depth}m')
        ax.legend()
        ax.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()
else:
    print("Cannot compare - one or both datasets not available")

## 8. Spectral Analysis: Effect of Filtering

Let's examine the spectral characteristics to see how the filter affects different frequency components.

In [None]:
if 'combined_ds' in locals() and filtered_ds is not None:
    from scipy import signal

    # Select a level with good data coverage
    level = 0

    # Get temperature data
    if 'temperature' in combined_ds.data_vars:
        orig_temp = combined_ds.temperature.isel(N_LEVELS=level).dropna('time')
        filt_temp = filtered_ds.temperature.isel(N_LEVELS=level).dropna('time')

        if len(orig_temp) > 100:  # Ensure sufficient data
            # Calculate sampling rate
            dt_hours = float(np.median(np.diff(orig_temp.time.values)) / np.timedelta64(1, 'h'))
            fs = 1.0 / dt_hours  # samples per hour

            # Compute power spectral density
            f_orig, psd_orig = signal.welch(orig_temp.values, fs=fs, nperseg=min(256, len(orig_temp)//4))
            f_filt, psd_filt = signal.welch(filt_temp.values, fs=fs, nperseg=min(256, len(filt_temp)//4))

            # Convert frequency to period in days
            period_orig = 1.0 / (f_orig * 24)  # days
            period_filt = 1.0 / (f_filt * 24)  # days

            # Plot power spectral density
            fig, ax = plt.subplots(1, 1, figsize=(12, 6))

            ax.loglog(period_orig[1:], psd_orig[1:], 'b-', alpha=0.7,
                     label='Original', linewidth=1.5)
            ax.loglog(period_filt[1:], psd_filt[1:], 'r-', alpha=0.8,
                     label='2-day Low-pass Filtered', linewidth=2)

            # Mark important periods
            ax.axvline(2.0, color='gray', linestyle='--', alpha=0.7,
                      label='2-day cutoff')
            ax.axvline(1.0, color='gray', linestyle=':', alpha=0.7,
                      label='1-day (diurnal)')
            ax.axvline(0.5, color='gray', linestyle=':', alpha=0.7,
                      label='12-hour (semidiurnal)')

            ax.set_xlabel('Period (days)')
            ax.set_ylabel('Power Spectral Density')
            ax.set_title('Spectral Analysis: Effect of 2-day Low-pass Filter')
            ax.legend()
            ax.grid(True, alpha=0.3)
            ax.set_xlim(0.1, 100)

            plt.tight_layout()
            plt.show()

            print(f"Spectral analysis completed for level {level}")
            print(f"Sampling rate: {dt_hours:.2f} hours")
            print(f"Data length: {len(orig_temp)} points")
        else:
            print("Insufficient data for spectral analysis")
    else:
        print("Temperature data not available for spectral analysis")
else:
    print("Cannot perform spectral analysis - datasets not available")

## 10. Multiple Mooring Processing Example

The time gridding module also supports batch processing of multiple moorings.

In [None]:
# Base directory containing the mooring data
basedir = '/Users/eddifying/Dropbox/data/ifmro_mixsed/ds_data_eleanor/'

# Example of processing multiple moorings
# (This will only work if you have multiple moorings in your dataset)

# List available moorings
moor_base = Path(basedir) / 'moor' / 'proc'
available_moorings = [d.name for d in moor_base.iterdir() if d.is_dir()]

print(f"Available moorings in {moor_base}:")
for mooring in available_moorings[:5]:  # Show first 5
    print(f"  - {mooring}")

if len(available_moorings) > 5:
    print(f"  ... and {len(available_moorings)-5} more")

# Example batch processing (commented out to avoid running on all moorings)
moorings_to_process = ['dsE_1_2018']  # Add your mooring names

results = process_multiple_moorings_time_gridding(
    moorings_to_process,
    basedir,
)

print("Batch processing results:")
for mooring, success in results.items():
    status = "SUCCESS" if success else "FAILED"
    print(f"  {mooring}: {status}")