In [1]:
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
from matplotlib import rc
from IPython.display import display
from scipy.stats import pearsonr
import seaborn as sns
from scipy.optimize import nnls  # for non-negative least squares
from scipy.stats import linregress
import matplotlib.patches as mpatches

rc('mathtext', default='regular')

from functions.functions_outliers_cr2sub import *

# === Load data ===
gw_all = pd.read_csv('../cr2sub/cr2sub_v1_mon.csv', parse_dates=['date'])
gw_all = gw_all.sort_values('date')
well_codes = [col for col in gw_all.columns if col != 'date']
date_index = pd.DatetimeIndex(gw_all['date'])

# (optional) figure generation
generate_figs = False  # set True to save per-well figures
out_figs = '../cr2sub/plots_outliers_removal'
if generate_figs:
    os.makedirs(out_figs, exist_ok=True)

# === Global parameters ===
min_obs    = 100
win_months = 6
perc_thr   = 0.99 # P99 for jump/residual thresholds
tail_p     = 0.01 # 0.5% / 99.5% for extreme tails
skew_thr   = 0.3 
min_n      = 50
skew_tail_p = 0.01
min_nan_gap = 6   # e.g., need ≥6-month NaN deserts
max_island_len = 3

# Collector for a single wide file (one column per well)
cols = {}  # collect columns here; build DataFrame once at the end to avoid fragmentation

for cod in well_codes:
    gw_raw = pd.Series(gw_all[cod].values, index=date_index)
    if gw_raw.notna().sum() < min_obs:
        cols[cod] = pd.Series(np.nan, index=date_index)
        continue  # Skip well (fill full-NaN column)

    # ===================== Raw time series =====================
    s0 = gw_raw.sort_index()
    
    # ===================== Pass 1 (on s0) =====================
    s1, rule0_1, rule1_1, smooth1 = run_pass(s0, 
                                             win_months, 
                                             perc_thr)

    # ===================== Pass 2 (on s1) =============
    s2, rule0_2, rule1_2, smooth2 = run_pass(s1, 
                                             win_months, 
                                             perc_thr)

    # =============== Filter tails (on s2) =======
    s3, tail_idx1, tail_idx2 = final_tail_filters(s2, 
                                                  win_months=win_months, 
                                                  perc_thr=perc_thr, 
                                                  tail_p=tail_p,
                                                  skew_thr=skew_thr, 
                                                  min_n=min_n, 
                                                  skew_tail_p=skew_tail_p)
    
    # =============== Filter isolated data (on s3) =======
    gw_final, islands_idx = remove_isolated_islands(s3,
                                                    min_nan_gap=min_nan_gap,      
                                                    max_island_len=max_island_len)

    
    # =============== Append data into single dataframe =======
    cols[cod] = gw_final.reindex(date_index)

    
    # ================== DETAILED PLOT (optional) ==================

        # --- 6-month rolling means: raw vs final ---
    if generate_figs:
        def roll6_trailing(series):
            # Put series on a complete monthly grid so "6 months" means 6 calendar months
            if series.dropna().empty:
                return pd.Series([], dtype=float)
            idx_full = pd.date_range(series.index.min().replace(day=1),
                                     series.index.max().replace(day=1),
                                     freq="MS")
            s_full = series.reindex(idx_full)
            return s_full.rolling(window=6, min_periods=1).mean()   # trailing window
    
        roll6_raw   = roll6_trailing(s0)
        roll6_final = roll6_trailing(gw_final)
    
        fig, ax = plt.subplots(figsize=(10, 4))
        ax.scatter(s0.index, s0.values, s=12, c='0.6', label='Original')
    
        # Build categories in chronological order of application
        cats = [
            ("Rule 0 — pass 1 (spike/dip)",   pd.Index(rule0_1), 'x', 'saddlebrown', 40),
            ("Rule 1 — pass 1 (mean jump)",   pd.Index(rule1_1), 'x', 'crimson',     40),
            ("Rule 0 — pass 2 (spike/dip)",   pd.Index(rule0_2), 'x', 'peru',        40),
            ("Rule 1 — pass 2 (mean jump)",   pd.Index(rule1_2), 'x', 'darkorange',  40),
            ("Tail A — skew one-sided",       pd.Index(tail_idx1), 'x', 'purple',    40),
            ("Tail B — two-sided robust-z",   pd.Index(tail_idx2), 'x', 'violet',    40),
            ("Isolated islands",              pd.Index(islands_idx), 'x', 'black',   40),
        ]
    
        # De-duplicate across categories (earlier category wins)
        seen = set()
        for label, idx, marker, color, size in cats:
            if len(idx):
                # keep only dates not already assigned to a previous category
                new_idx = pd.Index([d for d in idx if d not in seen])
                if len(new_idx):
                    ax.scatter(new_idx, s0.loc[new_idx], marker=marker, c=color, s=size, label=label)
                    seen.update(new_idx)
    
        # Smoothed line from pass 2 (post-Rule 0 data of pass 2)
        ax.set_title(f"Outliers removal — Well {cod}")
        ax.set_ylabel('GWL [m]')
        ax.grid(ls='--', alpha=.5)
    
            # Plot them
        if len(roll6_raw):
            ax.plot(roll6_raw.index, roll6_raw.values, lw=1, ls='--', label='Raw 6-mo mean', color = 'steelblue')
        if len(roll6_final):
            ax.plot(roll6_final.index, roll6_final.values, lw=1.2, label='Final 6-mo mean', color = 'black')
    
        ax.legend(ncol=1, fontsize=8)
        plt.tight_layout()
        plt.savefig(f"{out_figs}/{cod}_outliers_rem_detailed.pdf")
        plt.close()     

    

# Build the wide DataFrame once at the end to avoid fragmentation
gw_filtered_wide = pd.DataFrame(cols, index=date_index)
# =============== SAVE csv =======   
if not gw_filtered_wide.empty:
    gw_filtered_wide = gw_filtered_wide.sort_index()
    gw_filtered_wide.index.name = 'date'
    gw_filtered_wide.to_csv(os.path.join('../cr2sub', "cr2sub_v1_mon_clean.csv"), na_rep='NA')    