In [1]:
import requests
import lxml.html as lxml_html
import networkx as nx
import re

In [2]:
split_character = '|'

In [3]:
def get_page_title(tree):
    title_element = tree.find(".//title")
    if title_element is not None:
        # The title usually ends with " - Wikipedia", so we remove it
        return title_element.text.replace(" - Wikipedia", "")
    else:
        return None

def find_first_header(tree):
    # Find the div with class "mv-content-container"
    div = tree.find('.//div[@class="mw-content-container"]')

    if div is not None:
        # Find the first h2 within this div
        h2 = div.find('.//h2')
        if h2 is not None:
            return h2

    return None

def find_node_by_name(G, name):
    for node in G.nodes():
        if node.startswith(name + split_character):
            return node
    return None

In [4]:
def add_edge(G, parent, child_name):
    global node_counter

    child = f"{child_name}{split_character}{node_counter}"
    node_counter += 1
    G.add_edge(parent, child)

    return child
    
def try_add_edge(G, parent, child_name):
    global node_counter

    for child in G.successors(parent):
        if child.startswith(child_name + split_character):
            return None

    return add_edge(G, parent, child_name)

In [5]:
excluded_headers = ["References", "Sources", "See also",
                    "Notes", "Further reading", "External links"]


def generate_branch(graph, element, parent, link_level, branch_level, max_depth, visited):
    level = link_level + branch_level - 2
    if level >= max_depth:
        return

    while element is not None:
        header_level = 6
        if element.tag in ['h2', 'h3', 'h4', 'h5', 'h6']:
            header_level = int(element.tag[1])

        if header_level < branch_level:
            return

        if header_level > branch_level:
            element = element.getnext()
            continue

        header_text = element.text_content().strip()
        header_text = re.sub(r'\[edit\]', '', header_text)
        if header_text in excluded_headers:
            element = element.getnext()
            continue

        child = try_add_edge(graph, parent, header_text)

        if child is None:
            element = element.getnext()
            continue

        # Look for the link in the elements following this header
        link_element = element.getnext()
        while link_element is not None and link_element.tag not in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']:
            if link_element.tag == 'div' and 'Main article' in link_element.text_content():
                links = link_element.findall('.//a')
                if links:
                    for link in links:
                        link_url = 'https://en.wikipedia.org' + link.get('href')
                        get_page_hierarchy(
                            link_url, graph, child, level + 1, max_depth, visited)
                    break
            link_element = link_element.getnext()

        generate_branch(graph, element.getnext(), child,
                        link_level, branch_level + 1, max_depth, visited)

        element = element.getnext()

In [6]:
def get_page_hierarchy(url, graph, parent, link_level=0, max_depth=4, visited=None):
    global node_counter
    if visited is None:
        visited = set()

    if link_level >= max_depth:
        return

    if url in visited:
        return

    visited.add(url)

    try:
        response = requests.get(url)
        response.raise_for_status()
    except requests.RequestException as e:
        print(f"Failed to get {url}: {e}")
        return

    tree = lxml_html.fromstring(response.content)

    if parent is None:
        # Get the title of the page
        title = get_page_title(tree)
        parent = f"{title}{split_character}{node_counter}"
        node_counter += 1
        graph.add_node(parent)

    element = find_first_header(tree)
    if element is not None:
        generate_branch(graph, element, parent, link_level, 2, max_depth, visited)

In [7]:
def generate_map(url, max_depth=5):
    global node_counter
    G = nx.DiGraph()
    node_counter = 0

    get_page_hierarchy(url, G, None, max_depth=max_depth)

    root = list(G.nodes())[0]
    
    return G, root

In [8]:
def find_leaves(G):
  return [n for n in G.nodes() if G.out_degree(n) == 0]

In [9]:
import math

In [10]:
def circular_layout(G, root, scale, divergence):

    pos, nodes_on_level = set_first_level(G, root, scale, math.pi / 2)
    level = 1

    while len(nodes_on_level) > 0:
        nodes = []

        for node, node_angle, arc in nodes_on_level:
            children = list(G.successors(node))

            angle_start = node_angle - arc / 2
            child_radius = scale * pow(divergence, level)
            delta_angle = arc / (len(children) + 1)

            for jdx, child in enumerate(children):
                child_angle = angle_start + (jdx + 1) * delta_angle
                nodes.append([child, child_angle, delta_angle])
                pos[child] = child_radius * math.cos(child_angle), child_radius * math.sin(child_angle)
    
        nodes_on_level = nodes
        level += 1

    return pos
    
