In [None]:
import networkx as nx
import leidenalg as la
import igraph as ig
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from models import Observation
from config import STATEMENT_FILE
from sentence_transformers.util import cos_sim
from datetime import datetime
import plotly.graph_objects as go
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
from graphprompt import gprompt_reason, gprompt_reason_top

In [2]:
SEED = 42
DEVICE = "mps"

In [3]:
# Create the interactive network visualization
def create_interactive_network(G, pos, node_colors, statements, show_edges=False):
    # Create edges traces
    edge_traces = []

    # Add edges if show_edges is True
    if show_edges:
        for edge in G.edges(data=True):
            x0, y0 = pos[edge[0]]
            x1, y1 = pos[edge[1]]
            weight = edge[2]["weight"]

            # Different styling based on weight
            width = min(max(weight, 0), 1)
            opacity = min(max(weight, 0), 1)

            edge_trace = go.Scatter(
                x=[x0, x1, None],
                y=[y0, y1, None],
                line=dict(width=width, color="#888", dash="solid"),
                opacity=opacity,
                hoverinfo="none",
                mode="lines",
            )
            edge_traces.append(edge_trace)

    # Create nodes trace
    node_x = []
    node_y = []
    hover_texts = []

    for node in G.nodes():
        x, y = pos[node]
        node_x.append(x)
        node_y.append(y)
        # Create hover text with statement info and text wrapping
        stmt = statements[node]

        # Wrap text at 50 characters
        what_wrapped = "<br>".join(
            [stmt["what"][i : i + 50] for i in range(0, len(stmt["what"]), 50)]
        )
        how_wrapped = "<br>".join(
            [stmt["how"][i : i + 50] for i in range(0, len(stmt["how"]), 50)]
        )
        citations_wrapped = "<br>".join(
            [f"• {cite[i:i+50]}" for cite in stmt["citations"] for i in range(0, len(cite), 50)]
        )

        hover_text = (
            f"ID: {node}<br>"
            f"Date: {stmt['date'].strftime('%Y-%m')}<br>"
            f"<b>What:</b><br>{what_wrapped}<br><br>"
            f"<b>How:</b><br>{how_wrapped}<br><br>"
            f"<b>Citations:</b><br>{citations_wrapped}<br><br>"
            f"<b>Source:</b><br>{stmt['source']}"
        )
        hover_texts.append(hover_text)

    node_trace = go.Scatter(
        x=node_x,
        y=node_y,
        mode="markers",
        hoverinfo="text",
        text=hover_texts,
        marker=dict(color=node_colors, size=10, line_width=2),
    )

    # Create figure
    fig = go.Figure(
        data=[*edge_traces, node_trace],
        layout=go.Layout(
            title="Economic Period Similarity Network (Colored by Cluster)",
            showlegend=False,
            hovermode="closest",
            margin=dict(b=20, l=5, r=5, t=40),
            xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
            yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        ),
    )

    return fig

In [4]:
import pickle

with open(STATEMENT_FILE, "rb") as f:
    statements: list[Observation] = pickle.load(f)

for i, statement in enumerate(statements):
    statement["id"] = i

In [5]:
# Create similarity caches
what_sims = {}
how_sims = {}


