[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nepslor/B5203E-TSAF/blob/main/W3/forecastability_indicators.ipynb)

### Forecastability KPIs — definitions & interpretation
In this notebook we'll characterize some time series from the M4 weekly dataset, containing series with different level of forecastability. Additionally to the a-priori forecastability indexes we introduced during the lecture, we'll see some more. The complete list is the following:
| KPI | Formula | Meaning | More forecastable when… | Range |
|---|---|---|---|---|
| **Variance ratio** | $ \dfrac{\operatorname{Var}(\Delta y_t)}{\operatorname{Var}(y_t)} $, with $ \Delta y_t = y_t - y_{t-1} $ | Roughness vs overall scale (higher ⇒ rougher) | **Lower** | $[0,\infty)$ |
| **Spectral entropy (normalized)** | $ H_{\text{spec}} = -\dfrac{1}{\log K}\sum_{i=1}^{K} p_i \log p_i,\ \ p_i=\dfrac{P_i}{\sum_j P_j} $ | How concentrated the spectrum is (periodicity) | **Lower** | $[0,1]$ |
| **Spectral forecastability** $\,\Omega$ | $ \Omega = 1 - H_{\text{spec}} $ | Complement of spectral entropy | **Higher** | $[0,1]$ |
| **SVD entropy (normalized)** | $ H_{\text{SVD}} = -\dfrac{1}{\log m}\sum_{i=1}^{m} p_i \log p_i,\ \ p_i=\dfrac{\sigma_i}{\sum_j \sigma_j} $ | Rank-like complexity of delay embedding | **Lower** | $[0,1]$ |
| **Permutation entropy** ($m{=}5,\ \tau{=}1$) | $ H_{\text{perm}} = -\dfrac{1}{\log(m!)} \sum_{\pi} p(\pi)\log p(\pi) $ | Ordinal-pattern unpredictability (amp-invariant) | **Lower** | $[0,1]$ |
| **Sample entropy** ($m{=}2,\ r{=}\,0.2\,\sigma$) | $ \mathrm{SampEn}(m,r) = -\ln\!\left(\dfrac{A}{B}\right) $ | Similar patterns remain similar one step ahead | **Lower** | $[0,\infty)$ (often $0$–$3$) |
| **Approximate entropy** ($m{=}2,\ r{=}\,0.2\,\sigma$) | $ \operatorname{ApEn}(m,r) = \Phi_m(r) - \Phi_{m+1}(r) $ | Older entropy (counts self-matches); regularity | **Lower** | $[0,\infty)$ (often $0$–$3$) |
| **Lempel–Ziv complexity** (8-quantile symbols) | $ C_{\text{LZ}}^{\text{norm}} $ (normalized LZ76 parsing count) | Algorithmic randomness of discretized series | **Lower** | $[0,1]$ |
| **Permutation entropy (ordpy)** | same as above, via ordinal distribution | Entropy on the complexity–entropy plane | **Lower** | $[0,1]$ |
| **Statistical complexity (JS)** | $ C_{\text{JS}} $ (Jensen–Shannon–based complexity from ordinal $p$ vs uniform) | Structured-but-nonrandomness; peaks at intermediate order | **Non-monotone** (descriptive) | typically $[0,1]$ |
| **Autocorr. at lag 1** | $ \rho(1) $ | Linear dependence at 1 step | **Higher in magnitude** ($|\rho(1)|$) | $[-1,1]$ |
| **Autocorr. strength (up to 52)** | $ \sum_{k=1}^{52} |\rho(k)| $ | Total linear structure up to yearly (weekly data) | **Higher** | $[0,52]$ |
| **Periodicity index** | $ \dfrac{\max_{f>0} P(f)}{\sum_{f>0} P(f)} $ | Share of power at the dominant non-DC frequency | **Higher** | $(0,1]$ |

### Notes.
For composites, invert the “lower-is-better” metrics, use $|\rho(1)|$ for directionless strength, and consider z-scoring before aggregation. Many entropy/complexity measures prefer z-normalized inputs and enough length.


### Preprocessing tips
Z-normalize; handle missing values; for strongly seasonal data, compute metrics on both raw and seasonally adjusted series. For a monotone composite score, invert all “lower-is-better” metrics, use $|\mathrm{ACF}(1)|$, and optionally transform DFA $\alpha$ to $|\alpha-0.5|$ before z-scoring.


