In [1]:
import warnings; warnings.simplefilter('ignore')

import os

import numpy as np
import tensorflow as tf
# Get rid of the deprecation warnings
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

from plotly import graph_objs as go
import matplotlib.pyplot as plt

from tensorflow.python.lib.io import tf_record
from tensorflow.core.util import event_pb2
from tensorflow.python.framework import tensor_util


# Change working directory
%cd ..

from models.models import GraphNeuralSolver

/Users/balthazardonon/Documents/PhD/Code/GraphNeuralSolver


# Reloading a trained model

Change `model_path` to a path where you stored your own model, and it will reload it.
Also, do not pay attention to all the TensorFlow warnings .

In [2]:
# Initialize session 
sess = tf.Session()

# Reload model. Make sure that you have unzipped the file results/1575991681/train.zip
model_path = 'results/1581611217' # <-- You can change this to reload your own model
model = GraphNeuralSolver(sess, model_to_restore=model_path, default_data_directory='datasets/spring/default')

# Performing a prediction

Build your own data

A is a list edges. Each edge is defined by the vertex on the origin, the vertex on the extremity, 
and the stiffness.
Here we have the following edges :
- An edge from vertex 0 to vertex 1, with a stiffness of 0.5;
- An edge from vertex 0 to vertex 2, with a stiffness of 0.2;
- An edge from vertex 1 to vertex 2, with a stiffness of 0.8.

In [54]:
A_custom = np.array([[[0,0,1.],[1,1,1.3],[2,2,0.9],[1,0,-0.5],[2,0,-0.1],[1,2,-0.8],[2,1,-0.8]]])

In [55]:
print(A[0])

