# 04 — Interactive Feature Explorer

Scatter any two features against each other, then **click** a single point or
use the **lasso / box-select** tool in the plot toolbar to select multiple
points — the full post text appears in the panel below.

Useful for investigating:
- **Duplicate data points** (e.g. many posts with exactly the same `ppl_mean`
  and `emb_centroid_dist` → enable **Jitter** to separate them, then click to
  confirm they share identical text)
- **Outlier clusters** — why are these posts flagged?
- **Category or toxicity patterns** in the feature space

**Requirements** (install once if not already present):
```bash
pip install plotly ipywidgets anywidget
```

**Data sources:**
- `data/processed/posts_clean.parquet` — post content and metadata
- `outputs/features/features.parquet` — 19 extracted stylometric/perplexity/embedding features
- `outputs/outliers/outliers.parquet` — ensemble outlier flags and per-detector scores

In [1]:
import warnings
warnings.filterwarnings("ignore")

import textwrap
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import ipywidgets as widgets
from IPython.display import display

# ── Load & merge ──────────────────────────────────────────────────────
posts    = pd.read_parquet("../data/processed/posts_clean.parquet")
feats    = pd.read_parquet("../outputs/features/features.parquet")
outliers = pd.read_parquet("../outputs/outliers/outliers.parquet")

out_keep = ["id", "ensemble_flag"] + [
    c for c in ["iso_forest_score", "lof_score", "mahalanobis_score"]
    if c in outliers.columns
]
df = posts.merge(feats, on="id", suffixes=("", "_feat"))
df = df.merge(outliers[out_keep], on="id", how="left")
df["ensemble_flag"] = df["ensemble_flag"].fillna(False)
df = df.reset_index(drop=True)

print(f"Loaded {len(df):,} rows × {df.shape[1]} columns")

# ── Feature / score columns available for plotting ────────────────────
FEAT_COLS = [c for c in [
    "char_count", "word_count", "sentence_count",
    "avg_word_length", "avg_sentence_length",
    "punctuation_density", "capitalization_ratio", "lexical_diversity",
    "first_person_rate", "hedge_count", "temporal_deixis_count",
    "anecdote_marker_count", "typo_proxy",
    "ppl_mean", "ppl_var", "ppl_tail_95",
    "emb_mean_nn_dist", "emb_local_density", "emb_centroid_dist",
    "iso_forest_score", "lof_score", "mahalanobis_score",
] if c in df.columns]

TEXT_COL = "content"   # from config.yaml
print(f"{len(FEAT_COLS)} columns available: {FEAT_COLS}")

Loaded 32,709 rows × 39 columns
22 columns available: ['char_count', 'word_count', 'sentence_count', 'avg_word_length', 'avg_sentence_length', 'punctuation_density', 'capitalization_ratio', 'lexical_diversity', 'first_person_rate', 'hedge_count', 'temporal_deixis_count', 'anecdote_marker_count', 'typo_proxy', 'ppl_mean', 'ppl_var', 'ppl_tail_95', 'emb_mean_nn_dist', 'emb_local_density', 'emb_centroid_dist', 'iso_forest_score', 'lof_score', 'mahalanobis_score']


In [None]:
# ────────────────────────────────────────────────────────────────────────
# Helpers
# ────────────────────────────────────────────────────────────────────────

PALETTE = [
    "#4c72b0", "#dd8452", "#55a868", "#c44e52",
    "#8172b2", "#937860", "#da8bc3", "#8c8c8c",
    "#ccb974", "#64b5cd",
]


def build_color_array(col, idx):
    """Return (color_array, colorscale, tickvals, ticktext) for the given column."""
    series = df.loc[idx, col]
    unique_vals = series.dropna().unique()
    # Boolean / binary (ensemble_flag, is_exact_duplicate)
    if series.dtype == bool or set(unique_vals) <= {True, False, 0, 1}:
        colors = series.astype(int).values
        colorscale = [[0.0, "#4c72b0"], [1.0, "#dd8452"]]
        return colors, colorscale, [0, 1], ["False", "True"]
    # Categorical (topic_label, toxic_level)
    cats = sorted(str(v) for v in unique_vals)
    cat_map = {c: i for i, c in enumerate(cats)}
    colors = series.astype(str).map(cat_map).fillna(-1).astype(int).values
    n = len(cats)
    colorscale = [[i / max(n - 1, 1), PALETTE[i % len(PALETTE)]] for i in range(n)]
    return colors, colorscale, list(range(n)), cats


