In [2]:
# @title adding correlation section
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import ipywidgets as widgets
import requests
import warnings
import calendar
from datetime import datetime, timedelta
from IPython.display import display, HTML

# Disable Colab formatting
try:
    from google.colab import data_table
    data_table.disable_dataframe_formatter()
except:
    pass

warnings.filterwarnings("ignore", category=FutureWarning)
pd.set_option('future.no_silent_downcasting', True)

# --- HELPERS ---
def get_date_suffix(day):
    if 11 <= day <= 13: return 'th'
    else: return {1: 'st', 2: 'nd', 3: 'rd'}.get(day % 10, 'th')

def get_stations_for_state(state_code):
    url = "https://data.rcc-acis.org/StnMeta"
    payload = {"state": state_code, "sdate": "1991-01-01", "edate": datetime.now().strftime("%Y-%m-%d"), "elems": "maxt", "meta": ["name", "sids"]}
    try:
        response = requests.post(url, json=payload, timeout=10)
        if response.status_code == 200:
            meta = response.json().get("meta", [])
            stations = {f"{s['name']} ({s['sids'][0]})": s['sids'][0] for s in meta}
            return dict(sorted(stations.items()))
    except: return {}
    return {}

def get_xmacis_data(station_id):
    url = "https://data.rcc-acis.org/StnData"
    elems = [{"name": "avgt", "interval": "dly"}, {"name": "avgt", "interval": "dly", "normal": "91"}, {"name": "avgt", "interval": "dly", "normal": "departure91"},
             {"name": "pcpn", "interval": "dly"}, {"name": "pcpn", "interval": "dly", "normal": "91"}, {"name": "pcpn", "interval": "dly", "normal": "departure91"},
             {"name": "snow", "interval": "dly"}, {"name": "snow", "interval": "dly", "normal": "91"}, {"name": "snow", "interval": "dly", "normal": "departure91"}]
    payload = {"sid": station_id, "sdate": "1975-01-01", "edate": datetime.now().strftime("%Y-%m-%d"), "elems": elems}
    try:
        response = requests.post(url, json=payload, timeout=20)
        if response.status_code == 200:
            data = response.json().get("data", [])
            if not data: return pd.DataFrame()
            columns = ['Date', 'AvgTemp', 'AvgTempNormal', 'MinTemperatureDeparture', 'Precipitation', 'PrecipitationNormal', 'Precip_Dep', 'Snowfall', 'SnowfallNormal', 'Snow_Dep']
            df = pd.DataFrame(data, columns=columns)
            for col in columns[1:]:
                df[col] = pd.to_numeric(df[col].astype(str).str.strip().replace({'T': '0.005', 'M': np.nan}), errors='coerce')
            df.loc[(df['Precipitation'] == 0) & (df['Snowfall'].isna()), 'Snowfall'] = 0
            df['Date'] = pd.to_datetime(df['Date'])
            return df.dropna(subset=['Date']).sort_values('Date')
    except Exception as e:
        print(f"Data Fetch Error: {e}")
    return pd.DataFrame()

def get_cbrfc_volumes(station_id):
    url = f"https://www.cbrfc.noaa.gov/wsup/graph/esprank.py?id={station_id.upper()}&status=1&db="
    try:
        tables = pd.read_html(url)
        if not tables: return pd.DataFrame()
        df = tables[0].dropna(how='all').dropna(axis=1, how='all')
        df = df.rename(columns={'Unnamed: 1': 'Year', 'Unnamed: 2': 'Obs_Volume_KAF'})
        df = df[['Year', 'Obs_Volume_KAF']].dropna()
        df['Year'] = df['Year'].astype(int)
        df['Obs_Volume_KAF'] = pd.to_numeric(df['Obs_Volume_KAF'], errors='coerce')
        return df
    except: return pd.DataFrame()

# --- INITIALIZATION ---
yesterday = datetime.now() - timedelta(days=1)
initial_state = 'UT'
initial_stations = get_stations_for_state(initial_state)

