# MISA Extract Transform & Engineer (ETE) Pipeline

This notebook implements a full end-to-end workflow for building a binned‐regression ionospheric background model from Millstone Hill Incoherent Scatter Radar (MISA) data:

1. **Data Extraction & Loading**  
   - Reads raw MISA HDF5 scan groups (timestamps, range, ne/ti/tr, etc.)  
   - Ingests and resamples ancillary geophysical indices (Hp, Ap, Kp, Dst, SME, FISM2) into a unified 30 min timeline

2. **Cleaning & Filtering**  
   - Masks out eclipse/extreme days and bad‐quality radar returns  
   - Removes outliers via relative‐error, standard‐deviation, and percentile‐based filters  
   - Applies NaN masking to bad points

3. **Transformation & Feature Engineering**  
   - Flattens the (time × range) grid to a tabular DataFrame  
   - Computes and reindexes solar-local time (SLT), geomagnetic lat/lon, and time-lagged indices  
   - Generates cyclical features (sin/cos of UT, SLT, DOY), powers/interactions, and log–transformations

4. **Binned Regression Training**  
   - Defines equal‐sized bins in azimuth and altitude  
   - Within each bin, scales inputs, fits a degree-4 polynomial Ridge regression for the target (ne, ti, te)  
   - Stores per-bin models (scaler, transformer, regressor) for fast lookup
   
Use this notebook to reproduce the full model products ionospheric background modeling used in MISA_pySLIME_query.

## 0. Loading Libraries and Defining Functions

In [67]:
# importing libraries
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import pandas as pd
import xarray as xr
import numpy as np
import netCDF4
import h5py
import dask

# dirs
group_name = 'Array_with_kinst=31.0_and_mdtyp=115.0_and_pl=0.002' # MISA instrument data group name
misa_dir = '../ancillary/data_products/MadrigalLib' #data product dir
# geophysical ancillary data dirs
hp_ap_dir = '../ancillary/geophysical_ancillary/Hp30_ap30_complete_series.txt'
kp_f107_dir = '../ancillary/geophysical_ancillary/Kp_ap_Ap_SN_F107_since_1932.txt'
fism2_dir = '../ancillary/geophysical_ancillary/fism2/60s/'
ae_dir = '../ancillary/geophysical_ancillary/supermag_ae.csv'
dst_dir = '../ancillary/geophysical_ancillary/dst2000to2024.txt'

In [68]:
# defining functions
def std_column_filter_indices(a, nbstd):
    """
    Identify outlier entries in each column of a 2D array.

    Marks any point more than `nbstd` standard deviations from the column mean.

    Parameters
    ----------
    a : ndarray, shape (n_rows, n_cols)
        Data matrix where rows are observations and columns are variables.
    nbstd : float
        Number of standard deviations beyond which a point is considered bad.

    Returns
    -------
    bad_indices : list of (row, col) tuples
        Coordinates of all detected outlier points.
    """
    bad_indices = []
    for col in range(a.shape[1]):
        col_mean = np.nanmean(a[:, col])
        col_std  = np.nanstd(a[:, col])
        # rows that exceed the threshold in either direction
        mask = ((a[:, col] > col_mean + nbstd*col_std) |
                (a[:, col] < col_mean - nbstd*col_std))
        bad_rows = np.where(mask)[0]
        bad_indices.extend([(row, col) for row in bad_rows])
    return bad_indices


def apply_interpolation(data, bad_indices):
    """
    Linearly interpolate over flagged bad entries.

    For each column with bad indices, sets those to NaN, then fills
    them by 1D linear interpolation of the remaining valid data.

    Parameters
    ----------
    data : ndarray, shape (n_rows, n_cols)
        Original data array (will be modified in-place).
    bad_indices : list of (row, col) tuples
        Points to replace via interpolation.

    Returns
    -------
    data : ndarray
        Array with bad points replaced by linear interpolation.
    """
    for col in np.unique([c for _, c in bad_indices]):
        # mask out the bad points
        col_data = data[:, col].copy()
        bad_rows = [r for r, c in bad_indices if c == col]
        col_data[bad_rows] = np.nan
        # only interpolate if there's at least two valid points
        valid = ~np.isnan(col_data)
        if valid.sum() > 1:
            col_data[bad_rows] = np.interp(
                bad_rows, np.where(valid)[0], col_data[valid]
            )
        data[:, col] = col_data
    return data