[[ 1.          0.          0.        ]
 [ 2.          0.         -0.13507664]
 [ 3.          0.         -0.3494819 ]
 [ 4.          1.          0.        ]
 [ 5.          2.         -0.34530476]
 [ 6.          3.         -0.14551462]
 [ 7.          3.         -0.27744132]
 [ 8.          0.         -0.79557663]
 [ 9.          6.         -0.9281321 ]
 [ 1.          1.          1.        ]
 [ 3.          5.         -0.17370829]
 [ 7.          4.         -0.6379149 ]
 [ 2.          5.         -0.2928816 ]
 [ 0.          1.          0.        ]
 [ 1.          6.          0.        ]
 [ 0.          1.          0.        ]
 [ 0.          2.          0.        ]
 [ 0.          3.          0.        ]
 [ 1.          4.          0.        ]
 [ 2.          5.         -0.34530476]
 [ 3.          6.         -0.14551462]
 [ 3.          7.         -0.27744132]
 [ 0.          8.          0.        ]
 [ 6.          9.         -0.9281321 ]
 [ 1.          1.          1.        ]
 [ 5.          3.        

B is a list of external forces. It defines, for each vertex the value of the external force.
They should also sum to zero.

Here we have the following external forces :
    - A force applied on vertex 0 of 1.5;
    - A force applied on vertex 1 of -0.5;
    - A force applied on vertex 2 of -1.;

In [56]:
B_custom = np.array([[[1.5],[-0.5],[-1.]]])

Now let's predict using our previously trained GNS:

In [57]:
X_hat = sess.run(model.X_final, feed_dict={model.A:A_custom, model.B:B_custom})
print('Graph Neural Solver prediction :')
print('    X_0 = {0:.4f}'.format(X_hat[0,0,0]))
print('    X_1 = {0:.4f}'.format(X_hat[0,1,0]))
print('    X_2 = {0:.4f}'.format(X_hat[0,2,0]))

Graph Neural Solver prediction :
    X_0 = 4.4489
    X_1 = 0.2670
    X_2 = -0.4157


In [58]:
def solve(A,B):
    """
    Solves a linear system AX=B.
    Assumes that A is non-singular
    
    Args:
        - A : numpy array, sparse description of matrix [n_edges, 3]
        - B : numpy array [n_nodes, 1]
    """
    
    # Get relevant values
    n_nodes = np.shape(B)[0]
    n_edges = np.shape(A)[0]
    
    # Build matrix A
    A_mat = np.zeros([n_nodes, n_nodes])
    for edge in A:
        A_mat[int(edge[0]), int(edge[1])] = edge[2]
    print(A_mat)
    
    # Inverse A
    A_mat_inv = np.linalg.inv(A_mat)
    
    # Find X
    X = np.matmul(A_mat_inv, B)
    return X

In [59]:
X_gt = solve(A_custom[0], B_custom[0])
print('Ground truth :')
print('    X_0 = {0:.4f}'.format(X_gt[0,0]))
print('    X_1 = {0:.4f}'.format(X_gt[1,0]))
print('    X_2 = {0:.4f}'.format(X_gt[2,0]))

[[ 1.   0.   0. ]
 [-0.5  1.3 -0.8]
 [-0.1 -0.8  0.9]]
Ground truth :
    X_0 = 1.5000
    X_1 = -0.8585
    X_2 = -1.7075


# Animating the evolution of predictions

One of the basic ideas of the Graph Neural Solver architecture is to iteratively update the prediction (as well as latent messages) by propagating information between direct neighbors. During the training process, the loss is a (weighted) sum of how much each subsequent update violates the target equation.

Let's visualize how well the algorithm predicts the output at each update!

In [60]:
# Import numpy data 
data_dir = 'datasets/spring/default'
mode = 'test'

B_np = np.load(os.path.join(data_dir, 'B_'+mode+'.npy'))
A_np = np.load(os.path.join(data_dir, 'A_'+mode+'.npy'))
X_np = np.load(os.path.join(data_dir, 'X_'+mode+'.npy'))

In [61]:
# Sample one instance of linear system, RELOAD FROM THERE TO SAMPLE ANOTHER DATAPOINT
idx_x = np.random.randint(0,B_np.shape[0])
idx_y = np.random.randint(0,B_np.shape[0])

B_x = B_np[idx_x:idx_x+1]
B_y = B_np[idx_y:idx_y+1]
A = A_np[idx_x:idx_x+1]
X = solve(A_np[idx_x], B_np[idx_x])
Y = solve(A_np[idx_x], B_np[idx_y])

# Predict with GNS
X_hat = sess.run(model.X, feed_dict={model.A:A, model.B:B_x})
Y_hat = sess.run(model.X, feed_dict={model.A:A, model.B:B_y})
X_hat = np.array([X_hat[str(i)] for i in range(model.correction_updates+1)])[:,0,:,0]
#X_hat = X_hat - np.expand_dims(X_hat[:,0], -1)
Y_hat = np.array([Y_hat[str(i)] for i in range(model.correction_updates+1)])[:,0,:,0]
#Y_hat = Y_hat - np.expand_dims(Y_hat[:,0], -1)

[[ 1.          0.          0.          0.          0.          0.
   0.          0.          0.          0.        ]
 [-0.93447876  3.38364315 -0.34848607  0.         -0.13357964  0.
  -0.48632306 -0.84786296  0.         -0.63291234]
 [-0.70500618 -0.34848607  1.05349231  0.          0.          0.
   0.          0.          0.          0.        ]
 [-0.75141394  0.          0.          2.51994801  0.         -0.88061219
   0.          0.         -0.88792187  0.        ]
 [-0.68036813 -0.13357964  0.          0.          0.8139478   0.
   0.          0.          0.          0.        ]
 [ 0.          0.          0.         -0.88061219  0.          0.88061219
   0.          0.          0.          0.        ]
 [ 0.         -0.48632306  0.          0.          0.          0.
   1.77709877 -0.76407218 -0.52670354  0.        ]
 [-0.95362556 -0.84786296  0.          0.          0.          0.
  -0.76407218  2.70515299 -0.1395923   0.        ]
 [ 0.          0.          0.         -0.8879218

In [64]:
arrow_length = 0.1

node_traces = []
edge_traces = []
force_traces = []
node_traces_true = []
edge_traces_true = []

pos_x = []
pos_y = []
frames = []
slider_steps = []

for t in range(X_hat.shape[0]):
    
    ## GNS ##
    
    # Nodes 
    node_x = []
    node_y = []
    for i in range(X_hat.shape[1]):
        x, y = X_hat[t, i], Y_hat[t, i]
        node_x.append(x)
        node_y.append(y)
    node_trace = go.Scatter(x=node_x, 
                            y=node_y, 
                            mode='markers', 
                            marker_size=5, 
                            marker=dict(color='black'), 
                            name="GNS prediction - Nodes")
    node_traces.append(node_trace)
    pos_x.extend(node_x)
    pos_y.extend(node_y)
    
    # Edges
    edge_x = []
    edge_y = []
    for edge in A[0]:
        from_side = int(edge[0])
        to_side = int(edge[1])
        x0, y0 = X_hat[t, from_side], Y_hat[t, from_side]
        x1, y1 = X_hat[t, to_side]+1e-2, Y_hat[t, to_side]+1e-2
        edge_x.append(x0)
        edge_x.append(x1)
        edge_x.append(None)
        edge_y.append(y0)
        edge_y.append(y1)
        edge_y.append(None)
    edge_trace = go.Scatter(x=edge_x, 
                            y=edge_y, 
                            mode='lines',
                            line=dict(color='black', width=1),
                            name='GNS prediction - Edges')
    edge_traces.append(edge_trace)
    pos_x.extend(edge_x)
    pos_y.extend(edge_y)
    
    # Forces
    force_x = []
    force_y = []
    for i in range(X_hat.shape[1]):
        x0, y0 = X_hat[t, i], Y_hat[t, i]
        x1, y1 = X_hat[t, i]-B_x[0,i,0], Y_hat[t, i]-B_y[0,i,0]
        
        angle = np.arctan2(-B_y[0,i,0], -B_x[0,i,0])
        angle_left = angle + np.pi*3/4
        angle_right = angle - np.pi*3/4
        
        force_x.append(x0)
        force_x.append(x1)
        force_x.append(None)
        
        force_x.append(x1)
        force_x.append(x1+arrow_length*np.cos(angle_left))
        force_x.append(None)
        
        force_x.append(x1)
        force_x.append(x1+arrow_length*np.cos(angle_right))
        force_x.append(None)
        
        force_y.append(y0)
        force_y.append(y1)
        force_y.append(None)
        
        force_y.append(y1)
        force_y.append(y1+arrow_length*np.sin(angle_left))
        force_y.append(None)
        
        force_y.append(y1)
        force_y.append(y1+arrow_length*np.sin(angle_right))
        force_y.append(None)
        
    force_trace = go.Scatter(x=force_x, 
                             y=force_y, 
                             mode='lines',
                             line=dict(color='red', width=1),
                             name='Forces')
    force_traces.append(force_trace)
    pos_x.extend(force_x)
    pos_y.extend(force_y)
    
    
    ## GROUND TRUTH ##
    
    # Nodes
    node_x_true = []
    node_y_true = []
    for i in range(X_hat.shape[1]):
        x_true, y_true = X[i,0], Y[i,0]
        node_x_true.append(x_true)
        node_y_true.append(y_true)
    node_trace_true = go.Scatter(x=node_x_true, 
                                 y=node_y_true, 
                                 mode='markers', 
                                 marker_size=5, 
                                 marker=dict(color='lightgrey'),
                                 name='Ground truth - Nodes')
    node_traces_true.append(node_trace_true)
    pos_x.extend(node_x_true)
    pos_y.extend(node_y_true)
        
    # Edges
    edge_x_true = []
    edge_y_true = []
    for edge in A[0]:
        from_side = int(edge[0])
        to_side = int(edge[1])
        x0, y0 = X[from_side,0], Y[from_side,0]
        x1, y1 = X[to_side,0], Y[to_side,0]
        edge_x_true.append(x0)
        edge_x_true.append(x1)
        edge_x_true.append(None)
        edge_y_true.append(y0)
        edge_y_true.append(y1)
        edge_y_true.append(None)
    edge_trace_true = go.Scatter(x=edge_x_true, 
                                 y=edge_y_true, 
                                 mode='lines',
                                 line=dict(color='lightgrey', width=1),
                                 name='Ground truth - Edges')
    edge_traces_true.append(edge_trace_true)
    pos_x.extend(edge_x_true)
    pos_y.extend(edge_y_true)
    
    frame = go.Frame(data=[edge_trace_true, node_trace_true, force_trace, edge_trace, node_trace],
                          layout=go.Layout(title_text="Correction update {}".format(t)),
                    name= str(t))
    frames.append(frame)
    
    slider_step = {"args": [
        [str(t)],
        {"frame": {"duration": 300, "redraw": False},
         "mode": "immediate",
         "transition": {"duration": 300}}
    ],
        "label": str(t),
        "method": "animate"}
    slider_steps.append(slider_step)
    
pos_x = [x for x in pos_x if x is not None]
pos_y = [y for y in pos_y if y is not None]

fig = go.Figure(
    data=[edge_traces_true[0],  node_traces_true[0], force_traces[0], edge_traces[0], node_traces[0]],#[go.Scatter(x=[0, 1], y=[0, 1])],
    layout=go.Layout(
        xaxis=dict(range=[min(min(pos_x), min(pos_y)), max(max(pos_x), max(pos_y))], autorange=False),
        yaxis=dict(range=[min(min(pos_x), min(pos_y)), max(max(pos_x), max(pos_y))], autorange=False),
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)',
        width=800,
        height=700
    ),
    frames=frames
)

fig["layout"]["updatemenus"] = [
    {
        "buttons": [
            {
                "args": [None, {"frame": {"duration": 500, "redraw": False},
                                "fromcurrent": True, "transition": {"duration": 300,
                                                                    "easing": "quadratic-in-out"}}],
                "label": "Play",
                "method": "animate"
            },
            {
                "args": [[None], {"frame": {"duration": 0, "redraw": False},
                                  "mode": "immediate",
                                  "transition": {"duration": 0}}],
                "label": "Pause",
                "method": "animate"
            }
        ],
        "direction": "left",
        "pad": {"r": 10, "t": 87},
        "showactive": False,
        "type": "buttons",
        "x": 0.1,
        "xanchor": "right",
        "y": 0,
        "yanchor": "top"
    }
]

fig["layout"]["sliders"] = [{
    "active": 0,
    "yanchor": "top",
    "xanchor": "left",
    "currentvalue": {
        "font": {"size": 20},
        "prefix": "Correction update:",
        "visible": True,
        "xanchor": "right"
    },
    "transition": {"duration": 300, "easing": "cubic-in-out"},
    "pad": {"b": 10, "t": 50},
    "len": 0.9,
    "x": 0.1,
    "y": 0,
    "steps": slider_steps
}]

fig.show()


IndexError: too many indices for array

You can press play, or directly select the correction update you would like to visualize. You can also zoom in to better see how well or how bad the GNS performs.

In light grey is shown the actual ground truth (using a matrix inversion), and in black the prediction of the current update. Moreover, in red are shown the external forces applied to each vertex (they are indeed constant).


# Visualizing the training process

As previously mentionned, the loss that is minimized during the training process is a weighted sum of the loss of each correction update. Thus, it is possible to visualize how each of these losses evolve during the training process. Let's try to visualize that!

In [14]:
expe_path = os.path.join(model_path, 'train/')

MOVING_WINDOW = 200

def my_summary_iterator(path):
    try:
        for r in tf_record.tf_record_iterator(path):
            yield event_pb2.Event.FromString(r)
    except:
        pass
        
loss_dict = {}
for filename in os.listdir(expe_path):
    event_path = os.path.join(expe_path, filename)
    for event in my_summary_iterator(event_path):
        for value in event.summary.value:
            try:
                loss_dict[value.tag].append(value.simple_value)
            except:
                loss_dict[value.tag] = [value.simple_value]
                
for key in loss_dict.keys():           
    loss_dict[key] = np.array(loss_dict[key])
    loss_dict[key] = np.convolve(loss_dict[key], 
                                       np.ones((MOVING_WINDOW,))/MOVING_WINDOW, 
                                       mode='valid')
    loss_dict[key] = loss_dict[key][::MOVING_WINDOW]


In [15]:
def get_rgb(i, N):
    color_0 = np.array([223, 109, 81])
    color_1 = np.array([116., 20., 12.])
    color_i = color_0 * (1.-i/N) + color_1 * i/N
    return 'rgb({}, {}, {})'.format(color_i[0], color_i[1], color_i[2])


fig = go.Figure()
# Create and style traces
x = np.linspace(0, len(loss_dict['loss_final'])-1, len(loss_dict['loss_final'])) * MOVING_WINDOW

for update in range(1, model.correction_updates+1):
    fig.add_trace(go.Scatter(x=x, 
                         y=loss_dict['loss_{}'.format(update)], name='Loss at update {}'.format(update),
                         line=dict(color=get_rgb(update,model.correction_updates) , width=1)))
fig.update_layout(yaxis_type="log", 
                  paper_bgcolor='rgba(0,0,0,0)',
                  plot_bgcolor='rgba(0,0,0,0)',
                  width=800,
                  height=400,
                  title='Evolution of the loss across the Graph Neural Solver',
                  xaxis_title='Learning iteration',
                  yaxis_title='Linear equation violation')

fig.show()

# Accuracy on the Test set

Here, we visualize, for each vertex its prediction versus its actual ground truth, across the correction updates. Ideally, all vertices should end up on the diagonal y=x, which would mean that the accuracy is perfect.

In [16]:
# Import numpy data 
data_dir = 'datasets/spring/default'
mode = 'test'

B_np = np.load(os.path.join(data_dir, 'B_'+mode+'.npy'))
A_np = np.load(os.path.join(data_dir, 'A_'+mode+'.npy'))
X_np = np.load(os.path.join(data_dir, 'X_'+mode+'.npy'))

#n_samples = B_np.shape[0]
#n_nodes = B_np.shape[1]
#n_edges = A_np.shape[1]


#A_offset = np.linspace(0, n_samples-1, n_samples)*n_nodes
#A_offset = np.reshape(A_offset, [-1, 1, 1])
#offset = np.c_[A_offset, A_offset, np.zeros_like(A_offset)]
#A_np = A_np + offset

#A_np = np.reshape(A_np, [-1, 3])
#B_np = np.reshape(B_np, [-1, 1])

X_hat = sess.run(model.X, feed_dict={model.A:A_np, model.B:B_np})

#for i in range(0, model.correction_updates+1):
#    X_hat[str(i)] = np.reshape(X_hat[str(i)], [n_samples, n_nodes])
#    X_hat[str(i)] -= np.reshape(X_hat[str(i)][:,0], [n_samples, 1])
#    X_hat[str(i)] = np.reshape(X_hat[str(i)], [-1])
    
#X_np = np.reshape(X_np, [-1])

In [17]:
frames = []
slider_steps = []

middle_traces = []
right_traces = []
top_traces = []

for t in range(model.correction_updates+1):
    
    middle_trace = go.Histogram2d(
        x=np.reshape(X_np, [-1]),
        y=np.reshape(X_hat[str(t)], [-1]),
        xaxis = 'x',
        yaxis = 'y',
        xbins = dict(start=-20, end=20, size=0.2),
        ybins = dict(start=-20, end=20, size=0.2),
        colorscale= [
            [0, 'rgb(255, 255, 255)'],       
            [1./3333, 'rgb(251, 232, 204)'],
            [1./1000, 'rgb(247, 213, 165)'],
            [1./333, 'rgb(244, 189, 140)'],  
            [1./100, 'rgb(238, 146, 100)'],      
            [1./33.3, 'rgb(223, 109, 81)'],             
            [1./10, 'rgb(199, 63, 45)'],
            [1./3.33, 'rgb(165, 32, 21)'],
            [1, 'rgb(116, 20, 12)']

        ],
        colorbar= dict(
            tick0= 0,
            tickmode= 'array',
            tickvals= [0, 3, 10, 30]
        )
    )
    middle_traces.append(middle_trace)
    
    right_trace = go.Histogram(
        y = np.reshape(X_hat[str(t)], [-1]),
        xaxis = 'x2',
        marker = dict(
            color = 'rgba(116, 20, 12, 1)'
        )
    )
    right_traces.append(right_trace)
    
    top_trace = go.Histogram(
        x = np.reshape(X_np, [-1]),
        yaxis = 'y2',
        marker = dict(
            color = 'rgba(116, 20, 12, 1)'
        )
    )
    top_traces.append(top_trace)
    
    
    
    frame = go.Frame(data=[middle_trace, right_trace, top_trace],
                    name= str(t),
                    layout = dict(
                        showlegend=False,
                        annotations=[
                            go.layout.Annotation(
                                x=15,
                                y=-19,
                                xref="x",
                                yref="y",
                                text="Correlation = {0:.3f}".format(np.nan_to_num(np.corrcoef(np.reshape(X_np, -1), np.reshape(X_hat[str(t)], -1))[0,1])),
                                showarrow=False,
                                arrowhead=0,
                                ax=0,
                                ay=-40
                            )
                        ]
                    ))
    frames.append(frame)
    
    slider_step = {"args": [
        [str(t)],
        {"frame": {"duration": 0, "redraw": True},
         "mode": "immediate",
         "transition": {"duration": 0}}
    ],
        "label": str(t),
        "method": "animate"}
    slider_steps.append(slider_step)

fig = go.Figure(
    data=[middle_traces[0], right_traces[0], top_traces[0]],#[go.Scatter(x=[0, 1], y=[0, 1])],
    layout=go.Layout(
        autosize = False,
        xaxis = dict(
            zeroline = False,
            domain = [0,0.85],
            showgrid = False,
            tickvals = [-10, 0, 10],
            title = 'Ground_truth'
        ),
        yaxis = dict(
            zeroline = False,
            domain = [0,0.85],
            showgrid = False,
            tickvals = [-10, 0, 10],
            title = 'Prediction'
        ),
        xaxis2 = dict(
            zeroline = True,
            zerolinewidth=1, 
            zerolinecolor='black',
            domain = [0.85,1],
            showgrid = True,
            gridcolor='LightGrey',
            tickvals = [10, 50, 100, 500],
        ),
        yaxis2 = dict(
            zeroline = True,
            zerolinewidth=1, 
            zerolinecolor='black',
            domain = [0.85,1],
            showgrid = True,
            gridcolor='LightGrey',
            tickvals = [100, 500],
        ),
        height = 800,
        width = 800,
        bargap = 0,
        hovermode = 'closest',
        showlegend = False,
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)',
        title='Prediction vs. Ground truth - Test set'
    ),
    frames=frames
)
fig["layout"]["updatemenus"] = [
    {
        "buttons": [
            {
                "args": [None, {"frame": {"duration": 500, "redraw": True},
                                "fromcurrent": False, "transition": {"duration": 500,
                                                                    "easing": "quadratic-in-out"}}],
                "label": "Play",
                "method": "animate"
            },
            {
                "args": [[None], {"frame": {"duration": 0, "redraw": True},
                                  "mode": "immediate",
                                  "transition": {"duration": 0}}],
                "label": "Pause",
                "method": "animate"
            }
        ],
        "direction": "left",
        "pad": {"r": 10, "t": 87},
        "showactive": False,
        "type": "buttons",
        "x": 0.1,
        "xanchor": "right",
        "y": 0,
        "yanchor": "top"
    }
]

fig["layout"]["sliders"] = [{
    "active": 0,
    "yanchor": "top",
    "xanchor": "left",
    "currentvalue": {
        "font": {"size": 20},
        "prefix": "Correction update:",
        "visible": True,
        "xanchor": "right"
    },
    "transition": {"duration": 0, "easing": "cubic-in-out"},
    "pad": {"b": 10, "t": 50},
    "len": 0.9,
    "x": 0.1,
    "y": 0,
    "steps": slider_steps
}]

fig.show()

The color is in log scale. One can see that the final prediction has a correlation of 0.996 with the actual ground truth, which means that our predictor is very accurate!
You can rescale the density on the right by double-clicking.