# Explore effect of different graph structures for solving sudoku

In [1]:
import collections
import itertools
import random
import time

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import tensorflow as tf

from copy import deepcopy
from graph_nets import graphs
from graph_nets import utils_np
from graph_nets import utils_tf
from graph_nets.demos import models
from math import ceil, floor, sqrt
from scipy import spatial

SEED = 1
random.seed(SEED)
np.random.seed(SEED)
tf.set_random_seed(SEED)

In [2]:
test_grid_9 = [
    [0, 8, 0, 0, 5, 0, 0, 0, 9],
    [0, 0, 0, 0, 0, 0, 6, 0, 7],
    [0, 9, 1, 0, 0, 6, 0, 0, 0],
    [4, 0, 2, 3, 0, 0, 0, 0, 0],
    [3, 1, 0, 0, 0, 0, 0, 0, 8],
    [0, 6, 0, 0, 4, 7, 9, 0, 0],
    [0, 0, 0, 0, 0, 5, 0, 0, 0],
    [8, 0, 0, 0, 7, 0, 0, 4, 0],
    [1, 0, 0, 0, 2, 0, 8, 0, 0]]

test_grid_9h = [
    [8, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 3, 6, 0, 0, 0, 0, 0],
    [0, 7, 0, 0, 9, 0, 2, 0, 0],
    [0, 5, 0, 0, 0, 7, 0, 0, 0],
    [0, 0, 0, 0, 4, 5, 7, 0, 0],
    [0, 0, 0, 1, 0, 0, 0, 3, 0],
    [0, 0, 1, 0, 0, 0, 0, 6, 8],
    [0, 0, 8, 5, 0, 0, 0, 1, 0],
    [0, 9, 0, 0, 0, 0, 4, 0, 0]]

test_grid_12 = [
    [ 0,  0,  8,  0,  9,  0,  6,  0,  0,  3,  0,  2],
    [ 0,  3,  0, 11,  0,  2,  0,  1,  6, 12,  0,  0],
    [ 0,  0,  9,  6,  0,  0,  7,  0,  5,  0,  0,  0],
    [ 0,  0,  0,  0,  6,  0,  1,  0,  2,  0,  0,  0],
    [ 6,  2,  0,  0,  0,  8,  0,  9,  0,  0,  0,  0],
    [ 5,  0, 10,  0,  0, 12,  0, 11,  0,  1,  0,  0],
    [12,  0,  0,  0,  0,  0,  3,  0,  0,  0,  7,  8],
    [11,  8,  0,  0,  0,  0,  0,  0,  0,  0,  0,  6],
    [ 0,  0,  0,  0,  0, 10,  9,  0,  0,  0,  5,  0],
    [ 0,  0,  1,  0,  0,  7,  0,  2,  9,  0,  0,  5],
    [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  3,  0],
    [ 0,  7,  0,  4,  0,  1, 12,  0,  0, 11,  0,  0]]

test_grid_16 = [
    [ 0,  7,  5,  0,  0, 15, 12,  0,  0,  0,  0,  0, 14,  0, 10,  0],
    [13, 16,  0,  0,  7, 11,  4,  0,  6,  0, 12,  0,  0,  9,  0,  0],
    [ 0, 14,  8,  0,  0,  0,  0,  9,  0,  2,  5,  0, 16,  0,  0,  6],
    [ 0,  0,  0, 15,  0,  0,  0, 13,  0,  7,  0,  0,  0,  0,  0,  0],
    [ 0,  0, 11, 14,  0,  7,  0,  4,  0,  6,  0, 13,  0,  0,  0,  0],
    [ 0,  0,  4,  0,  0,  5,  6,  0,  0,  0,  0,  3,  0, 14,  0,  0],
    [ 0,  0,  0, 12,  0, 10, 15, 11,  0,  0,  0,  9,  0,  4,  3,  5],
    [16,  0,  0,  0, 14,  0,  0,  2,  1,  0,  8, 12, 10, 11,  0,  0],
    [ 0, 12,  0,  0, 10,  0,  0,  0,  0,  0,  0, 14,  0,  0,  0,  4],
    [ 0,  0, 15, 13,  0,  0,  0,  0,  4, 12,  0,  0, 11,  0,  0,  0],
    [ 6,  0, 16,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 15,  0],
    [ 0,  8,  0,  0,  0,  0,  0, 12,  0,  0, 16, 10,  0,  3,  0,  7],
    [ 0, 13,  0,  0,  0,  0,  7, 14, 16,  9,  0, 15,  0,  0,  1, 11],
    [ 0,  0, 14,  6,  0,  0, 13,  1,  0,  3,  0,  0,  0,  0,  0,  0],
    [ 0,  4,  0,  0,  3,  0,  0,  0,  7,  0,  1,  0,  5,  2,  0,  0],
    [ 0,  0,  2, 10,  0, 12,  0,  0, 14,  0,  0,  0,  0,  0, 16,  0]]