def apply_nonlinear_interpolation(data, bad_indices, kind='cubic'):
    """
    Replace bad entries via higher-order (cubic, quadratic, etc.) interpolation.

    Parameters
    ----------
    data : ndarray, shape (n_rows, n_cols)
        Original data array.
    bad_indices : list of (row, col) tuples
        Locations to fill.
    kind : str, default 'cubic'
        Interpolation method passed to scipy.interpolate.interp1d.

    Returns
    -------
    data : ndarray
        Array with bad points estimated by nonlinear interpolation.
    """
    for col in np.unique([c for _, c in bad_indices]):
        col_data = data[:, col]
        valid_idx   = np.where(~np.isnan(col_data))[0]
        invalid_idx = np.where(np.isnan(col_data))[0]
        # only proceed if enough valid points exist
        if valid_idx.size > 1:
            interp_fn = interp1d(
                valid_idx, col_data[valid_idx],
                kind=kind, fill_value='extrapolate'
            )
            col_data[invalid_idx] = interp_fn(invalid_idx)
        data[:, col] = col_data
    return data


def apply_nans(data, bad_indices):
    """
    Mask out bad data points by setting them to NaN.

    Parameters
    ----------
    data : ndarray, shape (n_rows, n_cols)
        Original data array.
    bad_indices : list of (row, col) tuples
        Points to invalidate.

    Returns
    -------
    data : ndarray
        Array with specified entries set to NaN.
    """
    for row, col in bad_indices:
        data[row, col] = np.nan
    return data


def load_datasets_with_timestamp_and_range(
    file_path, target_kinst, target_mdtyp, target_pl
):
    """
    Load an HDF5 group matching given instrument/type/parameter, returning an
    xarray.Dataset indexed by real timestamps and range.

    Parameters
    ----------
    file_path : str
        Path to the .hdf5 file.
    target_kinst : float
        Instrument code to match in group key.
    target_mdtyp : float
        Measurement type code.
    target_pl : float
        Physical parameter code.

    Returns
    -------
    ds : xarray.Dataset
        Contains variables at dims ('dates', 'range'), plus coords:
        - 'doy' (day-of-year)
        - 'dates' (datetime64 index).
    """
    with h5py.File(file_path, 'r') as f:
        # locate the matching group by parsing its key
        group_name = None
        for key in f.keys():
            if not key.startswith('Array_with_'): continue
            attrs = dict(p.split('=') for p in key.replace('Array_with_', '').split('_and_'))
            if (float(attrs.get('kinst', -1)) == target_kinst and
                float(attrs.get('mdtyp', -1)) == target_mdtyp and
                float(attrs.get('pl',    -1)) == target_pl):
                group_name = key
                break
        if group_name is None:
            print(f'No matching group for kinst={target_kinst}, mdtyp={target_mdtyp}, pl={target_pl}')
            return np.nan

        grp = f[group_name]
        timestamps = grp['timestamps'][()]
        ranges     = grp['range'][()]

        data_vars = {
            'range':     ('range',     ranges),
            'doy':       None  # will add after coords
        }
        # load each dataset, assigning dims based on shape
        for name, obj in grp.items():
            if name in ('timestamps', 'range'): continue
            arr = obj[()]
            if arr.ndim == 1 and arr.size == timestamps.size:
                data_vars[name] = ('timestamp', arr)
            elif arr.ndim == 2 and arr.shape == (timestamps.size, ranges.size):
                data_vars[name] = (('timestamp','range'), arr)

        # build xarray Dataset and swap to datetime index
        ds = xr.Dataset(data_vars)
        dates = pd.to_datetime(timestamps, unit='s')
        ds = ds.assign_coords(dates=('timestamp', dates)).swap_dims({'timestamp':'dates'})
        # add day-of-year coordinate for feature engineering
        ds['doy'] = ('dates', ds['dates'].dt.dayofyear.values)
    return ds

import os
def concatenate_hdf5_files_in_directory(
    directory_path, target_kinst, target_mdtyp, target_pl
):
    """
    Selects the highest‐revision HDF5 per day for given parameters, then
    concatenates them into one Dataset along the 'dates' dimension.

    Parameters
    ----------
    directory_path : str
        Folder containing .hdf5 files.
    target_kinst, target_mdtyp, target_pl : floats
        Matching criteria for groups within each file.

    Returns
    -------
    concatenated : xarray.Dataset
        Combined time series of all selected files.
    """
    # pick only the latest revision for each base filename
    best = {}
    for fn in os.listdir(directory_path):
        if not fn.endswith('.hdf5'): continue
        parts = fn.split('.')
        base, rev = ('.'.join(parts[:-2]), int(parts[-2])) if parts[-2].isdigit() else (fn, -1)
        if base not in best or rev > best[base][0]:
            best[base] = (rev, fn)

    datasets = []
    for _, fn in sorted(best.values(), key=lambda x: x[1]):
        path = os.path.join(directory_path, fn)
        ds = load_datasets_with_timestamp_and_range(path, target_kinst, target_mdtyp, target_pl)
        if isinstance(ds, xr.Dataset):
            datasets.append(ds)
    if not datasets:
        raise RuntimeError("No valid datasets to concatenate.")
    # merge all along the time axis
    concatenated = xr.concat(datasets, dim='dates')
    return concatenated