In [None]:

# Install dependencies
%pip -q install pandas numpy antropy ordpy nolds statsmodels scipy joblib tqdm requests umap-learn anywidget ipywidgets plotly


In [None]:

import os
import re
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
import requests

from tqdm import tqdm
from joblib import Parallel, delayed

from scipy import signal
from statsmodels.tsa.stattools import acf

import antropy as ant
import ordpy
import nolds

warnings.filterwarnings("ignore")
DATA_DIR = Path('data')
DATA_DIR.mkdir(exist_ok=True)
RESULTS_DIR = Path('results')
RESULTS_DIR.mkdir(exist_ok=True)

# Toggle for speed: set to True to skip heavier metrics if needed
FAST_MODE = False
N_JOBS = -1  # Use all cores
MAX_LAG_ACF = 52  # weekly: capture ~annual seasonality
PERM_ORDER = 5
SAMPEN_M = 2
APEN_M = 2
R_FRACTION = 0.2  # r = 0.2 * std
LZC_BINS = 8
RANDOM_SEED = 1337
np.random.seed(RANDOM_SEED)



## 1) Download M4 Weekly dataset

Primary source: **Zenodo 4656410** (Monash TSF Repository).  
We look for `M4_weekly_dataset.tsf`. If unavailable, we fall back to M4 GitHub `Weekly-train.csv`.


In [None]:

ZENODO_RECORD = "https://zenodo.org/api/records/4656410"
TARGET_TSF_NAME = "M4_weekly_dataset.tsf"
FALLBACK_CSV_URL = "https://raw.githubusercontent.com/Mcompetitions/M4-methods/master/Dataset/Train/Weekly-train.csv"

def download_zenodo_tsf(record_url: str, target_name: str, out_path: Path) -> bool:
    try:
        r = requests.get(record_url, timeout=30)
        r.raise_for_status()
        meta = r.json()
        files = meta.get("files", [])
        for f in files:
            if f.get("key","").endswith(target_name):
                file_url = f["links"]["self"]
                print(f"Downloading from Zenodo: {file_url}")
                fr = requests.get(file_url, timeout=60)
                fr.raise_for_status()
                out_path.write_bytes(fr.content)
                print(f"Saved TSF to {out_path}")
                return True
        print("TSF file not found in Zenodo record files.")
        return False
    except Exception as e:
        print(f"Zenodo download failed: {e}")
        return False

def download_fallback_csv(url: str, out_path: Path) -> bool:
    try:
        print(f"Downloading fallback CSV: {url}")
        r = requests.get(url, timeout=60)
        r.raise_for_status()
        out_path.write_bytes(r.content)
        print(f"Saved CSV to {out_path}")
        return True
    except Exception as e:
        print(f"Fallback CSV download failed: {e}")
        return False

tsf_path = DATA_DIR / TARGET_TSF_NAME
csv_path = DATA_DIR / "Weekly-train.csv"

HAVE_TSF = False
if not tsf_path.exists():
    HAVE_TSF = download_zenodo_tsf(ZENODO_RECORD, TARGET_TSF_NAME, tsf_path)
else:
    HAVE_TSF = True

if not HAVE_TSF and not csv_path.exists():
    _ = download_fallback_csv(FALLBACK_CSV_URL, csv_path)

print("Files present:", list(DATA_DIR.iterdir()))



## 2) Load dataset (TSF or CSV)

- **TSF** (`M4_weekly_dataset.tsf`) — Monash TSF format.
- **CSV** (`Weekly-train.csv`) — official M4 layout (each row is a series; first column = ID).


In [None]:

from typing import List, Tuple
import numpy as np
import pandas as pd

