# layerごとの可視化

In [37]:
# pip install pandas plotly
from pathlib import Path
import json
import re
from typing import Iterable, Optional, Sequence, Union, Tuple
import pandas as pd
import plotly.express as px

def _resolve_distance_dir(base_path: Union[str, Path], model_name: str, probe: str = "distance") -> Path:
    """
    Resolve the directory that directly contains n_comp* subdirectories.
    If base_path already points to that directory, return it as-is; otherwise,
    assume layout: <base_path>/<model_name>/metrics/<probe>/
    """
    p = Path(base_path).expanduser().resolve()
    # Heuristic: if there are any n_comp* subdirs here, we are already at the target
    has_ncomp_here = any(child.is_dir() and child.name.startswith("n_comp") for child in p.iterdir() if child.is_dir())
    if has_ncomp_here:
        return p
    # Else, construct from conventional layout
    candidate = p / model_name / "metrics" / probe
    if not candidate.exists():
        raise FileNotFoundError(f"Cannot find metrics directory: {candidate}")
    return candidate

def _to_int_if_possible(s: str) -> Union[int, str]:
    """Extract the first integer from a string like 'n_comp300' -> 300; fallback to original string."""
    m = re.search(r"\d+", s)
    return int(m.group(0)) if m else s

def _collect_probe_df(distance_dir: Path,
                      metric_key: str = "uuas",
                      n_comp_select: Optional[Sequence[int]] = None) -> pd.DataFrame:
    """Walk n_comp* directories and layer*.json files; return a tidy DataFrame."""
    rows = []
    comp_dirs = sorted(
        [p for p in distance_dir.glob("n_comp*") if p.is_dir()],
        key=lambda p: _to_int_if_possible(p.name)
    )

    for comp_dir in comp_dirs:
        n_comp_val = _to_int_if_possible(comp_dir.name)

        # Optional filter by n_comp
        if n_comp_select is not None:
            if isinstance(n_comp_val, int):
                if n_comp_val not in set(n_comp_select):
                    continue
            else:
                if str(n_comp_val) not in {str(x) for x in n_comp_select}:
                    continue

        for jf in comp_dir.glob("layer*.json"):
            lm = re.search(r"layer(\d+)\.json$", jf.name)
            if not lm:
                continue
            layer_idx = int(lm.group(1))
            with open(jf, "r") as f:
                data = json.load(f)

            if metric_key not in data:
                # Skip files that don't contain the requested metric
                continue

            rows.append({
                "n_comp": n_comp_val,
                "layer": layer_idx,
                metric_key: float(data[metric_key])
            })

    if not rows:
        raise RuntimeError(f"No data found under: {distance_dir} (metric='{metric_key}', n_comp_select={n_comp_select})")

    df = pd.DataFrame(rows)
    # Normalize layer to [0, 1] range based on max observed layer index
    max_layer = df["layer"].max()
    df["normalized_layer"] = df["layer"] / max_layer if max_layer > 0 else df["layer"]

    # Sort for prettier lines
    df = df.sort_values(["n_comp", "layer"]).reset_index(drop=True)

    # Make sure n_comp sorts numerically when possible
    if df["n_comp"].map(type).eq(int).all():
        df["n_comp"] = pd.to_numeric(df["n_comp"], errors="ignore")

    return df

def _normalize_range(r: Optional[Tuple[float, float]]) -> Optional[Tuple[float, float]]:
    """Validate and coerce (min, max) ranges; return None if r is None."""
    if r is None:
        return None
    if not isinstance(r, (tuple, list)) or len(r) != 2:
        raise ValueError("Range must be a 2-tuple like (min, max).")
    lo, hi = float(r[0]), float(r[1])
    if lo == hi:
        raise ValueError("Range min and max must be different.")
    return (lo, hi)

