# Semantic Shift Visualization via Lorenz Attractor
Track how word senses evolve across parliamentary terms using a Lorenz attractor-driven trajectory.

In [None]:
# ## Imports
%pip install "elasticsearch==8.6.2" sentence-transformers scikit-learn pandas matplotlib scipy plotly
from elasticsearch import Elasticsearch
from sentence_transformers import SentenceTransformer
from sklearn.cluster import AffinityPropagation
from sklearn.metrics.pairwise import cosine_similarity
from scipy.integrate import odeint
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import plotly.graph_objects as go
import re, os

In [None]:
# ## Connect to Elasticsearch
es = Elasticsearch(ES_URL)
print("Connected to Elasticsearch")
print(es.info().body["version"]["number"])


In [None]:
# ## Configuration
INDEX_NAME = "parliament_speeches"
ES_URL = "https://analog-advisory-many-specialists.trycloudflare.com"   # adjust if different
TARGET_WORD = "vergi"
START_TERM = 17
END_TERM = 27
YEARS_PER_TERM = 5
BASELINE_MAX_CLUSTERS = 50
MAX_CLUSTERS = 100
SIMILARITY_THRESHOLD = 0.8
TOP_K_CLUSTERS = 3  # Track top-3 clusters per year
OUTPUT_DIR = "./lorenz_results"
os.makedirs(OUTPUT_DIR, exist_ok=True)

## Helper Functions


In [None]:
def fetch_speeches(term, year, size=10000):
    """Fetch speeches for a specific term and year."""
    query = {
        "size": size,
        "_source": ["content", "term", "year"],
        "query": {
            "bool": {
                "must": [
                    {"term": {"term": term}},
                    {"term": {"year": year}}
                ]
            }
        }
    }
    res = es.search(index=INDEX_NAME, body=query)
    return [hit["_source"]["content"] for hit in res["hits"]["hits"]]


In [None]:
def make_term_year_tuples(start_term, end_term):
    result = []
    for term in range(start_term, end_term + 1):
        for year in range(1, 6):
            result.append((term, year))
    return result

TERM_YEAR_TUPLES=make_term_year_tuples(START_TERM,END_TERM)
print(f"Processing {len(TERM_YEAR_TUPLES)} term-year pairs from {TERM_YEAR_TUPLES[0]} to {TERM_YEAR_TUPLES[-1]}")

In [None]:
def extract_contexts(texts, target_word, window=10):
    """Extract short context windows around target word and its morphological variations."""
    contexts = []
    pattern = re.compile(rf"\b{re.escape(target_word.lower())}\w*\b")

    for t in texts:
        tokens = re.findall(r"\w+", t.lower())
        for i, tok in enumerate(tokens):
            if pattern.match(tok):
                start = max(0, i - window)
                end = min(len(tokens), i + window + 1)
                snippet = " ".join(tokens[start:end])
                contexts.append(snippet)
    return contexts


In [None]:
def compute_embeddings(model, contexts):
    """Compute embeddings for context snippets."""
    if len(contexts) == 0:
        return np.empty((0, model.get_sentence_embedding_dimension()))
    return model.encode(contexts, show_progress_bar=True)


In [None]:
def get_cluster_prototypes(X, labels, return_label_ids=False):
    """Compute centroids for each cluster and optionally return their IDs."""
    clusters = []
    label_ids = []
    for label in np.unique(labels):
        if label == -1:
            continue
        members = X[labels == label]
        if len(members) == 0:
            continue
        centroid = np.mean(members, axis=0)
        clusters.append(centroid)
        label_ids.append(label)
    clusters = np.array(clusters)
    if return_label_ids:
        return clusters, label_ids
    return clusters