def parse_tsf(path: Path) -> Tuple[List[str], List[np.ndarray]]:
    ids, series = [], []
    with path.open('r', encoding='utf-8') as f:
        lines = f.readlines()
    # Find @data
    data_idx = None
    for i, line in enumerate(lines):
        if line.strip().lower() == '@data':
            data_idx = i + 1
            break
    if data_idx is None:
        raise ValueError("Invalid TSF: missing @data section")
    for line in lines[data_idx:]:
        line = line.strip()
        if not line or line.startswith("#"):
            continue
        parts = [p.strip() for p in re.split(r',(?=(?:[^"]*"[^"]*")*[^"]*$)', line)]
        kv = {}
        for p in parts:
            if ':' in p:
                k, v = p.split(':', 1)
                kv[k.strip()] = v.strip().strip('"')
        sid = kv.get('series_name') or kv.get('series_id') or kv.get('series')
        val = kv.get('series_value') or kv.get('series_values') or ""
        if sid is None or val == "":
            toks = line.split(',')
            if len(toks) >= 2:
                sid = sid or toks[0].split(':')[-1].strip()
                val = val or toks[-1].split(':')[-1].strip()
        val = val.replace('[','').replace(']','')
        if ' ' in val and ',' not in val:
            raw_vals = val.split()
        else:
            raw_vals = val.split(',')
        vec = []
        for s in raw_vals:
            s = s.strip()
            if s in {'?', 'NaN', 'nan', ''}:
                vec.append(np.nan)
            else:
                try:
                    vec.append(float(s))
                except:
                    pass
        arr = np.asarray(vec, dtype=float)
        if np.isnan(arr).any():
            arr = arr[~np.isnan(arr)]
        if arr.size == 0:
            continue
        ids.append(sid)
        series.append(arr)
    return ids, series

def parse_m4_weekly_csv(path: Path) -> Tuple[List[str], List[np.ndarray]]:
    df = pd.read_csv(path)
    id_col = df.columns[0]
    ids = df[id_col].astype(str).tolist()
    vals = df.drop(columns=[id_col]).to_numpy()
    out = []
    for row in vals:
        m = ~np.isnan(row)
        if m.any():
            out.append(row[m].astype(float))
        else:
            out.append(np.array([], dtype=float))
    ids, out = zip(*[(i, x) for i, x in zip(ids, out) if x.size > 0])
    return list(ids), list(out)

if (DATA_DIR / "M4_weekly_dataset.tsf").exists():
    print("Loading TSF...")
    series_ids, series_list = parse_tsf(DATA_DIR / "M4_weekly_dataset.tsf")
elif (DATA_DIR / "Weekly-train.csv").exists():
    print("Loading CSV...")
    series_ids, series_list = parse_m4_weekly_csv(DATA_DIR / "Weekly-train.csv")
else:
    raise FileNotFoundError("No dataset file available. Please ensure internet access and re-run the download cell.")

print(f"Loaded {len(series_list)} series.")
print("Example lengths (first 5):", [len(x) for x in series_list[:5]])


In [None]:
# plot some examples
import matplotlib.pyplot as plt
fig, ax = plt.subplots(3, 2, figsize=(12, 8))
for i in range(6):
    ax.ravel()[i].plot(series_list[i])
    ax.ravel()[i].set_title(series_ids[i])
plt.tight_layout()

In [None]:
# discard series with less than 200 observations
min_length = 260
filtered = [(sid, x) for sid, x in zip(series_ids, series_list) if len(x) >= min_length]
series_ids, series_list = zip(*filtered)


## 3) Metric functions


In [None]:

from typing import Dict, Union, Optional
from scipy import signal
from statsmodels.tsa.stattools import acf

def variance_ratio(x: np.ndarray) -> Optional[float]:
    x = np.asarray(x, float)
    if x.size < 3:
        return np.nan
    num = np.nanvar(np.diff(x))
    den = np.nanvar(x)
    return float(num / den) if den > 0 else np.nan

def spectral_measures(x: np.ndarray) -> Dict[str, float]:
    x = np.asarray(x, float)
    if x.size < 16:
        return dict(spectral_entropy=np.nan, spectral_forecastability=np.nan, periodicity_index=np.nan)
    try:
        Hs = ant.spectral_entropy(x, sf=1.0, method='welch', normalize=True)
    except Exception:
        Hs = np.nan
    Omega = np.nan if np.isnan(Hs) else 1.0 - Hs
    try:
        freqs, psd = signal.periodogram(x, detrend='linear', scaling='density')
        if psd.size > 1:
            peak = np.nanmax(psd[1:])
            total = np.nansum(psd[1:])
            per_idx = (peak / total) if total and total > 0 else np.nan
        else:
            per_idx = np.nan
    except Exception:
        per_idx = np.nan
    return dict(spectral_entropy=Hs, spectral_forecastability=Omega, periodicity_index=per_idx)

def svd_entropy_safe(x: np.ndarray, order: int = 3, delay: int = 1) -> float:
    if x.size < order * delay + 1:
        return np.nan
    try:
        return ant.svd_entropy(x, order=order, delay=delay, normalize=True)
    except Exception:
        return np.nan