def apply_log(vals, do_log):
    """Log10-transform an array, replacing non-positive values with NaN."""
    v = np.asarray(vals, dtype=float)
    if do_log:
        v = np.where(v > 0, np.log10(v), np.nan)
    return v


_rng = np.random.default_rng(42)

def get_sample_idx(n):
    """Stable subsample of df indices (or all if n==0)."""
    if n == 0 or n >= len(df):
        return df.index.values
    return _rng.choice(len(df), size=n, replace=False)


# ────────────────────────────────────────────────────────────────────────
# Widgets
# ────────────────────────────────────────────────────────────────────────

ws = dict(description_width="80px")
L_wide = widgets.Layout(width="290px")
L_mid  = widgets.Layout(width="230px")
L_narr = widgets.Layout(width="120px")

x_dd = widgets.Dropdown(
    options=FEAT_COLS, value="ppl_mean",
    description="X axis:", style=ws, layout=L_wide)
y_dd = widgets.Dropdown(
    options=FEAT_COLS, value="emb_centroid_dist",
    description="Y axis:", style=ws, layout=L_wide)

COLOR_MAP = {"Ensemble flag": "ensemble_flag"}
if "is_exact_duplicate" in df.columns: COLOR_MAP["Exact duplicate"] = "is_exact_duplicate"
if "topic_label"       in df.columns: COLOR_MAP["Topic category"]  = "topic_label"
if "toxic_level"       in df.columns: COLOR_MAP["Toxicity level"]   = "toxic_level"

color_dd = widgets.Dropdown(
    options=list(COLOR_MAP.keys()), value="Ensemble flag",
    description="Color by:", style=ws, layout=L_mid)
logx_cb  = widgets.Checkbox(value=False, description="Log X",  layout=L_narr)
logy_cb  = widgets.Checkbox(value=False, description="Log Y",  layout=L_narr)
jitter_cb = widgets.Checkbox(value=False, description="Jitter", layout=L_narr)
sample_dd = widgets.Dropdown(
    options=[("All", 0), ("20 k", 20_000), ("10 k", 10_000),
             ("5 k", 5_000), ("1 k", 1_000)],
    value=10_000, description="Sample:", style=ws, layout=L_mid)

status_out = widgets.Output()
text_out   = widgets.Output(
    layout=widgets.Layout(
        border="1px solid #ccc", padding="10px",
        max_height="500px", overflow_y="auto", width="100%"))

# ────────────────────────────────────────────────────────────────────────
# Initial figure
# ────────────────────────────────────────────────────────────────────────

_state = {}   # mutable shared state

idx0 = get_sample_idx(sample_dd.value)
_state["idx"] = idx0

c0, cs0, tv0, tt0 = build_color_array("ensemble_flag", idx0)

scatter = go.Scatter(
    x=df.loc[idx0, "ppl_mean"].values,
    y=df.loc[idx0, "emb_centroid_dist"].values,
    mode="markers",
    marker=dict(
        size=4, opacity=0.5,
        color=c0, colorscale=cs0, showscale=True,
        colorbar=dict(
            title="Ensemble flag",
            tickvals=tv0, ticktext=tt0, thickness=15,
        ),
    ),
    customdata=idx0,
    hovertemplate="Post %{customdata}<br>X: %{x:.4g}<br>Y: %{y:.4g}<extra></extra>",
    selected=dict(marker=dict(color="red", size=9, opacity=1.0)),
    unselected=dict(marker=dict(opacity=0.08)),
)

fig = go.FigureWidget(
    data=[scatter],
    layout=go.Layout(
        title="Click a point or lasso-select to view post text",
        xaxis=dict(title="ppl_mean", zeroline=False),
        yaxis=dict(title="emb_centroid_dist", zeroline=False),
        width=820, height=520,
        dragmode="lasso",
        clickmode="event+select",
        margin=dict(l=55, r=100, t=45, b=55),
    ),
)


# ────────────────────────────────────────────────────────────────────────
# Callbacks
# ────────────────────────────────────────────────────────────────────────

