In [2]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from sklearn.decomposition import PCA
from tqdm.auto import tqdm
import random
import itertools
from utils.grid_utils import *
from utils.model_utils import *

SPRW = "/cs/student/projects1/aibh/2024/cbaumgar/MSC_THESIS/sv3_model_ft/checkpoint-163000"
SPH = "/cs/student/projects1/aibh/2024/cbaumgar/MSC_THESIS/sv2_model_fixed/save-checkpoint-46000"

def collect_embeddings(model, tokenizer, num_prompts, layer_idx, grid_size, context_type,
                       min_walk_len, max_walk_len, analysis_target, aggregation_method):
    if aggregation_method != 'all' or 'nodes' not in analysis_target:
        raise ValueError("This simplified script is designed for 'all' aggregation on nodes.")
        
    print(f"Collecting embeddings for '{analysis_target}'...")
    all_points = []
    num_nodes = grid_size * grid_size

    for _ in tqdm(range(num_prompts), desc="Processing prompts"):
        node_names = generate_random_names(num_nodes)
        G = get_grid_graph(node_names, size=grid_size)
        pos_map = {node_names[i]: (i // grid_size, i % grid_size) for i in range(num_nodes)}

        components = generate_sp_prompt_components(G, context_type, min_walk_len, max_walk_len)
        if not components:
            continue

        prompt, target_str = components['prompt'], components['target_str']
        full_text = f"{prompt} {target_str} [EOS]"
        hidden_states, offsets = get_hidden_states_with_offsets(model, tokenizer, full_text, layer_idx)

        is_task = 'task' in analysis_target
        path_nodes = components['target_path_nodes'] if is_task else components['context_path_nodes']
        
        path_directions = []
        for j in range(1, len(path_nodes)):
            prev, cur = path_nodes[j-1], path_nodes[j]
            if prev not in pos_map or cur not in pos_map:
                path_directions.append('UNK')
                continue
            r1, c1 = pos_map[prev]
            r2, c2 = pos_map[cur]
            dr, dc = r2 - r1, c2 - c1
            if (dr, dc) == (1, 0): path_directions.append('SOUTH')
            elif (dr, dc) == (-1, 0): path_directions.append('NORTH')
            elif (dr, dc) == (0, 1): path_directions.append('EAST')
            elif (dr, dc) == (0, -1): path_directions.append('WEST')
            else: path_directions.append('UNK')

        search_start_char = len(prompt) + 1 if is_task else 0
        search_end_char = len(full_text)
        
        current_search_offset = search_start_char
        for i, node in enumerate(path_nodes):
            pos_list = substring_positions(full_text, node, current_search_offset, search_end_char)
            
            if pos_list:
                span_start, span_end = pos_list[0]
                vec = gather_embeddings_for_span(hidden_states, offsets, (span_start, span_end))
                
                if vec is not None:
                    point = {
                        "vector": vec,
                        "coordinate": str(pos_map.get(node, (0, 0))),
                        "path_index": i,
                        "arrival_direction": 'Start' if i == 0 else path_directions[i - 1]
                    }

                    if is_task:
                        role = 'Intermediate'
                        if i == 0: role = 'Start'
                        if i == len(path_nodes) - 1: role = 'Goal'
                        point['role'] = role

                    all_points.append(point)
                
                current_search_offset = span_end
            else:
                print(f"Warning: Could not find node '{node}' (index {i}) in path. Skipping remainder of prompt.")
                break

    print(f"Collected {len(all_points)} data points.")
    return all_points

def plot_interactive_pca(points, title_prefix, color_options, save_path=None):
    if not points:
        print("No points to plot.")
        return

    print(f"Generating interactive PCA plot with color options: {color_options}")

    X = np.array([p["vector"] for p in points])
    pca = PCA(n_components=3, random_state=42)
    X3 = pca.fit_transform(X)

    df = pd.DataFrame(X3, columns=['PC1', 'PC2', 'PC3'])
    for option in color_options:
        df[option] = [p.get(option) for p in points]

    fig = go.Figure()

    custom_color_maps = {
        'arrival_direction': {'Start': 'black', 'NORTH': '#1f77b4', 'SOUTH': '#ff7f0e', 'EAST': '#2ca02c', 'WEST': '#d62728'},
        'role': {'Start': 'green', 'Goal': 'red', 'Intermediate': 'blue'},
    }
    colormap_variables = {'coordinate': 'plasma', 'path_index': 'viridis'}
    trace_info = []

    for i, color_by in enumerate(color_options):
        df_temp = df.copy()
        df_temp[color_by] = df_temp[color_by].astype(str)
        unique_vals = sorted(df_temp[color_by].unique(), key=lambda x: int(x) if x.isdigit() else x)

        if color_by in custom_color_maps:
            color_map = custom_color_maps[color_by]
        elif color_by in colormap_variables:
            colormap_name = colormap_variables[color_by]
            colorscale = getattr(px.colors.sequential, colormap_name.capitalize(), px.colors.sequential.Viridis)
            color_map = {val: colorscale[int(j * (len(colorscale) - 1) / max(1, len(unique_vals) - 1))] for j, val in enumerate(unique_vals)}
        else:
            color_sequence = px.colors.qualitative.Set1
            color_map = {val: color_sequence[j % len(color_sequence)] for j, val in enumerate(unique_vals)}

        for val in unique_vals:
            mask = df_temp[color_by] == val
            if not mask.any(): continue

            hover_text = []
            for idx in df[mask].index:
                row = df.loc[idx]
                hover_items = [f"index: {idx}"] + [f"{opt}: {val_hover}" for opt, val_hover in row.items() if opt not in ['PC1', 'PC2', 'PC3']]
                hover_text.append("<br>".join(hover_items))

            fig.add_trace(go.Scatter3d(
                x=df.loc[mask, 'PC1'], y=df.loc[mask, 'PC2'], z=df.loc[mask, 'PC3'],
                mode='markers', marker=dict(size=5, color=color_map.get(val, 'grey')),
                text=hover_text, hoverinfo='text', name=str(val),
                visible=(i == 0), showlegend=True
            ))
            trace_info.append((i, val))

    buttons = []
    for i, color_by in enumerate(color_options):
        visibility = [trace_color_idx == i for trace_color_idx, _ in trace_info]
        buttons.append(dict(method='update', label=color_by.replace('_', ' ').title(), args=[{'visible': visibility}]))

    fig.update_layout(
        updatemenus=[dict(
            active=0, buttons=buttons, direction="down",
            pad={"r": 10, "t": 10}, showactive=True,
            x=0.01, xanchor="left", y=0.99, yanchor="top"
        )],
        title=title_prefix,
        scene=dict(xaxis_title='PC 1', yaxis_title='PC 2', zaxis_title='PC 3'),
        margin=dict(l=0, r=0, b=0, t=50)
    )

    fig.show()
    if save_path:
        fig.write_html(save_path)
        print(f"Interactive plot saved to: {save_path}")

def run_analysis(model_path, layer_idx, num_prompts, grid_size, context_type,
                 analysis_target, aggregation_method,
                 min_walk_len=20, max_walk_len=20, exclude_first_n=0, save_plot=False):
    if exclude_first_n > 0 and aggregation_method != 'all':
        raise ValueError("Cannot use 'exclude_first_n' with 'average' aggregation.")

    print("-" * 50); print(f"Starting Simplified Interactive SP Analysis:")
    model_name = "SPRW" if model_path == SPRW else "SPH"
    print(f"  - Model: {model_name}")
    print(f"  - Layer: {layer_idx}, Grid: {grid_size}x{grid_size}, Context: {context_type}")
    if exclude_first_n > 0:
        print(f"  - Excluding first {exclude_first_n} points from each path.")
    print("-" * 50)

    model, tokenizer = load_model_and_tokenizer(model_path)
    points = collect_embeddings(model, tokenizer, num_prompts, layer_idx, grid_size, context_type, min_walk_len, max_walk_len, analysis_target, aggregation_method)

    if not points:
        print("No data points were collected. Exiting.")
        return

    if exclude_first_n > 0:
        original_count = len(points)
        points = [p for p in points if p.get('path_index', -1) >= exclude_first_n]
        print(f"Filtered points: {original_count} -> {len(points)}")

    if not points:
        print("No points left to plot after filtering.")
        return

    title_prefix = f"L{layer_idx}, {grid_size}x{grid_size} | {analysis_target.replace('_', ' ')}"
    if exclude_first_n > 0:
        title_prefix += f" (p_idx ≥ {exclude_first_n})"

    available_keys = sorted(list(set().union(*(p.keys() for p in points)) - {'vector'}))

    save_path = None
    if save_plot:
        node_type = analysis_target.replace('_', '')
        save_path = f"{model_name}_{node_type}_{grid_size}x{grid_size}_L{layer_idx}.html"

    plot_interactive_pca(points, title_prefix, color_options=available_keys, save_path=save_path)

if __name__ == '__main__':
    run_analysis(
        model_path=SPH,
        layer_idx=1,
        num_prompts=500,
        grid_size=3,
        context_type="H",
        analysis_target="context_nodes",
        aggregation_method="all",
        exclude_first_n=0,
        min_walk_len=9,
        max_walk_len=9,
        save_plot=False
    )

    run_analysis(
        model_path=SPRW,
        layer_idx=12,
        num_prompts=500,
        grid_size=3,
        context_type="RW",
        analysis_target="context_nodes",
        aggregation_method="all",
        exclude_first_n=0,
        min_walk_len=50,
        max_walk_len=50,
        save_plot=False
    )

--------------------------------------------------
Starting Simplified Interactive SP Analysis:
  - Model: SPH
  - Layer: 1, Grid: 3x3, Context: H
--------------------------------------------------
Collecting embeddings for 'context_nodes'...


Processing prompts:  85%|████████▍ | 423/500 [00:01<00:00, 338.77it/s]

Processing prompts: 100%|██████████| 500/500 [00:01<00:00, 334.32it/s]


Collected 4500 data points.
Generating interactive PCA plot with color options: ['arrival_direction', 'coordinate', 'path_index']


--------------------------------------------------
Starting Simplified Interactive SP Analysis:
  - Model: SPRW
  - Layer: 12, Grid: 3x3, Context: RW
--------------------------------------------------
Collecting embeddings for 'context_nodes'...


Processing prompts: 100%|██████████| 500/500 [00:02<00:00, 202.84it/s]


Collected 25000 data points.
Generating interactive PCA plot with color options: ['arrival_direction', 'coordinate', 'path_index']