In [22]:
def show_yearly_clusters(year: int = 2018, threshold: float = 0.2, seed: int = 42):
    """
    Build a new graph of statements for a given year, apply Leiden clustering,
    then use our Plotly helper to display the resulting network.

    :param year: The target year (default 2018)
    :param threshold: Similarity threshold for adding edges
    :param seed: Random seed for reproducible layouts
    """
    import igraph as ig
    import leidenalg as la
    from sentence_transformers.util import cos_sim

    # 1) Filter statements for the target year
    yearly_statements = [stmt for stmt in statements if stmt["date"].year == year]
    if not yearly_statements:
        print(f"No statements found for year {year}.")
        return

    # 2) Create a NetworkX graph for the year's statements
    G_year = nx.Graph()
    # Add nodes (store the entire statement dict as node data)
    G_year.add_nodes_from((s["id"], s) for s in yearly_statements)

    # 3) Add edges between similar statements (using a simple loop, as in your existing code)
    for i, stmt1 in enumerate(yearly_statements):
        for stmt2 in yearly_statements[i + 1 :]:
            # Get embeddings for both statements and calculate weighted similarity
            stmt1_id = stmt1["id"]
            stmt2_id = stmt2["id"]
            cache_key = (min(stmt1_id, stmt2_id), max(stmt1_id, stmt2_id)) # type: ignore
            
            if cache_key not in what_sims:
                what_sims[cache_key] = cos_sim(stmt1["what_embedding"], stmt2["what_embedding"])
                how_sims[cache_key] = cos_sim(stmt1["how_embedding"], stmt2["how_embedding"])
                
            what_sim = what_sims[cache_key]
            how_sim = how_sims[cache_key]
            
            sim = 1 * what_sim + 0 * how_sim
            sim = sim**2  # Square it, as in your example

            if float(sim) > threshold:
                G_year.add_edge(stmt1["id"], stmt2["id"], weight=float(sim))

    # 4) Convert to igraph and apply Leiden clustering
    g_ig = ig.Graph.from_networkx(G_year)
    partition = la.find_partition(
        g_ig,
        la.ModularityVertexPartition,
        weights=list(nx.get_edge_attributes(G_year, 'weight').values()),
        # resolution_parameter=.05
    )
    cluster_dict = {node: cluster for node, cluster in zip(G_year.nodes(), partition.membership)}

    # 5) Color the nodes based on their Leiden cluster
    color_palette = [
        "red", "blue", "green", "purple", "orange", "cyan", "magenta",
        "yellow", "brown", "pink"
    ]
    num_colors = len(color_palette)
    node_colors = []
    for node, data in G_year.nodes(data=True):
        cluster_id = cluster_dict[node]
        node_colors.append(color_palette[cluster_id % num_colors])

    # 6) Create a layout and visualize with your existing Plotly helper
    pos = nx.spring_layout(G_year, seed=seed)

    # Reuse the statements list (this helper expects a dictionary keyed by the node index)
    # We'll create a quick lookup so that 'node' in G_year is used as key in statements_lookup
    statements_lookup = {stmt["id"]: stmt for stmt in yearly_statements}

    fig = create_interactive_network(G_year, pos, node_colors, statements_lookup, show_edges=True)
    fig.update_layout(title=f"Leiden Clustering for {year}")
    fig.show()
show_yearly_clusters(year=2001,threshold=.25)

In [16]:
THRESHHOLD = 0.35

# Create graph
G = nx.Graph()


G.add_nodes_from((s["id"], s) for s in statements)

# Add edges between similar statements using real embeddings
total = len(statements) * (len(statements) - 1) // 2  # Total number of pairs
with tqdm(total=total, desc="Processing statement pairs") as pbar:
    for i, stmt1 in enumerate(statements):
        for j, stmt2 in enumerate(statements[i + 1 :], start=i + 1):
            pbar.update(1)
            # Get embeddings for both statements and calculate weighted similarity
            stmt1_id = stmt1["id"]
            stmt2_id = stmt2["id"]
            cache_key = (min(stmt1_id, stmt2_id), max(stmt1_id, stmt2_id)) # type: ignore
            
            if cache_key not in what_sims:
                what_sims[cache_key] = cos_sim(stmt1["what_embedding"], stmt2["what_embedding"])  # type: ignore
                how_sims[cache_key] = cos_sim(stmt1["how_embedding"], stmt2["how_embedding"])  # type: ignore
                
            what_sim = what_sims[cache_key]
            how_sim = how_sims[cache_key]
            sim = 1 * what_sim + 0 * how_sim  # Weight what more heavily than how
            sim = sim ** 2


            # # Calculate time difference in years
            # time_diff = abs((stmt1["date"] - stmt2["date"]).days / 365.25)
            # # Calculate time penalty (starts near 0 for recent, approaches 1 for distant)
            # time_penalty = 1 - np.exp(-0.5 * time_diff)

            # sim = sim * time_penalty
            

            # Add edge if similarity is high enough
            if sim > THRESHHOLD:
                G.add_edge(i, j, weight=float(sim))

Processing statement pairs: 100%|██████████| 1800253/1800253 [00:16<00:00, 108096.89it/s]


In [24]:
# Convert to igraph for Leiden
g_ig = ig.Graph.from_networkx(G)

# Apply Leiden clustering
partition = la.find_partition(
    g_ig,
    la.ModularityVertexPartition,
    weights=list(nx.get_edge_attributes(G, "weight").values()),
    seed=SEED,
    # n_iterations=-1,
    # resolution_parameter=0.42
)

# Create a color map for nodes based on cluster
node_colors: list[str] = []  # Type annotation