target_id = "24127 1"
default_sid = target_id if target_id in initial_stations.values() else (list(initial_stations.values())[0] if initial_stations else "24127 1")
weather_df = get_xmacis_data(default_sid)

options = {'Avg Temp anomaly': 'Temp_Dep_Avg', 'Snowfall anomaly': 'Snow_Dep_Sum', 'Precip anomaly': 'Precip_Dep_Sum'}
out = widgets.Output()
status_label = widgets.HTML(value="<b>Status:</b> <span style='color: green;'>Ready</span>")

# --- WIDGETS ---
c_style = {'description_width': '110px'}
c_layout = widgets.Layout(width='380px')

state_dropdown = widgets.Dropdown(options=['AZ','CA','CO','ID','MT','NM','NV','OR','UT','WA','WY'], value='UT', description='State:', style=c_style, layout=c_layout)
station_dropdown = widgets.Dropdown(description="Station:", options=initial_stations, value=default_sid, style=c_style, layout=c_layout)
rfc_stn_input = widgets.Text(value='DELU1', description='CBRFC ID:', style=c_style, layout=c_layout)
extra_years_input = widgets.Text(value='', description='Add Years:', style=c_style, layout=c_layout)
x_drop = widgets.Dropdown(options=list(options.keys()), value='Avg Temp anomaly', description='X Axis:', style=c_style, layout=c_layout)
y_drop = widgets.Dropdown(options=list(options.keys()), value='Snowfall anomaly', description='Y Axis:', style=c_style, layout=c_layout)
m_sld = widgets.IntSlider(min=1, max=12, value=yesterday.month, description='Month:', style=c_style, layout=c_layout)
d_sld = widgets.IntSlider(min=1, max=31, value=yesterday.day, description='Day:', style=c_style, layout=c_layout)
window_sld = widgets.IntSlider(min=1, max=180, value=60, description='window days:', style=c_style, layout=c_layout)