def _normalize_fig_size(fig_size: Optional[Tuple[int, int]],
                        width: Optional[int],
                        height: Optional[int]) -> Tuple[Optional[int], Optional[int]]:
    """
    Resolve figure size preference.
    Priority: fig_size (tuple) > width/height (individual) > None (Plotly auto).
    """
    if fig_size is not None:
        if (not isinstance(fig_size, (tuple, list))) or len(fig_size) != 2:
            raise ValueError("fig_size must be a tuple like (width, height).")
        w, h = int(fig_size[0]), int(fig_size[1])
        if w <= 0 or h <= 0:
            raise ValueError("fig_size width/height must be positive integers.")
        return w, h

    # Fall back to width/height if provided
    w = int(width) if width is not None else None
    h = int(height) if height is not None else None
    if w is not None and w <= 0:
        raise ValueError("width must be a positive integer.")
    if h is not None and h <= 0:
        raise ValueError("height must be a positive integer.")
    return w, h

def plot_uuas_vs_normalized_layer(
    base_path: Union[str, Path],
    model_name: str = "gpt2",
    dims: Optional[Iterable[int]] = None,
    probe: str = "distance",
    metric_key: str = "uuas",
    title_prefix: Optional[str] = None,
    save_html: Optional[Union[str, Path]] = None,
    show_figure: bool = True,
    x_range: Optional[Tuple[float, float]] = None,
    y_range: Optional[Tuple[float, float]] = None,
    fig_size: Optional[Tuple[int, int]] = None,   # (width, height) in pixels
    width: Optional[int] = None,                  # alternative: individual width
    height: Optional[int] = None,                 # alternative: individual height
):
    """
    Create a Plotly line chart for y=metric_key (default 'uuas') vs x=Normalized Layer.

    Parameters
    ----------
    base_path : str | Path
        Directory that already contains n_comp* subdirectories (e.g., .../metrics/distance)
        OR the project root such that the data lives at <base_path>/<model_name>/metrics/<probe>/
    model_name : str
        Model name for labeling (e.g., "gpt2").
    dims : Iterable[int] | None
        Optional subset of n_comp values to include (e.g., [5, 50, 100, 300, 768]).
    probe : str
        Probe/metric group folder name (default "distance").
    metric_key : str
        Which metric from JSON to plot on y-axis (default "uuas").
    title_prefix : str | None
        Optional title prefix; if None, a sensible default is constructed.
    save_html : str | Path | None
        If given, save the interactive figure to this HTML file.
    show_figure : bool
        If True, display the figure via fig.show().
    x_range : (float, float) | None
        Axis range for x (Normalized Layer). Example: (0.2, 0.8).
    y_range : (float, float) | None
        Axis range for y (e.g., UUAS). Example: (0.7, 0.95).
    fig_size : (int, int) | None
        Figure size in pixels as (width, height). Takes precedence over width/height.
    width : int | None
        Figure width in pixels when fig_size is not provided.
    height : int | None
        Figure height in pixels when fig_size is not provided.

    Returns
    -------
    plotly.graph_objs._figure.Figure
        The Plotly figure object.
    """
    # Resolve directory that holds n_comp* subdirs
    distance_dir = _resolve_distance_dir(base_path, model_name=model_name, probe=probe)

    # Build dataframe
    df = _collect_probe_df(distance_dir, metric_key=metric_key, n_comp_select=list(dims) if dims is not None else None)

    # Build title
    if title_prefix is None:
        title_prefix = f"{metric_key.upper()} vs Normalized Layer"
    title = f"{title_prefix} — model={model_name}, probe={probe}"
    # if dims is not None:
    #     title += f", n_comp∈{sorted(list(dims))}"

    # Plot
    fig = px.line(
        df,
        x="normalized_layer",
        y=metric_key,
        color="n_comp",
        markers=True,
        title=title,
        labels={"normalized_layer": "Normalized Layer", metric_key: metric_key, "n_comp": "n_comp"},
        hover_data={"layer": True, "normalized_layer": ":.3f", metric_key: ":.3f"}
    )

    # Base styling
    fig.update_layout(
        template="plotly_white",
        legend_title_text="n_comp",
        xaxis=dict(tickmode="array", tickvals=[0.0, 0.25, 0.5, 0.75, 1.0], tickformat=".2f")
    )

    # Axis ranges
    xr = _normalize_range(x_range)
    yr = _normalize_range(y_range)
    if xr is not None:
        fig.update_xaxes(range=[xr[0], xr[1]])
    if yr is not None:
        fig.update_yaxes(range=[yr[0], yr[1]])

    # Figure size (width/height in pixels)
    w, h = _normalize_fig_size(fig_size, width, height)
    if w is not None or h is not None:
        fig.update_layout(width=w, height=h, margin=dict(l=60, r=20, t=60, b=60))

    # Save & show
    if save_html is not None:
        save_path = Path(save_html).expanduser()
        save_path.parent.mkdir(parents=True, exist_ok=True)
        fig.write_html(str(save_path), include_plotlyjs="cdn")

    return fig


