# Heat equation on a graph

We are interested in solving the heat equation

\begin{equation}
\frac{\partial u}{\partial t} = \nabla^2 u,
\end{equation}

on a graph $G(V,E)$, where $V=\{v_1, v_2, \ldots, v_m\}$ is the set of vertices and $E=\{e_1, e_2, \ldots, v_n\}$ the set of edges.

Rather than discretizing spatially the PDE in a classical sense, we aim at solving the system of ODEs:

\begin{equation}
    \frac{\mathrm{d}\mathbf{u}}{\mathrm{d} t} = L \mathbf{u},
\end{equation}

where $\mathbf{u} = [u(v_1),~u(v_2),~\ldots,~u(v_n)]$, and ${L}$ is the graph Laplacian.

## Import packages

In [1]:
import plotly.graph_objects as go
import networkx as nx
import networkx.linalg as nla
import numpy as np
import scipy.sparse as sps
import time 

from IPython.display import clear_output

## Utility functions for plotting

In [2]:
def get_edge_trace(G: nx.Graph) -> go.Scatter:
    """Produce edge_trace for plotting with plotly"""
    
    edge_x = []
    edge_y = []
    
    for edge in G.edges():
        x0, y0 = G.nodes[edge[0]]['pos']
        x1, y1 = G.nodes[edge[1]]['pos']
        edge_x.append(x0)
        edge_x.append(x1)
        edge_x.append(None)
        edge_y.append(y0)
        edge_y.append(y1)
        edge_y.append(None)

    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=0.5, color='#888'),
        hoverinfo='none',
        mode='lines')
    
    return edge_trace


def get_node_trace(G: nx.Graph, color_vals = None) -> go.Scatter:
    """Produce node_trace for plotting with plotly"""
    
    node_x = []
    node_y = []
    for node in G.nodes():
        x, y = G.nodes[node]['pos']
        node_x.append(x)
        node_y.append(y)

    if color_vals is not None:
        
        node_text = ["{:.3}".format(color_val) for color_val in color_vals]
        node_trace = go.Scatter(
        x=node_x, 
        y=node_y,
        mode='markers',
        hoverinfo='text',
        marker=dict(
            showscale=True,
            reversescale=True,
            cmin=0,
            cmax=1,
            color=color_vals,
            # colorscale options
            #'Greys' | 'YlGnBu' | 'Greens' | 'YlOrRd' | 'Bluered' | 'RdBu' |
            #'Reds' | 'Blues' | 'Picnic' | 'Rainbow' | 'Portland' | 'Jet' |
            #'Hot' | 'Blackbody' | 'Earth' | 'Electric' | 'Viridis' |
            colorscale="Hot",
            size=10,
            colorbar=dict(
                thickness=15,
                title='Value',
                xanchor='left',
                titleside='right'
            ),
            line_width=2),
        text = node_text,
    )
    
    else:
        
        node_text = [f'# of connections: {str(len(adj[1]))}' for adj in G.adjacency()]
        node_trace = go.Scatter(
            x=node_x, 
            y=node_y,
            mode='markers',
            hoverinfo='text',
            marker=dict(
                color="Black",
                size=10,
                line_width=2),
            text=node_text,
    )
    
    return node_trace