# --- DASHBOARD LOGIC ---
def update_dashboard(x_axis, y_axis, month, day, rfc_id, extra_yrs, window_size):
    plt.close('all')
    status_label.value = "<b>Status:</b> <span style='color: blue;'>Updating Analysis...</span>"
    stn_full_name = next((k for k, v in station_dropdown.options.items() if v == station_dropdown.value), station_dropdown.value)
    date_str = f"{calendar.month_name[month]} {day}{get_date_suffix(day)}"

    df = weather_df.copy()
    if df.empty:
        with out: out.clear_output(wait=True); print("No weather data available."); status_label.value = "<b>Status:</b> <span style='color: red;'>Data Fetch Failed</span>"
        return

    # Calculate Rolling Windows
    thresh = int(window_size * 0.95)
    df['Temp_Dep_Avg'] = df['MinTemperatureDeparture'].rolling(window_size, min_periods=thresh).mean()
    df['Snow_Dep_Sum'] = (df['Snowfall'] - df['SnowfallNormal']).rolling(window_size, min_periods=thresh).sum()
    df['Precip_Dep_Sum'] = (df['Precipitation'] - df['PrecipitationNormal']).rolling(window_size, min_periods=thresh).sum()

    x_col, y_col = options[x_axis], options[y_axis]
    try: manual_years = [int(y.strip()) for y in extra_yrs.split(',') if y.strip().isdigit()]
    except: manual_years = []

    with out:
        out.clear_output(wait=True)
        analogs = df[(df['Date'].dt.month == month) & (df['Date'].dt.day == day)].copy()
        analogs['Year'] = analogs['Date'].dt.year
        plot_df = analogs.dropna(subset=[x_col, y_col]).copy()

        if plot_df.empty:
            print(f"No valid data found for {date_str} with {window_size}-day window."); status_label.value = "<b>Status:</b> <span style='color: orange;'>Insufficient Data</span>"; return

        latest_yr = int(plot_df['Year'].max())
        target_x = plot_df[plot_df['Year'] == latest_yr][x_col].values[0]
        target_y = plot_df[plot_df['Year'] == latest_yr][y_col].values[0]

        xr = plot_df[x_col].max() - plot_df[x_col].min() or 1.0
        yr = plot_df[y_col].max() - plot_df[y_col].min() or 1.0
        plot_df['Closeness'] = np.sqrt(((plot_df[x_col]-target_x)/xr)**2 + ((plot_df[y_col]-target_y)/yr)**2)
        analogs_list = plot_df[plot_df['Year'] != latest_yr].sort_values('Closeness').head(10)['Year'].tolist()

        # --- MAIN SCATTER PLOT ---
        plt.figure(figsize=(12, 7))

        max_abs_x = max(abs(plot_df[x_col].min()), abs(plot_df[x_col].max()))
        max_abs_y = max(abs(plot_df[y_col].min()), abs(plot_df[y_col].max()))
        limit_x = max_abs_x * 1.15 if max_abs_x != 0 else 1.0
        limit_y = max_abs_y * 1.15 if max_abs_y != 0 else 1.0

        plt.xlim(-limit_x, limit_x)
        plt.ylim(-limit_y, limit_y)

        hist_df = plot_df[~plot_df['Year'].isin([latest_yr] + analogs_list + manual_years)]
        plt.scatter(hist_df[x_col], hist_df[y_col], c='#D1D5DB', s=50, alpha=0.4, zorder=2)

        spec_df = plot_df[plot_df['Year'].isin([latest_yr] + analogs_list + manual_years)]
        colors = spec_df['Year'].apply(lambda yr: '#E11D48' if yr == latest_yr else ('#2563EB' if yr in analogs_list else '#059669'))
        plt.scatter(spec_df[x_col], spec_df[y_col], c=colors, s=spec_df['Year'].apply(lambda y: 200 if y==latest_yr else 120), edgecolor='white', linewidth=1.5, zorder=4)

        for i, row in plot_df.iterrows():
            y_v = int(row['Year']); is_s = y_v in ([latest_yr] + analogs_list + manual_years)
            plt.text(row[x_col], row[y_col]+(limit_y*0.04), str(y_v), fontsize=9 if is_s else 7, weight='bold' if is_s else 'normal', color='#111827' if is_s else '#9CA3AF', ha='center', zorder=5)

        plt.axhline(0, color='black', linewidth=1.2, zorder=1)
        plt.axvline(0, color='black', linewidth=1.2, zorder=1)

        plt.title(f"{stn_full_name} {window_size} day anomaly departure ending {date_str}", loc='left', weight='bold', fontsize=12)
        plt.xlabel(x_axis); plt.ylabel(y_axis); plt.grid(True, linestyle=':', alpha=0.3); sns.despine(); plt.show()

        # --- RUNOFF DATA & TABLE ---
        runoff = get_cbrfc_volumes(rfc_id)
        if not runoff.empty:
            b_avg = runoff[(runoff['Year']>=1991) & (runoff['Year']<=2020)]['Obs_Volume_KAF'].mean()
            match_table = runoff.merge(plot_df[['Year', 'Closeness']], on='Year', how='inner')
            match_table = match_table[match_table['Year'].isin(list(set([latest_yr]+analogs_list+manual_years)))]
            match_table['% of Avg'] = (match_table['Obs_Volume_KAF']/b_avg*100).round(1)

            final_table = match_table[['Year', 'Closeness', 'Obs_Volume_KAF', '% of Avg']].sort_values('Closeness')
            final_table['Closeness'] = final_table['Closeness'].map('{:.3f}'.format)
            final_table['Obs_Volume_KAF'] = final_table['Obs_Volume_KAF'].map('{:.1f}'.format)
            final_table['% of Avg'] = final_table['% of Avg'].map('{:.1f}%'.format)

            table_title = f"Observed Runoff Volumes: {rfc_id.upper()} (91-20 Avg: {b_avg:.1f} KAF)"

            html = f"""
            <style>
                .table-container {{ font-family: sans-serif; margin-top: 20px; margin-bottom: 20px; }}
                .table-title {{ font-size: 1.1em; font-weight: bold; margin-bottom: 8px; color: #333; }}
                .sortable-table {{ border-collapse: collapse; font-size: 0.9em; min-width: 400px; border: 1px solid #ddd; }}
                .sortable-table th {{ cursor: pointer; background-color: #f2f2f2; padding: 8px; text-align: left; border-bottom: 2px solid #ddd; }}
                .sortable-table td {{ padding: 8px; border-bottom: 1px solid #ddd; }}
                .sortable-table tr:hover {{ background-color: #f5f5f5; }}
            </style>
            <div class="table-container">
                <div class="table-title">{table_title}</div>
                <table id="myTable" class="sortable-table">
                    <thead><tr>{''.join(f'<th onclick="sortTable({i})">{c} ↕</th>' for i, c in enumerate(final_table.columns))}</tr></thead>
                    <tbody>{''.join('<tr>' + ''.join(f'<td>{v}</td>' for v in row) + '</tr>' for row in final_table.values)}</tbody>
                </table>
            </div>
            <script>
            function sortTable(n) {{
              var table, rows, switching, i, x, y, shouldSwitch, dir, switchcount = 0;
              table = document.getElementById("myTable");
              switching = true; dir = "asc";
              while (switching) {{
                switching = false; rows = table.rows;
                for (i = 1; i < (rows.length - 1); i++) {{
                  shouldSwitch = false; x = rows[i].getElementsByTagName("TD")[n]; y = rows[i+1].getElementsByTagName("TD")[n];
                  var xVal = isNaN(parseFloat(x.innerHTML)) ? x.innerHTML.toLowerCase() : parseFloat(x.innerHTML);
                  var yVal = isNaN(parseFloat(y.innerHTML)) ? y.innerHTML.toLowerCase() : parseFloat(y.innerHTML);
                  if (dir == "asc") {{ if (xVal > yVal) {{ shouldSwitch = true; break; }} }}
                  else if (dir == "desc") {{ if (xVal < yVal) {{ shouldSwitch = true; break; }} }}
                }}
                if (shouldSwitch) {{ rows[i].parentNode.insertBefore(rows[i + 1], rows[i]); switching = true; switchcount ++; }}
                else {{ if (switchcount == 0 && dir == "asc") {{ dir = "desc"; switching = true; }} }}
              }}
            }}
            </script>
            """
            display(HTML(html))

            # ### --- NEW: COMBINED PREDICTIVE POWER ANALYSIS --- ###
            # 1. Prepare Data
            corr_df = df.dropna(subset=[x_col, y_col]).copy()
            corr_df['WaterYear'] = corr_df['Date'].apply(lambda d: d.year + 1 if d.month >= 10 else d.year)

            # 2. Join with Runoff
            merged_corr = corr_df.merge(runoff[['Year', 'Obs_Volume_KAF']], left_on='WaterYear', right_on='Year')

            if not merged_corr.empty:
                # 3. Create generic PlotDate (Handle Leap Year Bug)
                merged_corr = merged_corr[~((merged_corr['Date'].dt.month == 2) & (merged_corr['Date'].dt.day == 29))]
                merged_corr['PlotDate'] = merged_corr['Date'].apply(lambda d: d.replace(year=2000) if d.month >= 10 else d.replace(year=2001))

                # 4. Calculate MULTIPLE CORRELATION for every day
                results = []
                unique_days = merged_corr['PlotDate'].unique()

                for d in unique_days:
                    day_slice = merged_corr[merged_corr['PlotDate'] == d]
                    if len(day_slice) < 5: # Need minimum data points
                        results.append({'PlotDate': d, 'R': np.nan})
                        continue

                    # Vectors
                    x_data = day_slice[x_col]
                    y_data = day_slice[y_col]
                    z_data = day_slice['Obs_Volume_KAF']

                    # Pairwise correlations
                    rxz = x_data.corr(z_data) # X vs Runoff
                    ryz = y_data.corr(z_data) # Y vs Runoff
                    rxy = x_data.corr(y_data) # X vs Y

                    if x_col == y_col or abs(rxy) > 0.99:
                        # If axes are same or perfectly correlated, reduce to simple correlation
                        R = abs(rxz)
                    else:
                        # Multiple Correlation Formula
                        try:
                            R_sq = (rxz**2 + ryz**2 - 2*rxz*ryz*rxy) / (1 - rxy**2)
                            R = np.sqrt(max(0, min(1, R_sq))) # Clip to 0-1
                        except:
                            R = 0

                    results.append({'PlotDate': d, 'R': R})

                daily_stats = pd.DataFrame(results).sort_values('PlotDate')

                # 5. Plot
                plt.figure(figsize=(12, 3))

                # Plot the line
                plt.plot(daily_stats['PlotDate'], daily_stats['R'], color='#7C3AED', linewidth=2, label='Combined Power')

                # Highlight "Today"
                curr_plot_year = 2000 if month >= 10 else 2001
                if month == 2 and day == 29: curr_plot_date = datetime(curr_plot_year, 2, 28)
                else: curr_plot_date = datetime(curr_plot_year, month, day)

                try:
                    curr_val = daily_stats[daily_stats['PlotDate'] == curr_plot_date]['R'].values[0]
                    plt.axvline(curr_plot_date, color='#E11D48', linestyle='--', linewidth=2)
                    plt.text(curr_plot_date + timedelta(days=5), 0.1, f"Combined Strength\nR = {curr_val:.2f}", color='#E11D48', weight='bold')
                except: pass

                plt.title(f"Predictive Strength: How well {x_axis} + {y_axis} combined predict Runoff", loc='left', weight='bold', fontsize=10)
                plt.ylabel("Multiple R (0-1)")
                plt.ylim(0, 1.05) # R is always positive 0 to 1
                plt.grid(True, linestyle=':', alpha=0.5)

                import matplotlib.dates as mdates
                plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%b'))
                sns.despine()
                plt.show()

            status_label.value = f"<b>Status:</b> <span style='color: green;'>Updated for {rfc_id.upper()}</span>"
        else: status_label.value = f"<b>Status:</b> <span style='color: orange;'>Updated (No CBRFC Data)</span>"