def set_first_level(G, root, scale, theta):
    pos = {}
    nodes = []
    pos[root] = 0, 0
    arc = 2 * math.pi
    children = list(G.successors(root))

    half = len(children) // 2
    alpha = (math.pi - theta) / 2
    arcs = arc / 2 - theta
    # left children
    angle_start = math.pi - alpha
    delta_angle = arcs / (half + 1)
    for jdx, child in enumerate(children[:half]):
        child_angle = angle_start + (jdx + 1) * delta_angle
        child_radius = scale
        nodes.append([child, child_angle, delta_angle])
        pos[child] = child_radius * math.cos(child_angle), child_radius * math.sin(child_angle)
    # right children
    angle_start = - alpha
    delta_angle = arcs / (len(children) - half + 1)
    for jdx, child in enumerate(children[half:]):
        child_angle = angle_start + (jdx + 1) * delta_angle
        child_radius = scale
        nodes.append([child, child_angle, delta_angle])
        pos[child] = child_radius * math.cos(child_angle), child_radius * math.sin(child_angle)

    return pos, nodes

In [11]:
def custom_layout(G, root, max_depth, width, height, scale):
  leaves_d = get_leaves_dict(G, root)
  left_children, right_children, left_sum, right_sum = split_children(G, root, leaves_d)
  layout = {}
  layout[root] = 0, 0, scale * (max_depth + 1), True

  step_x = 0.5 * width / max_depth
  left_step_y, right_step_y = height / left_sum, height / right_sum

  arrange_nodes(G, left_children,  layout, -1, 0, 1, max_depth, 0, height, leaves_d, step_x, left_step_y,  scale)
  arrange_nodes(G, right_children, layout,  1, 0, 1, max_depth, 0, height, leaves_d, step_x, right_step_y, scale)

  return layout

def arrange_nodes(G, nodes, layout, side, parent_x, depth, max_depth, start, height, leaves_d, step_x, step_y, scale):
  if len(nodes) <= 0:
    return

  spot = start
  node_scale = scale * (max_depth - depth + 1)
  step_x_factor = min(0.1 * len(nodes) + 0.5 / (depth + 1), 1)
  x = parent_x + side * step_x * step_x_factor

  for node in nodes:
    n_leaves = leaves_d[node]
    large_step = n_leaves * step_y
    y = spot - height / 2 + large_step / 2
    layout[node] = x, y, node_scale, (depth < 2)
    arrange_nodes(G, list(G.successors(node)), layout, side, x, depth + 1, max_depth, spot, height, leaves_d, step_x, step_y, scale)
    spot = spot + large_step

def split_children(G, root, leaves_d):

  children = list(G.successors(root))
  children.sort(key=lambda child: -leaves_d[child])
  
  left_children, right_children = [], []
  left_sum, right_sum = 0, 0

  for child in children:
    if left_sum < right_sum:
      left_children.append(child)
      left_sum += leaves_d[child]
    else:
      right_children.append(child)
      right_sum += leaves_d[child]

  return left_children, right_children, left_sum, right_sum

def fill_leaves_dict(G, root, leaves_d, node):
  if node == root:
    return
  parent = list(G.predecessors(node))[0]
  leaves_d[parent] += 1
  fill_leaves_dict(G, root, leaves_d, parent)

def get_leaves_dict(G, root):
  leaves = find_leaves(G)
  print(list(G.nodes()))

  leaves_d = {}
  for node in list(G.nodes()):
    leaves_d[node] = 0
  for leaf in leaves:
    leaves_d[leaf] = 1
    fill_leaves_dict(G, root, leaves_d, leaf)

  return leaves_d

In [12]:
import plotly.graph_objects as go
import matplotlib.colors as mcolors
import matplotlib.cm as cm

In [13]:
def calculate_node_colors(G, root):
  distances = nx.single_source_shortest_path_length(G, root)
  children = list(G.successors(root))
  n_children = len(children)
  max_distance = max(distances.values())

  parent_color = {}  # Dict to store the color of the parent nodes

  for node in G.nodes():
    distance = distances[node]

    if distance == 0:  # root node
      color = "#000000"  # black
    elif node in children:  # Direct children of the root node
      color_value = children.index(node) / n_children
      color = cm.get_cmap("hsv")(color_value)  # hsv for more vivid, rainbow-like colors
      color = mcolors.rgb2hex(color[:3])  # convert RGB to hex, ignore alpha
    else:
      parent = next(G.predecessors(node))
      parent_rgb = mcolors.hex2color(parent_color[parent])
      white_rgb = (1, 1, 1)
      # Closer to white as the node is farther from the root
      mix_rgb = [max(0, min(c + 0.8 * (w - c) * (distance / (max_distance + 1)), 1)) for c, w in zip(parent_rgb, white_rgb)]
      color = mcolors.rgb2hex(mix_rgb)
    
    parent_color[node] = color  # Update the color of the node

  return parent_color