# Define a list of colors that work in both matplotlib and plotly
colors = [
    "red",
    "blue",
    "green",
    "purple",
    "orange",
    "cyan",
    "magenta",
    "yellow",
    "brown",
    "pink",
]
num_colors = len(colors)

for i, cluster in enumerate(partition.membership):
    node_colors.append(colors[cluster % num_colors])  # Rotate through colors
# elarge = [(u, v) for (u, v, d) in G.edges(data=True) if d["weight"] > THRESHHOLD]
# esmall = [(u, v) for (u, v, d) in G.edges(data=True) if d["weight"] <= THRESHHOLD]

pos = nx.spring_layout(G, seed=SEED)


# Create and display the interactive visualization
fig = create_interactive_network(G, pos, node_colors, statements)
fig.show()

In [23]:
import io


def get_period_cluster_comparison(
    G,
    statements: list[Observation],
    partition,
    target_period,
    n_similar_periods=3,
    decay_rate=0.6,  # Controls how quickly the time penalty decays
):
    """
    Compare periods based on cluster-averaged embeddings with time-based penalty
    """
    # Convert target_period to year string
    target_year = target_period.year
    target_year_str = str(target_year)
    
    # Group statements by cluster and year
    cluster_period_map = {}
    for i, stmt in enumerate(statements):
        year = str(stmt["date"].year)
        cluster = partition.membership[i]

        if cluster not in cluster_period_map:
            cluster_period_map[cluster] = {}
        if year not in cluster_period_map[cluster]:
            cluster_period_map[cluster][year] = []

        cluster_period_map[cluster][year].append(stmt)

    # Calculate target cluster embeddings once
    target_cluster_embeddings = {}
    for cluster, periods in cluster_period_map.items():
        if target_year_str in periods:
            target_embeddings = np.stack([s["what_embedding"] for s in periods[target_year_str]])
            target_cluster_embeddings[cluster] = np.mean(target_embeddings, axis=0)

    # Get all years except target
    all_years = set()
    for cluster_data in cluster_period_map.values():
        all_years.update(cluster_data.keys())
    other_years = all_years - {target_year_str}

    # Compare each year to target year
    period_similarities = []
    
    for other_year in other_years:
        matching_clusters = []
        other_year_int = int(other_year)
        
        # Calculate time difference in years
        time_diff = abs(target_year - other_year_int)
        
        # Calculate time penalty
        time_penalty = 1 - np.exp(-decay_rate * time_diff)
        
        # Find clusters that contain both years
        for cluster, periods in cluster_period_map.items():
            if target_year_str in periods and other_year in periods:
                target_what_emb = target_cluster_embeddings[cluster].reshape(1, -1)
                
                # First calculate statement-level similarities
                statement_similarities = []
                for stmt in periods[other_year]:
                    what_sim = float(cos_sim(target_what_emb, stmt["what_embedding"].reshape(1, -1)))
                    how_sim = float(cos_sim(target_what_emb, stmt["how_embedding"].reshape(1, -1)))
                    what_weight = 1
                    sim = what_weight * what_sim + (1 - what_weight) * how_sim
                    adjusted_sim = sim * time_penalty  # Apply time penalty to each statement
                    statement_similarities.append((stmt, adjusted_sim))
                
                # Sort statements by similarity
                sorted_statements = sorted(statement_similarities, key=lambda x: x[1], reverse=True)
                
                # Calculate number of statements to include in ranking (75%)
                n_include = max(1, int(len(sorted_statements) * 0.9))
                included_statements = sorted_statements[:n_include]
                
                # Calculate cluster similarity using only the included statements
                included_what_embeddings = np.stack([s[0]["what_embedding"] for s in included_statements])
                included_how_embeddings = np.stack([s[0]["how_embedding"] for s in included_statements])
                
                what_similarity = float(cos_sim(
                    target_what_emb,
                    np.mean(included_what_embeddings, axis=0).reshape(1, -1)
                ))
                
                how_similarity = float(cos_sim(
                    target_what_emb,
                    np.mean(included_how_embeddings, axis=0).reshape(1, -1)
                ))
                
                what_weight = .95
                cluster_similarity = what_weight * what_similarity + (1 - what_weight) * how_similarity
                cluster_adjusted_similarity = cluster_similarity * time_penalty
                
                excluded_statements = sorted_statements[n_include:]
                
                matching_clusters.append({
                    "cluster": cluster,
                    "similarity": cluster_adjusted_similarity,
                    "raw_similarity": cluster_similarity,
                    "time_penalty": time_penalty,
                    "time_diff_years": time_diff,
                    "target_statements": periods[target_year_str],
                    "other_statements": sorted_statements,  # Keep all statements but mark included/excluded
                    "included_statements": included_statements,
                    "excluded_statements": excluded_statements,
                    "included_in_ranking": False  # Initialize flag
                })
        
        if matching_clusters:
            # Sort clusters by similarity
            sorted_clusters = sorted(matching_clusters, key=lambda x: x["similarity"], reverse=True)
            
            # Calculate number of clusters to remove (max of 5% or 1)
            n_remove = max(1, int(len(sorted_clusters) * 0.10))
            
            # Mark clusters as included/excluded in ranking
            for i, cluster in enumerate(sorted_clusters):
                cluster["included_in_ranking"] = i < (len(sorted_clusters) - n_remove)
                
            # Calculate overall similarity using all except bottom n_remove clusters
            included_clusters = [c for c in sorted_clusters if c["included_in_ranking"]]
            overall_sim = np.mean([c["similarity"] for c in included_clusters])
            period_similarities.append({
                "period": other_year,
                "clusters": sorted_clusters,  # Keep all clusters but only top ones affect ranking
                "overall_similarity": overall_sim,
                "time_diff_years": time_diff,
                "time_penalty": (1- time_penalty) * 100
            })

    return sorted(period_similarities, key=lambda x: x["overall_similarity"], reverse=True)[:n_similar_periods]