# --- EVENT HANDLERS ---
def on_state_change(change):
    status_label.value = "<b>Status:</b> <span style='color: blue;'>Loading Stations...</span>"
    new = get_stations_for_state(change.new)
    if new:
        station_dropdown.options = new
        station_dropdown.value = "24127 1" if "24127 1" in new.values() else list(new.values())[0]
        status_label.value = "<b>Status:</b> <span style='color: green;'>Stations Loaded</span>"
    else: status_label.value = "<b>Status:</b> <span style='color: red;'>Failed to Load Stations</span>"

def on_station_change(change):
    global weather_df
    if change.new:
        status_label.value = f"<b>Status:</b> <span style='color: blue;'>Fetching Data for {change.new}...</span>"
        with out: out.clear_output(wait=True); print(f"Downloading daily data for station {change.new}...")
        weather_df = get_xmacis_data(change.new)
        if not weather_df.empty:
            status_label.value = "<b>Status:</b> <span style='color: green;'>Data Loaded. Refreshing...</span>"
            cur = d_sld.value; d_sld.value = 31 if cur != 31 else 30; d_sld.value = cur
        else:
            status_label.value = "<b>Status:</b> <span style='color: red;'>Data Fetch Failed</span>"
            with out: print("Failed to download data.")

state_dropdown.observe(on_state_change, names='value')
station_dropdown.observe(on_station_change, names='value')

spacer = widgets.Label(layout=widgets.Layout(width='380px'))
ui = widgets.VBox([widgets.HBox([state_dropdown, station_dropdown]), status_label, widgets.HBox([rfc_stn_input, extra_years_input]), widgets.HBox([x_drop, y_drop]), widgets.HBox([window_sld, spacer]), widgets.HBox([m_sld, d_sld])])
widgets.interactive_output(update_dashboard, {'x_axis': x_drop, 'y_axis': y_drop, 'month': m_sld, 'day': d_sld, 'rfc_id': rfc_stn_input, 'extra_yrs': extra_years_input, 'window_size': window_sld})
display(ui, out)

VBox(children=(HBox(children=(Dropdown(description='State:', index=8, layout=Layout(width='380px'), options=('…

Output()