In [38]:
def plot_metric_vs_layer_across_models(
    base_path: Union[str, Path],
    models: Union[Sequence[str], str],
    n_comp: Union[int, str],
    probe: str = "distance",
    metric_key: str = "uuas",
    label_name: Optional[Union[Sequence[str], str]] = None,  # <-- NEW
    title_prefix: Optional[str] = None,
    save_html: Optional[Union[str, Path]] = None,
    show_figure: bool = True,
    x_range: Optional[Tuple[float, float]] = None,
    y_range: Optional[Tuple[float, float]] = None,
    fig_size: Optional[Tuple[int, int]] = None,   # (width, height) in pixels
    width: Optional[int] = None,                  # alternative: individual width
    height: Optional[int] = None,                 # alternative: individual height
):
    """
    Compare y=metric_key (default 'uuas') vs x=Normalized Layer across multiple models,
    while fixing the dimensionality (n_comp) to a single value.

    Parameters
    ----------
    base_path : str | Path
        Project root containing <model>/metrics/<probe>/n_comp*/layer*.json, OR
        a concrete metrics directory if comparing a single model.
    models : Sequence[str] | str
        One or more model identifiers, e.g. ["gpt2", "meta-llama/Meta-Llama-3-8B"].
    n_comp : int | str
        A single n_comp to include (e.g., 300). This is fixed across all models.
    probe : str
        Probe/metric group folder name (default "distance").
    metric_key : str
        Which metric from JSON to plot on y-axis (default "uuas").
    label_name : Sequence[str] | str | None
        Optional legend labels corresponding to `models`.
        - If a list is provided, its length must match `models`.
        - If a string is provided, `models` must be a single model.
        - If None, model identifiers are used as labels.
    title_prefix : str | None
        Optional title prefix; if None, a sensible default is constructed.
    save_html : str | Path | None
        If given, save the interactive figure to this HTML file.
    show_figure : bool
        If True, display the figure via fig.show().
    x_range, y_range : (float, float) | None
        Optional axis ranges.
    fig_size, width, height :
        Figure sizing options. fig_size takes precedence.

    Returns
    -------
    plotly.graph_objs._figure.Figure
        The Plotly figure object.
    """
    # Coerce models to a list for convenience
    if isinstance(models, str):
        models = [models]

    # Build a label map from model -> label
    if label_name is None:
        label_map = {m: m for m in models}
    else:
        if isinstance(label_name, str):
            if len(models) != 1:
                raise ValueError("Single label_name string is only allowed when a single model is provided.")
            label_map = {models[0]: label_name}
        else:
            labels = list(label_name)
            if len(labels) != len(models):
                raise ValueError("label_name must have the same length as models.")
            label_map = {m: l for m, l in zip(models, labels)}

    # Optional safety check: base_path should be a common root when comparing multiple models
    if len(models) > 1:
        p = Path(base_path).expanduser().resolve()
        try:
            has_ncomp_here = any(
                child.is_dir() and child.name.startswith("n_comp")
                for child in p.iterdir() if child.is_dir()
            )
        except FileNotFoundError:
            has_ncomp_here = False
        if has_ncomp_here:
            raise ValueError(
                "For multi-model comparison, set base_path to the common parent that contains "
                "<model>/metrics/<probe>/..., not a specific 'distance' directory."
            )

    # Collect data for each model at the fixed n_comp
    dfs = []
    for model in models:
        # Resolve the directory that contains n_comp* for this model
        distance_dir = _resolve_distance_dir(base_path, model_name=model, probe=probe)

        # Build a DataFrame filtered to the requested n_comp
        try:
            df_m = _collect_probe_df(
                distance_dir,
                metric_key=metric_key,
                n_comp_select=[n_comp]
            )
        except RuntimeError as e:
            # Add model name context to make debugging easier
            raise RuntimeError(f"Model '{model}': {e}")

        # Annotate with model (raw id) and human-friendly label used in legend
        df_m["model"] = model
        df_m["label"] = label_map[model]
        dfs.append(df_m)

    if not dfs:
        raise RuntimeError("No data aggregated. Check model names, base_path, and n_comp.")

    df_all = pd.concat(dfs, ignore_index=True)

    # Build title
    if title_prefix is None:
        title_prefix = f"{metric_key.upper()} vs Normalized Layer"
    title = f"{title_prefix} — n_comp={n_comp}, probe={probe}"

    # Use 'label' for color (legend). Keep 'model' in hover for traceability.
    color_col = "label"
    fig = px.line(
        df_all,
        x="normalized_layer",
        y=metric_key,
        color=color_col,
        markers=True,
        title=title,
        labels={"normalized_layer": "Normalized Layer", metric_key: metric_key, color_col: "series"},
        hover_data={"layer": True, "normalized_layer": ":.3f", metric_key: ":.3f", "model": True, "n_comp": True}
    )

    # Base styling
    fig.update_layout(
        template="plotly_white",
        legend_title_text="series",
        xaxis=dict(tickmode="array", tickvals=[0.0, 0.25, 0.5, 0.75, 1.0], tickformat=".2f")
    )

    # Axis ranges
    xr = _normalize_range(x_range)
    yr = _normalize_range(y_range)
    if xr is not None:
        fig.update_xaxes(range=[xr[0], xr[1]])
    if yr is not None:
        fig.update_yaxes(range=[yr[0], yr[1]])

    # Figure size (width/height in pixels)
    w, h = _normalize_fig_size(fig_size, width, height)
    if w is not None or h is not None:
        fig.update_layout(width=w, height=h, margin=dict(l=60, r=20, t=60, b=60))

    # Save & show
    if save_html is not None:
        save_path = Path(save_html).expanduser()
        save_path.parent.mkdir(parents=True, exist_ok=True)
        fig.write_html(str(save_path), include_plotlyjs="cdn")
        
    return fig