def create_geophysical_index_xr(
    hp_ap_dir, kp_f107_dir, dst_dir, supermag_ae_csv, fism2_1min_dir
):
    """
    Reads and merges multiple geophysical indices onto a uniform 30-minute timeline:
      - Hp30 & ap30 from MJD-based ascii files
      - Adjusted F10.7 flux and expanded 3-h Kp values
      - Dst index from hourly ascii
      - SME from SuperMAG AE CSV (1-min → 30-min max)
      - FISM2 irradiance (1-min netCDF → 30-min max)
    Returns:
        xr.Dataset with coords 'dates' (30-min) and data_vars:
          ut, hp30, ap30, f107, kp, dst, sme, fism2
    """
    # 1) LOAD Hp30 & ap30
    hp_ap_df = pd.read_csv(
        hp_ap_dir,
        delim_whitespace=True,
        skiprows=30,   # header offset in ascii file
        header=None,
        names=['YYYY','MM','DD','hh.h','hh._m','days','d_m','Hp30','ap30','D']
    )
    # build datetime index from separate columns
    hp_ap_df['date'] = pd.to_datetime({
        'year':  hp_ap_df['YYYY'],
        'month': hp_ap_df['MM'],
        'day':   hp_ap_df['DD'],
        'hour':  hp_ap_df['hh.h']
    })
    hp_series = hp_ap_df.set_index('date')['Hp30']
    ap_series = hp_ap_df.set_index('date')['ap30']

    # 2) LOAD Kp & adjusted F10.7
    f10_kp_df = pd.read_csv(
        kp_f107_dir,
        delim_whitespace=True,
        skiprows=40,
        header=None,
        names=[
            'YYYY','MM','DD','days','days_m','BSR','dB',
            'Kp1','Kp2','Kp3','Kp4','Kp5','Kp6','Kp7','Kp8',
            'ap1','ap2','ap3','ap4','ap5','ap6','ap7','ap8',
            'Ap','SN','F10.7obs','F10.7adj','D'
        ]
    )
    # replace sentinel -1 with NaN before processing
    f10_kp_df.replace(-1, np.nan, inplace=True)
    # create a daily datetime index for F10.7
    f10_kp_df['date'] = pd.to_datetime({
        'year':  f10_kp_df['YYYY'],
        'month': f10_kp_df['MM'],
        'day':   f10_kp_df['DD']
    })
    f107_series = pd.Series(f10_kp_df['F10.7adj'].values,
                             index=f10_kp_df['date'])

    # expand 3-hourly Kp into 8 snapshots per day
    kp_intervals = pd.to_timedelta([0,3,6,9,12,15,18,21], unit='h')
    expanded = f10_kp_df.loc[f10_kp_df.index.repeat(8)].copy()
    # shift each repeat by the corresponding 3-h offset
    expanded['date'] += np.tile(kp_intervals, len(f10_kp_df))
    # flatten Kp1…Kp8 into a single series
    kp_cols = [f'Kp{i}' for i in range(1,9)]
    expanded['Kp'] = np.concatenate(
        [f10_kp_df[col].values for col in kp_cols]
    )
    kp_series = pd.Series(expanded['Kp'].values, index=expanded['date'])

    # 3) LOAD Dst index
    dst_df = pd.read_csv(dst_dir, delim_whitespace=True, comment='#')
    dst_df['datetime'] = pd.to_datetime(dst_df['DATE'] + ' ' + dst_df['TIME'])
    dst_series = pd.Series(dst_df['DST'].values, index=dst_df['datetime'])

    # 4) LOAD SME (1-min → 30-min max)
    sme_df = pd.read_csv(supermag_ae_csv, parse_dates=['Date_UTC'])
    sme_df.set_index('Date_UTC', inplace=True)
    sme_1m   = sme_df['SME']
    sme_30m  = sme_1m.resample('30T').max()

    # 5) LOAD & PROCESS FISM2 irradiance
    def fix_fism2_dataset(ds):
        # Drop old 'date' dim by converting YYYYDOY -> timestamp
        if 'date' in ds.coords and ds.dims.get('date',0)==1:
            raw = str(ds['date'].item()).strip()
            year, doy = int(raw[:4]), int(raw[4:])
            base = pd.Timestamp(year,1,1) + pd.Timedelta(days=doy-1)
            ds = ds.drop_vars('date')
        else:
            base = None
        # rename & convert 'utc' seconds to real time
        if 'utc' in ds.coords:
            ds = ds.rename(utc='time')
            if base is not None:
                ds['time'] = base + pd.to_timedelta(ds['time'], unit='s')
            ds = ds.assign_coords(time=ds['time'])
        # remove 'jd' index if present
        if 'jd' in ds.dims:    ds = ds.drop_dims('jd')
        if 'jd' in ds.coords:  ds = ds.drop_vars('jd')
        return ds

    files = sorted(Path(fism2_1min_dir).glob("FISM_60sec_*.nc"))
    if not files:
        raise FileNotFoundError(f"No FISM2 files in {fism2_1min_dir}")

    ds_fism2_1m = xr.open_mfdataset(
        files, combine='by_coords', preprocess=fix_fism2_dataset
    )
    # pick the 30.4 nm channel, then resample to 30-min max irradiance
    ds_fism2_1m = ds_fism2_1m.sel(wavelength=30.4, method='nearest')
    fism2_df_1m = ds_fism2_1m['irradiance'].to_dataframe()
    fism2_30m   = fism2_df_1m.resample('30T').max()

    # 6) BUILD THE COMMON 30-MIN TIMELINE
    all_starts = [
        hp_series.index.min(), ap_series.index.min(), f107_series.index.min(),
        kp_series.index.min(), dst_series.index.min(), sme_30m.index.min(),
        fism2_30m.index.min()
    ]
    all_ends   = [
        hp_series.index.max(), ap_series.index.max(), f107_series.index.max(),
        kp_series.index.max(), dst_series.index.max(), sme_30m.index.max(),
        fism2_30m.index.max()
    ]
    dates_full = pd.date_range(min(all_starts), max(all_ends), freq='30T')

    # 7) FORWARD-FILL each series onto the common timeline
    hp30_full  = hp_series.reindex(dates_full, method='ffill')
    ap30_full  = ap_series.reindex(dates_full, method='ffill')
    f107_full  = f107_series.reindex(dates_full, method='ffill')
    kp_full    = kp_series.reindex(dates_full, method='ffill')
    dst_full   = dst_series.reindex(dates_full, method='ffill')
    sme_full   = sme_30m.reindex(dates_full, method='ffill')
    fism2_full = fism2_30m['irradiance'].reindex(dates_full, method='ffill')

    # 8) COMPUTE solar local time (SLT) if needed, here we only output UT
    ut_full = [t.hour + t.minute/60 + t.second/3600 for t in dates_full]

    # 9) ASSEMBLE INTO AN xarray.Dataset
    ds_out = xr.Dataset({
        'ut':   ('dates', ut_full),
        'hp30': ('dates', hp30_full),
        'ap30': ('dates', ap30_full),
        'f107': ('dates', f107_full),
        'kp':   ('dates', kp_full),
        'dst':  ('dates', dst_full),
        'sme':  ('dates', sme_full),
        'fism2':('dates', fism2_full),
    }, coords={'dates': dates_full})

    return ds_out

