In [2]:
import networkx as nx
import plotly.graph_objects as go
import numpy as np

In [3]:
clusters = {
    "paper1": {
        "topic1": 0.6,
        "topic2": 0.4,
        "topic3": 0.2
    },
    "paper2": {
        "topic1": 0.4,
        "topic2": 0.6,
        "topic3": 0.2
    }
}

In [26]:
import networkx as nx
import plotly.graph_objects as go

# Define the clustered data

clusters = {
    "paper1": {
        "topic1": 100,
        "topic2": 100,
        "topic3": 100
    },
    "paper2": {
        "topic1": 100,
        "topic2": 100,
        "topic3": 100
    }
}
# Create a graph
G = nx.Graph()

# Add nodes and edges with weights
for paper, topics in clusters.items():
    G.add_node(paper)
    for topic, weight in topics.items():
        G.add_node(topic)
        G.add_edge(paper, topic, weight=weight)

# Function to create a plotly graph figure from a NetworkX graph and layout
def create_plotly_graph(G, layout):
    # Get positions for nodes
    pos = layout(G, iterations=50)
    
    # Extract edge information
    edge_x = []
    edge_y = []
    for edge in G.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x.append(x0)
        edge_x.append(x1)
        edge_x.append(None)
        edge_y.append(y0)
        edge_y.append(y1)
        edge_y.append(None)
    
    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=0.5, color='#888'),
        hoverinfo='none',
        mode='lines'
    )
    
    # Extract node information
    node_x = []
    node_y = []
    node_text = []
    for node in G.nodes():
        x, y = pos[node]
        node_x.append(x)
        node_y.append(y)
        node_text.append(node)
    
    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode='markers+text',
        text=node_text,
        textposition="top center",
        hoverinfo='text',
        marker=dict(
            showscale=True,
            colorscale='YlGnBu',
            colorbar=dict(
                thickness=15,
                title='Node Connections',
                xanchor='left',
                titleside='right'
            ),
            color=[],
            size=20,
            line_width=2
        )
    )
    
    # Color nodes based on their degree
    node_adjacencies = []
    for node, adjacencies in enumerate(G.adjacency()):
        node_adjacencies.append(len(adjacencies[1]))
    node_trace.marker.color = node_adjacencies
    
    # Create figure
    fig = go.Figure(data=[edge_trace, node_trace],
                    layout=go.Layout(
                        title='Network Graph Visualization',
                        titlefont_size=16,
                        showlegend=False,
                        hovermode='closest',
                        margin=dict(b=20,l=5,r=5,t=40),
                        annotations=[ dict(
                            text="Network layout",
                            showarrow=False,
                            xref="paper", yref="paper"
                        )],
                        xaxis=dict(showgrid=False, zeroline=False),
                        yaxis=dict(showgrid=False, zeroline=False))
                   )
    return fig

# Create plotly figures for spring and circular layouts
fig_spring = create_plotly_graph(G, nx.spring_layout)
# fig_circular = create_plotly_graph(G, nx.circular_layout)

# Show the figures
fig_spring.show()
# fig_circular.show()
