In [2]:
# --- imports ---
import numpy as np
import xarray as xr
import panel as pn
import hvplot.xarray  # activates .hvplot on xarray objects
import holoviews as hv

pn.extension('bokeh')
hv.extension('bokeh')



In [8]:
test_file = '/p/project1/training2533/zampieri2/inference/test_inference/aifs_forced_2012-01-02_1y_1deg.nc' 
ds = xr.open_dataset(test_file)#.isel(time=slice(0,20))


In [None]:
# --- helpers ---
def _vars_3d(ds):
    v = [v for v in ds.data_vars if {'time','lat','lon'}.issubset(ds[v].dims)]
    return v or list(ds.data_vars)

def _has_datashader():
    try:
        import datashader  # noqa: F401
        return True
    except Exception:
        return False

def compute_firstN_clims(ds, var_list=None, n=10):
    """Return {var: (vmin, vmax)} from the first n timesteps (ignores NaNs)."""
    if var_list is None:
        var_list = _vars_3d(ds)
    n = min(n, int(ds.sizes.get('time', ds['time'].size)))
    out = {}
    for v in var_list:
        da = ds[v].isel(time=slice(0, n))
        vmin = float(da.min(skipna=True).values) if da.size else 0.0
        vmax = float(da.max(skipna=True).values) if da.size else 1.0
        if not np.isfinite(vmin): vmin = 0.0
        if not np.isfinite(vmax): vmax = 1.0
        out[v] = (vmin, vmax)
    return out

# --- defaults from first 10 steps ---
default_clims = compute_firstN_clims(ds)

# --- widgets ---
var_w     = pn.widgets.Select(name="Variable", options=_vars_3d(ds), value=_vars_3d(ds)[0])
show_cb   = pn.widgets.Checkbox(name="Show colorbar", value=True)

# Scaling controls
free_w    = pn.widgets.Checkbox(name="Auto color scaling per frame", value=True)   # framewise
robust_w  = pn.widgets.Checkbox(name="Robust auto limits (2–98%)", value=True)
sym_w     = pn.widgets.Checkbox(name="Symmetric around zero", value=False)

# Manual limits (used when free_w=False; custom can override defaults)
use_clim  = pn.widgets.Checkbox(name="Use custom color limits", value=False)
vmin_w    = pn.widgets.FloatInput(name="vmin", value=None, step=0.1)
vmax_w    = pn.widgets.FloatInput(name="vmax", value=None, step=0.1)

cmap_w    = pn.widgets.Select(
    name="Colormap",
    options=["viridis","plasma","inferno","magma","cividis","turbo","RdBu_r","coolwarm","Spectral_r"],
    value="viridis",
)

fw_w      = pn.widgets.IntInput(name="Frame width",  value=900, step=10)
fh_w      = pn.widgets.IntInput(name="Frame height", value=600, step=10)

# Sync vmin/vmax inputs to per-variable defaults when not using custom limits
def _sync_defaults(*_):
    v = var_w.value
    if (not use_clim.value) and v in default_clims:
        vmin_w.value, vmax_w.value = map(float, default_clims[v])

var_w.param.watch(_sync_defaults, 'value')
use_clim.param.watch(_sync_defaults, 'value')
_sync_defaults()

# --- click-to-timeseries state ---
_selected_pts = []         # list of dicts: {'lon': float, 'lat': float, 'label': str}
_last_tap = {'x': None, 'y': None}
_MAX_PTS = 10