In [None]:
model_name = "gpt2"
fig = plot_uuas_vs_normalized_layer(
    base_path=f"/home/masaki/hierarchical-repr/EntityTree/scripts/structural_probe/probe_results/{model_name}/metrics/distance",
    model_name=model_name,
    dims=[5, 50, 100, 300, 768],
    fig_size=(800, 400),
)
fig.show()


fig = plot_uuas_vs_normalized_layer(
    base_path=f"/home/masaki/hierarchical-repr/EntityTree/scripts/structural_probe/probe_results/{model_name}/metrics/distance",
    model_name=model_name,
    dims=[50, 100],
    fig_size=(800, 400),
)
fig.show()



errors='ignore' is deprecated and will raise in a future version. Use to_numeric without passing `errors` and catch exceptions explicitly instead




errors='ignore' is deprecated and will raise in a future version. Use to_numeric without passing `errors` and catch exceptions explicitly instead



In [8]:
model_name = "meta-llama/Meta-Llama-3-8B"
fig = plot_uuas_vs_normalized_layer(
    base_path=f"/home/masaki/hierarchical-repr/EntityTree/scripts/structural_probe/probe_results/{model_name}/metrics/distance",
    model_name=model_name,
    dims=[5, 50, 100, 300, 768],
    fig_size=(800, 400),
)
fig.show()


fig = plot_uuas_vs_normalized_layer(
    base_path=f"/home/masaki/hierarchical-repr/EntityTree/scripts/structural_probe/probe_results/{model_name}/metrics/distance",
    model_name=model_name,
    dims=[768],
    fig_size=(800, 400),
)
fig.show()



errors='ignore' is deprecated and will raise in a future version. Use to_numeric without passing `errors` and catch exceptions explicitly instead




errors='ignore' is deprecated and will raise in a future version. Use to_numeric without passing `errors` and catch exceptions explicitly instead



In [40]:
# Compare GPT-2 and Llama 3 8B 
fig = plot_metric_vs_layer_across_models(
    base_path="/home/masaki/hierarchical-repr/EntityTree/scripts/structural_probe/probe_results",
    models=["gpt2", "meta-llama/Meta-Llama-3-8B"],
    label_name=["GPT-2", "Llama 3 8B"],
    n_comp=300,
    metric_key="uuas",
    fig_size=(800, 400),
    probe="distance"
)
fig.show()


