In [None]:
import polars as pl
import os

In [None]:
# Read the annotated books into a polars dataframe

INPUT_PATH = r"./data/outs/"

files = [f for f in os.listdir(INPUT_PATH) if f.endswith(".parquet")]
actions_df = pl.concat(
    [pl.read_parquet(os.path.join(INPUT_PATH, f)) for f in files]
    ).explode("actions").unnest("actions")
    
actions_df

In [None]:
# Add metadata to the actions as in the original corpus
id_processed_books = actions_df["gutenberg_id"].unique().to_list()
metadata_columns = ['SOURCE', 'book_id', 'language', 'title', 'issued',
        'authors', 'subjects', 'locc', 'bookshelves']

initial_corpus = pl.read_parquet("./data/gutenberg_en_novels.parquet")
print(initial_corpus.columns)

In [None]:
# For each id, take the values of metadata columns in initial corpus and add it to the actions_df with that gutenberg_id
meta_dict = {}
for idx in id_processed_books:
    row_in_original_corpus = initial_corpus.filter(pl.col("text_id") == idx)
    meta_dict[idx] = {'SOURCE': row_in_original_corpus['SOURCE'].item(), 
                    'book_id': row_in_original_corpus['text_id'].item(),
                    'authors': row_in_original_corpus['authors'].item(), 
                    'title': row_in_original_corpus['title'].item(), 
                    'issued': row_in_original_corpus['issued'].item(),
                    'language': row_in_original_corpus['language'].item(),
                    'subjects': row_in_original_corpus['subjects'].item(),
                    'locc': row_in_original_corpus['locc'].item(),
                    'bookshelves': row_in_original_corpus['bookshelves'].item(),
                    
                    
                    }

# Add this information in the actions_df, by matching the id of meta_dict to the gutenberg_id
# Convert meta_dict to a DataFrame
meta_df = pl.DataFrame([
    {"gutenberg_id": gid, **meta}
    for gid, meta in meta_dict.items()
])

# Join with actions_df on gutenberg_id
actions_df = actions_df.join(meta_df, on="gutenberg_id", how="left")
print(actions_df)

In [None]:
from load_annotations import Book

def build_book_by_title(title: str) -> Book:
    """
    Instantiate a Book object for the given title using the global actions_df.
    """
    book_metadata_cols = [
        'book_id', 'SOURCE', 'language', 'gutenberg_id', 'title', 'issued',
        'authors', 'subjects', 'locc', 'bookshelves'
    ]
    book_df = actions_df.filter(pl.col("title") == title).with_row_count("action_id", offset=0)

    if book_df.is_empty():
        raise ValueError(f"No book found with title: {title}")

    meta = book_df.select(book_metadata_cols).unique().to_dicts()[0]
    return Book(**meta, df=book_df)

### Functions to create the graph visualizations

In [None]:
from graph_postprocessing import character_interaction_graph
import numpy as np
import matplotlib.pyplot as plt


In [None]:
from collections import OrderedDict, Counter
import networkx as nx
import plotly.graph_objects as go
import plotly.express as px
import ipywidgets as widgets

In [None]:
def get_3d_layout(G, distance_factor=2.0, scale_factor=2.0):
    """
    3-D spring layout with a larger default node–node distance.

    Parameters
    ----------
    G : networkx.Graph
    distance_factor : float
        Multiplier for the default optimal distance (k).  1.0 reproduces the
        NetworkX default (k = 1 / sqrt(n)).  Higher → more space between nodes.
    scale_factor : float
        Scales the final coordinates (acts like zoom-out if > 1).

    Returns
    -------
    dict
        {node: (x, y, z)} coordinates for all nodes in G.
    """
    n = len(G)
    k = distance_factor / np.sqrt(n)      # default was 1 / sqrt(n)
    return nx.spring_layout(G,
                            dim=3,
                            seed=42,
                            k=k,
                            scale=scale_factor)

