In [1]:
%pip install plotly

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 25.1.1 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
import pandas as pd
import numpy as np
import networkx as nx
import plotly.graph_objects as go
from tqdm import tqdm


In [3]:
# Load conflict_df from your earlier export
conflict_df = pd.read_csv("conflict_df_export.csv", parse_dates=["week"])

# Load cluster mapping
clusters = pd.read_csv("subreddit_cluster_categories.csv")
clusters['name'] = clusters['name'].str.lower()

clusters = clusters[['name', 'category']].rename(columns={
    'name': 'subreddit',
    'category': 'cluster'
})

# Attach cluster labels
conflict_df['source'] = conflict_df['source'].str.lower()
conflict_df['target'] = conflict_df['target'].str.lower()

merged = (
    conflict_df
    .merge(clusters.rename(columns={'subreddit': 'source', 'cluster': 'src_cluster'}),
           on='source', how='left')
    .merge(clusters.rename(columns={'subreddit': 'target', 'cluster': 'tgt_cluster'}),
           on='target', how='left')
)

# Filter to spikes only
spikes = merged[merged['is_spike'] == True].copy()

# Remove unknown clusters
spikes = spikes[
    (spikes['src_cluster'].notna()) &
    (spikes['tgt_cluster'].notna()) &
    (spikes['src_cluster'] != "Unknown") &
    (spikes['tgt_cluster'] != "Unknown")
].copy()

print("Total spike records:", len(spikes))


Total spike records: 2390


In [4]:
spikes['month'] = spikes['week'].dt.to_period('M')

cluster_month = (
    spikes.groupby(['month', 'src_cluster', 'tgt_cluster'])
          .size()
          .reset_index(name='spike_count')
)

cluster_month['month_start'] = cluster_month['month'].dt.start_time

cluster_month.head()


Unnamed: 0,month,src_cluster,tgt_cluster,spike_count,month_start
0,2013-12,Informative,Informative,2,2013-12-01
1,2013-12,Politics,Informative,2,2013-12-01
2,2014-01,Informative,Gaming,3,2014-01-01
3,2014-01,Informative,Informative,16,2014-01-01
4,2014-01,Informative,Politics,4,2014-01-01


In [5]:
# cluster size dataset
cluster_sizes = clusters.groupby('cluster').size().reset_index(name="size")
cluster_sizes = cluster_sizes[cluster_sizes['cluster'] != "Unknown"]
cluster_sizes


Unnamed: 0,cluster,size
0,AdultContent,3900
1,Education,51
2,Entertainment,3300
3,Gaming,6430
4,General,881
5,Geography,94
6,Informative,15763
7,Lifestyle,14694
8,Politics,4278
9,Sports,1155


In [6]:
min_size = 30
max_size = 90

sizes = cluster_sizes["size"]
norm_sizes = min_size + (sizes - sizes.min()) * (max_size - min_size) / (sizes.max() - sizes.min())
cluster_sizes["node_size"] = norm_sizes


In [7]:
import plotly.express as px

unique_clusters = cluster_sizes['cluster'].tolist()
colors = px.colors.qualitative.Set3  # 12-friendly palette

cluster_color_map = {
    cl: colors[i % len(colors)] 
    for i, cl in enumerate(unique_clusters)
}

cluster_color_map


{'AdultContent': 'rgb(141,211,199)',
 'Education': 'rgb(255,255,179)',
 'Entertainment': 'rgb(190,186,218)',
 'Gaming': 'rgb(251,128,114)',
 'General': 'rgb(128,177,211)',
 'Geography': 'rgb(253,180,98)',
 'Informative': 'rgb(179,222,105)',
 'Lifestyle': 'rgb(252,205,229)',
 'Politics': 'rgb(217,217,217)',
 'Sports': 'rgb(188,128,189)',
 'Technology': 'rgb(204,235,197)'}

In [8]:
months = sorted(cluster_month["month_start"].unique())
frames = []
pos_cache = {}   # cache layout positions for stability


In [9]:
# ======================================
# FRAME LOOP — ORIGINAL GOOD VERSION (STRAIGHT ARROWS, SPRING LAYOUT)
# ======================================

import numpy as np

frames = []
# ======================================
# RESTORE ORIGINAL LAYOUT (based on FIRST MONTH ONLY)
# ======================================

last_month = sorted(cluster_month["month_start"].unique())[1]
sub_last = cluster_month[cluster_month['month_start'] == last_month]

# Build layout graph
G_layout = nx.DiGraph()
for cl in unique_clusters:
    G_layout.add_node(cl)

# Add ONLY last-month edges (for layout spacing)
for _, row in sub_last.iterrows():
    G_layout.add_edge(row['src_cluster'], row['tgt_cluster'])

# Compute layout
base_pos = nx.spring_layout(G_layout, seed=42, k=1.8)