fig = plot_metric_vs_layer_across_models(
    base_path="/home/masaki/hierarchical-repr/EntityTree/scripts/structural_probe/probe_results",
    models=["gpt2", "meta-llama/Meta-Llama-3-8B"],
    label_name=["GPT-2", "Llama 3 8B"],
    n_comp=768,
    metric_key="uuas",
    fig_size=(800, 400),
    probe="distance"
)
fig.show()



errors='ignore' is deprecated and will raise in a future version. Use to_numeric without passing `errors` and catch exceptions explicitly instead


errors='ignore' is deprecated and will raise in a future version. Use to_numeric without passing `errors` and catch exceptions explicitly instead




errors='ignore' is deprecated and will raise in a future version. Use to_numeric without passing `errors` and catch exceptions explicitly instead


errors='ignore' is deprecated and will raise in a future version. Use to_numeric without passing `errors` and catch exceptions explicitly instead



# 次元数ごとの可視化

In [23]:
# pip install pandas plotly
from pathlib import Path
import json
import re
from typing import Iterable, Optional, Sequence, Union, Tuple, Literal
import pandas as pd
import plotly.express as px

# ---------- helpers ----------
def _resolve_distance_dir(base_path: Union[str, Path], model_name: str, probe: str = "distance") -> Path:
    """
    Resolve the directory that directly contains n_comp* subdirectories.
    If base_path already points to that directory, return it as-is; otherwise,
    assume layout: <base_path>/<model_name>/metrics/<probe>/
    """
    p = Path(base_path).expanduser().resolve()
    # Heuristic: if there are any n_comp* subdirs here, we are already at the target
    has_ncomp_here = any(child.is_dir() and child.name.startswith("n_comp") for child in p.iterdir() if child.is_dir())
    if has_ncomp_here:
        return p
    # Else, construct from conventional layout
    candidate = p / model_name / "metrics" / probe
    if not candidate.exists():
        raise FileNotFoundError(f"Cannot find metrics directory: {candidate}")
    return candidate

def _to_int_if_possible(s: str) -> Union[int, str]:
    """Extract the first integer from a string like 'n_comp300' -> 300; fallback to original string."""
    m = re.search(r"\d+", s)
    return int(m.group(0)) if m else s

def _collect_probe_df(distance_dir: Path,
                      metric_key: str = "uuas",
                      n_comp_select: Optional[Sequence[int]] = None,
                      layer_select: Optional[Sequence[int]] = None) -> pd.DataFrame:
    """Walk n_comp* directories and layer*.json files; return a tidy DataFrame."""
    rows = []
    comp_dirs = sorted(
        [p for p in distance_dir.glob("n_comp*") if p.is_dir()],
        key=lambda p: _to_int_if_possible(p.name)
    )

    for comp_dir in comp_dirs:
        n_comp_val = _to_int_if_possible(comp_dir.name)

        # Optional filter by n_comp
        if n_comp_select is not None:
            if isinstance(n_comp_val, int):
                if n_comp_val not in set(n_comp_select):
                    continue
            else:
                if str(n_comp_val) not in {str(x) for x in n_comp_select}:
                    continue

        for jf in comp_dir.glob("layer*.json"):
            lm = re.search(r"layer(\d+)\.json$", jf.name)
            if not lm:
                continue
            layer_idx = int(lm.group(1))

            # Optional filter by layer
            if layer_select is not None and layer_idx not in set(layer_select):
                continue

            with open(jf, "r") as f:
                data = json.load(f)

            if metric_key not in data:
                # Skip files that don't contain the requested metric
                continue

            rows.append({
                "n_comp": n_comp_val,
                "layer": layer_idx,
                metric_key: float(data[metric_key])
            })

    if not rows:
        raise RuntimeError(
            f"No data found under: {distance_dir} (metric='{metric_key}', "
            f"n_comp_select={n_comp_select}, layer_select={layer_select})"
        )

    df = pd.DataFrame(rows)

    # Ensure n_comp is numeric when possible (so x-axis is numeric, not categorical)
    df["n_comp"] = df["n_comp"].apply(lambda v: int(v) if isinstance(v, int) or str(v).isdigit() else v)

    return df