def plot_graph(G: nx.Graph, show_plot=False) -> go.Figure:
    """Plot graph"""
    
    num_edges = G.number_of_edges()
    edge_trace = get_edge_trace(G)

    num_nodes = G.number_of_nodes()
    node_trace = get_node_trace(G)
    node_text = [f'# of connections: {str(len(adj[1]))}' for adj in G.adjacency()]
    node_trace.text = node_text
    

    title = f"Connected random graph with {num_nodes} nodes and {num_edges} edges"    
    
    fig = go.Figure(data=[edge_trace, node_trace],
             layout=go.Layout(
                title=title,
                showlegend=False,
                hovermode='closest',
                margin=dict(b=20,l=5,r=5,t=40),
                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
    )
    
    if show_plot:
        fig.show()
    
    return fig


def plot_colormap(G: nx.Graph, 
                  edge_trace: go.Scatter, 
                  node_trace: go.Scatter, 
                  time: float,
                  show_plot = False
                  ):
    """Plot a given state of the node variable in a color map"""        
        
    fig = go.Figure(data=[edge_trace, node_trace],
             layout=go.Layout(
                title=dict(text="<br>Heat equation on a graph", x=0.45),
                titlefont_size=20,
                showlegend=False,
                margin=dict(b=20,l=5,r=5,t=70),
                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
    )
        
    fig.add_annotation(x=0, y=1.10,
                text=f"t = 0",
                showarrow=False,
                font=dict(
                family="Courier New, monospace",
                size=22,
                color="black")
        )
    
    if show_plot:
        fig.show()
    
    return fig



## Generate graph

In [None]:
# Create a random graph
G = nx.random_geometric_graph(200, 0.125)

# Extract the subgraph with strictly connected components
G = nx.subgraph(G, max(nx.connected_components(G), key=len))

# Plot graph
_ = plot_graph(G, show_plot=True)

#print(G.number_of_nodes())

## Establish initial conditions

In [None]:
edge_trace = get_edge_trace(G)

u0 = np.random.rand(G.number_of_nodes())
node_trace = get_node_trace(G, color_vals=u0)


fig = plot_colormap(G, edge_trace, node_trace, time = 0, show_plot=True)

In [None]:
def create_animation(data: list, times: list):
    
    frames = [go.Frame(data=plot) for plot in data]
    
    fig = go.Figure(data=data[0],
             layout=go.Layout(
                title=dict(text="<br>Heat equation on a graph", x=0.45),
                titlefont_size=20,
                showlegend=False,
                margin=dict(b=20,l=5,r=5,t=70),
                updatemenus=[dict(
                    type="buttons",
                    buttons=[dict(label="Play",
                          method="animate",
                          args=[None])])],
                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)),
        frames=frames
        )
    
    fig.show()

## Obtain graph Laplacian and prepare to solve

In [None]:
T = 0.3  # final time
t = 0  # current time
dt = 0.01  # time step

L = -nla.laplacian_matrix(G)  # Laplacian Matrix
eye = sps.eye(G.number_of_nodes())  # Identity matrix
A = dt * L + eye

data = [[edge_trace, node_trace]]

while t <= T:
    
    # Increase time
    t += dt  
    
    # Get new values
    u = A * u0  
    
    node_trace.marker.color = u
    node_trace.text = ["{:.3f}".format(val) for val in u]
    data.append([edge_trace, node_trace])
    
    
#     # Update plot properties  
#     fig.data[1]["marker"]["color"] = u
#     fig.data[1]["text"] = ["{:.3f}".format(val) for val in u]
#     formatted_time = "{:.2f}".format(t)
#     fig.update_annotations(x=0, y=1.10,
#             text=f"t = {formatted_time}",
#             showarrow=False,
#             font=dict(
#             family="Courier New, monospace",
#             size=22,
#             color="black")
#         ) 
#    fig.show()
        
        
    # Update value
    u_0 = u  

In [None]:
#data = [[edge_trace, node_trace]]

create_animation(data)

In [None]:
# fig = go.Figure(
#     data=[go.Scatter(x=[0, 1], y=[0, 1])],
#     layout=go.Layout(
#         xaxis=dict(range=[0, 5], autorange=False),
#         yaxis=dict(range=[0, 5], autorange=False),
#         title="Start Title",
#         updatemenus=[dict(
#             type="buttons",
#             buttons=[dict(label="Play",
#                           method="animate",
#                           args=[None])])]
#     ),
#     frames=[go.Frame(data=[go.Scatter(x=[1, 2], y=[1, 2])]),
#             go.Frame(data=[go.Scatter(x=[1, 4], y=[1, 4])]),
#             go.Frame(data=[go.Scatter(x=[3, 4], y=[3, 4])],
#                      layout=go.Layout(title_text="End Title"))]
# )

# fig.show()