## 1. Data Extraction & Loading

In [None]:
### LOAD & PROCESS MISA DATA ### (if already processed previously, skip this step)
# 1) Load & concatenate all relevant MISA HDF5 files into one xarray.Dataset
concatenated_data_df = concatenate_hdf5_files_in_directory(
    misa_dir,
    target_kinst=31.0,    # instrument code for MISA
    target_mdtyp=115.0,   # measurement type code for electron density etc.
    target_pl=0.002       # physical parameter code
)

# 2) Build boolean masks to remove data from unwanted scan regions:
#    • az_mask: restrict azimuth between –185° and +5° (densest + eclipse region)
#    • el_mask: only include beams at 6° elevation (both el1 and el2)
#    • gdalt_mask: keep ground‐to‐scatter distance between 0 and 1100 km
az_mask    = (concatenated_data_df.az1.data < 5) & (concatenated_data_df.az1.data >= -185)
el_mask    = (concatenated_data_df.el1.data == 6) & (concatenated_data_df.el2.data == 6)
gdalt_mask = (concatenated_data_df.gdalt.data > 0) & (concatenated_data_df.gdalt.data < 1100)

# 3) Copy the concatenated dataset so we can mask it in place
MISA_ds = concatenated_data_df.copy()

# 4) Apply each mask to every data variable (skip temporal coords 'doy' & 'slt'):
for mask in (el_mask, gdalt_mask, az_mask):
    for var in MISA_ds.data_vars:
        if var in ('doy', 'slt'):
            continue  # preserve DOY and SLT for later feature alignment
        data = MISA_ds[var]
        # Match mask dimensionality to the variable before masking:
        if data.ndim == mask.ndim:
            # same shape (e.g. 1D → 1D or 2D → 2D)
            MISA_ds[var] = data.where(mask)
        elif data.ndim < mask.ndim:
            # variable is lower‐dimensional (e.g. 1D var masked by 2D mask)
            MISA_ds[var] = data.where(mask.any(axis=1))
        else:
            # variable is higher‐dimensional (e.g. 2D var masked by 1D mask)
            MISA_ds[var] = data.where(mask[:, np.newaxis])