# --- view: map (with built-in time slider) + time series below ---
@pn.depends(var_w, show_cb, free_w, robust_w, sym_w, use_clim, vmin_w, vmax_w, cmap_w, fw_w, fh_w)
def view(var, show_cb, free_scale, robust, sym_zero, use_custom, vmin, vmax, cmap, fw, fh):
    da = ds[var]

    title = f"{var}"
    if da.attrs.get("long_name"): title += f" — {da.attrs['long_name']}"
    if da.attrs.get("units"):     title += f" [{da.attrs['units']}]"

    # Map kwargs (unchanged base behavior)
    kwargs = dict(
        x="lon", y="lat",
        groupby="time",
        rasterize=_has_datashader(),
        colorbar=bool(show_cb),
        cmap=cmap,
        frame_width=int(fw),
        frame_height=int(fh),
        title=title,
        tools=['tap','hover'],   # enable clicking
    )

    # Color scaling logic
    if free_scale:
        kwargs["framewise"] = True
        kwargs["robust"] = bool(robust)
    else:
        kwargs["framewise"] = False
        if use_custom and (vmin is not None) and (vmax is not None):
            base_min, base_max = float(vmin), float(vmax)
        else:
            base_min, base_max = default_clims.get(var, (0.0, 1.0))
        if sym_zero:
            m = max(abs(base_min), abs(base_max))
            kwargs["clim"] = (-m, m)
        else:
            kwargs["clim"] = (base_min, base_max)

    # Build hvplot (HoloMap/DynamicMap with time)
    hm = da.hvplot.quadmesh(**kwargs)

    # Tap stream to capture clicks (lon/lat)
    tap = hv.streams.Tap(source=hm, x=None, y=None)

    # Time-series panel: accumulates up to 10 points
    @pn.depends(tap.param.x, tap.param.y, var_w)
    def ts_view(x, y, current_var):
        # append a new point only on a *new* tap event
        if x is not None and y is not None:
            if _last_tap['x'] != x or _last_tap['y'] != y:
                _last_tap['x'], _last_tap['y'] = x, y
                # snap to nearest gridpoint to ensure exact selection
                lon0 = float(ds['lon'].sel(lon=x, method='nearest').values)
                lat0 = float(ds['lat'].sel(lat=y, method='nearest').values)
                lbl = f"({lon0:.2f}, {lat0:.2f})"
                _selected_pts.append({'lon': lon0, 'lat': lat0, 'label': lbl})
                if len(_selected_pts) > _MAX_PTS:
                    _selected_pts.pop(0)

        # overlay time series for selected points
        overlay = None
        for pt in _selected_pts:
            series = ds[current_var].sel(lon=pt['lon'], lat=pt['lat'], method='nearest')
            if 'time' in series.dims and series.ndim == 1:
                line = series.hvplot.line(
                    label=pt['label'],
                    responsive=False,
                    frame_width=int(fw),
                    frame_height=250,
                    ylabel=f"{current_var} ({series.attrs.get('units','')})",
                    legend='right'
                )
                overlay = line if overlay is None else overlay * line

        if overlay is None:
            return pn.pane.Markdown(
                "⬆️ Click on the map to add up to **10** points and see their time series here.",
                height=260
            )
        return overlay

    # Map on top (with **DiscretePlayer under it**), time series below
    map_panel = pn.panel(hm, widgets={'time': pn.widgets.DiscretePlayer}, widget_location='bottom')
    return pn.Column(map_panel, ts_view)

# --- layout ---
controls = pn.WidgetBox(
    "### Variable",
    var_w,
    "### Colorbar & colormap",
    show_cb, cmap_w,
    "### Scaling",
    free_w, robust_w, sym_w,
    "### Custom fixed limits",
    use_clim, vmin_w, vmax_w,
    "### Sizing",
    fw_w, fh_w,
    width=300,
)



# How to use & customize the map + time-series viewer

This cell builds an interactive map (with a time slider + ▶ play) and a time-series plot underneath. You can change appearance, scaling, and behavior via the widgets and a few small code tweaks.

---

## Controls (left panel)

- **Variable** — choose any 3-D field with dims `(time, lat, lon)`.
- **Show colorbar** — toggle the colorbar.
- **Colormap** — pick from common colormaps (e.g. `viridis`, `RdBu_r` for anomalies).
- **Auto color scaling per frame** — when ON, each frame rescales colors (nice for evolving ranges).
  - **Robust auto limits (2–98%)** — ignore outliers in auto mode.
- **Symmetric around zero** — when using **fixed** scaling, clamp to `[-M, +M]` where `M = max(|min|, |max|)`.
- **Use custom color limits / vmin / vmax** — override defaults with fixed limits.
- **Frame width / height** — size of the map (time-series height is ~250px).

---

## Color limits logic

- **Defaults per variable** come from the **first 10 time steps** (precomputed).
- With **Auto color scaling per frame = ON**:
  - Uses `framewise=True`; if **Robust** is ON, scales by the 2–98th percentiles per frame.
- With **Auto color scaling per frame = OFF**:
  - Uses **custom vmin/vmax** if provided, else the **first-10-steps defaults**.
  - If **Symmetric around zero** is ON, uses `[-M, +M]`.

🔧 Change how many steps define defaults: in `compute_firstN_clims(ds, n=10)`, set `n` to the window you want (e.g., `n=30`).

🔧 If you’ve precomputed limits elsewhere, set:
    default_clims = {"t2m": (-10.0, 25.0), "10u": (-15.0, 15.0), ...}

---

## Playback & time slider

- The slider under the map includes a **▶ play/pause** button.
- **Speed**: set a custom interval (milliseconds per frame). Replace the player wiring with:

    time_player = pn.widgets.DiscretePlayer(interval=400)  # slower playback
    pn.panel(hm, widgets={'time': time_player}, widget_location='bottom')

---

## Click to add time-series

- Click on the **map** to add a point; the nearest grid cell is used.
- The **time-series plot** below accumulates up to **10** lines (oldest drops off first).
- Hover on the map shows coordinates; the legend shows `(lon, lat)` for each series.

🔧 **Clear the selection** (without changing code): run in a small cell:

    _selected_pts.clear(); _last_tap.update(x=None, y=None)

Or re-execute the main cell below.

In [11]:
pn.Row(controls, view)