# Test the updated function
period = datetime(2018, 1, 1)
report = get_period_cluster_comparison(G, statements, partition, period, n_similar_periods=10)

# Print summary to console
print(f"\nAnalyzing periods similar to {period.strftime('%Y-%m-%d')}")
print("\nMost similar periods:")
for period_data in report:
    print(f"{period_data['period']}: {period_data['overall_similarity']:.3f} (time penalty: {period_data['time_penalty']:.1f}%)")

# Print top period details
# Print top 3 period details
for period_data in report[:1]:
    print(f"\nPeriod {period_data['period']} vs {period.strftime('%Y')}")
    print(f"Time Difference: {period_data['time_diff_years']:.1f} years")
    print(f"Time Penalty: {period_data['time_penalty']:.2f}%")
    print(f"Similarity: {period_data['overall_similarity']:.3f}")

    # Sort clusters by similarity
    sorted_clusters = sorted(period_data["clusters"], key=lambda x: x["similarity"], reverse=True)

    # Print top 3 and bottom cluster
    print("\nTop 3 clusters:")
    for cluster in sorted_clusters[:3]:
        print(f"\nCluster {cluster['cluster']}")
        print(f"Similarity: {cluster['similarity']:.3f} {'✓' if cluster['included_in_ranking'] else '✗'}")
        
        print(f"{period.strftime('%Y')} statements (top 3):")
        for stmt in cluster["target_statements"][:3]:
            print(f"- {stmt['what']}")
            
        print(f"{period_data['period']} included statements (top 3):")
        for stmt, sim in cluster["included_statements"][:3]:
            print(f"- {stmt['what']} (sim: {sim:.3f})")
            
        print(f"{period_data['period']} included statements (bottom 1):")
        if cluster["included_statements"]:
            stmt, sim = cluster["included_statements"][-1]
            print(f"- {stmt['what']} (sim: {sim:.3f})")

    print("\nBottom cluster:")
    cluster = sorted_clusters[-1]
    print(f"\nCluster {cluster['cluster']}")
    print(f"Similarity: {cluster['similarity']:.3f} {'✓' if cluster['included_in_ranking'] else '✗'}")

    print(f"{period.strftime('%Y')} statements (top 3):")
    for stmt in cluster["target_statements"][:3]:
        print(f"- {stmt['what']}")

    print(f"{period_data['period']} included statements (top 3):")
    for stmt, sim in cluster["included_statements"][:3]:
        print(f"- {stmt['what']} (sim: {sim:.3f})")

    print(f"{period_data['period']} included statements (bottom 1):")
    if cluster["included_statements"]:
        stmt, sim = cluster["included_statements"][-1]
        print(f"- {stmt['what']} (sim: {sim:.3f})")



Analyzing periods similar to 2018-01-01

