In [1]:
# CELL 1: Install
import sys, subprocess
pkgs = ["pandas", "numpy", "xarray", "netCDF4", "scipy", "plotly", "ipywidgets", "ipyfilechooser", "kaleido"]
subprocess.check_call([sys.executable, "-m", "pip", "install"] + pkgs + ["--quiet"])
print("Done!")

Done!


In [2]:
# CELL 2: Imports
import os, re, json, warnings
from pathlib import Path
from datetime import datetime
import numpy as np
import pandas as pd
import xarray as xr
from scipy import stats
from scipy.interpolate import RegularGridInterpolator
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import ipywidgets as widgets
from ipyfilechooser import FileChooser
from IPython.display import display, clear_output
warnings.filterwarnings('ignore')
print(f"OK - pandas {pd.__version__}")

OK - pandas 2.3.3


In [3]:
# CELL 3: Dashboard

def compute_horizontal_gradient(field, lats, lons):
    """
    Compute horizontal gradient magnitude in physical units (per km).
    
    Uses central differences with proper spherical Earth distance calculation.
    
    Parameters
    ----------
    field : 2D array (lat, lon)
    lats : 1D array of latitudes
    lons : 1D array of longitudes
    
    Returns
    -------
    grad_magnitude : 2D array, gradient magnitude in [field_units / km]
    """
    # Earth radius in km
    R_EARTH = 6371.0
    
    # Grid spacing in degrees
    dlat = np.abs(np.gradient(lats))  # degrees
    dlon = np.abs(np.gradient(lons))  # degrees
    
    # Convert to radians
    lat_rad = np.radians(lats)
    
    # Distance in km for each grid cell
    # dy = R * dlat (in radians)
    dy_km = R_EARTH * np.radians(dlat)  # km per grid point in y
    
    # dx = R * cos(lat) * dlon (in radians)
    # This varies with latitude, so we create a 2D array
    cos_lat = np.cos(lat_rad)
    dx_km_1d = R_EARTH * cos_lat * np.radians(np.mean(dlon))  # km per grid point in x
    
    # Expand to 2D for broadcasting
    dy_km_2d = dy_km[:, np.newaxis] * np.ones((1, len(lons)))
    dx_km_2d = dx_km_1d[:, np.newaxis] * np.ones((1, len(lons)))
    
    # Compute gradients using central differences (numpy.gradient)
    # gradient returns derivative in units of [field_units / grid_points]
    dfdy = np.gradient(field, axis=0)  # d/dlat
    dfdx = np.gradient(field, axis=1)  # d/dlon
    
    # Convert to physical units [field_units / km]
    dfdy_km = dfdy / dy_km_2d
    dfdx_km = dfdx / dx_km_2d
    
    # Gradient magnitude
    grad_mag = np.sqrt(dfdx_km**2 + dfdy_km**2)
    
    return grad_mag


class NetCDFScanner:
    def __init__(self, data_dir='./data'):
        self.data_dir = Path(data_dir)
        self.folders = {}
    
    def scan(self):
        self.folders = {}
        if not self.data_dir.exists():
            return {}
        
        for subdir in sorted(self.data_dir.iterdir()):
            if not subdir.is_dir():
                continue
            nc_files = list(subdir.glob('*.nc'))
            if not nc_files:
                continue
            
            variables, has_depth, depth_levels = [], False, []
            lon_convention = '180'
            has_time = False
            
            try:
                with xr.open_dataset(nc_files[0]) as ds:
                    dims = set(ds.dims.keys())
                    variables = [v for v in ds.data_vars if v not in dims]
                    
                    for dim in ds.dims:
                        if any(x in dim.lower() for x in ['depth', 'lev']):
                            has_depth = True
                            depth_levels = ds[dim].values.tolist()
                            break
                    
                    has_time = 'time' in ds.dims
                    
                    for coord in ['lon', 'longitude']:
                        if coord in ds.coords:
                            if ds[coord].values.max() > 180:
                                lon_convention = '360'
                            break
            except:
                continue
            
            dates = {}
            for f in nc_files:
                m = re.search(r'(\d{8})', f.name)
                if m:
                    try:
                        dates[datetime.strptime(m.group(1), '%Y%m%d').date()] = f
                    except: pass
            
            self.folders[subdir.name] = {
                'path': subdir, 'files': nc_files, 'variables': variables,
                'dates': dates, 'has_depth': has_depth, 'depth_levels': depth_levels,
                'lon_convention': lon_convention, 'has_time': has_time
            }
        return self.folders
    
    def get_file(self, folder, target_date):
        if folder not in self.folders:
            return None
        dates = self.folders[folder]['dates']
        if not dates:
            files = self.folders[folder]['files']
            return files[0] if files else None
        if target_date in dates:
            return dates[target_date]
        nearest = min(dates.keys(), key=lambda d: abs((d - target_date).days))
        if abs((nearest - target_date).days) <= 7:
            return dates[nearest]
        return None