def perm_entropy_safe(x: np.ndarray, order: int = 5, delay: int = 1) -> float:
    if x.size < order * delay + 1:
        return np.nan
    try:
        return ant.perm_entropy(x, order=order, delay=delay, normalize=True)
    except Exception:
        return np.nan

def sample_entropy_safe(x: np.ndarray, m: int = 2, r_fraction: float = 0.2) -> float:
    x = np.asarray(x, float)
    if x.size < 20:
        return np.nan
    s = np.nanstd(x)
    if s == 0:
        return np.nan
    for rf in (r_fraction, 0.25, 0.15, 0.3):
        try:
            val = ant.sample_entropy(x, order=m)
            if np.isfinite(val):
                return float(val)
        except Exception as e:
            pass
    return np.nan

def approximate_entropy_safe(x: np.ndarray, m: int = 2, r_fraction: float = 0.2) -> float:
    x = np.asarray(x, float)
    if x.size < 20:
        return np.nan
    s = np.nanstd(x)
    if s == 0:
        return np.nan
    for rf in (r_fraction, 0.25, 0.15, 0.3):
        try:
            val = ant.app_entropy(x, order=m)
            if np.isfinite(val):
                return float(val)
        except Exception:
            pass
    return np.nan

def quantize_series(x: np.ndarray, n_bins: int = 8) -> np.ndarray:
    x = np.asarray(x, float)
    if np.unique(x).size <= 1:
        return np.zeros_like(x, dtype=int)
    try:
        q = pd.qcut(x, q=min(n_bins, max(2, np.unique(x).size)), labels=False, duplicates='drop')
        return q.astype(int).to_numpy()
    except Exception:
        bins = np.linspace(np.nanmin(x), np.nanmax(x), n_bins + 1)
        return np.digitize(x, bins[:-1], right=False).astype(int)

def lzc_safe(x: np.ndarray, n_bins: int = 8) -> float:
    if x.size < 16:
        return np.nan
    try:
        sym = quantize_series(x, n_bins=n_bins)
        return ant.lziv_complexity(sym, normalize=True)
    except Exception:
        return np.nan

def ordpy_complexity_entropy(x: np.ndarray, dx: int = 5, tau: int = 1) -> Dict[str, float]:
    """
    Returns:
      ordpy_perm_entropy: permutation entropy H (normalized)
      ordpy_stat_complexity: Jensen–Shannon statistical complexity C (normalized)
    Notes:
      - Uses dx (embedding dimension along time).
      - Uses taux=tau, tauy=1, dy=1 (1D patterns).
      - Falls back to old API with `tau` if needed.
    """
    x = np.asarray(x, float)
    if x.size < dx * tau + 1:
        return dict(ordpy_perm_entropy=np.nan, ordpy_stat_complexity=np.nan)
    try:
        # Newer ordpy API
        H, C = ordpy.complexity_entropy(x, dx=dx, dy=1, taux=tau, tauy=1, probs=False)
    except Exception as e:
        print(f"ordpy error: {e}")
        # Older ordpy API fallback
        H, C = ordpy.complexity_entropy(x, dx=dx, taux=tau)
    return dict(ordpy_perm_entropy=float(H), ordpy_stat_complexity=float(C))

def acf_measures(x: np.ndarray, max_lag: int = 52) -> Dict[str, float]:
    if x.size < max_lag + 2:
        return dict(acf1=np.nan, acf_mean_abs=np.nan)
    try:
        a = acf(x, nlags=max_lag, fft=True, missing='drop')
        a1 = float(a[1]) if a.size > 1 else np.nan
        asum = float(np.nanmean(np.abs(a[1:])))
        return dict(acf1=a1, acf_mean_abs=asum)
    except Exception:
        return dict(acf1=np.nan, acf_mean_abs=np.nan)

def dfa_alpha_safe(x: np.ndarray) -> float:
    if x.size < 64:
        return np.nan
    try:
        return float(nolds.dfa(x))
    except Exception:
        return np.nan





## 4) Compute metrics for all series