Most similar periods:
2011: 0.896 (time penalty: 1.5%)
2005: 0.892 (time penalty: 0.0%)
2000: 0.888 (time penalty: 0.0%)
2004: 0.875 (time penalty: 0.0%)
2003: 0.874 (time penalty: 0.0%)
2009: 0.867 (time penalty: 0.5%)
2006: 0.862 (time penalty: 0.1%)
2012: 0.862 (time penalty: 2.7%)
2002: 0.862 (time penalty: 0.0%)
2024: 0.852 (time penalty: 2.7%)

Period 2011 vs 2018
Time Difference: 7.0 years
Time Penalty: 1.50%
Similarity: 0.896

Top 3 clusters:

Cluster 3
Similarity: 0.922 ✓
2018 statements (top 3):
- Economic growth was positive across all districts with 11 showing modest to moderate gains and Dallas showing stronger growth
- Economic activity expanded at a modest to moderate pace across all Federal Reserve Districts
- Economic activity expanded moderately across most regions, with Dallas showing notably stronger growth
2011 included statements (top 3):
- Economic activity expanded at a modest to moderate pace across Federal Reserve Dist

In [None]:
# Print basic period info
raw_output = io.StringIO()
analysis_output = io.StringIO()

print(f"\nAnalyzing periods similar to {period.strftime('%Y-%m-%d')}", file=raw_output)
print("\nMost similar periods:", file=raw_output)
for period_data in report:
    print(f"{period_data['period']}: {period_data['overall_similarity']:.3f} (time penalty: {period_data['time_penalty']:.1f}%)", file=raw_output)
print("\nDetailed analysis:", file=raw_output)


# Print results with more detail about time penalties
for period_data in report[:1]:
    print(f"\nPeriod {period_data['period']} vs {period.strftime('%Y')}", file=raw_output)
    print(f"Time Difference: {period_data['time_diff_years']:.1f} years", file=raw_output)
    print(f"Time Penalty: {period_data['time_penalty']:.2f}%", file=raw_output)
    print(f"Similarity: {period_data['overall_similarity']:.3f}", file=raw_output)
    
    cluster_outputs = []
    with ThreadPoolExecutor(max_workers=8) as executor:
        futures = []
        
        for cluster_data in period_data["clusters"]:
            cluster_output = io.StringIO()
            print(f"\n  Cluster {cluster_data['cluster']}", file=cluster_output)
            print(f"  Similarity: {cluster_data['similarity']:.3f} {'✓' if cluster_data['included_in_ranking'] else '✗'}", file=cluster_output)
            print(f"  {period.strftime('%Y')} statements:", file=cluster_output)
            for stmt in cluster_data["target_statements"]:  
                print(f"    - {stmt['what']} (how: {stmt['how']}, id: {stmt['id']}, citations: {stmt['citations']}, source: {stmt['source']})", file=cluster_output)
            print(f"  {period_data['period']} statements:", file=cluster_output)
            for stmt, sim in cluster_data["included_statements"]:
                print(f"    - {stmt['what']} (how: {stmt['how']}, id: {stmt['id']}, citations: {stmt['citations']}, source: {stmt['source']}, sim: {sim:.3f}) ✓", file=cluster_output)
            for stmt, sim in cluster_data["excluded_statements"]:
                print(f"    - {stmt['what']} (how: {stmt['how']}, id: {stmt['id']}, citations: {stmt['citations']}, source: {stmt['source']}, sim: {sim:.3f}) ✗", file=cluster_output)
            
            cluster_text = cluster_output.getvalue()
            cluster_output.close()
            
            print(cluster_text, file=raw_output)
            
            future = executor.submit(gprompt_reason, {"chunk": cluster_text})
            futures.append(future)
            
        # Write reasoning results
        for future in futures:
            print("\nReasoning Analysis:", file=analysis_output)
            print(future.result(), file=analysis_output)
            print("-" * 80, file=analysis_output)

raw_report = raw_output.getvalue()
analysis_report = analysis_output.getvalue()
raw_output.close()
analysis_output.close()

# Save raw cluster data to file
with open('period_comparison_raw.md', 'w') as f:
    f.write(raw_report)

# Save analysis to file
with open('period_comparison_analysis.md', 'w') as f:
    f.write(analysis_report)


In [None]:
with open('period_comparison_analysis.md', 'r') as f:
    analysis_text = f.read()