def interactive_graph_3d(one_graph):
    # 1) Group every edge by chunk_id
    edges_by_chunk = OrderedDict()
    for action in one_graph.edges:
        if not (action.source and action.target):
            continue
        cid = action.chunk_id
        u, v = sorted((action.source, action.target))
        edges_by_chunk.setdefault(cid, []).append((u, v))

    all_chunk_edges = list(edges_by_chunk.values())
    flat_edges = [e for chunk in all_chunk_edges for e in chunk]

    # build the complete NetworkX graph to get layout
    G_full = nx.Graph()
    G_full.add_edges_from(flat_edges)

    # 2) first-appearance index
    first_appearance = {}
    for idx, chunk_edges in enumerate(all_chunk_edges):
        for u, v in chunk_edges:
            for n in (u, v):
                if n not in first_appearance:
                    first_appearance[n] = idx

    # 3) color palette
    palette = px.colors.qualitative.Plotly
    num_chunks = len(all_chunk_edges)
    chunk_colors = {i: palette[i % len(palette)] for i in range(num_chunks)}

    # 4) fixed 3D layout
    pos = get_3d_layout(G_full)
    xs, ys, zs = zip(*pos.values())
    x_range, y_range, z_range = [min(xs), max(xs)], [min(ys), max(ys)], [min(zs), max(zs)]

    def update(chunk_idx):
        sel_edges = [e for ch in all_chunk_edges[:chunk_idx+1] for e in ch]
        counts = Counter(sel_edges)

        G_sub = nx.Graph()
        for (u, v), w in counts.items():
            if u in pos and v in pos:
                G_sub.add_edge(u, v, weight=w)

        # build edge traces
        edge_traces = []
        for u, v in G_sub.edges():
            x0,y0,z0 = pos[u]
            x1,y1,z1 = pos[v]
            edge_traces.append(go.Scatter3d(
                x=[x0,x1,None], y=[y0,y1,None], z=[z0,z1,None],
                mode='lines',
                line=dict(color='rgba(200,200,200,0.4)', width=2),
                hoverinfo='none'
            ))

        # build node trace
        nodes = list(G_sub.nodes())
        node_trace = go.Scatter3d(
            x=[pos[n][0] for n in nodes],
            y=[pos[n][1] for n in nodes],
            z=[pos[n][2] for n in nodes],
            mode='markers+text',
            marker=dict(size=8,
                        color=[chunk_colors[first_appearance[n]] for n in nodes],
                        line=dict(color='white', width=1)),
            text=nodes,
            textposition='top center',
            textfont=dict(size=11, color='white'),
            hoverinfo='text'
        )

        layout = go.Layout(
            showlegend=False,
            scene=dict(
                xaxis=dict(range=x_range, visible=False),
                yaxis=dict(range=y_range, visible=False),
                zaxis=dict(range=z_range, visible=False),
                bgcolor='black',
                # This enables 3D drag-to-zoom mode:
                dragmode='zoom'
            ),
            paper_bgcolor='black',
            plot_bgcolor='black',
            margin=dict(l=0, r=0, b=0, t=30),
            scene_aspectmode='cube'
        )

        # build a FigureWidget instead of a bare Figure
        fig = go.FigureWidget(data=edge_traces + [node_trace], layout=layout)
        # turn on scroll-to-zoom
        fig.update_layout(scene_dragmode='zoom')
        fig.show(config={'scrollZoom': True})
        return fig

    slider = widgets.IntSlider(
        value=0,
        min=0, max=len(all_chunk_edges)-1,
        step=1, description='Chunks:',
        continuous_update=False
    )
    return widgets.interact(update, chunk_idx=slider)