In [3]:
# Sudoku generator functions

def select(X, Y, r):
    cols = []
    for j in Y[r]:
        for i in X[j]:
            for k in Y[i]:
                if k != j:
                    X[k].remove(i)
        cols.append(X.pop(j))
    return cols


def deselect(X, Y, r, cols):
    for j in reversed(Y[r]):
        X[j] = cols.pop()
        for i in X[j]:
            for k in Y[i]:
                if k != j:
                    X[k].add(i)


def exact_cover(X, Y):
    '''Exact cover algorithm.'''
    X = {j: set() for j in X}
    for i, row in Y.items():
        for j in row:
            X[j].add(i)
    return X, Y


def solver(X, Y, solution):
    '''Recursive solver sub-routine.'''
    if not X:
        yield list(solution)
    else:
        c = min(X, key=lambda c: len(X[c]))
        for r in list(X[c]):
            solution.append(r)
            cols = select(X, Y, r)
            for s in solver(X, Y, solution):
                yield s
            deselect(X, Y, r, cols)
            solution.pop()


def solve_sudoku(size=9, grid=[]):
    '''Sudoku solver using Algorithm X.'''
    R = C = int(ceil(sqrt(size)))
    N = size
    X = ([("rc", rc) for rc in itertools.product(range(N), range(N))] +
         [("rn", rn) for rn in itertools.product(range(N), range(1, N + 1))] +
         [("cn", cn) for cn in itertools.product(range(N), range(1, N + 1))] +
         [("bn", bn) for bn in itertools.product(range(N), range(1, N + 1))])
    Y = dict()
    for r, c, n in itertools.product(range(N), range(N), range(1, N + 1)):
        #block = m * (x // m) + y // (n // m)
        #b = (r // R) * R + (c // C) # Box number
        b = (c // C) * C + r // (N // R) # Box number
        Y[(r, c, n)] = [
            ("rc", (r, c)),
            ("rn", (r, n)),
            ("cn", (c, n)),
            ("bn", (b, n))]
    X, Y = exact_cover(X, Y)
    for i, row in enumerate(grid):
        for j, n in enumerate(row):
            if n:
                select(X, Y, (i, j, n))
    
    for row in solver(X, Y, []):
        for (r, c, n) in row:
            grid[r][c] = n
        yield grid


def get_sudoku_solution(size=9, grid=[]):
    '''Solve sudoku and return solution.
    
    Example:
        test_grid_9 = [
            [0, 8, 0, 0, 5, 0, 0, 0, 9],
            [0, 0, 0, 0, 0, 0, 6, 0, 7],
            [0, 9, 1, 0, 0, 6, 0, 0, 0],
            [4, 0, 2, 3, 0, 0, 0, 0, 0],
            [3, 1, 0, 0, 0, 0, 0, 0, 8],
            [0, 6, 0, 0, 4, 7, 9, 0, 0],
            [0, 0, 0, 0, 0, 5, 0, 0, 0],
            [8, 0, 0, 0, 7, 0, 0, 4, 0],
            [1, 0, 0, 0, 2, 0, 8, 0, 0]]
        solution = get_sudoku_solution(9, test_grid_9)
        print(*solution, sep='\n')
        # [6, 8, 4, 7, 5, 3, 2, 1, 9]
        # [2, 3, 5, 4, 1, 9, 6, 8, 7]
        # [7, 9, 1, 2, 8, 6, 3, 5, 4]
        # [4, 7, 2, 3, 9, 8, 1, 6, 5]
        # [3, 1, 9, 5, 6, 2, 4, 7, 8]
        # [5, 6, 8, 1, 4, 7, 9, 3, 2]
        # [9, 4, 6, 8, 3, 5, 7, 2, 1]
        # [8, 2, 3, 9, 7, 1, 5, 4, 6]
        # [1, 5, 7, 6, 2, 4, 8, 9, 3]
    '''
    solve_iter = iter(solve_sudoku(size, deepcopy(grid)))
    solution = []
    try:
        solution = next(solve_iter)
        solution.append(next(solve_iter))
    except StopIteration:
        pass
    finally:
        if len(solution) < size:
            return "NO SOLUTION"
        elif len(solution) > size:
            return "MORE THAN ONE SOLUTION"
    return solution


def generate_sudoku(size=9):
    '''Generate random sudoku with solution.
    
    Example:
        unsolved, solved = generate_sudoku(9)
        print("Unsolved:", *unsolved, "Solved:", *solved, sep='\n')
        # Unsolved:
        # [0, 0, 7, 0, 6, 2, 0, 3, 0]
        # [6, 2, 3, 0, 0, 0, 0, 7, 0]
        # [0, 1, 0, 4, 0, 7, 0, 2, 0]
        # [0, 0, 0, 6, 0, 0, 0, 0, 0]
        # [1, 0, 8, 0, 9, 0, 5, 0, 7]
        # [2, 0, 6, 0, 0, 5, 0, 1, 0]
        # [0, 0, 2, 0, 4, 0, 0, 5, 0]
        # [0, 0, 0, 0, 2, 8, 7, 9, 6]
        # [0, 0, 0, 0, 7, 1, 3, 0, 2]
        # Solved:
        # [4, 8, 7, 1, 6, 2, 9, 3, 5]
        # [6, 2, 3, 8, 5, 9, 1, 7, 4]
        # [9, 1, 5, 4, 3, 7, 6, 2, 8]
        # [7, 5, 4, 6, 1, 3, 2, 8, 9]
        # [1, 3, 8, 2, 9, 4, 5, 6, 7]
        # [2, 9, 6, 7, 8, 5, 4, 1, 3]
        # [3, 7, 2, 9, 4, 6, 8, 5, 1]
        # [5, 4, 1, 3, 2, 8, 7, 9, 6]
        # [8, 6, 9, 5, 7, 1, 3, 4, 2]
    '''
    grid = [[0]*size for _ in range(size)]  
    coords = list(divmod(i, size) for i in range(size ** 2))
    random.shuffle(coords)     
    for i, j in coords:
        g = deepcopy(grid)
        options = list(range(1, size + 1))
        random.shuffle(options)
        g[i][j] = options.pop()
        check_solve = get_sudoku_solution(size, g)
        while check_solve == "NO SOLUTION":
            g[i][j] = options.pop()
            check_solve = get_sudoku_solution(size, g)
        if check_solve == "MORE THAN ONE SOLUTION":
            grid = deepcopy(g)
        elif isinstance(check_solve, list):
            solution = check_solve
            break
    return grid, solution

In [4]:
solution = get_sudoku_solution(9, test_grid_9h)
print(*solution, sep='\n')

[8, 1, 2, 7, 5, 3, 6, 4, 9]
[9, 4, 3, 6, 8, 2, 1, 7, 5]
[6, 7, 5, 4, 9, 1, 2, 8, 3]
[1, 5, 4, 2, 3, 7, 8, 9, 6]
[3, 6, 9, 8, 4, 5, 7, 2, 1]
[2, 8, 7, 1, 6, 9, 5, 3, 4]
[5, 2, 1, 9, 7, 4, 3, 6, 8]
[4, 3, 8, 5, 2, 6, 9, 1, 7]
[7, 9, 6, 3, 1, 8, 4, 5, 2]


In [3]:
# Helper functions

EDGE_WEIGHT_NAME = 'weight'
EDGE_WEIGHT = 1.


def pairwise(iterable):
    '''s -> (s0,s1), (s1,s2), (s2, s3), ...'''
    a, b = itertools.tee(iterable)
    next(b, None)
    return zip(a, b)


def set_diff(seq0, seq1):
    '''Return the set difference between 2 sequences as a list.'''
    return list(set(seq0) - set(seq1))


def to_one_hot(indices, max_value, axis=-1):
    one_hot = np.eye(max_value)[indices]
    if axis not in (-1, one_hot.ndim):
        one_hot = np.moveaxis(one_hot, -1, axis)
    return one_hot


def get_node_dict(graph, attr):
    '''Return a `dict` of node:attribute pairs from a graph.'''
    return {k: v[attr] for k, v in graph.node.items()}


def base_graph_dict(size, edge_weight):
    '''Define a basic sudoku graph structure.

    The board is composed of 'size'x'size' nodes connected by edges
    in a grid-like structure. Nodes are `size` values deep to represent
    the possible values as probabilities - with initial given values
    converted with one-hot vectors.

    Args:
    size: int representing dimension of square 'size' by 'size'
    edge_weight: float value assigned as edge weights

    Returns:
    data_dict: dictionary with globals, nodes, edges, receivers and senders
        to represent a graph structure.
    '''
    nodes = np.zeros((size, size, size), dtype=np.float32)
    edges, senders, receivers = [], [], []
    for x in range(size - 1):
        for y in range(size - 1):
            top_node = left_node = (x, y)
            right_node = (x + 1, y)
            bottom_node = (x, y + 1)
            # If statemens set to prevent incoming edges from
            # planar-graph boundary.
            if right_node[0] < size - 1:
                # Left incoming edge
                edges.append([edge_weight])
                senders.append(left_node)
                receivers.append(right_node)
            if left_node[0] > 0:
                # Right incoming edge
                edges.append([edge_weight])
                senders.append(right_node)
                receivers.append(left_node)
            if bottom_node[1] < size - 1:
                # Top incoming edge
                edges.append([edge_weight])
                senders.append(top_node)
                receivers.append(bottom_node)
            if bottom_node[1] < size - 1:
                # Bottom incoming edge
                edges.append([edge_weight])
                senders.append(bottom_node)
                receivers.append(top_node)      
    return {
        "globals": [size],
        "nodes": nodes,
        "edges": edges,
        "receivers": receivers,
        "senders": senders,
        "n_node": size ** 2,
        "n_edge": (size - 1) ** 2
    }


def generate_graph_dicts(num_examples, graph_sizes=[4, 9]):
    '''Define basic sudoku graph structure.'''
    edge_weight = EDGE_WEIGHT
    graph_size = tf.gather(
        tf.convert_to_tensor(graph_sizes, dtype=tf.int32),
        tf.random_uniform(
            [num_examples],
            minval=0,
            maxval=len(graph_sizes),
            dtype=tf.int32))
    input_graph_dicts, target_graph_dicts, raw_graph_dicts = [], [], []
    for i in range(num_examples):
        input_grid, target_grid = generate_sudoku(graph_size[i])
        input_nodes = tf.one_hot(
            tf.convert_to_tensor(input_grid) - 1,
            depth=graph_size[i])
        target_nodes = tf.one_hot(
            tf.convert_to_tensor(target_grid) - 1,
            depth=graph_size[i])
        graph_dict = base_graph_dict(graph_size[i], edge_weight)
        input_dict = deepcopy(graph_dict)
        input_dict.nodes = input_nodes
        target_dict = deepcopy(graph_dict)
        target_dict.nodes = target_nodes
        input_graph_dicts.append(input_dict)
        target_graph_dicts.append(target_dict)
        raw_graph_dicts.append(graph_dict)
    return input_graph_dicts, target_graph_dicts, raw_graph_dicts


def create_placeholders(rand, batch_size, graph_sizes=[4, 9]):
    """Creates placeholders for the model training and evaluation.

    Args:
        rand: A random seed (np.RandomState instance).
        batch_size: Total number of graphs per batch.
        graph_sizes: A list [4, 9, 12, 16] with the allowable sizes n of the 
            n x n graph. The graph size is uniformly sampled within this range.

    Returns:
        input_ph: The input graph's placeholders, as a graph namedtuple.
        target_ph: The target graph's placeholders, as a graph namedtuple.
    """
    # Create some example data for inspecting the vector sizes.
    input_graphs, target_graphs, _ = generate_graph_dicts(batch_size, graph_sizes)
    input_ph = utils_tf.placeholders_from_data_dicts(input_graphs)
    target_ph = utils_tf.placeholders_from_data_dicts(target_graphs)
    return input_ph, target_ph


def create_feed_dict(rand,
                     input_ph,
                     target_ph,
                     batch_size,
                     graph_sizes=[4, 9]):
    """Creates placeholders for the model training and evaluation.

    Args:
        rand: A random seed (np.RandomState instance).
        input_ph: The input graph's placeholders, as a graph namedtuple.
        target_ph: The target graph's placeholders, as a graph namedtuple.
        batch_size: Total number of graphs per batch.
        graph_sizes: A list [4, 9, 12, 16] with the allowable size n of the 
            n x n graph. The graph size is uniformly sampled within this range.

    Returns:
        feed_dict: The feed `dict` of input and target placeholders and data.
        raw_graphs: The `dict` of raw networkx graphs.
    """
    inputs, targets, raw_graphs = generate_graph_dicts(rand, batch_size, graph_sizes)
    input_graphs = utils_tf.data_dicts_to_graphs_tuple(inputs)
    target_graphs = utils_tf.data_dicts_to_graphs_tuple(targets)
    feed_dict = {input_ph: input_graphs, target_ph: target_graphs}
    return feed_dict, raw_graphs


def compute_accuracy(target, output, use_nodes=True, use_edges=False):
    '''Calculate model accuracy.

    Returns the number of correctly predicted shortest path nodes and the number
    of completely solved graphs (100% correct predictions).

    Args:
        target: A `graphs.GraphsTuple` that contains the target graph.
        output: A `graphs.GraphsTuple` that contains the output graph.
        use_nodes: A `bool` indicator of whether to compute node accuracy or not.
        use_edges: A `bool` indicator of whether to compute edge accuracy or not.

    Returns:
        correct: A `float` fraction of correctly labeled nodes/edges.
        solved: A `float` fraction of graphs that are completely correctly labeled.

    Raises:
        ValueError: Nodes or edges (or both) must be used
    '''
    if not use_nodes and not use_edges:
        raise ValueError(
            "Nodes or edges (or both) must be used to compute accuracy")
    tdds = utils_np.graphs_tuple_to_data_dicts(target)
    odds = utils_np.graphs_tuple_to_data_dicts(output)
    cs = []
    ss = []
    for td, od in zip(tdds, odds):
        xn = np.argmax(td["nodes"], axis=-1)
        yn = np.argmax(od["nodes"], axis=-1)
        xe = np.argmax(td["edges"], axis=-1)
        ye = np.argmax(od["edges"], axis=-1)
        c = []
        if use_nodes:
            c.append(xn == yn)
        if use_edges:
            c.append(xe == ye)
        c = np.concatenate(c, axis=0)
        s = np.all(c)
        cs.append(c)
        ss.append(s)
    correct = np.mean(np.concatenate(cs, axis=0))
    solved = np.mean(np.stack(ss))
    return correct, solved


def create_loss_ops(target_op, output_ops):
    loss_ops = [
        tf.losses.softmax_cross_entropy(target_op.nodes, output_op.nodes)
        for output_op in output_ops
    ]
    return loss_ops


def make_all_runnable_in_session(*args):
    '''Lets an iterable of TF graphs be output from a session as NP graphs.'''
    return [utils_tf.make_runnable_in_session(a) for a in args]




In [5]:
class GraphPlotter(object):

    def __init__(self, ax, graph, pos):
        self._ax = ax
        self._graph = graph
        self._pos = pos
        self._base_draw_kwargs = dict(G=self._graph,
                                      pos=self._pos, ax=self._ax)
        self._solution_length = None
        self._nodes = None
        self._edges = None
        self._start_nodes = None
        self._end_nodes = None
        self._solution_nodes = None
        self._intermediate_solution_nodes = None
        self._solution_edges = None
        self._non_solution_nodes = None
        self._non_solution_edges = None
        self._ax.set_axis_off()

    @property
    def solution_length(self):
        if self._solution_length is None:
            self._solution_length = len(self._solution_edges)
        return self._solution_length

    @property
    def nodes(self):
        if self._nodes is None:
            self._nodes = self._graph.nodes()
        return self._nodes

    @property
    def edges(self):
        if self._edges is None:
            self._edges = self._graph.edges()
        return self._edges

    @property
    def start_nodes(self):
        if self._start_nodes is None:
            self._start_nodes = [
                n for n in self.nodes if self._graph.node[n].get("start", False)
            ]
        return self._start_nodes

    @property
    def end_nodes(self):
        if self._end_nodes is None:
            self._end_nodes = [
                n for n in self.nodes if self._graph.node[n].get("end", False)
            ]
        return self._end_nodes

    @property
    def solution_nodes(self):
        if self._solution_nodes is None:
            self._solution_nodes = [
                n for n in self.nodes if self._graph.node[n].get("solution", False)
            ]
        return self._solution_nodes

    @property
    def intermediate_solution_nodes(self):
        if self._intermediate_solution_nodes is None:
            self._intermediate_solution_nodes = [
                n for n in self.nodes
                if self._graph.node[n].get("solution", False) and
                not self._graph.node[n].get("start", False) and
                not self._graph.node[n].get("end", False)
            ]
        return self._intermediate_solution_nodes

    @property
    def solution_edges(self):
        if self._solution_edges is None:
            self._solution_edges = [
                e for e in self.edges
                if self._graph.get_edge_data(e[0], e[1]).get("solution", False)
          ]
        return self._solution_edges

    @property
    def non_solution_nodes(self):
        if self._non_solution_nodes is None:
            self._non_solution_nodes = [
                n for n in self.nodes
                if not self._graph.node[n].get("solution", False)
            ]
        return self._non_solution_nodes

    @property
    def non_solution_edges(self):
        if self._non_solution_edges is None:
            self._non_solution_edges = [
                e for e in self.edges
                if not self._graph.get_edge_data(e[0], e[1]).get("solution", False)
            ]
        return self._non_solution_edges

    def _make_draw_kwargs(self, **kwargs):
        kwargs.update(self._base_draw_kwargs)
        return kwargs

    def _draw(self, draw_function, zorder=None, **kwargs):
        draw_kwargs = self._make_draw_kwargs(**kwargs)
        collection = draw_function(**draw_kwargs)
        if collection is not None and zorder is not None:
            try:
                # This is for compatibility with older matplotlib.
                collection.set_zorder(zorder)
            except AttributeError:
                # This is for compatibility with newer matplotlib.
                collection[0].set_zorder(zorder)
        return collection

def draw_nodes(self, **kwargs):
    """Useful kwargs: nodelist, node_size, node_color, linewidths."""
    if ("node_color" in kwargs and
        isinstance(kwargs["node_color"], collections.Sequence) and
        len(kwargs["node_color"]) in {3, 4} and
        not isinstance(kwargs["node_color"][0],
                       (collections.Sequence, np.ndarray))):
        
        num_nodes = len(kwargs.get("nodelist", self.nodes))
        kwargs["node_color"] = np.tile(
            np.array(kwargs["node_color"])[None], [num_nodes, 1])
    return self._draw(nx.draw_networkx_nodes, **kwargs)

    def draw_edges(self, **kwargs):
        """Useful kwargs: edgelist, width."""
        return self._draw(nx.draw_networkx_edges, **kwargs)

    def draw_graph(self,
                   node_size=200,
                   node_color=(0.4, 0.8, 0.4),
                   node_linewidth=1.0,
                   edge_width=1.0):
        # Plot nodes.
        self.draw_nodes(
            nodelist=self.nodes,
            node_size=node_size,
            node_color=node_color,
            linewidths=node_linewidth,
            zorder=20)
        # Plot edges.
        self.draw_edges(edgelist=self.edges, width=edge_width, zorder=10)

    def draw_graph_with_solution(self,
                                 node_size=200,
                                 node_color=(0.4, 0.8, 0.4),
                                 node_linewidth=1.0,
                                 edge_width=1.0,
                                 start_color="w",
                                 end_color="k",
                                 solution_node_linewidth=3.0,
                                 solution_edge_width=3.0):
        node_border_color = (0.0, 0.0, 0.0, 1.0)
        node_collections = {}
        # Plot start nodes.
        node_collections["start nodes"] = self.draw_nodes(
            nodelist=self.start_nodes,
            node_size=node_size,
            node_color=start_color,
            linewidths=solution_node_linewidth,
            edgecolors=node_border_color,
            zorder=100)
        # Plot end nodes.
        node_collections["end nodes"] = self.draw_nodes(
            nodelist=self.end_nodes,
            node_size=node_size,
            node_color=end_color,
            linewidths=solution_node_linewidth,
            edgecolors=node_border_color,
            zorder=90)
        # Plot intermediate solution nodes.
        if isinstance(node_color, dict):
            c = [node_color[n] for n in self.intermediate_solution_nodes]
        else:
            c = node_color
        node_collections["intermediate solution nodes"] = self.draw_nodes(
            nodelist=self.intermediate_solution_nodes,
            node_size=node_size,
            node_color=c,
            linewidths=solution_node_linewidth,
            edgecolors=node_border_color,
            zorder=80)
        # Plot solution edges.
        node_collections["solution edges"] = self.draw_edges(
            edgelist=self.solution_edges, width=solution_edge_width, zorder=70)
        # Plot non-solution nodes.
        if isinstance(node_color, dict):
            c = [node_color[n] for n in self.non_solution_nodes]
        else:
            c = node_color
        node_collections["non-solution nodes"] = self.draw_nodes(
            nodelist=self.non_solution_nodes,
            node_size=node_size,
            node_color=c,
            linewidths=node_linewidth,
            edgecolors=node_border_color,
            zorder=20)
        # Plot non-solution edges.
        node_collections["non-solution edges"] = self.draw_edges(
            edgelist=self.non_solution_edges, width=edge_width, zorder=10)
        # Set title as solution length.
        self._ax.set_title("Solution length: {}".format(self.solution_length))
        return node_collections