In [None]:
def limit_clusters(labels, max_clusters):
    """Keep only the largest max_clusters and map the rest to -1."""
    if max_clusters is None:
        return labels, np.unique(labels).tolist()
    labels = np.asarray(labels)
    unique, counts = np.unique(labels, return_counts=True)
    cluster_counts = [
        (label, count) for label, count in zip(unique, counts) if label != -1
    ]
    cluster_counts.sort(key=lambda item: item[1], reverse=True)
    keep = [label for label, _ in cluster_counts[:max_clusters]]
    if not keep:
        return np.full_like(labels, -1), []
    filtered = np.array([label if label in keep else -1 for label in labels], dtype=labels.dtype)
    return filtered, keep


In [None]:
class ClusterAligner:
    """Keeps global cluster IDs and assigns consistent colors over time."""

    def __init__(self, max_clusters=100, similarity_threshold=0.8, cmap_name="gist_ncar"):
        self.max_clusters = max_clusters
        self.similarity_threshold = similarity_threshold
        self.centroids = []
        self.global_ids = []
        self.cmap = plt.cm.get_cmap(cmap_name, max_clusters)
        self.palette = [self.cmap(i) for i in range(self.cmap.N)]
        self.overflow_color = (0.65, 0.65, 0.65, 1.0)

    def _add_centroid(self, centroid):
        if len(self.global_ids) >= self.max_clusters:
            return -1
        new_id = len(self.global_ids)
        self.centroids.append(centroid)
        self.global_ids.append(new_id)
        return new_id

    def _match_or_create(self, centroid):
        centroid = centroid.reshape(1, -1)
        if not self.centroids:
            return self._add_centroid(centroid)
        stacked = np.vstack(self.centroids)
        sims = cosine_similarity(stacked, centroid)[:, 0]
        best_idx = int(np.argmax(sims))
        if sims[best_idx] >= self.similarity_threshold:
            return self.global_ids[best_idx]
        return self._add_centroid(centroid)

    def align(self, raw_labels, centroid_map):
        aligned = np.full_like(raw_labels, -1)
        for local_label, centroid in centroid_map.items():
            global_id = self._match_or_create(centroid)
            if global_id == -1:
                continue
            aligned[raw_labels == local_label] = global_id
        return aligned

    def get_color(self, label):
        if 0 <= label < len(self.palette):
            return self.palette[label]
        return self.overflow_color


In [None]:
def create_cluster_guide(cluster_contexts_map, target_word, output_dir, aligner):
    """
    Create a cluster guide with summary CSV and detailed context file.
    Shows what each global cluster represents semantically.
    """

    if not cluster_contexts_map:
        print("  No clusters to document.")
        return

    # Calculate statistics for each cluster
    guide_rows = []
    for global_id in sorted(cluster_contexts_map.keys()):
        contexts = cluster_contexts_map[global_id]
        term_years = sorted(set(f"T{ctx['term']}Y{ctx['year']}" for ctx in contexts))

        guide_rows.append({
            'global_id': global_id,
            'color_index': global_id,
            'total_contexts': len(contexts),
            'term_year_span': ', '.join(term_years),
            'num_appearances': len(term_years)
        })

    # Create summary CSV
    df_summary = pd.DataFrame(guide_rows).sort_values('total_contexts', ascending=False)
    summary_path = os.path.join(output_dir, f"cluster_guide_{target_word}_summary.csv")
    df_summary.to_csv(summary_path, index=False)
    print(f"  Saved cluster summary to {summary_path}")

    # Create detailed context file
    context_file_path = os.path.join(output_dir, f"cluster_guide_{target_word}_contexts.txt")
    with open(context_file_path, 'w', encoding='utf-8') as f:
        f.write(f"{'='*80}\n")
        f.write(f"CLUSTER GUIDE FOR '{target_word.upper()}'\n")
        f.write(f"Generated: {pd.Timestamp.now()}\n")
        f.write(f"Total clusters: {len(cluster_contexts_map)}\n")
        f.write(f"{'='*80}\n\n")

        # Sort clusters by total contexts (most common first)
        for global_id in sorted(cluster_contexts_map.keys(),
                               key=lambda x: len(cluster_contexts_map[x]),
                               reverse=True):
            contexts = cluster_contexts_map[global_id]
            term_years = sorted(set(f"T{ctx['term']}Y{ctx['year']}" for ctx in contexts))

            # Get color info
            color = aligner.get_color(global_id)
            color_hex = '#{:02x}{:02x}{:02x}'.format(
                int(color[0]*255), int(color[1]*255), int(color[2]*255)
            )

            f.write(f"\n{'='*80}\n")
            f.write(f"CLUSTER {global_id} (Color: {color_hex})\n")
            f.write(f"{'-'*80}\n")
            f.write(f"Total contexts: {len(contexts)}\n")
            f.write(f"Appearances: {len(term_years)} term-years\n")
            f.write(f"Term-year span: {', '.join(term_years)}\n")
            f.write(f"\nREPRESENTATIVE CONTEXTS:\n")
            f.write(f"{'-'*80}\n")

            # Show up to 15 diverse examples
            shown = 0
            for ctx_item in contexts[:15]:
                f.write(f"\n[{ctx_item['term']}-{ctx_item['year']}] ")
                f.write(ctx_item['context'][:250])
                if len(ctx_item['context']) > 250:
                    f.write("...")
                f.write("\n")
                shown += 1

            if len(contexts) > 15:
                f.write(f"\n... and {len(contexts) - 15} more contexts\n")

    print(f"  Saved detailed contexts to {context_file_path}")
    print(f"  Total clusters documented: {len(cluster_contexts_map)}")

    return df_summary