# 5) Sort by timestamp and compute decimal‐hour 'ut'
MISA_ds = MISA_ds.sortby('dates')
MISA_ds['ut'] = (
    MISA_ds['dates'].dt.hour +
    MISA_ds['dates'].dt.minute / 60 +
    MISA_ds['dates'].dt.second / 3600
)
MISA_ds['az_normalized'] = (['dates'],((MISA_ds.az1 + MISA_ds.az2) / 2).values)  # mean beam azimuth (n_times,)

# 6) Load master geophysical indices (Hp, Ap, Kp, Dst, SME, FISM2) on a 30-min grid
master_geo_ds = create_geophysical_index_xr(
    hp_ap_dir, kp_f107_dir, dst_dir, ae_dir, fism2_dir
)

# 7) Align (ffill) geophysical indices to radar timestamps & append to radar data
master_geo_ds_reindexed = master_geo_ds.reindex(
    dates=MISA_ds['dates'], method='ffill'
)
for var in master_geo_ds_reindexed.data_vars:
    MISA_ds[var] = master_geo_ds_reindexed[var]

# 8) Generate time‐lagged indices for feature engineering:
#    - AP at [0,3,5,6,7,9,12,24,48,72] hr prior
for lag in [0, 3, 5, 6, 7, 9, 12, 24, 48, 72]:
    master_geo_ds[f"ap_{lag}hr_prior"] = master_geo_ds['ap30'].interp(
        dates=master_geo_ds['dates'] - pd.Timedelta(f"{lag}H")
    )
#    - 24 hr & 48 hr prior for ap30, dst, sme, fism2
for var in ("ap30", "dst", "sme", "fism2"):
    master_geo_ds[f"{var}_24hr_prior"] = master_geo_ds[var].interp(
        dates=master_geo_ds['dates'] - pd.Timedelta("1D")
    )
    master_geo_ds[f"{var}_48hr_prior"] = master_geo_ds[var].interp(
        dates=master_geo_ds['dates'] - pd.Timedelta("2D")
    )
#    - HP at [0,3,6,9,12,24,48,72] hr prior
for lag in [0, 3, 6, 9, 12, 24, 48, 72]:
    master_geo_ds[f"hp_{lag}hr_prior"] = master_geo_ds['hp30'].interp(
        dates=master_geo_ds['dates'] - pd.Timedelta(f"{lag}H")
    )

# 9) Final align: nearest‐neighbor reindex within 1 hr tolerance and append these lagged vars
master_geo_ds_reindexed = master_geo_ds.reindex(
    dates=MISA_ds['dates'], method='nearest', tolerance=pd.Timedelta("1H")
)
for var in master_geo_ds_reindexed.data_vars:
    MISA_ds[var] = master_geo_ds_reindexed[var]

## PROCESSING AND SAVING ANCILLARY GRID DATASET ##
# grid_ds is used to compute bidirectional interpolators between (az, alt) ↔ (lat, lon)
# it is derived from MISA_ds as such
grid_ds = MISA_ds[['gdalt', 'az_normalized','glon','gdlat']]
grid_ds = grid_ds.where(
    (grid_ds['dates'].dt.date >= np.datetime64('2024-04-01')) &
    (grid_ds['dates'].dt.date <= np.datetime64('2024-04-30')),
    drop=True
)
grid_ds = grid_ds.where(
    (grid_ds['gdalt'] > 0) & (grid_ds['gdalt'] < 510),
    drop=True
)
grid_ds = grid_ds.where(
    (grid_ds['az_normalized'] > -145) & (grid_ds['az_normalized'] < -45),
    drop=True
)
## save grid_ds to netCDF
grid_ds.to_netcdf(
    '../ancillary/grid_ds_2.0.7.nc',
    format='NETCDF4', engine='netcdf4', unlimited_dims='dates'
)