def _normalize_range(r: Optional[Tuple[float, float]]) -> Optional[Tuple[float, float]]:
    """Validate and coerce (min, max) ranges; return None if r is None."""
    if r is None:
        return None
    if not isinstance(r, (tuple, list)) or len(r) != 2:
        raise ValueError("Range must be a 2-tuple like (min, max).")
    lo, hi = float(r[0]), float(r[1])
    if lo == hi:
        raise ValueError("Range min and max must be different.")
    return (lo, hi)

def _normalize_fig_size(fig_size: Optional[Tuple[int, int]],
                        width: Optional[int],
                        height: Optional[int]) -> Tuple[Optional[int], Optional[int]]:
    """
    Resolve figure size preference.
    Priority: fig_size (tuple) > width/height (individual) > None (Plotly auto).
    """
    if fig_size is not None:
        if (not isinstance(fig_size, (tuple, list))) or len(fig_size) != 2:
            raise ValueError("fig_size must be a tuple like (width, height).")
        w, h = int(fig_size[0]), int(fig_size[1])
        if w <= 0 or h <= 0:
            raise ValueError("fig_size width/height must be positive integers.")
        return w, h
    w = int(width) if width is not None else None
    h = int(height) if height is not None else None
    if w is not None and w <= 0:
        raise ValueError("width must be a positive integer.")
    if h is not None and h <= 0:
        raise ValueError("height must be a positive integer.")
    return w, h

# ---------- main plotting ----------
def plot_uuas_vs_probe_rank_dims(
    base_path: Union[str, Path],
    model_name: str = "gpt2",
    dims: Optional[Iterable[int]] = None,              # subset of n_comp values (e.g., [5, 50, 100, 300, 768])
    layers: Optional[Iterable[int]] = None,            # subset of layers (e.g., [0, 6, 12])
    probe: str = "distance",
    metric_key: str = "uuas",
    color_by: Literal["layer", "n_comp"] = "layer",    # default: lines per layer across dims
    title_prefix: Optional[str] = None,
    save_html: Optional[Union[str, Path]] = None,
    show_figure: bool = True,
    x_range: Optional[Tuple[float, float]] = None,     # numeric range over dims (e.g., (5, 768))
    y_range: Optional[Tuple[float, float]] = None,     # e.g., (0.7, 0.95)
    fig_size: Optional[Tuple[int, int]] = None,        # (width, height) in px
    width: Optional[int] = None,
    height: Optional[int] = None,
):
    """
    Plot y=metric_key (default 'uuas') vs x=Probe Rank, where "Probe Rank" means the actual
    probe dimension (n_comp). For example, dims=[5, 50, 100, 300, 768] => x values are 5, 50, 100, 300, 768.

    Lines are colored by 'layer' by default so you can see, for each layer, how the metric changes
    as the probe dimension increases.
    """
    # Resolve directory and load table
    distance_dir = _resolve_distance_dir(base_path, model_name=model_name, probe=probe)
    df = _collect_probe_df(
        distance_dir,
        metric_key=metric_key,
        n_comp_select=list(dims) if dims is not None else None,
        layer_select=list(layers) if layers is not None else None
    )

    # Ensure numeric and sort along x (n_comp)
    # If any n_comp is non-numeric, they will appear as categories and may break the numeric axis.
    if not pd.api.types.is_numeric_dtype(df["n_comp"]):
        # Best effort: coerce to numeric where possible
        df["n_comp_num"] = pd.to_numeric(df["n_comp"], errors="coerce")
        if df["n_comp_num"].notna().all():
            df["n_comp"] = df["n_comp_num"].astype(int)
            df.drop(columns=["n_comp_num"], inplace=True)
        else:
            raise ValueError("Found non-numeric n_comp values that cannot be plotted on a numeric x-axis.")

    df = df.sort_values(by=["layer", "n_comp"]).reset_index(drop=True)

    # Title
    if title_prefix is None:
        title_prefix = f"{metric_key.upper()} vs Probe Rank (Probe Rank = n_comp)"
    title = f"{title_prefix} — model={model_name}, probe={probe}"
    # if dims is not None:
    #     title += f", n_comp∈{sorted(list(dims))}"
    if layers is not None:
        title += f", layers∈{sorted(list(layers))}"

    # Plot
    fig = px.line(
        df,
        x="n_comp",
        y=metric_key,
        color=color_by,
        markers=True,
        title=title,
        labels={
            "n_comp": "Probe Rank",     # label name is "Probe Rank" per your definition
            metric_key: metric_key,
            "layer": "layer",
            "n_comp": "Probe Rank"      # ensure legend/hover also says Probe Rank when color_by='n_comp'
        },
        hover_data={
            "layer": True,
            "n_comp": True,             # will show as "Probe Rank" due to labels
            metric_key: ":.3f"
        }
    )

    # Layout & axes
    fig.update_layout(
        template="plotly_white",
        legend_title_text=color_by
    )

    # Optional axis ranges
    xr = _normalize_range(x_range)
    yr = _normalize_range(y_range)
    if xr is not None:
        fig.update_xaxes(range=[xr[0], xr[1]])
    if yr is not None:
        fig.update_yaxes(range=[yr[0], yr[1]])

    # Figure size
    w, h = _normalize_fig_size(fig_size, width, height)
    if w is not None or h is not None:
        fig.update_layout(width=w, height=h, margin=dict(l=60, r=20, t=60, b=60))

    # Save and/or show
    if save_html is not None:
        save_path = Path(save_html).expanduser()
        save_path.parent.mkdir(parents=True, exist_ok=True)
        fig.write_html(str(save_path), include_plotlyjs="cdn")

    return fig