In [None]:
def plot_wheel_graph_colored(graph, num_chunks):
    """
    Render a circular (wheel) layout of the Graph, with
    edge thickness ∝ frequency and node color by first chunk.
    """
    # 1) Group every edge by its chunk_id in insertion order
    edges_by_chunk = OrderedDict()
    for action in graph.edges:
        if action.chunk_id > num_chunks:
            continue
        if not (action.source and action.target):
            continue
        # normalize to undirected
        u, v = sorted((action.source, action.target))
        edges_by_chunk.setdefault(action.chunk_id, []).append((u, v))

    all_chunk_edges = list(edges_by_chunk.values())
    # flatten for global counts
    flat_edges = [e for chunk in all_chunk_edges for e in chunk]
    edge_counts = Counter(flat_edges)

    # 2) Build NetworkX graph with weighted edges
    G = nx.Graph()
    for (u, v), w in edge_counts.items():
        G.add_edge(u, v, weight=w)

    # 3) Compute first‐appearance chunk for each node
    first_appearance = {}
    for idx, chunk_edges in enumerate(all_chunk_edges):
        for u, v in chunk_edges:
            for n in (u, v):
                if n not in first_appearance:
                    first_appearance[n] = idx

    # 4) Pick a Plotly qualitative palette & map chunk → color
    palette      = px.colors.qualitative.Plotly
    num_chunks   = len(all_chunk_edges)
    chunk_colors = {i: palette[i % len(palette)] for i in range(num_chunks)}

    # 5) Compute a circular layout in 2D
    pos = nx.circular_layout(G, scale=1.0)

    # 6) Build edge traces (thickness ∝ count)
    edge_traces = []
    for u, v, dat in G.edges(data=True):
        x0, y0 = pos[u]
        x1, y1 = pos[v]
        edge_traces.append(go.Scatter(
            x=[x0, x1, None],
            y=[y0, y1, None],
            mode='lines',
            line=dict(
                width=1 + dat['weight'] * 0.3,
                color='lightgray'
            ),
            hoverinfo='none'
        ))

    # 7) Build node trace, coloring by first appearance
    nodes      = list(G.nodes())
    node_x     = [pos[n][0] for n in nodes]
    node_y     = [pos[n][1] for n in nodes]
    node_color = [chunk_colors[first_appearance[n]] for n in nodes]

    node_trace = go.Scatter(
        x=node_x,
        y=node_y,
        mode='markers+text',
        marker=dict(
            size=10,
            color=node_color,
            line=dict(width=2, color='white')
        ),
        text=nodes,
        textposition='top center',
        hoverinfo='text'
    )

    # 8) Final layout: clean, fixed size, scroll‐zoom enabled
    layout = go.Layout(
        showlegend=False,
        xaxis=dict(showgrid=False, zeroline=False, visible=False),
        yaxis=dict(showgrid=False, zeroline=False, visible=False),
        plot_bgcolor='white',
        margin=dict(l=20, r=20, t=20, b=20),
        width=600, height=600
    )

    fig = go.Figure(data=edge_traces + [node_trace], layout=layout)
    # enable scroll‐to‐zoom
    fig.show(config={'scrollZoom': True})