No matching group for kinst=31.0, mdtyp=115.0, pl=0.002
No matching group for kinst=31.0, mdtyp=115.0, pl=0.002
No matching group for kinst=31.0, mdtyp=115.0, pl=0.002
No matching group for kinst=31.0, mdtyp=115.0, pl=0.002
No matching group for kinst=31.0, mdtyp=115.0, pl=0.002
No matching group for kinst=31.0, mdtyp=115.0, pl=0.002
No matching group for kinst=31.0, mdtyp=115.0, pl=0.002
No matching group for kinst=31.0, mdtyp=115.0, pl=0.002
No matching group for kinst=31.0, mdtyp=115.0, pl=0.002
No matching group for kinst=31.0, mdtyp=115.0, pl=0.002
No matching group for kinst=31.0, mdtyp=115.0, pl=0.002
No matching group for kinst=31.0, mdtyp=115.0, pl=0.002
No matching group for kinst=31.0, mdtyp=115.0, pl=0.002
No matching group for kinst=31.0, mdtyp=115.0, pl=0.002
No matching group for kinst=31.0, mdtyp=115.0, pl=0.002
No matching group for kinst=31.0, mdtyp=115.0, pl=0.002
No matching group for kinst=31.0, mdtyp=115.0, pl=0.002
No matching group for kinst=31.0, mdtyp=115.0, p

  hp_ap_df = pd.read_csv(
  f10_kp_df = pd.read_csv(
  dst_df = pd.read_csv(dst_dir, delim_whitespace=True, comment='#')


ParserError: Error tokenizing data. C error: Calling read(nbytes) on source failed. Try engine='python'.

## 2. Cleaning and Filtering

In [None]:
# Specify which variable we’re targeting for filtering
target_variable = 'ne'

# --- 1) LOAD & EXCLUDE UNWANTED DATES ---
# Load eclipse dates to mask out days with total or partial solar eclipse
eclipse_days_df = pd.read_csv('../ancillary/eclipse_days.csv')
eclipse_dates = pd.to_datetime(eclipse_days_df['Date'])

# Load extreme (storm) dates if you wanted to mask them (not used below but available)
extreme_days_df = pd.read_csv('../ancillary/extreme_days.csv')
extreme_dates = pd.to_datetime(extreme_days_df['Storm date'])

# Normalize radar timestamps to just the date (drop time component)
MISA_ds_dates_only = pd.to_datetime(MISA_ds['dates'].values).normalize()

# Mask out any rows whose date is in the eclipse list
MISA_ds = MISA_ds.sel(
    dates=~np.isin(MISA_ds_dates_only, eclipse_dates),
    drop=False
)

# Restrict to days between vernal equinox (Mar 20) and summer solstice (Jun 21)
MISA_ds = MISA_ds.where(
    (
        (MISA_ds['dates'].dt.month == 3) & (MISA_ds['dates'].dt.day >= 20)
    ) | (
        (MISA_ds['dates'].dt.month > 3) & (MISA_ds['dates'].dt.month < 6)
    ) | (
        (MISA_ds['dates'].dt.month == 6) & (MISA_ds['dates'].dt.day <= 21)
    )
)

# --- 2) DERIVE ADDITIONAL VARIABLES ---
# Compute electron temperature: te = tr * ti
MISA_ds['te']  = MISA_ds['tr'] * MISA_ds['ti']
# Compute delta electron temperature: dte = dtr * dti
MISA_ds['dte'] = MISA_ds['dtr'] * MISA_ds['dti']

# --- 3) COMPUTE RELATIVE ERROR FOR NE ---
dne      = MISA_ds[f'd{target_variable}'].values
ne_vals  = MISA_ds[target_variable].values
relerr_ne = dne / ne_vals

# --- 4) DEFINE PERCENTILE‐BASED OUTLIER DETECTOR ---
def filter_data_by_bins_indices(target_array, percentileRange, data=MISA_ds):
    """
    Identify outliers in `target_array` separately within each 100-km gdalt bin.

    For each bin of ground-to-scatter distance:
      1. Extract values in that range (drop NaNs).
      2. Compute the lower and upper percentiles (e.g., [5, 95]).
      3. Flag points below the lower or above the upper cutoff.
    
    Parameters
    ----------
    target_array : xarray.DataArray
        2D array of the variable to filter (dims: dates × range).
    percentileRange : list or tuple of two floats
        [lower_percentile, upper_percentile], e.g. [5, 95].
    data : xarray.Dataset
        Must contain 'gdalt' for bin definitions.
    
    Returns
    -------
    bad_indices : list of (row, col) tuples
        Coordinates of all flagged outlier points.
    """
    # Define 100-km altitude bins from 0 up to max+100
    gdalt_bins = np.arange(0, data['gdalt'].max().data + 100, 100)
    bad_indices = []

    # Loop through each bin interval
    for i in range(len(gdalt_bins) - 1):
        # Mask array to that bin
        bin_mask = (data['gdalt'] >= gdalt_bins[i]) & (data['gdalt'] < gdalt_bins[i + 1])
        binned = target_array.where(bin_mask, drop=False)

        # Flatten & drop NaNs for percentile computation
        flat = binned.data.flatten()
        flat = flat[~np.isnan(flat)]
        if flat.size < 2:
            continue  # skip bins with insufficient data

        # Compute percentile thresholds
        low, high = np.percentile(flat, percentileRange, method='inverted_cdf')

        # Flag points outside [low, high)
        mask_pct = (binned < low) | (binned >= high)
        rows, cols = np.where(mask_pct.data)
        bad_indices.extend(zip(rows, cols))

    return bad_indices

# --- 5) COLLECT ALL BAD‐POINT INDICES ---
# 5a) Points with >3% relative error
rel_err_bad_indices = np.argwhere(relerr_ne > 0.03)

# 5b) Points >3.67 std dev from mean in each column
std_bad_indices = std_column_filter_indices(MISA_ds[target_variable].data, nbstd=3.67)

# 5c) Points outside the 5th–95th percentile per altitude bin
percentile_bad_indices = filter_data_by_bins_indices(
    MISA_ds[target_variable], percentileRange=[5, 95]
)

# Combine into a unique set of (row, col) tuples
all_bad_indices = set(map(tuple, rel_err_bad_indices))
all_bad_indices.update(std_bad_indices)
all_bad_indices.update(percentile_bad_indices)

# --- 6) APPLY NaN MASK FOR ALL BAD POINTS ---
filtered_ne = MISA_ds[target_variable].data.copy()
filtered_ne = apply_nans(filtered_ne, list(all_bad_indices))

# --- 7) REPORT FILTERING RESULTS & UPDATE DATASET ---
orig_count     = np.count_nonzero(~np.isnan(MISA_ds[target_variable].data))
filtered_count = np.count_nonzero(~np.isnan(filtered_ne))
print(f"Original dataset size: {orig_count}")
print(f"Filtered dataset size: {filtered_count}")

# Overwrite 'ne' in MISA_ds with the cleaned array
MISA_ds[target_variable] = (('dates', 'range'), filtered_ne)

  col_mean = np.nanmean(a[:, col])
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,


Original dataset size: 335742
Filtered dataset size: 95280


## 3. Transformation & Feature Engineering

In [None]:
# 1) Extract core coords & time indices from the cleaned dataset
dates_data     = MISA_ds['dates'].values                # (n_times,)
range_data     = MISA_ds['range'].values                # (n_range,)
normalized_az  = MISA_ds['az_normalized'].values       # mean beam azimuth (n_times,)
ut_data        = MISA_ds['ut'].values                   # universal time (decimal hours) (n_times,)

# 2) Flatten 2D fields (time × range) to 1D
flat_alt    = MISA_ds['gdalt'].values.flatten()    # ground-to-scatter distance
flat_gdlat  = MISA_ds['gdlat'].values.flatten()    # geomagnetic latitude
flat_glon   = MISA_ds['glon'].values.flatten()     # geomagnetic longitude
flat_doy    = MISA_ds['doy'].values.flatten()      # day-of-year

# 3) Repeat 1D arrays to align with each time-range pair
n_times = len(dates_data)
n_range = len(range_data)

dates_repeated = np.repeat(dates_data, n_range)      # timestamp per point
az_repeated    = np.repeat(normalized_az, n_range)   # az per point
ut_repeated    = np.repeat(ut_data, n_range)         # UT per point
doy_repeated   = np.repeat(flat_doy, n_range)        # DOY per point
range_repeated = np.tile(range_data, n_times)        # range per point

# 4) Prepare geophysical indices: repeat each time-series along range dim
geophysical_vars = [
    'dst','dst_24hr_prior','dst_48hr_prior',
    'sme','sme_24hr_prior','sme_48hr_prior',
    'fism2','fism2_24hr_prior','fism2_48hr_prior',
    'ap30','ap_3hr_prior','ap_5hr_prior','ap_6hr_prior',
    'ap_7hr_prior','ap_9hr_prior','ap_12hr_prior',
    'hp30','hp_3hr_prior','hp_6hr_prior','hp_9hr_prior','hp_12hr_prior'
]
flat_geophysical_data = []
for var in geophysical_vars:
    # Each series has length = n_times; repeat over n_range
    flat_geophysical_data.append(np.repeat(MISA_ds[var].values, n_range))

# 5) Stack all inputs into a single (n_times*n_range × n_features) array
combined_inputs = np.column_stack([
    ut_repeated,
    range_repeated,
    az_repeated,
    flat_alt,
    flat_gdlat,
    flat_glon,
    *flat_geophysical_data
])

# 6) Extract and append the target variable, then assemble into a DataFrame
target_data = MISA_ds[target_variable].values.flatten()
combined_inputs_df = pd.DataFrame(combined_inputs)
combined_inputs_df[target_variable] = target_data
# Add the timestamp column back for reference
combined_inputs_df['date'] = dates_repeated

# 7) Correct Solar Local Time (SLT) from UT + lon/15, wrapping negatives
slt_corrected = ut_repeated + (flat_glon / 15)
neg_mask = slt_corrected < 0
# For points where SLT < 0, decrement the DOY by 1
doy_repeated[neg_mask] -= 1
# Wrap SLT into [0,24)
slt_corrected[neg_mask] += 24

# 8) Insert corrected SLT and DOY into the DataFrame
combined_inputs_df.insert(1, 'slt', slt_corrected)
combined_inputs_df.insert(1, 'doy', doy_repeated)

# 9) Rename columns to meaningful feature names and drop missing rows
features = (
    ['date','doy','slt','ut','range','az','alt','lat','lon']
    + geophysical_vars
    + [target_variable]
)
combined_inputs_df.columns = features
combined_inputs_df = combined_inputs_df.dropna()

## 3. Binned Regression Training

In [None]:
# remove a test day from training
# test_date = np.datetime64('2024-04-06')
# print(f'removing test day ({test_date}) from dataset.')
# entries = np.sum(~np.isnan(engineered_inputs_df.where(engineered_inputs_df['date'].dt.floor(freq='1D') == test_date))) # remove april 12th
# if entries.sum() > 0:
#     engineered_inputs_df = engineered_inputs_df.where(engineered_inputs_df['date'].dt.floor(freq='1D') != test_date) # remove test day
#     print(f'{entries} dataset entries removed.')
# else:
#     print('no entries removed.')

In [None]:
# === TRAINING: Binned Polynomial Ridge Regression for each (az, alt) cell ===

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, PolynomialFeatures
from sklearn.linear_model import Ridge

# 1) Prepare the DataFrame and restrict to your best-performing az/alt domain
target_df = combined_inputs_df.copy(deep=True)

# Define the “core” azimuth & altitude ranges + a small buffer
set_az_domain  = (-140, -40)  # degrees
set_alt_domain = (100,  500)  # km
buffer = 5                    # slack outside the selected domain

# Mask out rows outside the buffered az/alt region
target_df = target_df.where(
    (target_df['az']  > set_az_domain[0] - buffer) &
    (target_df['az']  <= set_az_domain[1] + buffer) &
    (target_df['alt'] > set_alt_domain[0] - buffer) &
    (target_df['alt'] < set_alt_domain[1] + buffer * 5)
)

# 2) Define the bin edges for azimuth & altitude based on grid-search results
target_dAz  = 6.5   # bin width in degrees
target_dAlt = 28    # bin width in km

az_bin_edges  = np.arange(set_az_domain[0], set_az_domain[1] + target_dAz, target_dAz)
alt_bin_edges = np.arange(set_alt_domain[0], set_alt_domain[1] + target_dAlt, target_dAlt)

# Assign each row to an integer bin index
target_df['az_bin']  = pd.cut(target_df['az'],  bins=az_bin_edges,  labels=False, right=False)
target_df['alt_bin'] = pd.cut(target_df['alt'], bins=alt_bin_edges, labels=False, right=False)

# 3) Specify input features and the target variable for regression
target_indices_list = ['fism2_48hr_prior', 'ap_7hr_prior']
features = ['doy', 'slt'] + target_indices_list
target   = target_variable  # e.g., 'ne'

# 4) Prepare dictionaries to collect models and (optionally) debug info
debug_bin_data = {}     # store train/test splits for a specific bin if needed
bin_models     = {}     # final storage for (scaler, poly, model) per bin

# 5) Loop over each (az_bin, alt_bin) cell
for (az_bin, alt_bin), group in target_df.groupby(['az_bin', 'alt_bin']):

    # Skip bins with too few samples to fit a reliable model
    if len(group) <= 5:
        continue

    # Extract feature matrix X and target vector y for this bin
    X_bin = group[features].values
    y_bin = group[target].values

    # Split into train/test (33% for testing)
    X_train_bin, X_test_bin, y_train_bin, y_test_bin = train_test_split(
        X_bin, y_bin, test_size=0.33, random_state=42
    )

    # Compute the actual physical ranges of this bin (for metadata)
    az_range  = (az_bin_edges[int(az_bin)],   az_bin_edges[int(az_bin) + 1])
    alt_range = (alt_bin_edges[int(alt_bin)], alt_bin_edges[int(alt_bin) + 1])

    # Optional DEBUGGING: capture the splits for one “interesting” bin
    # if (az_range and alt_range near your test_az, test_alt):
    #     debug_bin_data = {
    #         'X_train': X_train_bin, 'y_train': y_train_bin,
    #         'X_test':  X_test_bin,  'y_test':  y_test_bin,
    #         'az_range': az_range, 'alt_range': alt_range,
    #         'n_samples': len(group)
    #     }

    # 6) Standardize inputs (zero mean, unit variance) based on training data
    scaler_bin = StandardScaler()
    X_train_scaled = scaler_bin.fit_transform(X_train_bin)

    # 7) Expand features into polynomial terms (up to 4th degree)
    poly_bin = PolynomialFeatures(degree=4, include_bias=False)
    X_train_poly = poly_bin.fit_transform(X_train_scaled)

    # 8) Fit a Ridge regression (L2-regularized linear model)
    model_bin = Ridge(alpha=1.0)
    model_bin.fit(X_train_poly, y_train_bin)

    # 9) Save the components for this bin to be used later in prediction
    bin_models[(az_bin, alt_bin)] = {
        'model':   model_bin,
        'scaler':  scaler_bin,
        'poly':    poly_bin,
        'az_range': az_range,
        'alt_range': alt_range,
        'n_samples': len(group)
    }

In [None]:
np.save(f'../model/{target_variable}_model_2_0_7_alpha.npy', bin_models)