# Get high-level synthesis of all analyses using reason-top endpoint
final = gprompt_reason_top({"analysis": analysis_text})

print("Top-Level Analysis:")
print(final)

# Save final analysis to file
with open('period_comparison_final.md', 'w') as f:
    f.write(final)


/Users/evanpierce/projects/drw-hack/data/v2/period_comparison_final.html

In [None]:
def find_optimal_resolution(G, resolution_range=(0.01, 2.0), n_steps=20):
    """Find resolution parameter that gives desired number of clusters"""
    # Convert NetworkX graph to igraph
    g_ig = ig.Graph.from_networkx(G)
    
    resolutions = np.linspace(resolution_range[0], resolution_range[1], n_steps)
    results = []
    for res in tqdm(resolutions, desc="Testing resolutions"):
        partition = la.find_partition(
            g_ig,  # Use the igraph Graph
            la.CPMVertexPartition,
            weights=list(nx.get_edge_attributes(G, 'weight').values()),  # Extract weights explicitly
            resolution_parameter=res
        )
        
        # Count statements per cluster
        cluster_sizes = np.bincount(partition.membership)
        avg_size = np.mean(cluster_sizes)
        std_size = np.std(cluster_sizes)
        
        n_clusters = len(set(partition.membership))
        
        results.append({
            'resolution': res,
            'n_clusters': n_clusters,
            'avg_cluster_size': avg_size,
            'std_cluster_size': std_size,
            'min_cluster_size': min(cluster_sizes),
            'max_cluster_size': max(cluster_sizes),
            'partition': partition
        })
        print(f"Resolution: {res:.3f}")
        print(f"Clusters: {n_clusters}")
        print(f"Avg Size: {avg_size:.1f} ± {std_size:.1f}")
        print(f"Size Range: {min(cluster_sizes)} - {max(cluster_sizes)}\n")
        
        # Stop if number of clusters equals number of nodes
        if n_clusters == g_ig.vcount():
            break
    
    return results

# Run this to find good resolution parameter
results = find_optimal_resolution(G)

# Create two subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Plot number of clusters
ax1.plot([r['resolution'] for r in results], 
         [r['n_clusters'] for r in results], 
         'o-')
ax1.set_xlabel('Resolution Parameter')
ax1.set_ylabel('Number of Clusters')
ax1.set_title('Resolution vs Number of Clusters')
ax1.grid(True)

# Plot average cluster size
ax2.plot([r['resolution'] for r in results], 
         [r['avg_cluster_size'] for r in results], 
         'o-', label='Average Size')
ax2.fill_between([r['resolution'] for r in results],
                 [r['avg_cluster_size'] - r['std_cluster_size'] for r in results],
                 [r['avg_cluster_size'] + r['std_cluster_size'] for r in results],
                 alpha=0.2)
ax2.plot([r['resolution'] for r in results], 
         [r['max_cluster_size'] for r in results], 
         '--', label='Max Size', alpha=0.5)
ax2.plot([r['resolution'] for r in results], 
         [r['min_cluster_size'] for r in results], 
         '--', label='Min Size', alpha=0.5)
ax2.set_xlabel('Resolution Parameter')
ax2.set_ylabel('Cluster Size')
ax2.set_title('Resolution vs Cluster Sizes')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.show()