In [None]:
def compute_metrics_for_series(sid: str, x: np.ndarray) -> Dict[str, Union[str, float]]:
    x = np.asarray(x, float)
    x = x - np.nanmean(x)
    s = np.nanstd(x)
    if s > 0:
        x = x / s
    out = {
        "series_id": sid,
        "length": int(x.size),
        "var": float(np.nanvar(x)) if x.size > 0 else np.nan,
        "variance_ratio": variance_ratio(x),
    }
    out.update(spectral_measures(x))
    out["svd_entropy"] = svd_entropy_safe(x, order=3, delay=1)
    out["perm_entropy"] = perm_entropy_safe(x, order=PERM_ORDER, delay=1)
    if FAST_MODE:
        out["sample_entropy"] = np.nan
        out["approx_entropy"] = np.nan
        out["lzc"] = np.nan
        out["ordpy_perm_entropy"] = np.nan
        out["ordpy_stat_complexity"] = np.nan
    else:
        out["sample_entropy"] = sample_entropy_safe(x, m=SAMPEN_M, r_fraction=R_FRACTION)
        out["approx_entropy"] = approximate_entropy_safe(x, m=APEN_M, r_fraction=R_FRACTION)
        out["lzc"] = lzc_safe(x, n_bins=LZC_BINS)
        oc = ordpy_complexity_entropy(x, dx=PERM_ORDER, tau=1)
        out.update(oc)
    out.update(acf_measures(x, max_lag=MAX_LAG_ACF))
    out["dfa_alpha"] = dfa_alpha_safe(x)
    return out

results = Parallel(n_jobs=N_JOBS, backend="loky")(delayed(compute_metrics_for_series)(sid, x)
                                                  for sid, x in tqdm(zip(series_ids, series_list), total=len(series_list)))


df_metrics = pd.DataFrame(results)
df_metrics.head()



## 5) Composite forecastability index (optional)


In [None]:

from sklearn.preprocessing import StandardScaler
import numpy as np

cols_invert = ["spectral_entropy","svd_entropy","perm_entropy","sample_entropy","approx_entropy",
               "ordpy_perm_entropy","lzc","variance_ratio"]
cols_keep = ["spectral_forecastability","periodicity_index","acf1","acf_mean_abs",
             "ordpy_stat_complexity","dfa_alpha"]

for c in cols_invert:
    if c in df_metrics.columns:
        df_metrics[c + "_inv"] = -df_metrics[c]

use_cols = [c + "_inv" for c in cols_invert if c in df_metrics.columns] + [c for c in cols_keep if c in df_metrics.columns]

df_comp = df_metrics[["series_id"] + use_cols].copy()
df_comp = df_comp.replace([np.inf,-np.inf], np.nan)

scaler = StandardScaler()
X = scaler.fit_transform(df_comp[use_cols].to_numpy(dtype=float))
comp_score = np.nanmean(X, axis=1)
df_metrics["forecastability_index"] = comp_score
df_metrics["rank"] = df_metrics["forecastability_index"].rank(ascending=False, method="average")

df_metrics.sort_values("forecastability_index", ascending=False).head(10)


In [None]:
# plot violinplots of metrics using seaborn
import seaborn as sns
melted = df_metrics.melt(id_vars=["series_id"], value_vars=use_cols, var_name="metric", value_name="value")
plt.figure(figsize=(12, 6))
sns.violinplot(data=melted, x="metric", y="value", inner="quartile", scale="width")
plt.xticks(rotation=45, ha='right')
plt.title("Distribution of Forecastability Metrics")
plt.tight_layout()

In [None]:
# plot more and less forecastable series
fig, ax = plt.subplots(len(use_cols), 2, figsize=(12, 30))
for i, m in enumerate(use_cols):
    idx_sorted = df_metrics[m].sort_values().index
    best_idxs = idx_sorted[:4]
    worst_idxs = idx_sorted[-4:]
    for j in range(3):
        ax[i, 0].plot(series_list[best_idxs[j]], linewidth=1)
        ax[i, 1].plot(series_list[worst_idxs[j]], linewidth=1)
    # set titles
    ax[i, 0].set_title(f"Best 3: {m}")
    ax[i, 1].set_title(f"Worst 3: {m}")
plt.tight_layout()