In [None]:
def interactive_graph_3d(one_graph):
    # --- (unchanged pre-processing: edges_by_chunk, colors, etc.) ---
    edges_by_chunk = OrderedDict()
    for action in one_graph.edges:
        if not (action.source and action.target):
            continue
        cid = action.chunk_id
        u, v = sorted((action.source, action.target))
        edges_by_chunk.setdefault(cid, []).append((u, v))

    all_chunk_edges = list(edges_by_chunk.values())
    flat_edges = [e for chunk in all_chunk_edges for e in chunk]

    G_full = nx.Graph()
    G_full.add_edges_from(flat_edges)

    first_appearance = {}
    for idx, chunk_edges in enumerate(all_chunk_edges):
        for u, v in chunk_edges:
            for n in (u, v):
                if n not in first_appearance:
                    first_appearance[n] = idx

    palette = px.colors.qualitative.Plotly
    num_chunks = len(all_chunk_edges)
    chunk_colors = {i: palette[i % len(palette)] for i in range(num_chunks)}

    # ---------- callback ----------
    def update(chunk_idx, distance):
        # ← recompute layout on-the-fly
        pos = get_3d_layout(G_full,
                            distance_factor=distance,
                            scale_factor=distance)

        xs, ys, zs = zip(*pos.values())
        x_range, y_range, z_range = [min(xs), max(xs)], [min(ys), max(ys)], [min(zs), max(zs)]

        sel_edges = [e for ch in all_chunk_edges[:chunk_idx+1] for e in ch]
        counts = Counter(sel_edges)

        G_sub = nx.Graph()
        for (u, v), w in counts.items():
            if u in pos and v in pos:
                G_sub.add_edge(u, v, weight=w)

        # ---------- visuals ----------
        edge_traces = []
        for u, v in G_sub.edges():
            x0, y0, z0 = pos[u]
            x1, y1, z1 = pos[v]
            edge_traces.append(go.Scatter3d(
                x=[x0, x1, None], y=[y0, y1, None], z=[z0, z1, None],
                mode='lines',
                line=dict(color='rgba(200,200,200,0.4)', width=2),
                hoverinfo='none'
            ))

        nodes = list(G_sub.nodes())
        node_trace = go.Scatter3d(
            x=[pos[n][0] for n in nodes],
            y=[pos[n][1] for n in nodes],
            z=[pos[n][2] for n in nodes],
            mode='markers+text',
            marker=dict(size=8,
                        color=[chunk_colors[first_appearance[n]] for n in nodes],
                        line=dict(color='white', width=1)),
            text=nodes,
            textposition='top center',
            textfont=dict(size=11, color='white'),
            hoverinfo='text'
        )

        layout = go.Layout(
            showlegend=False,
            scene=dict(
                xaxis=dict(range=x_range, visible=False),
                yaxis=dict(range=y_range, visible=False),
                zaxis=dict(range=z_range, visible=False),
                bgcolor='black',
                dragmode='zoom'
            ),
            paper_bgcolor='black',
            plot_bgcolor='black',
            margin=dict(l=0, r=0, b=0, t=30),
            scene_aspectmode='cube'
        )

        fig = go.FigureWidget(data=edge_traces + [node_trace], layout=layout)
        fig.update_layout(scene_dragmode='zoom')
        fig.show(config={'scrollZoom': True})
        return fig

    # ---------- sliders ----------
    chunk_slider = widgets.IntSlider(
        value=0, min=0, max=len(all_chunk_edges)-1, step=1,
        description='Chunks:', continuous_update=False
    )

    distance_slider = widgets.FloatSlider(
        value=5.0, min=0.5, max=25.0, step=5,
        description='Distance:', readout_format='.1f',
        continuous_update=False
    )

    return widgets.interact(update,
                            chunk_idx=chunk_slider,
                            distance=distance_slider)


In [None]:
import math, networkx as nx, plotly.graph_objects as go, ipywidgets as widgets
from collections import OrderedDict, Counter
import plotly.express as px

# ---- 2-D helper ------------------------------------------------------------
def get_2d_layout(G, distance_factor=2.0, scale_factor=2.0, seed=42):
    """
    Plain 2-D spring layout, but with a tunable 'distance_factor' slider.
    """
    n = len(G)
    k = distance_factor / math.sqrt(n)          # default was 1/√n
    return nx.spring_layout(G, dim=2, k=k,
                            scale=scale_factor, seed=seed)