def refresh(change=None):
    """Rebuild scatter data whenever any control changes."""
    idx = get_sample_idx(sample_dd.value)
    _state["idx"] = idx

    x_col, y_col = x_dd.value, y_dd.value
    xv = apply_log(df.loc[idx, x_col].values, logx_cb.value)
    yv = apply_log(df.loc[idx, y_col].values, logy_cb.value)

    if jitter_cb.value:
        jrng = np.random.default_rng(99)
        xv = xv + jrng.normal(0, np.nanstd(xv) * 0.005, len(xv))
        yv = yv + jrng.normal(0, np.nanstd(yv) * 0.005, len(yv))

    col = COLOR_MAP[color_dd.value]
    colors, cscale, tvals, ttexts = build_color_array(col, idx)

    xl = f"log\u2081\u2080({x_col})" if logx_cb.value else x_col
    yl = f"log\u2081\u2080({y_col})" if logy_cb.value else y_col

    with fig.batch_update():
        fig.data[0].x = xv
        fig.data[0].y = yv
        fig.data[0].customdata = idx
        fig.data[0].marker.color = colors
        fig.data[0].marker.colorscale = cscale
        fig.data[0].marker.colorbar.tickvals = tvals
        fig.data[0].marker.colorbar.ticktext = ttexts
        fig.data[0].marker.colorbar.title.text = color_dd.value
        fig.layout.xaxis.title.text = xl
        fig.layout.yaxis.title.text = yl

    with status_out:
        status_out.clear_output()
        jitter_note = "  [jitter ON]" if jitter_cb.value else ""
        print(f"Showing {len(idx):,} points.{jitter_note}  "
              f"Click or lasso-select to read posts.")


def show_posts(point_inds, customdata):
    """Display post text for selected point indices."""
    with text_out:
        text_out.clear_output()
        n = len(point_inds)
        if n == 0:
            print("No points selected.")
            return
        show = min(n, 8)
        print(f"Selected {n:,} post(s) — showing first {show}:\n")
        for rank, pi in enumerate(point_inds[:show], 1):
            df_i = int(customdata[pi])
            row  = df.loc[df_i]

            x_col, y_col = x_dd.value, y_dd.value
            x_raw = row.get(x_col, float("nan"))
            y_raw = row.get(y_col, float("nan"))
            cat   = row.get("topic_label",      "?")
            tox   = row.get("toxic_level",       "?")
            sub   = row.get("submolt_name",      "?")
            flag  = row.get("ensemble_flag",     "?")
            is_dup = row.get("is_exact_duplicate", "?")
            dup_cnt = row.get("duplicate_count",  "?")
            dup_grp = row.get("duplicate_group_id", "?")

            print(f"{'─' * 74}")
            print(f"[{rank}/{show}]  id = {row['id']}")
            print(f"  Category: {cat}   Toxicity: {tox}   "
                  f"Submolt: {sub}   Flagged: {flag}")
            print(f"  Exact duplicate: {is_dup}   "
                  f"Copies in dataset: {dup_cnt}   Group ID: {dup_grp}")
            try:
                print(f"  {x_col}: {float(x_raw):.4g}   {y_col}: {float(y_raw):.4g}")
            except (TypeError, ValueError):
                print(f"  {x_col}: {x_raw}   {y_col}: {y_raw}")
            print()
            text = str(row.get(TEXT_COL, ""))
            if len(text) > 2000:
                text = text[:2000] + " ... [truncated]"
            print(textwrap.fill(text, width=74))
            print()
        if n > show:
            print(f"  … and {n - show} more posts not shown.")


def on_click(trace, points, state_):
    show_posts(points.point_inds, trace.customdata)

def on_select(trace, points, selector):
    show_posts(points.point_inds, trace.customdata)


# Register on the fixed trace (persists through data updates)
fig.data[0].on_click(on_click)
fig.data[0].on_selection(on_select)

# Observe all controls
for ctrl in [x_dd, y_dd, logx_cb, logy_cb, color_dd, jitter_cb, sample_dd]:
    ctrl.observe(refresh, names="value")

# ────────────────────────────────────────────────────────────────────────
# Display
# ────────────────────────────────────────────────────────────────────────

row1 = widgets.HBox([x_dd, y_dd])
row2 = widgets.HBox([color_dd, logx_cb, logy_cb, jitter_cb, sample_dd])

# Show initial status
with status_out:
    print(f"Showing {len(idx0):,} points.  Click or lasso-select to read posts.")

display(widgets.VBox([
    row1, row2, status_out,
    fig,
    widgets.HTML("<b>Selected Posts</b>"),
    text_out,
]))

VBox(children=(HBox(children=(Dropdown(description='X axis:', index=13, layout=Layout(width='290px'), options=…