In [8]:
# -*- coding: utf-8 -*-
"""
Compare Sentinel-3 vs. all Sentinel-2 chlorophyll-a versions with per-pair linear regression.

Model per pair: S2_version = intercept + slope * S3

Outputs:
- s2_vs_s3_regression_summary.csv : one row per S2 column with full stats
- plots/s2_vs_s3_<S2COL>.png      : scatter + OLS fit + 1:1 line with annotated stats
"""

import os
import re
import math
import numpy as np
import pandas as pd
import scipy.stats as stats
import statsmodels.api as sm
import matplotlib.pyplot as plt
from typing import List, Optional, Tuple

In [9]:
# -------------------- CONFIG --------------------
CSV_PATH = "/home/data/ocean-colour/s2_s3_sampling_wolines.csv"

In [10]:
S3_COL_NAME: Optional[str] = None   # e.g. "S3_chla"
S2_INCLUDE_REGEX: Optional[str] = None  # e.g. r"^S2_.*chla$"

LOG10_TRANSFORM = False

REMOVE_OUTLIERS = True
OUTLIER_METHOD = "iqr"  # "iqr" or "zscore"
IQR_K = 1.5
ZSCORE_THRESH = 4.0

PLOT_DIR = "plots"
FIGSIZE = (6, 5)
MARKER_SIZE = 15
ALPHA = 0.6
# ------------------------------------------------

def autodetect_columns(df: pd.DataFrame) -> Tuple[str, List[str]]:
    cols = list(df.columns)
    s3_candidates = [c for c in cols if re.search(r"(?:^|[^a-zA-Z0-9])s3(?:[^a-zA-Z0-9]|$)", c, flags=re.I)
                     or re.search(r"sentinel[-_\s]*3", c, flags=re.I)]
    if not s3_candidates:
        s3_candidates = [c for c in cols if re.search(r"(?i)S3.*(chl|chla|chlor|a)\b", c)]
    if not s3_candidates:
        numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
        if len(numeric_cols) == 1:
            s3_guess = numeric_cols[0]
        else:
            raise ValueError("Could not auto-detect an S3 column. Set S3_COL_NAME explicitly.")
    else:
        s3_guess = s3_candidates[0]

    s2_candidates = [c for c in cols if (
        re.search(r"(?:^|[^a-zA-Z0-9])s2(?:[^a-zA-Z0-9]|$)", c, flags=re.I) or
        re.search(r"sentinel[-_\s]*2", c, flags=re.I)
    )]
    s2_candidates = [c for c in s2_candidates if pd.api.types.is_numeric_dtype(df[c])]

    if S2_INCLUDE_REGEX:
        s2_candidates = [c for c in cols if re.search(S2_INCLUDE_REGEX, c, flags=re.I)]
        s2_candidates = [c for c in s2_candidates if pd.api.types.is_numeric_dtype(df[c])]

    if not s2_candidates:
        numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
        s2_candidates = [c for c in numeric_cols if c != s3_guess]

    if not s2_candidates:
        raise ValueError("No S2 columns found. Consider setting S2_INCLUDE_REGEX.")

    return s3_guess, s2_candidates

def apply_transform(x: np.ndarray) -> np.ndarray:
    if not LOG10_TRANSFORM:
        return x.astype(float)
    x = x.astype(float)
    x = np.where(x > 0, x, np.nan)
    return np.log10(x)