for month in months:
    sub = cluster_month[cluster_month['month_start'] == month]

    # Monthly graph
    G = nx.DiGraph()
    for cl in unique_clusters:
        G.add_node(cl)
    for _, row in sub.iterrows():
        G.add_edge(row['src_cluster'], row['tgt_cluster'], spike=row['spike_count'])

    edge_traces = []
    arrow_annotations = []

    for u, v, data in G.edges(data=True):

        # Base coordinates
        x0, y0 = base_pos[u]
        x1, y1 = base_pos[v]

        # Spike → thickness (your original)
        spike = data["spike"]
        width = 1 + 2.3 * np.log1p(spike)

        # Compute direction
        dx = x1 - x0
        dy = y1 - y0
        dist = max(np.sqrt(dx*dx + dy*dy), 1e-9)
        ux = dx / dist
        uy = dy / dist

        # ---------------------------------------------------------
        # PERPENDICULAR OFFSET FOR BIDIRECTIONAL EDGES (RESTORED)
        # ---------------------------------------------------------
        # If v→u also exists, both arrows need a small perpendicular shift
        if G.has_edge(v, u) and u < v:   # avoid double offsetting
            # perpendicular to (ux, uy) is (-uy, ux)
            offset_mag = dist * 0.04     # 4% of distance — tweak if needed
            px = -uy * offset_mag
            py =  ux * offset_mag

            # apply perpendicular shift to both points
            x0 += px; y0 += py
            x1 += px; y1 += py

            # recompute normalized direction after shift
            dx = x1 - x0
            dy = y1 - y0
            dist = max(np.sqrt(dx*dx + dy*dy), 1e-9)
            ux = dx / dist
            uy = dy / dist

        # ---------------------------------------------------------
        # BORDER-TO-BORDER OFFSET (does NOT shorten too much)
        # ---------------------------------------------------------
        border_offset = dist * 0.06   # you can tune 0.04–0.08

        x0b = x0 + ux * border_offset
        y0b = y0 + uy * border_offset

        x1b = x1 - ux * border_offset
        y1b = y1 - uy * border_offset

        # ---------------------------------------------------------
        # DRAW EDGE LINE
        # ---------------------------------------------------------
        edge_traces.append(go.Scatter(
            x=[x0b, x1b],
            y=[y0b, y1b],
            mode='lines',
            line=dict(width=width, color="rgba(150,20,20,0.55)"),
            hoverinfo='text',
            text=f"{u} → {v}<br>spikes: {spike}",
            showlegend=False
        ))

        # ---------------------------------------------------------
        # DRAW ARROW
        # ---------------------------------------------------------
        arrow_annotations.append(dict(
            x=x1b, y=y1b,     # arrow tip
            ax=x0b, ay=y0b,   # arrow tail
            xref='x', yref='y',
            axref='x', ayref='y',
            showarrow=True,
            arrowhead=3,
            arrowsize=1.5,
            arrowwidth=max(1.2, width),
            arrowcolor="rgba(150,20,20,0.9)"
        ))


    # Node positions
    node_x = [base_pos[n][0] for n in unique_clusters]
    node_y = [base_pos[n][1] for n in unique_clusters]


    # Node trace
    node_trace = go.Scatter(
        x=node_x,
        y=node_y,
        mode='markers+text',
        text=unique_clusters,
        textposition="top center",
        marker=dict(
            size=[cluster_sizes.loc[cluster_sizes['cluster']==n, 'node_size'].values[0]
                  for n in unique_clusters],
            color=[cluster_color_map[n] for n in unique_clusters],
            line_width=2,
            line_color="black"
        )
    )

    # Save frame
    frames.append(go.Frame(
        data=[node_trace] + edge_traces,
        layout=go.Layout(annotations=arrow_annotations),
        name=str(month)[:10]
    ))


In [10]:
import plotly.io as pio
pio.renderers.default = "browser"

In [None]:
# ======================================
# FINAL FIGURE (TIME SLIDER)
# ======================================

initial_data = frames[0].data

fig = go.Figure(
    data=[initial_data[0]] + list(initial_data[1:]),
    frames=frames
)

fig.update_layout(
    title="Monthly Cluster Conflict Network (Directed, Straight Edges)",
    showlegend=False,
    plot_bgcolor='white',
    xaxis=dict(visible=False),
    yaxis=dict(visible=False),

    updatemenus=[{
        "type": "buttons",
        "buttons": [
            {
                "label": "Play",
                "method": "animate",
                "args": [
                    None,
                    {
                        "frame": {"duration": 900, "redraw": True},
                        "fromcurrent": True,
                        "mode": "immediate"
                    }
                ]
            },
            {
                "label": "Pause",
                "method": "animate",
                "args": [
                    [None],
                    {"frame": {"duration": 0}, "mode": "immediate"}
                ]
            }
        ],
        "direction": "left",
        "pad": {"r": 10, "t": 80},
        "x": 0.1,
        "y": 0
    }],

    sliders=[{
        "active": 0,
        "pad": {"t": 50},
        "steps": [
            {
                "label": str(m)[:10],
                "method": "animate",
                "args": [
                    [str(m)[:10]],
                    {"frame": {"duration": 600, "redraw": True},
                     "mode": "immediate"}
                ]
            }
            for m in months
        ]
    }]
)

fig.update_layout(annotations=frames[0].layout.annotations)
fig.show()

out = Path("website/assets/figures/graphA.html")  # adjust "website/" if needed
out.parent.mkdir(parents=True, exist_ok=True)

fig.write_html(out.as_posix(), include_plotlyjs="cdn")