In [None]:
# Create Color Reference and Mapping
def create_color_reference(df_timeline, target_word, output_dir, aligner):
    """Create color reference chart and mapping CSV."""

    if df_timeline is None or len(df_timeline) == 0:
        print("  No timeline data to map.")
        return

    # Get all global IDs from timeline
    used_cluster_ids = sorted(df_timeline['global_id'].unique())

    # Create color mapping CSV
    color_mapping = []
    for global_id in used_cluster_ids:
        color = aligner.get_color(global_id)
        color_hex = '#{:02x}{:02x}{:02x}'.format(
            int(color[0]*255), int(color[1]*255), int(color[2]*255)
        )
        color_rgb = f"({int(color[0]*255)}, {int(color[1]*255)}, {int(color[2]*255)})"

        color_mapping.append({
            'global_id': global_id,
            'hex_color': color_hex,
            'rgb_color': color_rgb
        })

    df_colors = pd.DataFrame(color_mapping)
    color_csv_path = os.path.join(output_dir, f'cluster_colors_{target_word}.csv')
    df_colors.to_csv(color_csv_path, index=False)
    print(f"  Saved color mapping to {color_csv_path}")

    # Print console reference
    print("\n=== Color Reference ===")
    for _, row in df_colors.iterrows():
        print(f"Cluster {row['global_id']}: {row['hex_color']}")

    # Create visual color reference chart
    import matplotlib.patches as mpatches

    fig, ax = plt.subplots(figsize=(12, max(6, len(used_cluster_ids) // 4)))
    patches = []

    for global_id in used_cluster_ids:
        color = aligner.get_color(global_id)
        patches.append(mpatches.Patch(color=color, label=f'Cluster {global_id}'))

    ax.legend(handles=patches, loc='center', ncol=min(4, len(patches)), fontsize=10)
    ax.axis('off')
    plt.title(f"Color Reference for '{target_word}' Clusters", fontsize=14, fontweight='bold')
    plt.tight_layout()

    chart_path = os.path.join(output_dir, f'color_reference_{target_word}.png')
    plt.savefig(chart_path, dpi=150, bbox_inches='tight')
    plt.show()
    print(f"  Saved color reference chart to {chart_path}")

    return df_colors


In [None]:
# ## Load Sentence Transformer Model
model = SentenceTransformer("all-MiniLM-L6-v2")
print("Model loaded")


## Clustering and Tracking


In [None]:
# Data structure to track cluster evolution
cluster_timeline = []  # List of dicts: {term, year, global_id, centroid, size, contexts}
cluster_contexts_map = {}  # Store contexts per global_id for cluster guide

aligner = ClusterAligner(max_clusters=MAX_CLUSTERS, similarity_threshold=SIMILARITY_THRESHOLD)
baseline_used = False

print(f"\n=== Analyzing '{TARGET_WORD}' across {len(TERM_YEAR_TUPLES)} term-year pairs ===")

for term, year in TERM_YEAR_TUPLES:
    print(f"\n--- Term {term}, Year {year} ---")
    texts = fetch_speeches(term, year)
    contexts = extract_contexts(texts, TARGET_WORD)
    print(f"  Contexts: {len(contexts)}")

    if len(contexts) < 10:
        print("  Not enough contexts, skipping this slice.")
        continue

    embeddings = compute_embeddings(model, contexts)
    ap = AffinityPropagation(random_state=42)
    ap.fit(embeddings)
    local_labels = ap.labels_

    cap = BASELINE_MAX_CLUSTERS if not baseline_used else MAX_CLUSTERS
    limited_labels, kept_clusters = limit_clusters(local_labels, cap)
    print(f"  Raw clusters: {len(np.unique(local_labels))}, kept: {len(kept_clusters)} (cap={cap})")

    prototypes, proto_labels = get_cluster_prototypes(embeddings, limited_labels, return_label_ids=True)
    centroid_map = dict(zip(proto_labels, prototypes))
    if not centroid_map:
        print("  No clusters survived filtering, skipping.")
        continue

    baseline_used = True
    aligned_labels = aligner.align(limited_labels, centroid_map)

    # Count cluster sizes
    cluster_sizes = {}
    for label in aligned_labels:
        if label >= 0:
            cluster_sizes[label] = cluster_sizes.get(label, 0) + 1

    # Get top-K clusters by size
    top_clusters = sorted(cluster_sizes.items(), key=lambda x: x[1], reverse=True)[:TOP_K_CLUSTERS]
    print(f"  Top {len(top_clusters)} clusters: {top_clusters}")

    # Store cluster info and contexts
    for global_id, size in top_clusters:
        # Find the centroid for this global_id
        centroid_idx = aligner.global_ids.index(global_id)
        centroid = aligner.centroids[centroid_idx]

        # Get contexts belonging to this global cluster (up to 10 examples)
        cluster_context_examples = [
            contexts[i] for i, label in enumerate(aligned_labels) if label == global_id
        ]

        # Store contexts for cluster guide
        if global_id not in cluster_contexts_map:
            cluster_contexts_map[global_id] = []
        cluster_contexts_map[global_id].extend([
            {'term': term, 'year': year, 'context': ctx}
            for ctx in cluster_context_examples[:10]
        ])

        cluster_timeline.append({
            'term': term,
            'year': year,
            'global_id': global_id,
            'centroid': centroid,
            'size': size,
            'total_contexts': len(contexts)
        })

print(f"\n=== Collected {len(cluster_timeline)} cluster snapshots ===")
print(f"=== Stored contexts for {len(cluster_contexts_map)} unique clusters ===")


In [None]:
# Convert to DataFrame for easier manipulation
df_timeline = pd.DataFrame(cluster_timeline)
df_timeline['size_share'] = df_timeline['size'] / df_timeline['total_contexts']
df_timeline['time_idx'] = df_timeline.groupby('global_id').cumcount()
print(df_timeline.head(10))

# Generate cluster guide
print("\n=== Generating Cluster Guide ===")
create_cluster_guide(cluster_contexts_map, TARGET_WORD, OUTPUT_DIR, aligner)

# Generate color reference
print("\n=== Generating Color Reference ===")
create_color_reference(df_timeline, TARGET_WORD, OUTPUT_DIR, aligner)


## Compute Lorenz Variables


In [None]:
def compute_lorenz_coords(df_timeline):
    """Compute Lorenz coordinates (x, y, z) for each cluster snapshot."""
    lorenz_data = []

    for global_id in df_timeline['global_id'].unique():
        cluster_df = df_timeline[df_timeline['global_id'] == global_id].sort_values(['term', 'year'])

        if len(cluster_df) == 0:
            continue

        baseline_centroid = cluster_df.iloc[0]['centroid']
        prev_centroid = None

        for idx, row in cluster_df.iterrows():
            current_centroid = row['centroid']

            # x: distance from baseline
            baseline_sim = cosine_similarity(
                current_centroid.reshape(1, -1),
                baseline_centroid.reshape(1, -1)
            )[0, 0]
            x = 1.0 - baseline_sim

            # y: local drift (distance from previous)
            if prev_centroid is not None:
                local_sim = cosine_similarity(
                    current_centroid.reshape(1, -1),
                    prev_centroid.reshape(1, -1)
                )[0, 0]
                y = 1.0 - local_sim
            else:
                y = 0.0

            # z: size share
            z = row['size_share']

            lorenz_data.append({
                'term': row['term'],
                'year': row['year'],
                'global_id': global_id,
                'x': x,
                'y': y,
                'z': z,
                'size': row['size']
            })

            prev_centroid = current_centroid

    return pd.DataFrame(lorenz_data)

df_lorenz = compute_lorenz_coords(df_timeline)
print(df_lorenz.head(10))
print(f"\nLorenz coordinates computed for {len(df_lorenz)} snapshots")


## Lorenz System Integration


In [None]:
def lorenz_system(state, t, sigma=10.0, rho=28.0, beta=8.0/3.0):
    """Classic Lorenz attractor equations."""
    x, y, z = state
    dx = sigma * (y - x)
    dy = x * (rho - z) - y
    dz = x * y - beta * z
    return [dx, dy, dz]

def integrate_cluster_trajectory(cluster_data, dt=0.01, steps_per_point=50):
    """Integrate Lorenz system driven by cluster data."""
    trajectory = []

    if len(cluster_data) == 0:
        return np.array([])

    # Initialize with first data point (scaled up for visibility)
    x0 = cluster_data.iloc[0]['x'] * 10
    y0 = cluster_data.iloc[0]['y'] * 10
    z0 = cluster_data.iloc[0]['z'] * 30
    state = [x0, y0, z0]

    for idx, row in cluster_data.iterrows():
        # Use data as perturbation/forcing
        target_x = row['x'] * 10
        target_y = row['y'] * 10
        target_z = row['z'] * 30

        # Integrate towards target with Lorenz dynamics
        t_span = np.linspace(0, dt * steps_per_point, steps_per_point)

        for step in range(steps_per_point):
            # Add attraction towards data point
            force_x = (target_x - state[0]) * 0.1
            force_y = (target_y - state[1]) * 0.1
            force_z = (target_z - state[2]) * 0.1

            # Lorenz dynamics
            d_state = lorenz_system(state, 0)

            # Combine
            state[0] += (d_state[0] + force_x) * dt
            state[1] += (d_state[1] + force_y) * dt
            state[2] += (d_state[2] + force_z) * dt

            trajectory.append([
                state[0], state[1], state[2],
                row['term'], row['year'], row['global_id'], row['size']
            ])

    return np.array(trajectory)

# Generate trajectories for each cluster
all_trajectories = {}
for global_id in df_lorenz['global_id'].unique():
    cluster_data = df_lorenz[df_lorenz['global_id'] == global_id].sort_values(['term', 'year'])
    traj = integrate_cluster_trajectory(cluster_data)
    if len(traj) > 0:
        all_trajectories[global_id] = traj
        print(f"Cluster {global_id}: {len(traj)} trajectory points")

print(f"\n=== Generated {len(all_trajectories)} cluster trajectories ===")


## 3D Visualization

In [None]:
# Matplotlib 3D plot
fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection='3d')

for global_id, traj in all_trajectories.items():
    color = aligner.get_color(global_id)
    ax.plot(traj[:, 0], traj[:, 1], traj[:, 2],
            color=color, alpha=0.7, linewidth=2, label=f'Cluster {global_id}')
    # Mark start and end
    ax.scatter(traj[0, 0], traj[0, 1], traj[0, 2],
              color=color, s=100, marker='o', edgecolors='black', linewidths=2)
    ax.scatter(traj[-1, 0], traj[-1, 1], traj[-1, 2],
              color=color, s=100, marker='s', edgecolors='black', linewidths=2)

ax.set_xlabel('X (Baseline Drift)', fontsize=12)
ax.set_ylabel('Y (Local Drift)', fontsize=12)
ax.set_zlabel('Z (Size Share)', fontsize=12)
ax.set_title(f'Lorenz Attractor: Semantic Evolution of "{TARGET_WORD}"\n({START_TERM},{1}) to ({END_TERM},{YEARS_PER_TERM})',
             fontsize=14, fontweight='bold')
ax.legend(loc='upper right', fontsize=8)
ax.view_init(elev=20, azim=45)

plt.tight_layout()
plot_path = os.path.join(OUTPUT_DIR, f'lorenz_{TARGET_WORD}.png')
plt.savefig(plot_path, dpi=150, bbox_inches='tight')
plt.show()
print(f"Saved plot to {plot_path}")

In [None]:
# Interactive Plotly 3D plot
fig_plotly = go.Figure()

for global_id, traj in all_trajectories.items():
    color_tuple = aligner.get_color(global_id)
    color_str = f'rgba({int(color_tuple[0]*255)},{int(color_tuple[1]*255)},{int(color_tuple[2]*255)},{color_tuple[3]})'

    # Trajectory line
    fig_plotly.add_trace(go.Scatter3d(
        x=traj[:, 0], y=traj[:, 1], z=traj[:, 2],
        mode='lines',
        line=dict(color=color_str, width=4),
        name=f'Cluster {global_id}',
        hovertext=[f'T{int(t)},Y{int(y)}' for t, y in zip(traj[:, 3], traj[:, 4])],
        hoverinfo='text'
    ))

    # Start marker
    fig_plotly.add_trace(go.Scatter3d(
        x=[traj[0, 0]], y=[traj[0, 1]], z=[traj[0, 2]],
        mode='markers',
        marker=dict(size=8, color=color_str, symbol='circle', line=dict(color='black', width=2)),
        name=f'Start {global_id}',
        showlegend=False
    ))

    # End marker
    fig_plotly.add_trace(go.Scatter3d(
        x=[traj[-1, 0]], y=[traj[-1, 1]], z=[traj[-1, 2]],
        mode='markers',
        marker=dict(size=8, color=color_str, symbol='square', line=dict(color='black', width=2)),
        name=f'End {global_id}',
        showlegend=False
    ))

fig_plotly.update_layout(
    title=f'Lorenz Attractor: Semantic Evolution of "{TARGET_WORD}"',
    scene=dict(
        xaxis_title='X (Baseline Drift)',
        yaxis_title='Y (Local Drift)',
        zaxis_title='Z (Size Share)'
    ),
    width=1000,
    height=800
)

html_path = os.path.join(OUTPUT_DIR, f'lorenz_{TARGET_WORD}_interactive.html')
fig_plotly.write_html(html_path)
print(f"Saved interactive plot to {html_path}")
fig_plotly.show()

## Animated Plotly Visualization


In [None]:
# Create animated Plotly visualization with frame-by-frame progression
def create_animated_lorenz(all_trajectories, aligner, target_word, output_dir, frame_step=5):
    """
    Create an animated 3D plot where trajectories progressively reveal over time.
    frame_step: how many integration steps to advance per frame (lower = smoother but more frames)
    """

    # Determine max trajectory length for frame count
    max_len = max(len(traj) for traj in all_trajectories.values())
    num_frames = max_len // frame_step

    frames = []

    for frame_idx in range(0, max_len, frame_step):
        frame_data = []

        for global_id, traj in all_trajectories.items():
            color_tuple = aligner.get_color(global_id)
            color_str = f'rgba({int(color_tuple[0]*255)},{int(color_tuple[1]*255)},{int(color_tuple[2]*255)},{color_tuple[3]})'

            # Get trajectory up to current frame
            end_idx = min(frame_idx + 1, len(traj))
            if end_idx == 0:
                continue

            current_traj = traj[:end_idx]

            # Trajectory line (growing)
            frame_data.append(go.Scatter3d(
                x=current_traj[:, 0],
                y=current_traj[:, 1],
                z=current_traj[:, 2],
                mode='lines',
                line=dict(color=color_str, width=4),
                name=f'Cluster {global_id}',
                hovertext=[f'T{int(t)},Y{int(y)}' for t, y in zip(current_traj[:, 3], current_traj[:, 4])],
                hoverinfo='text',
                showlegend=(frame_idx == 0)
            ))

            # Current position marker (moving point)
            if end_idx > 0:
                frame_data.append(go.Scatter3d(
                    x=[current_traj[-1, 0]],
                    y=[current_traj[-1, 1]],
                    z=[current_traj[-1, 2]],
                    mode='markers',
                    marker=dict(size=10, color=color_str, symbol='diamond',
                               line=dict(color='white', width=2)),
                    name=f'Current {global_id}',
                    showlegend=False
                ))

        frames.append(go.Frame(data=frame_data, name=str(frame_idx)))

    # Initial frame (empty or first frame)
    initial_data = frames[0].data if frames else []

    # Create figure
    fig_anim = go.Figure(
        data=initial_data,
        frames=frames
    )

    # Add play/pause buttons and slider
    fig_anim.update_layout(
        title=dict(
            text=f'Lorenz Attractor: Semantic Evolution of "{target_word}" (Animated)',
            font=dict(size=18)
        ),
        scene=dict(
            xaxis_title='X (Baseline Drift)',
            yaxis_title='Y (Local Drift)',
            zaxis_title='Z (Size Share)',
            camera=dict(
                eye=dict(x=1.5, y=1.5, z=1.2)
            )
        ),
        width=1200,
        height=900,
        updatemenus=[
            dict(
                type='buttons',
                showactive=False,
                buttons=[
                    dict(
                        label='Play',
                        method='animate',
                        args=[None, dict(
                            frame=dict(duration=50, redraw=True),
                            fromcurrent=True,
                            mode='immediate',
                            transition=dict(duration=0)
                        )]
                    ),
                    dict(
                        label='Pause',
                        method='animate',
                        args=[[None], dict(
                            frame=dict(duration=0, redraw=False),
                            mode='immediate',
                            transition=dict(duration=0)
                        )]
                    )
                ],
                x=0.1, y=0.0, xanchor='left', yanchor='bottom'
            )
        ],
        sliders=[
            dict(
                active=0,
                steps=[
                    dict(
                        args=[[f.name], dict(
                            frame=dict(duration=0, redraw=True),
                            mode='immediate',
                            transition=dict(duration=0)
                        )],
                        label=str(i),
                        method='animate'
                    ) for i, f in enumerate(frames)
                ],
                x=0.1, y=0.0, len=0.9, xanchor='left', yanchor='top',
                currentvalue=dict(
                    visible=True,
                    prefix='Frame: ',
                    xanchor='right'
                )
            )
        ]
    )

    # Save and display
    html_path = os.path.join(output_dir, f'lorenz_{target_word}_animated.html')
    fig_anim.write_html(html_path)
    print(f"Saved animated plot to {html_path}")

    return fig_anim

# Generate animation
fig_animated = create_animated_lorenz(all_trajectories, aligner, TARGET_WORD, OUTPUT_DIR, frame_step=10)
fig_animated.show()


In [None]:
# Save data to CSV
csv_path = os.path.join(OUTPUT_DIR, f'lorenz_data_{TARGET_WORD}.csv')
df_lorenz.to_csv(csv_path, index=False)
print(f"Saved Lorenz coordinates to {csv_path}")