def remove_outliers_xy(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    if not REMOVE_OUTLIERS:
        return x, y
    xv = x.copy()
    yv = y.copy()
    finite = np.isfinite(xv) & np.isfinite(yv)
    xv = xv[finite]; yv = yv[finite]

    if OUTLIER_METHOD == "iqr":
        def iqr_filter(a):
            q1, q3 = np.nanpercentile(a, [25, 75])
            iqr = q3 - q1
            lo, hi = q1 - IQR_K * iqr, q3 + IQR_K * iqr
            return (a >= lo) & (a <= hi)
        keep = iqr_filter(xv) & iqr_filter(yv)
        return xv[keep], yv[keep]

    elif OUTLIER_METHOD == "zscore":
        zx = stats.zscore(xv, nan_policy='omit')
        zy = stats.zscore(yv, nan_policy='omit')
        keep = (np.abs(zx) <= ZSCORE_THRESH) & (np.abs(zy) <= ZSCORE_THRESH)
        return xv[keep], yv[keep]

    else:
        return xv, yv

def regression_and_stats(x: np.ndarray, y: np.ndarray) -> dict:
    """
    Robust to numpy or pandas inputs. Forces named pandas objects so we can
    index params and confidence intervals safely.
    """
    # Make named Series
    x_s = pd.Series(x, name="x").astype(float)
    y_s = pd.Series(y, name="y").astype(float)

    # Drop NaNs in a paired way
    xy = pd.concat([x_s, y_s], axis=1).replace([np.inf, -np.inf], np.nan).dropna()
    x_s = xy["x"]; y_s = xy["y"]

    if len(x_s) < 3:
        return dict(n=int(len(x_s)), slope=np.nan, slope_se=np.nan, slope_ci_lo=np.nan, slope_ci_hi=np.nan,
                    intercept=np.nan, intercept_se=np.nan, intercept_ci_lo=np.nan, intercept_ci_hi=np.nan,
                    r_squared=np.nan, pearson_r=np.nan, pearson_p=np.nan,
                    spearman_rho=np.nan, spearman_p=np.nan, rmse=np.nan, mae=np.nan, bias=np.nan)

    X = sm.add_constant(x_s)  # columns: ['const', 'x']
    model = sm.OLS(y_s, X, missing='drop').fit()

    a = float(model.params["const"])
    b = float(model.params["x"])
    a_se = float(model.bse["const"])
    b_se = float(model.bse["x"])

    ci = model.conf_int()
    # Ensure the index names exist (should be ['const','x'])
    if "const" not in ci.index or "x" not in ci.index:
        ci.index = ["const", "x"]
    a_ci_lo, a_ci_hi = map(float, ci.loc["const"])
    b_ci_lo, b_ci_hi = map(float, ci.loc["x"])

    pearson_r, pearson_p = stats.pearsonr(x_s.values, y_s.values) if len(x_s) > 1 else (np.nan, np.nan)
    spearman_rho, spearman_p = stats.spearmanr(x_s.values, y_s.values, nan_policy='omit') if len(x_s) > 1 else (np.nan, np.nan)

    resid = y_s.values - (a + b * x_s.values)
    rmse = float(np.sqrt(np.nanmean(resid**2)))
    mae = float(np.nanmean(np.abs(resid)))
    bias = float(np.nanmean(y_s.values - x_s.values))  # mean difference S2 - S3

    return dict(
        n=int(len(x_s)),
        slope=b, slope_se=b_se, slope_ci_lo=b_ci_lo, slope_ci_hi=b_ci_hi,
        intercept=a, intercept_se=a_se, intercept_ci_lo=a_ci_lo, intercept_ci_hi=a_ci_hi,
        r_squared=float(model.rsquared),
        pearson_r=float(pearson_r), pearson_p=float(pearson_p),
        spearman_rho=float(spearman_rho), spearman_p=float(spearman_p),
        rmse=rmse, mae=mae, bias=bias
    )

def scatter_plot(x: np.ndarray, y: np.ndarray, s3_name: str, s2_name: str, stats_row: dict, out_path: str):
    plt.figure(figsize=FIGSIZE)
    plt.scatter(x, y, s=MARKER_SIZE, alpha=ALPHA)

    a, b = stats_row["intercept"], stats_row["slope"]
    xgrid = np.linspace(np.nanmin(x), np.nanmax(x), 200)
    yfit = a + b * xgrid
    plt.plot(xgrid, yfit, linewidth=2)
    plt.plot(xgrid, xgrid, linestyle="--", linewidth=1)  # 1:1

    xlabel = f"{s3_name}"
    ylabel = f"{s2_name}"
    if LOG10_TRANSFORM:
        xlabel += " (log10)"
        ylabel += " (log10)"
    plt.xlabel(xlabel); plt.ylabel(ylabel); plt.title(f"{s2_name} vs {s3_name}")

    txt = (f"n={stats_row['n']}\n"
           f"slope={stats_row['slope']:.3f}  (95% CI {stats_row['slope_ci_lo']:.3f},{stats_row['slope_ci_hi']:.3f})\n"
           f"intercept={stats_row['intercept']:.3f}  (95% CI {stats_row['intercept_ci_lo']:.3f},{stats_row['intercept_ci_hi']:.3f})\n"
           f"R²={stats_row['r_squared']:.3f}  r={stats_row['pearson_r']:.3f} (p={stats_row['pearson_p']:.1e})  "
           f"ρ={stats_row['spearman_rho']:.3f} (p={stats_row['spearman_p']:.1e})\n"
           f"RMSE={stats_row['rmse']:.3f}  MAE={stats_row['mae']:.3f}  bias(S2-S3)={stats_row['bias']:.3f}")
    plt.gcf().text(0.02, 0.02, txt, fontsize=9, va="bottom", ha="left")

    plt.tight_layout()
    os.makedirs(PLOT_DIR, exist_ok=True)
    plt.savefig(out_path, dpi=200)
    plt.close()

def main():
    df = pd.read_csv(CSV_PATH)

    # try to coerce numeric columns
    for c in df.columns:
        if not pd.api.types.is_numeric_dtype(df[c]):
            df[c] = pd.to_numeric(df[c], errors='ignore')

    s3_col, s2_cols = (S3_COL_NAME, None) if S3_COL_NAME else (None, None)
    if s3_col is None or s2_cols is None:
        s3_col, s2_cols = autodetect_columns(df)

    df = df.drop_duplicates()
    s3_series = df[s3_col].astype(float)

    rows = []
    for s2_col in s2_cols:
        s2_series = df[s2_col].astype(float)
        xy = pd.DataFrame({s3_col: s3_series, s2_col: s2_series}).replace([np.inf, -np.inf], np.nan).dropna()
        if xy.empty or xy.shape[0] < 3:
            continue

        x = apply_transform(xy[s3_col].to_numpy())
        y = apply_transform(xy[s2_col].to_numpy())

        valid = np.isfinite(x) & np.isfinite(y)
        x, y = x[valid], y[valid]
        if len(x) < 3:
            continue

        x, y = remove_outliers_xy(x, y)
        if len(x) < 3:
            continue

        stats_row = regression_and_stats(x, y)
        stats_row.update(dict(s3_col=s3_col, s2_col=s2_col))

        safe_name = re.sub(r"[^A-Za-z0-9_.-]+", "_", s2_col)
        out_plot = os.path.join(PLOT_DIR, f"s2_vs_s3_{safe_name}.png")
        scatter_plot(x, y, s3_col, s2_col, stats_row, out_plot)

        rows.append(stats_row)

    if not rows:
        raise RuntimeError("No valid S2/S3 pairs found to regress. Check column names and data quality.")

    summary = pd.DataFrame(rows).sort_values(["r_squared", "pearson_r"], ascending=[False, False])
    summary_cols = [
        "s2_col", "s3_col", "n",
        "slope", "slope_se", "slope_ci_lo", "slope_ci_hi",
        "intercept", "intercept_se", "intercept_ci_lo", "intercept_ci_hi",
        "r_squared", "pearson_r", "pearson_p", "spearman_rho", "spearman_p",
        "rmse", "mae", "bias"
    ]
    summary = summary[summary_cols]
    summary.to_csv("s2_vs_s3_regression_summary.csv", index=False)

    print("✅ Done.")
    print(f"S3 column: {s3_col}")
    print(f"Compared S2 columns (n={len(summary)}):")
    for c in summary['s2_col']:
        print(f"  - {c}")
    print("\nWrote: s2_vs_s3_regression_summary.csv")
    print(f"Plots in: {os.path.abspath(PLOT_DIR)}/")

if __name__ == "__main__":
    main()


✅ Done.
S3 column: S3
Compared S2 columns (n=5):
  - S2_orig
  - s2_fp2
  - s2_fp
  - S2_anchor
  - s2_fp1

Wrote: s2_vs_s3_regression_summary.csv
Plots in: /home/notebooks/ocean-colour/plots/