def interpolate_nc(ds, var, lat, lon, depth_mode='surface', lon_convention='180', compute_grad=False):
    """
    Interpolate NetCDF variable to lat/lon point.
    Optionally compute horizontal gradient magnitude first.
    """
    try:
        data = ds[var]
        lat_dim = lon_dim = depth_dim = None
        
        for dim in data.dims:
            dl = dim.lower()
            if 'lat' in dl: lat_dim = dim
            elif 'lon' in dl: lon_dim = dim
            elif any(x in dl for x in ['depth', 'lev']): depth_dim = dim
        
        if not lat_dim or not lon_dim:
            return np.nan
        
        # Handle time
        if 'time' in data.dims:
            data = data.isel(time=0)
        
        # Handle depth
        if depth_dim:
            if depth_mode == 'surface':
                data = data.isel({depth_dim: 0})
            elif depth_mode == 'bottom':
                data = data.isel({depth_dim: -1})
            elif depth_mode == 'column_mean':
                data = data.mean(dim=depth_dim, skipna=True)
            else:
                data = data.isel({depth_dim: 0})
        
        lats = ds[lat_dim].values
        lons = ds[lon_dim].values
        vals = data.values
        
        # Handle fill values
        vals = np.where(np.isfinite(vals) & (np.abs(vals) < 1e30), vals, np.nan)
        
        # Compute gradient if requested
        if compute_grad:
            # Need to handle NaN for gradient calculation
            # Fill NaN temporarily with interpolation for smoother gradients
            from scipy.ndimage import generic_filter
            
            # Simple approach: compute gradient, NaN will propagate to edges
            vals = compute_horizontal_gradient(vals, lats, lons)
        
        # Convert query longitude if needed
        query_lon = lon
        if lon_convention == '360' and lon < 0:
            query_lon = lon + 360
        
        # Check bounds
        if lat < lats.min() or lat > lats.max():
            return np.nan
        if query_lon < lons.min() or query_lon > lons.max():
            return np.nan
        
        # Ensure monotonic
        if lats[0] > lats[-1]:
            lats = lats[::-1]
            vals = vals[::-1, :]
        if lons[0] > lons[-1]:
            lons = lons[::-1]
            vals = vals[:, ::-1]
        
        interp = RegularGridInterpolator((lats, lons), vals, method='linear',
                                         bounds_error=False, fill_value=np.nan)
        return interp([[lat, query_lon]])[0]
    except:
        return np.nan