In [None]:
# use UMAP to visualize forecastability space
from umap.umap_ import UMAP
reducer = UMAP(random_state=RANDOM_SEED)
X_emb = reducer.fit_transform(df_comp[use_cols].to_numpy(dtype=float), y=df_metrics['forecastability_index'].to_numpy(dtype=float))
plt.figure(figsize=(10, 8))
sc = plt.scatter(X_emb[:,0], X_emb[:,1], c=df_metrics["forecastability_index"], cmap='viridis', s=10)
plt.colorbar(sc, label='Forecastability Index')
plt.title('UMAP Projection of Forecastability Metrics')
plt.xlabel('UMAP 1')
plt.ylabel('UMAP 2')
plt.show()

# Visualization via UMAP
In the following cell we'll use UMAP to visualize the forecastability space defined by the indicators we computed.
UMAP (Uniform Manifold Approximation and Projection) is a dimensionality reduction technique based on manifold learning. As PCA, it can be used to project a multidimensional space into a lower dimensional one (usually 2D/3D for visualization).


The target (or y) option in UMAP's fit_transform method allows you to provide additional information (such as labels or a continuous variable) to guide the embedding. When specified, UMAP can use this target information to influence the layout, potentially making the embedding more discriminative with respect to the target. This is especially useful for supervised or semi-supervised visualization tasks.

In [None]:
from IPython.display import display
import ipywidgets as widgets
from plotly.subplots import make_subplots
import plotly.graph_objects as go

df_emb = df_metrics.copy()
df_emb['umap_x'] = X_emb[:,0]
df_emb['umap_y'] = X_emb[:,1]
df_emb.head()

id_to_series = {sid: s for sid, s in zip(series_ids, series_list)}
if 'series_id' not in df_emb.columns:
    raise KeyError("df_metrics must contain a 'series_id' column.")
default_color = 'forecastability_index' if 'forecastability_index' in df_emb.columns else use_cols[0]
fig = make_subplots(rows=1, cols=2, column_widths=[0.58, 0.42],
                    subplot_titles=('UMAP of Forecastability Metrics', 'Time Series Preview'))
scatter = go.Scattergl(
    x=df_emb['umap_x'],
    y=df_emb['umap_y'],
    mode='markers',
    marker=dict(size=6, color=df_emb[default_color], colorscale='Viridis', showscale=True,
                colorbar=dict(title=default_color)),
    text=df_emb['series_id'],
    hovertemplate='<b>%{text}</b><br>UMAP1=%{x:.3f}<br>UMAP2=%{y:.3f}<extra></extra>',
    name='UMAP points'
)
fig.add_trace(scatter, row=1, col=1)
line = go.Scatter(x=[], y=[], mode='lines', name='series')
fig.add_trace(line, row=1, col=2)
fig.update_xaxes(title_text='UMAP 1', row=1, col=1)
fig.update_yaxes(title_text='UMAP 2', row=1, col=1)
fig.update_xaxes(title_text='t (index)', row=1, col=2)
fig.update_yaxes(title_text='value (z-norm)', row=1, col=2)
fig.update_layout(height=600, margin=dict(l=10, r=10, t=40, b=10))
figw = go.FigureWidget(fig)
scatter_trace = figw.data[0]
line_trace = figw.data[1]
def znorm(x):
    x = np.asarray(x, float)
    if x.size == 0: return x
    m = np.nanmean(x); s = np.nanstd(x)
    return (x - m) / s if s > 0 else x - m
def hover_fn(trace, points, state):
    if not points.point_inds: return
    idx = points.point_inds[0]
    sid = df_emb.iloc[idx]['series_id']
    series = id_to_series.get(sid, None)
    if series is None or len(series) == 0:
        with figw.batch_update():
            line_trace.x = []; line_trace.y = []
            figw.layout.annotations[1].text = f'Time Series Preview (No data for {sid})'
        return
    x = np.arange(len(series)); y = znorm(series)
    with figw.batch_update():
        line_trace.x = x; line_trace.y = y; line_trace.name = str(sid)
        figw.layout.annotations[1].text = f'Time Series Preview — {sid}'
scatter_trace.on_hover(hover_fn)
options = [c for c in df_emb.columns if c not in ['umap_x','umap_y'] and df_emb[c].dtype != 'O']
if default_color not in options: options.insert(0, default_color)
color_dd = widgets.Dropdown(options=options, value=default_color, description='Color by:', layout=widgets.Layout(width='350px'))
def on_color_change(change):
    col = change['new']
    with figw.batch_update():
        figw.data[0].marker.color = df_emb[col]
        figw.data[0].marker.colorbar.title = col
color_dd.observe(on_color_change, names='value')
display(color_dd)
figw