# ---- interactive widget ----------------------------------------------------
def interactive_graph_2d(one_graph):
    # ---------- data wrangling (unchanged) ----------
    edges_by_chunk = OrderedDict()
    for action in one_graph.edges:
        if not (action.source and action.target):
            continue
        cid = action.chunk_id
        u, v = sorted((action.source, action.target))
        edges_by_chunk.setdefault(cid, []).append((u, v))

    all_chunk_edges = list(edges_by_chunk.values())
    flat_edges = [e for chunk in all_chunk_edges for e in chunk]

    G_full = nx.Graph()
    G_full.add_edges_from(flat_edges)

    first_appearance = {}
    for idx, chunk_edges in enumerate(all_chunk_edges):
        for u, v in chunk_edges:
            for n in (u, v):
                first_appearance.setdefault(n, idx)

    palette = px.colors.qualitative.Plotly
    chunk_colors = {i: palette[i % len(palette)]
                    for i in range(len(all_chunk_edges))}

    # ---------- callback ----------
    def update(chunk_idx, distance):
        pos = get_2d_layout(G_full,
                            distance_factor=distance,
                            scale_factor=distance)

        xs, ys = zip(*pos.values())
        #x_range, y_range = [min(xs), max(xs)], [min(ys), max(ys)]
        x_min, x_max = min(xs), max(xs)
        y_min, y_max = min(ys), max(ys)

        padding_factor = 0.1
        x_padding = (x_max - x_min) * padding_factor
        y_padding = (y_max - y_min) * padding_factor

        x_range = [x_min - x_padding, x_max + x_padding]
        y_range = [y_min - y_padding, y_max + y_padding]

        sel_edges = [e for ch in all_chunk_edges[:chunk_idx+1] for e in ch]
        counts = Counter(sel_edges)

        G_sub = nx.Graph()
        for (u, v), w in counts.items():
            if u in pos and v in pos:
                G_sub.add_edge(u, v, weight=w)

        # ---- edge & node traces (2-D) ----
        edge_x, edge_y = [], []
        for u, v in G_sub.edges():
            x0, y0 = pos[u]
            x1, y1 = pos[v]
            edge_x += [x0, x1, None]
            edge_y += [y0, y1, None]

        edge_trace = go.Scatter(x=edge_x, y=edge_y,
                                mode='lines',
                                line=dict(color='rgba(200,200,200,0.4)', width=2),
                                hoverinfo='none')

        node_trace = go.Scatter(
            x=[pos[n][0] for n in G_sub.nodes()],
            y=[pos[n][1] for n in G_sub.nodes()],
            mode='markers+text',
            marker=dict(size=8,
                        color=[chunk_colors[first_appearance[n]]
                               for n in G_sub.nodes()],
                        line=dict(color='white', width=1)),
            text=list(G_sub.nodes()),
            textposition='top center',
            textfont=dict(size=11, color='white'),
            hoverinfo='text'
        )

        fig = go.FigureWidget(data=[edge_trace, node_trace],
                      layout=go.Layout(
                          showlegend=False,
                          xaxis=dict(range=x_range, visible=False),
                          yaxis=dict(range=y_range, visible=False,
                                     scaleanchor='x', scaleratio=1),
                          paper_bgcolor='black',
                          plot_bgcolor='black',
                          margin=dict(l=60, r=60, t=60, b=60),
                          dragmode='pan',
                          width=800,   # <─ NEW — pixels; crank this up as needed
                          height=800     # <─ NEW — pixels
                      ))

        fig.update_layout(scene_dragmode='zoom')   # keep zoom-scroll behaviour
        fig.show(config={'scrollZoom': True})

        return fig

    # ---------- sliders ----------
    chunk_slider = widgets.IntSlider(
        value=0, min=0, max=len(all_chunk_edges)-1, step=1,
        description='Chunks:', continuous_update=False
    )

    distance_slider = widgets.FloatSlider(
        value=5.0, min=0.5, max=25.0, step=0.5,
        description='Distance:', readout_format='.1f',
        continuous_update=False
    )

    return widgets.interact(update,
                            chunk_idx=chunk_slider,
                            distance=distance_slider)


### Now let's apply them to the books

In [None]:
# We take the annotations for a book
abook = build_book_by_title("Alice's Adventures in Wonderland")

In [None]:
# Create the character interaction graph
agraph = abook.create_full_graph()
subgraph = character_interaction_graph(agraph, split=True, naive_beautify=True)

In [None]:
chars = [x.node_name for x in subgraph.nodes]
print(chars)


In [None]:
sources = list(set([x.source for x in subgraph.edges]))
print(sources)

In [None]:
interactive_graph_2d(subgraph)

In [None]:
plot_wheel_graph_colored(subgraph,400)