In [None]:
def get_cross_cluster_similarities(
    G,
    statements: list[Observation],
    partition,
    target_period,
    n_similar_periods=3,
    decay_rate=0.35,
    top_k_statements=5,
    top_k_clusters=5,  # New parameter
):
    """
    Find similar statements across all clusters based on target period statements,
    regardless of cluster membership
    """
    target_year = target_period.year
    target_year_str = str(target_year)

    # Get target period statements and their clusters
    target_clusters = {}
    for i, stmt in enumerate(statements):
        if str(stmt["date"].year) == target_year_str:
            cluster = partition.membership[i]
            if cluster not in target_clusters:
                target_clusters[cluster] = []
            target_clusters[cluster].append(stmt)

    # Group all other statements by year and cluster
    year_statements = {}
    cluster_sizes = {}
    for i, stmt in enumerate(statements):
        year = str(stmt["date"].year)
        cluster = partition.membership[i]
        if year != target_year_str:
            if year not in year_statements:
                year_statements[year] = []
            year_statements[year].append(stmt)

            # Track cluster sizes
            if cluster not in cluster_sizes:
                cluster_sizes[cluster] = 0
            cluster_sizes[cluster] += 1

    # Compare each target cluster's statements with all other years
    period_similarities = []

    for year, other_statements in year_statements.items():
        year_int = int(year)
        time_diff = abs(target_year - year_int)
        time_penalty = 1 - np.exp(-decay_rate * time_diff)

        cluster_matches = []

        # For each target cluster
        for cluster, target_statements in target_clusters.items():
            statement_similarities = []

            # Compare each target statement with all statements from other year
            for target_stmt in target_statements:
                target_what_emb = target_stmt["what_embedding"].reshape(1, -1)

                # Compare with each statement from other year
                for other_stmt in other_statements:
                    other_cluster = partition.membership[other_stmt["id"]]
                    cluster_size = cluster_sizes[other_cluster]
                    cluster_weight = 1.0 / (
                        cluster_size ** 0.5
                    )  # Square root to moderate the effect

                    what_sim = float(
                        cos_sim(
                            target_what_emb, other_stmt["what_embedding"].reshape(1, -1)
                        )
                    )
                    how_sim = float(
                        cos_sim(
                            target_stmt["how_embedding"].reshape(1, -1),
                            other_stmt["how_embedding"].reshape(1, -1),
                        )
                    )

                    what_weight = .7
                    sim = what_weight * what_sim + (1 - what_weight) * how_sim
                    adjusted_sim = sim * time_penalty
                    adjusted_sim = adjusted_sim **2
                    # adjusted_sim = adjusted_sim  * cluster_weight eh feature. now have top_k clusters

                    statement_similarities.append(
                        {
                            "target_statement": target_stmt,
                            "other_statement": other_stmt,
                            "similarity": adjusted_sim,
                            "raw_similarity": sim,
                            "cluster_weight": cluster_weight,
                        }
                    )

            # Sort and get top matches for this cluster
            sorted_similarities = sorted(
                statement_similarities, key=lambda x: x["similarity"], reverse=True
            )[:top_k_statements]

            if sorted_similarities:
                cluster_matches.append(
                    {
                        "cluster": cluster,
                        "similarity": np.mean(
                            [s["similarity"] for s in sorted_similarities]
                        ),
                        "raw_similarity": np.mean(
                            [s["raw_similarity"] for s in sorted_similarities]
                        ),
                        "time_penalty": time_penalty,
                        "time_diff_years": time_diff,
                        "target_statements": target_statements,
                        "top_matches": sorted_similarities,
                    }
                )

        if cluster_matches:
            # Sort clusters by similarity and take top k
            sorted_clusters = sorted(
                cluster_matches,
                key=lambda x: x["similarity"],
                reverse=True
            )[:top_k_clusters]  # Limit number of clusters
            
            period_similarities.append(
                {
                    "period": year,
                    "clusters": sorted_clusters,  # Use limited clusters
                    "overall_similarity": np.mean(
                        [c["similarity"] for c in sorted_clusters]  # Use limited clusters
                    ),
                    "time_diff_years": time_diff,
                    "time_penalty": (1 - time_penalty) * 100,
                }
            )

    # Return top N most similar periods
    return sorted(
        period_similarities, key=lambda x: x["overall_similarity"], reverse=True
    )[:n_similar_periods]


# Test the function
report = get_cross_cluster_similarities(G, statements, partition, datetime(2018, 1, 30))

# Print results
for period_data in report:
    print(f"\nPeriod: {period_data['period']}")
    print(f"Time Difference: {period_data['time_diff_years']:.1f} years")
    print(f"Time Penalty: {period_data['time_penalty']:.2f}%")
    print(f"Overall Adjusted Similarity: {period_data['overall_similarity']:.3f}")

    for cluster_data in period_data["clusters"]:
        print(f"\n  Cluster {cluster_data['cluster']}")
        print(f"  Raw Similarity: {cluster_data['raw_similarity']:.3f}")
        print(f"  Adjusted Similarity: {cluster_data['similarity']:.3f}")
        print("  Current period statements:")
        for stmt in cluster_data["target_statements"]:  # Show top 2
            print(f"    - {stmt['what']}")
        print("  Most similar statements from comparison period:")
        for match in cluster_data["top_matches"]:  # Show top 2
            print(
                f"    - {match['other_statement']['what']} (adj_sim: {match['similarity']:.3f})"
            )

In [None]:
# Get first node from networkx graph
first_node = list(G.nodes.items())[0]
first_node