In [14]:
def display_graph(G, scale=2.0, spot=3.0):
    root = list(G.nodes())[0]
    name = root.split(split_character)[0]
    pos = custom_layout(G, root, scale, spot)
    distances = nx.shortest_path_length(G, source=root)
    max_distance = max(distances.values())

    # Children of the root node for gradient color
    children = list(G[root])
    n_children = len(children)

    edge_trace = go.Scatter(
        x=[],
        y=[],
        line=dict(width=0.5, color="#888"),
        hoverinfo="none",
        mode="lines",
    )

    for edge in G.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_trace["x"] += tuple([x0, x1, None])
        edge_trace["y"] += tuple([y0, y1, None])

    node_trace = go.Scatter(
        x=[],
        y=[],
        text=[],
        mode="markers",
        hoverinfo="text",
        marker=dict(
            showscale=False,
            color=[],  # Initialise comme une liste vide
            size=10,
            line=dict(width=2)
        ),
    )

    node_text_trace = go.Scatter(
        x=[],
        y=[],
        mode="text",
        text=[],
        textposition="top center"
    )

    parent_color = {}  # Dict to store the color of the parent nodes

    for node in G.nodes():
        x, y = pos[node]
        distance = distances[node]
        if distance == 0:  # root node
            color = "#000000"  # black
        elif node in children:  # Direct children of the root node
            color_value = children.index(node) / n_children
            color = cm.get_cmap("hsv")(color_value)  # hsv for more vivid, rainbow-like colors
            color = mcolors.rgb2hex(color[:3])  # convert RGB to hex, ignore alpha
        else:
            parent = next(G.predecessors(node))
            parent_rgb = mcolors.hex2color(parent_color[parent])
            white_rgb = (1, 1, 1)
            # Closer to white as the node is farther from the root
            mix_rgb = [max(0, min(c + 0.8 * (w - c) * (distance / (max_distance + 1)), 1)) for c, w in zip(parent_rgb, white_rgb)]
            color = mcolors.rgb2hex(mix_rgb)
        
        parent_color[node] = color  # Update the color of the node
        node_name = node.split(split_character)[0]
        node_trace['marker']['color'] += tuple([color])  # Use += to add to tuple
        node_trace["x"] += tuple([x])
        node_trace["y"] += tuple([y])
        node_text_trace["x"] += tuple([x])
        node_text_trace["y"] += tuple([y])
        # node_text_trace["text"] += tuple([node_name])
        node_trace["text"] += tuple([node_name])        

    fig = go.Figure(
        data=[edge_trace, node_trace, node_text_trace],
        layout=go.Layout(
            title="<br>" + name,
            titlefont=dict(size=16),
            showlegend=False,
            hovermode="closest",
            margin=dict(b=20, l=5, r=5, t=40),
            xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
            yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        ),
    )

    fig.show()

In [15]:
url = 'https://en.wikipedia.org/wiki/Rocket'
G, root = generate_map(url, max_depth=4)

In [16]:
#  display_graph(G)

In [17]:
import json

In [20]:
def convert_networkx_to_gojs(G, filename, scale):
    
    root = list(G.nodes())[0]
    max_depth = nx.dag_longest_path_length(G)
    height = 0.75 * len(find_leaves(G)) * scale
    width = 2 * height
    print('width, height: ', width, ', ', height)

    layout = custom_layout(G, root, max_depth, width, height, 1)
    node_colors = calculate_node_colors(G, root)

    # Convert to dict format
    graph_dict = {
        'class': 'go.GraphLinksModel',
        'nodeKeyProperty': 'id',
        'maxDepth': max_depth,
        'nodeDataArray': [{'id': int(node.split(split_character)[1]), 'text': node.split(split_character)[0], \
            'loc': f"{layout[node][0]} {layout[node][1]}", 'color': node_colors[node], \
            'scale': layout[node][2], \
            'visible': layout[node][3], \
            } for node in G.nodes],
        'linkDataArray': [{'from': int(edge[0].split(split_character)[1]), 'to': int(edge[1].split(split_character)[1])} for edge in G.edges(data='weight')]
    }

    # Convert to JSON and save to file
    with open(f"./static/json_mmaps/{filename}.json", 'w') as f:
        json.dump(graph_dict, f, indent=4)

In [None]:
convert_networkx_to_gojs(G, 'graph', 100)