class CorrelationDashboard:
    def __init__(self, data_dir='./data'):
        self.data_dir = Path(data_dir)
        self.scanner = NetCDFScanner(data_dir)
        self.csv_df = None
        self.merged_df = None
        self.corr_details = None
        
        self.csv_selected = []
        self.nc_selected = []
        
        self._build_ui()
    
    def _build_ui(self):
        style = {'description_width': 'initial'}
        
        # === TAB 1: CSV ===
        self.csv_chooser = FileChooser(str(self.data_dir), filter_pattern='*.csv')
        self.csv_chooser.register_callback(self._load_csv)
        
        self.lat_dd = widgets.Dropdown(description='Latitude:', style=style, layout=widgets.Layout(width='300px'))
        self.lon_dd = widgets.Dropdown(description='Longitude:', style=style, layout=widgets.Layout(width='300px'))
        self.date_dd = widgets.Dropdown(description='Date:', style=style, layout=widgets.Layout(width='300px'))
        
        self.csv_var_dd = widgets.Dropdown(description='Variable:', options=[], style=style, layout=widgets.Layout(width='250px'))
        self.csv_log_cb = widgets.Checkbox(description='log', value=False, indent=False)
        self.csv_add_btn = widgets.Button(description='Add', button_style='success', layout=widgets.Layout(width='60px'))
        self.csv_add_btn.on_click(self._add_csv_var)
        
        self.csv_selected_box = widgets.VBox([], layout=widgets.Layout(border='1px solid #ccc', padding='5px', min_height='100px'))
        self.csv_clear_btn = widgets.Button(description='Clear All', button_style='danger', layout=widgets.Layout(width='100px'))
        self.csv_clear_btn.on_click(lambda b: self._clear_csv_vars())
        
        # === TAB 2: FILTERS ===
        self.species_select = widgets.SelectMultiple(description='Species:', options=[],
                                                      layout=widgets.Layout(height='200px', width='450px'), style=style)
        self.date_start = widgets.DatePicker(description='Start:', style=style)
        self.date_end = widgets.DatePicker(description='End:', style=style)
        self.filter_info = widgets.HTML('')
        self.check_filter_btn = widgets.Button(description='Check Filter', button_style='info')
        self.check_filter_btn.on_click(self._check_filter)
        
        # === TAB 3: NetCDF ===
        self.scan_btn = widgets.Button(description='Scan Folders', button_style='info')
        self.scan_btn.on_click(self._scan_nc)
        
        self.nc_folder_dd = widgets.Dropdown(description='Folder:', options=[], style=style, layout=widgets.Layout(width='180px'))
        self.nc_folder_dd.observe(self._on_folder_change, 'value')
        self.nc_var_dd = widgets.Dropdown(description='Variable:', options=[], style=style, layout=widgets.Layout(width='180px'))
        self.nc_depth_dd = widgets.Dropdown(description='Depth:', options=['surface', 'bottom', 'column_mean'],
                                            value='surface', style=style, layout=widgets.Layout(width='150px'))
        self.nc_log_cb = widgets.Checkbox(description='log', value=False, indent=False, layout=widgets.Layout(width='50px'))
        self.nc_grad_cb = widgets.Checkbox(description='∇ grad', value=False, indent=False, layout=widgets.Layout(width='70px'))
        self.nc_add_btn = widgets.Button(description='Add', button_style='success', layout=widgets.Layout(width='50px'))
        self.nc_add_btn.on_click(self._add_nc_var)
        
        self.nc_selected_box = widgets.VBox([], layout=widgets.Layout(border='1px solid #ccc', padding='5px', min_height='100px'))
        self.nc_clear_btn = widgets.Button(description='Clear All', button_style='danger', layout=widgets.Layout(width='100px'))
        self.nc_clear_btn.on_click(lambda b: self._clear_nc_vars())
        self.nc_info = widgets.HTML('')
        
        # === TAB 4: Extract ===
        self.process_btn = widgets.Button(description='EXTRACT DATA', button_style='success',
                                          layout=widgets.Layout(width='100%', height='45px'))
        self.process_btn.on_click(self._extract)
        self.progress = widgets.IntProgress(min=0, max=100, description='Progress:')
        self.status = widgets.HTML('')
        self.save_name = widgets.Text(value='correlation_data.csv', description='Output:', style=style)
        self.save_btn = widgets.Button(description='Save CSV', button_style='warning')
        self.save_btn.on_click(self._save)
        self.extract_summary = widgets.HTML('')
        
        # === TAB 5: Plot ===
        self.plot_btn = widgets.Button(description='PLOT CORRELATION MATRIX', button_style='primary',
                                       layout=widgets.Layout(width='100%', height='45px'))
        self.plot_btn.on_click(self._plot)
        
        self.save_plot_btn = widgets.Button(description='Save PNG + Metrics', button_style='warning')
        self.save_plot_btn.on_click(self._save_plot_and_metrics)
        self.plot_filename = widgets.Text(value='correlation_matrix', description='Filename:', style=style)
        
        self.plot_out = widgets.Output()
        
        self.log = widgets.Output(layout=widgets.Layout(border='1px solid #ddd', max_height='150px', overflow='auto'))
        
        # === Tabs ===
        tab1 = widgets.VBox([
            widgets.HTML('<h3>1. Load CSV & Select Variables</h3>'),
            self.csv_chooser,
            widgets.HTML('<b>Coordinate columns:</b>'),
            self.lat_dd, self.lon_dd, self.date_dd,
            widgets.HTML('<hr><b>Add CSV variables:</b>'),
            widgets.HBox([self.csv_var_dd, self.csv_log_cb, self.csv_add_btn]),
            widgets.HTML('<b>Selected:</b>'),
            self.csv_selected_box,
            self.csv_clear_btn
        ])
        
        tab2 = widgets.VBox([
            widgets.HTML('<h3>2. Filter Data (optional)</h3>'),
            widgets.HTML('<b>Species (Ctrl+click):</b>'),
            self.species_select,
            widgets.HTML('<b>Date range:</b>'),
            widgets.HBox([self.date_start, self.date_end]),
            self.check_filter_btn,
            self.filter_info
        ])
        
        tab3 = widgets.VBox([
            widgets.HTML('<h3>3. NetCDF Variables</h3>'),
            self.scan_btn,
            self.nc_info,
            widgets.HTML('<hr><b>Add variable:</b>'),
            widgets.HTML('<i>Options: log=log10(var), ∇grad=horizontal gradient magnitude [units/km]</i>'),
            widgets.HBox([self.nc_folder_dd, self.nc_var_dd, self.nc_depth_dd]),
            widgets.HBox([self.nc_log_cb, self.nc_grad_cb, self.nc_add_btn]),
            widgets.HTML('<b>Selected:</b>'),
            self.nc_selected_box,
            self.nc_clear_btn
        ])
        
        tab4 = widgets.VBox([
            widgets.HTML('<h3>4. Extract & Merge</h3>'),
            widgets.HTML('<b>Summary:</b>'),
            self.extract_summary,
            self.process_btn,
            self.progress,
            self.status,
            widgets.HTML('<hr>'),
            widgets.HBox([self.save_name, self.save_btn])
        ])
        
        tab5 = widgets.VBox([
            widgets.HTML('<h3>5. Correlation Matrix</h3>'),
            self.plot_btn,
            widgets.HBox([self.plot_filename, self.save_plot_btn]),
            self.plot_out
        ])
        
        self.tabs = widgets.Tab([tab1, tab2, tab3, tab4, tab5])
        for i, t in enumerate(['1.CSV', '2.Filter', '3.NetCDF', '4.Extract', '5.Plot']):
            self.tabs.set_title(i, t)
        
        self.tabs.observe(self._update_summary, 'selected_index')
        
        self.ui = widgets.VBox([
            widgets.HTML('<h2>Correlation Dashboard v7</h2>'),
            self.tabs,
            widgets.HTML('<b>Log:</b>'),
            self.log
        ])
    
    def _log(self, msg):
        with self.log:
            print(msg)
    
    def _load_csv(self, chooser):
        if not chooser.selected:
            return
        try:
            self.csv_df = pd.read_csv(chooser.selected)
            cols = list(self.csv_df.columns)
            
            self.lat_dd.options = cols
            self.lon_dd.options = cols
            self.date_dd.options = cols
            
            for c in cols:
                cl = c.lower()
                if 'lat' in cl and 'dec' in cl: self.lat_dd.value = c
                elif 'lon' in cl and 'dec' in cl: self.lon_dd.value = c
                elif 'gmt' in cl and 'date' in cl and '1' not in cl: self.date_dd.value = c
            
            numeric = sorted(self.csv_df.select_dtypes(include=[np.number]).columns.tolist())
            self.csv_var_dd.options = numeric
            
            if 'SCIENTIFIC_NAME' in cols:
                species = sorted(self.csv_df['SCIENTIFIC_NAME'].dropna().unique().tolist())
                self.species_select.options = species
            
            self._log(f"Loaded: {len(self.csv_df)} rows, {len(numeric)} numeric vars")
        except Exception as e:
            self._log(f"Error: {e}")
    
    def _add_csv_var(self, b):
        var = self.csv_var_dd.value
        use_log = self.csv_log_cb.value
        if not var:
            return
        name = f"{var}_log" if use_log else var
        if any(n == name for n, _ in self.csv_selected):
            self._log(f"{name} already added")
            return
        self.csv_selected.append((name, (var, use_log)))
        self._refresh_csv_selected()
        self._log(f"Added: {name}")
    
    def _remove_csv_var(self, name):
        self.csv_selected = [(n, v) for n, v in self.csv_selected if n != name]
        self._refresh_csv_selected()
    
    def _clear_csv_vars(self):
        self.csv_selected = []
        self._refresh_csv_selected()
    
    def _refresh_csv_selected(self):
        items = []
        for name, _ in self.csv_selected:
            btn = widgets.Button(description='X', button_style='danger', layout=widgets.Layout(width='30px', height='25px'))
            btn.on_click(lambda b, n=name: self._remove_csv_var(n))
            items.append(widgets.HBox([btn, widgets.HTML(f"<code>{name}</code>")]))
        if not items:
            items = [widgets.HTML('<i style="color:gray">No variables</i>')]
        self.csv_selected_box.children = items
    
    def _check_filter(self, b):
        df = self._get_filtered_df()
        if df is not None:
            self.filter_info.value = f"<b style='color:green'>Filtered: {len(df)} / {len(self.csv_df)} rows</b>"
    
    def _get_filtered_df(self):
        if self.csv_df is None:
            return None
        df = self.csv_df.copy()
        if self.species_select.value and 'SCIENTIFIC_NAME' in df.columns:
            df = df[df['SCIENTIFIC_NAME'].isin(self.species_select.value)]
        if self.date_dd.value:
            dates = pd.to_datetime(df[self.date_dd.value], errors='coerce')
            if self.date_start.value:
                df = df[dates >= pd.to_datetime(self.date_start.value)]
                dates = pd.to_datetime(df[self.date_dd.value], errors='coerce')
            if self.date_end.value:
                df = df[dates <= pd.to_datetime(self.date_end.value)]
        return df
    
    def _scan_nc(self, b):
        self._log("Scanning...")
        folders = self.scanner.scan()
        if not folders:
            self._log("No folders!")
            return
        self.nc_folder_dd.options = list(folders.keys())
        
        info_lines = []
        for fname, info in folders.items():
            flags = []
            if info['has_depth']:
                flags.append(f"{len(info['depth_levels'])} depths")
            if info['lon_convention'] == '360':
                flags.append("lon 0-360")
            if not info['has_time']:
                flags.append("static")
            flag_str = f" <i>({', '.join(flags)})</i>" if flags else ""
            info_lines.append(f"<li><b>{fname}</b>: {len(info['dates'])} dates, vars: {', '.join(info['variables'])}{flag_str}</li>")
        self.nc_info.value = f"<ul>{''.join(info_lines)}</ul>"
        self._log(f"Found {len(folders)} folders")
    
    def _on_folder_change(self, change):
        folder = change['new']
        if folder and folder in self.scanner.folders:
            info = self.scanner.folders[folder]
            self.nc_var_dd.options = info['variables']
            self.nc_depth_dd.layout.display = 'block' if info['has_depth'] else 'none'
    
    def _add_nc_var(self, b):
        folder = self.nc_folder_dd.value
        var = self.nc_var_dd.value
        if not folder or not var:
            return
        info = self.scanner.folders.get(folder, {})
        depth_mode = self.nc_depth_dd.value if info.get('has_depth') else 'surface'
        use_log = self.nc_log_cb.value
        use_grad = self.nc_grad_cb.value
        
        # Build name
        name = f"{folder}_{var}"
        if info.get('has_depth'):
            name += f"_{depth_mode}"
        if use_grad:
            name += "_grad"
        if use_log:
            name += "_log"
        
        if any(n == name for n, _ in self.nc_selected):
            self._log(f"{name} already added")
            return
        
        # Store: (folder, var, depth_mode, use_log, use_grad)
        self.nc_selected.append((name, (folder, var, depth_mode, use_log, use_grad)))
        self._refresh_nc_selected()
        self._log(f"Added: {name}")
    
    def _remove_nc_var(self, name):
        self.nc_selected = [(n, v) for n, v in self.nc_selected if n != name]
        self._refresh_nc_selected()
    
    def _clear_nc_vars(self):
        self.nc_selected = []
        self._refresh_nc_selected()
    
    def _refresh_nc_selected(self):
        items = []
        for name, _ in self.nc_selected:
            btn = widgets.Button(description='X', button_style='danger', layout=widgets.Layout(width='30px', height='25px'))
            btn.on_click(lambda b, n=name: self._remove_nc_var(n))
            items.append(widgets.HBox([btn, widgets.HTML(f"<code>{name}</code>")]))
        if not items:
            items = [widgets.HTML('<i style="color:gray">No variables</i>')]
        self.nc_selected_box.children = items
    
    def _update_summary(self, change):
        if change['new'] == 3:
            csv_vars = [n for n, _ in self.csv_selected]
            nc_vars = [n for n, _ in self.nc_selected]
            html = f"<ul><li>CSV ({len(csv_vars)}): {', '.join(csv_vars) or 'none'}</li>"
            html += f"<li>NetCDF ({len(nc_vars)}): {', '.join(nc_vars) or 'none'}</li></ul>"
            self.extract_summary.value = html
    
    def _extract(self, b):
        if self.csv_df is None:
            self._log("Load CSV!")
            return
        
        lat_col, lon_col, date_col = self.lat_dd.value, self.lon_dd.value, self.date_dd.value
        if not all([lat_col, lon_col, date_col]):
            self._log("Select columns!")
            return
        if not self.csv_selected and not self.nc_selected:
            self._log("Select variables!")
            return
        
        self._log(f"Extracting {len(self.csv_selected)} CSV + {len(self.nc_selected)} NC...")
        
        source_df = self._get_filtered_df()
        df = source_df[[lat_col, lon_col, date_col]].copy()
        df['_date'] = pd.to_datetime(df[date_col], errors='coerce').dt.date
        valid = df['_date'].notna() & df[lat_col].notna() & df[lon_col].notna()
        df = df[valid].reset_index(drop=True)
        source_df = source_df[valid].reset_index(drop=True)
        
        # CSV vars
        for name, (var, use_log) in self.csv_selected:
            vals = pd.to_numeric(source_df[var], errors='coerce').values
            if use_log:
                vals = np.log10(np.where(vals > 0, vals, np.nan))
            df[name] = vals
            self._log(f"  {name}: {np.sum(np.isfinite(vals))}/{len(vals)}")
        
        # NC vars
        n_rows = len(df)
        self.progress.max = len(self.nc_selected) if self.nc_selected else 1
        self.progress.value = 0
        
        for vi, (name, params) in enumerate(self.nc_selected):
            folder, var, depth_mode, use_log, use_grad = params
            
            self.status.value = f"Extracting {name}..."
            values = np.full(n_rows, np.nan)
            folder_info = self.scanner.folders.get(folder, {})
            lon_conv = folder_info.get('lon_convention', '180')
            
            for dt, group in df.groupby('_date'):
                nc_file = self.scanner.get_file(folder, dt)
                if nc_file is None:
                    continue
                try:
                    with xr.open_dataset(nc_file) as ds:
                        for i in group.index:
                            lat = df.loc[i, lat_col]
                            lon = df.loc[i, lon_col]
                            val = interpolate_nc(ds, var, lat, lon, depth_mode, lon_conv, compute_grad=use_grad)
                            if use_log:
                                val = np.log10(val) if val > 0 else np.nan
                            values[i] = val
                except Exception as e:
                    self._log(f"    Error {nc_file.name}: {e}")
            
            df[name] = values
            self.progress.value = vi + 1
            valid_n = np.sum(np.isfinite(values))
            self._log(f"  {name}: {valid_n}/{n_rows}")
        
        df = df.drop(columns=['_date'])
        self.merged_df = df
        self.status.value = f"<b style='color:green'>Done! {n_rows} rows</b>"
    
    def _save(self, b):
        if self.merged_df is None:
            self._log("No data!")
            return
        self.merged_df.to_csv(self.save_name.value, index=False)
        self._log(f"Saved: {self.save_name.value}")
    
    def _plot(self, b):
        if self.merged_df is None:
            self._log("Extract first!")
            return
        
        with self.plot_out:
            clear_output()
            
            lat_col, lon_col, date_col = self.lat_dd.value, self.lon_dd.value, self.date_dd.value
            exclude = {lat_col, lon_col, date_col}
            num_cols = [c for c in self.merged_df.columns
                       if c not in exclude and self.merged_df[c].dtype in [np.float64, np.int64, np.float32]]
            
            if len(num_cols) < 2:
                print("Need >= 2 variables!")
                return
            
            n = len(num_cols)
            df = self.merged_df
            self._log(f"Plotting {n}x{n}...")
            
            self.corr_details = []
            
            fig = make_subplots(rows=n, cols=n, horizontal_spacing=0.02, vertical_spacing=0.02)
            
            for i, vy in enumerate(num_cols):
                for j, vx in enumerate(num_cols):
                    row, col = i + 1, j + 1
                    idx = (row - 1) * n + col
                    xref = 'x domain' if idx == 1 else f'x{idx} domain'
                    yref = 'y domain' if idx == 1 else f'y{idx} domain'
                    
                    if i == j:
                        vals = df[vx].dropna()
                        fig.add_trace(go.Histogram(x=vals, nbinsx=20, marker_color='steelblue',
                                                   showlegend=False), row=row, col=col)
                        fig.add_annotation(x=0.95, y=0.95, xref=xref, yref=yref,
                                          text=f'n={len(vals)}', showarrow=False,
                                          font=dict(size=8), bgcolor='white')
                    
                    elif i > j:
                        pair = df[[vx, vy]].dropna()
                        n_valid = len(pair)
                        
                        if n_valid > 2:
                            x, y = pair[vx].values, pair[vy].values
                            r, p = stats.pearsonr(x, y)
                            slope, intercept, _, _, stderr = stats.linregress(x, y)
                            
                            self.corr_details.append({
                                'var_x': vx, 'var_y': vy, 'n': n_valid,
                                'r': r, 'r2': r**2, 'p_value': p,
                                'slope': slope, 'intercept': intercept, 'stderr': stderr
                            })
                            
                            if n_valid > 5000:
                                idx_s = np.random.choice(n_valid, 5000, replace=False)
                                x_plot, y_plot = x[idx_s], y[idx_s]
                            else:
                                x_plot, y_plot = x, y
                            
                            fig.add_trace(go.Scatter(x=x_plot, y=y_plot, mode='markers',
                                                     marker=dict(size=2, color='steelblue', opacity=0.3),
                                                     showlegend=False), row=row, col=col)
                            xl = np.array([x.min(), x.max()])
                            fig.add_trace(go.Scatter(x=xl, y=slope*xl+intercept, mode='lines',
                                                    line=dict(color='red', width=1.5),
                                                    showlegend=False), row=row, col=col)
                            
                            rcolor = 'red' if abs(r) > 0.5 else ('darkorange' if abs(r) > 0.3 else 'black')
                            fig.add_annotation(x=0.95, y=0.95, xref=xref, yref=yref,
                                             text=f'r={r:.2f}<br>n={n_valid}', showarrow=False,
                                             font=dict(size=7, color=rcolor), bgcolor='white', align='right')
                        else:
                            fig.add_annotation(x=0.5, y=0.5, xref=xref, yref=yref,
                                             text=f'n={n_valid}', showarrow=False,
                                             font=dict(size=9, color='gray'))
                    
                    else:
                        pair = df[[vx, vy]].dropna()
                        n_valid = len(pair)
                        
                        if n_valid > 2:
                            r, _ = stats.pearsonr(pair[vx], pair[vy])
                            color = 'blue' if r >= 0 else 'red'
                            size = 10 + int(abs(r) * 6)
                            fig.add_trace(go.Scatter(x=[0.5], y=[0.6], mode='text',
                                                    text=[f'{r:.2f}'],
                                                    textfont=dict(size=size, color=color),
                                                    showlegend=False), row=row, col=col)
                            fig.add_annotation(x=0.5, y=0.25, xref=xref, yref=yref,
                                             text=f'n={n_valid}', showarrow=False,
                                             font=dict(size=7, color='gray'))
                        else:
                            fig.add_trace(go.Scatter(x=[0.5], y=[0.5], mode='text',
                                                    text=['--'], textfont=dict(size=10, color='gray'),
                                                    showlegend=False), row=row, col=col)
                        
                        fig.update_xaxes(showticklabels=False, showgrid=False, row=row, col=col)
                        fig.update_yaxes(showticklabels=False, showgrid=False, row=row, col=col)
            
            for i, v in enumerate(num_cols):
                short = v[:15] + '..' if len(v) > 15 else v
                fig.update_xaxes(title_text=short, row=n, col=i+1, title_font_size=7, tickfont_size=6)
                fig.update_yaxes(title_text=short, row=i+1, col=1, title_font_size=7, tickfont_size=6)
            
            fig.update_layout(height=120*n, width=120*n,
                             title=dict(text=f'Correlation Matrix ({len(df)} rows)', font_size=14),
                             showlegend=False, margin=dict(l=60, r=20, t=40, b=60))
            
            self.current_fig = fig
            fig.show()
    
    def _save_plot_and_metrics(self, b):
        if not hasattr(self, 'current_fig') or self.current_fig is None:
            self._log("Plot first!")
            return
        
        base = self.plot_filename.value
        
        try:
            self.current_fig.write_image(f"{base}.png", scale=2)
            self._log(f"Saved: {base}.png")
        except Exception as e:
            self._log(f"PNG error: {e}")
        
        if self.corr_details:
            metrics_df = pd.DataFrame(self.corr_details)
            metrics_df.to_csv(f"{base}_metrics.csv", index=False)
            self._log(f"Saved: {base}_metrics.csv")
        
        if self.merged_df is not None:
            self.merged_df.to_csv(f"{base}_data.csv", index=False)
            self._log(f"Saved: {base}_data.csv")
    
    def show(self):
        display(self.ui)


# === RUN ===
print("="*50)
print("Correlation Dashboard v7")
print("- NEW: Horizontal gradient option (∇)")
print("  Computes |∇f| in [units/km] using central diff.")
print("="*50)

data_dir = './data'
for p in ['./data', '../data', '../../data']:
    if os.path.exists(p):
        data_dir = p
        break

app = CorrelationDashboard(data_dir)
app.show()

Correlation Dashboard v7
- NEW: Horizontal gradient option (∇)
  Computes |∇f| in [units/km] using central diff.


VBox(children=(HTML(value='<h2>Correlation Dashboard v7</h2>'), Tab(children=(VBox(children=(HTML(value='<h3>1…