# ---------- example usage ----------
# 1) Direct distance dir, color lines by layer, set figure size to 1000x600
# plot_uuas_vs_probe_rank_dims(
#     base_path="/home/masaki/hierarchical-repr/EntityTree/scripts/structural_probe/probe_results/gpt2/metrics/distance",
#     model_name="gpt2",
#     dims=[5, 50, 100, 300, 768],
#     layers=None,                  # or e.g., [0, 6, 12]
#     x_range=(5, 800),
#     y_range=(0.70, 0.95),
#     fig_size=(1000, 600),
#     save_html="uuas_vs_probe_rank_dims.html",
# )

# 2) Project root dir, color by probe (each n_comp as a line across layers -> not recommended when x=n_comp)
#    but still available; in that case you might facet by layer separately.
# plot_uuas_vs_probe_rank_dims(
#     base_path="/home/masaki/hierarchical-repr/EntityTree/scripts/structural_probe/probe_results",
#     model_name="gpt2",
#     dims=[5, 50, 100, 300, 768],
#     layers=[0, 6, 12],
#     color_by="layer",
#     width=900,
# )


In [28]:
model_name = "gpt2"
plot_uuas_vs_probe_rank_dims(
    base_path=f"/home/masaki/hierarchical-repr/EntityTree/scripts/structural_probe/probe_results/{model_name}/metrics/distance",
    model_name=model_name,
    dims=[5, 50, 100, 300, 768],
    layers=[2, 6, 12],                 
    fig_size=(800, 400),
    save_html="uuas_vs_probe_rank_dims.html",
)

In [18]:
model_name = "meta-llama/Meta-Llama-3-8B"
fig = plot_uuas_vs_probe_rank_dims(
    base_path=f"/home/masaki/hierarchical-repr/EntityTree/scripts/structural_probe/probe_results/{model_name}/metrics/distance",
    model_name=model_name,
    dims=[50, 100, 300, 768],
    layers=[0, 12, 24, 32],                  # or e.g., [0, 6, 12]
    fig_size=(800, 400),
    save_html="uuas_vs_probe_rank_dims.html",
)
fig.show()


fig = plot_uuas_vs_probe_rank_dims(
    base_path=f"/home/masaki/hierarchical-repr/EntityTree/scripts/structural_probe/probe_results/{model_name}/metrics/distance",
    model_name=model_name,
    dims=[50, 100, 300, 768],
    layers=[8, 12, 16],                  # or e.g., [0, 6, 12]
    fig_size=(800, 400),
    save_html="uuas_vs_probe_rank_dims.html",
)
fig.show()