<a href="https://colab.research.google.com/github/krzysztofrusek/net2vec/blob/master/jupyter_notebooks/routing_by_backprop_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install dm_haiku optax jraph
!apt install xsltproc
!git clone https://github.com/krzysztofrusek/net2vec.git

In [None]:
%%bash

curl -o sndlib-networks-xml.tgz http://sndlib.zib.de/download/sndlib-networks-xml.tgz

tar -xvkf sndlib-networks-xml.tgz
xsltproc -o janos-us.graphml net2vec/routing_by_backprop/topo/net2graphml.xslt sndlib-networks-xml/janos-us.xml

In [None]:
import itertools as it
import functools
import pickle
from typing import Tuple

import chex
import haiku as hk
import jax
import jax.numpy as jnp
import jraph
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import optax
import seaborn as sns

sns.set()


# Toy model

Let us begin with a simple toy model showing our approach.
Assume we wish to minimize the piece-wise constant function presented in Figure

Since the gradient of $f$ 0 almost everywhere, gradient methods cannot be applied.
However, we can approximate the staircase function $f(x)$ a neural network $\hat f(x)$.
The NN is differentiable, so we can surrogate the objective gradient by the gradient of the approximation $\nabla \hat f(x)$ in gradient descent.
Notice that sign of the surrogate gradient shows the correct monotonicity of the function.


In [None]:
def stair_case_function(x: chex.Array) -> chex.Array:
    p = jnp.poly(-jnp.array([-1, -0.5, 0.5, 1]))
    smooth = jnp.polyval(p, x) / 0.2
    return jnp.ceil(smooth), smooth


x = np.linspace(-1.2, 1.2, 100)
y, ysmooth = stair_case_function(x)


xtrain = np.random.uniform(low=x[0], high=x[-1], size=(128, 1))
ytrain, _ = stair_case_function(xtrain)


@hk.transform
def nn(x):
    return hk.nets.MLP([16, 1], activation=jax.nn.tanh)(x)


rng = jax.random.PRNGKey(42)

params = nn.init(rng, xtrain)
opt = optax.adam(0.01)

opt_state = opt.init(params)


def loss(params, x, y):
    yhat = nn.apply(params, None, x)
    return jnp.mean(optax.l2_loss(yhat, y))


grad_loss = jax.jit(jax.grad(loss))


@jax.jit
def step(params, opt_state, x, y):
    grads = jax.grad(loss)(params, xtrain, ytrain)
    updates, opt_state = opt.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state


for _ in range(1000):
    params, opt_state = step(params, opt_state, xtrain, ytrain)


# %%
@jax.jit
def grad_fn(params, rng, x):
    @functools.partial(jax.vmap, in_axes=(0, None, None))
    @jax.grad
    def fn(x, params, rng):
        return nn.apply(params, rng, jnp.expand_dims(x, 0))[0]

    return fn(x, params, rng)


yhat = nn.apply(params, rng, x[..., np.newaxis])[:, 0]
grads = grad_fn(params, rng, x)



fig, ax = plt.subplots( constrained_layout=True)

sns.lineplot(x=x, y=y, label='$f(x)$', ax=ax)

sns.lineplot(x=x, y=yhat, label='$\hat f(x)$', ax=ax)

ax = sns.lineplot(x=x, y=0.1 * grads, label=r'$\nabla \hat f(x)$', ax=ax)
ax.lines[2].set_linestyle("--")
plt.legend(ncol=3)



# Real Network

In [None]:
G = nx.read_graphml("janos-us.graphml")
G = nx.convert_node_labels_to_integers(G)
G = nx.DiGraph(G)
w = {k: 1. for k in G.edges}
nx.set_edge_attributes(G, w, name='weight')

all_pairs = dict(nx.all_pairs_dijkstra_path(G, weight='weight'))
G.graph['all_pairs']=all_pairs
pos =  {k:[float(v['x']),float(v['y'])] for k,v in G.nodes.data()}
nx.draw_networkx(G, pos=pos)

## Utility functions

In [None]:
def pairwise(iterable):
    "s -> (s0,s1), (s1,s2), (s2, s3), ..."
    a, b = it.tee(iterable)
    next(b, None)
    return zip(a, b)

def to_jraph(G:nx.Graph, src:int, dst:int)->Tuple[jraph.GraphsTuple, chex.Array]:
        nnodes = len(G)
        routing = G.graph['all_pairs']
        ledges = list(G.edges)
        relation_data = np.array(list(G.edges.data('weight')))

        nodes = np.zeros((nnodes, 2))
        nodes[src, 0] = 1
        nodes[dst, 1] = 1
        nodes = jnp.array(nodes)


        senders, receivers = jnp.split(relation_data[:, :2].astype(np.int32), 2, 1)
        edges = jnp.array(relation_data[:, 2:])

        x = jraph.GraphsTuple(
            n_node=jnp.asarray([nodes.shape[0]]),
            n_edge=jnp.asarray([edges.shape[0]]),
            nodes=nodes,
            edges=edges,
            senders=senders.flatten(),
            receivers=receivers.flatten(),
            globals=None
        )
        y = np.zeros((relation_data.shape[0],1))

        path = [ledges.index(e) for e in pairwise(routing[src][dst])]
        y[path,0] = 1
        return x,y



# The model

Our model maps weighted graph to soft routing representing probabilities that given link belongs to the shortest path between two given nodes

In [None]:
NUM_LAYERS = 2  # Hard-code number of layers in the edge/node/global models.
LATENT_SIZE = 128  # Hard-code latent layer sizes for demos.


class MLPModule(hk.Module):
  def __call__(self, x, is_training=False)->jnp.ndarray:
    x=hk.nets.MLP([LATENT_SIZE] * NUM_LAYERS,activate_final=True)(x)
    x = hk.LayerNorm(axis=1, create_offset=True, create_scale=True)(x)
    return x


def network_definition(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
    latent = jraph.GraphMapFeatures(MLPModule(), MLPModule())(graph)
    latent0 = latent
    outputs = []

    core = jraph.GraphNetwork(
        update_edge_fn=jraph.concatenated_args(MLPModule()),
        update_node_fn=jraph.concatenated_args(MLPModule())
    )
    decoder = jraph.GraphMapFeatures(MLPModule(), MLPModule())

    output_transform = jraph.GraphMapFeatures(
        embed_edge_fn=hk.Linear(1)
    )

    num_message_passing_steps = 8
    for _ in range(num_message_passing_steps):
        core_input = latent._replace(
            nodes=jnp.concatenate((latent.nodes, latent0.nodes), axis=1),
            edges=jnp.concatenate((latent.edges, latent0.edges), axis=1)
        )
        latent = core(core_input)
        decoded = decoder(latent)
        outputs.append(output_transform(decoded))

    return outputs


Let's load pretrained model as training from scratch would take ~30 min.

In [None]:
with open('net2vec/routing_by_backprop/log/labsim/sp1/snapshot-8000.pickle', 'rb') as f:
    params, state = pickle.load(f)


Now we can evaluate the model and compare its prediction with Dijkstra algorithm

In [None]:
src = 4
dst = 7

x,y = to_jraph(G,src,dst)

In [None]:
network = hk.without_apply_rng(hk.transform_with_state(network_definition))
hats, _ = network.apply(params, state, x)

In [None]:
p = jax.nn.sigmoid(hats[-2].edges)
fig,axs = plt.subplots(ncols=2,figsize=(12,4))

spl = list(pairwise(G.graph['all_pairs'][src][dst]))

def plot_route(G, pos, alpha, ax):
    nx.draw_networkx_nodes(G, pos, nodelist=set(G.nodes).difference(set([src,dst])),ax=ax)
    nx.draw_networkx_nodes(G, pos, nodelist=[src],node_color='green',ax=ax)
    nx.draw_networkx_nodes(G, pos, nodelist=[dst],node_color='red',ax=ax)
    nx.draw_networkx_edges(G,pos,edge_color='blue', alpha=0.2,ax=ax,style=':');
    nx.draw_networkx_edges(G,pos,edge_color=[(0.,0.,0.,np.asarray(a)) for a in alpha],ax=ax);
    nx.draw_networkx_labels(G,pos,ax=ax);

spl = list(pairwise(G.graph['all_pairs'][src][dst]))
alpha=[float(e in spl) for e in G.edges]

plot_route(G,pos,alpha,axs[0])
axs[0].set_title('Hard routing');
plot_route(G,pos,p.flatten(),axs[1])
axs[1].set_title('Soft routing');


# Optimization

For network optimization we need to evaluate our model for all possible source-destination pars.
An efficient implementation is obtained by `vmap` transformation.

In [None]:
def apply_for_pair(w: jnp.array, src: jnp.array, dst: jnp.array, params: hk.Params, state: hk.State,
                   graph_template: jraph.GraphsTuple) -> jnp.array:
    nodes = jnp.zeros_like(graph_template.nodes)
    nodes = nodes.at[src, 0].set(1)
    nodes = nodes.at[dst, 1].set(1)
    graph = graph_template._replace(nodes=nodes)
    graph = graph._replace(edges=w)
    hats, _ = network.apply(params, state, graph)
    return jax.nn.sigmoid(hats[-1].edges.flatten())

apply_for_graph = jax.vmap(apply_for_pair, in_axes=(None, 0, 0, None, None, None))

temperature = 0.1

@jax.jit
def cost_fn(x:jnp.ndarray)->float:
    '''Smooth approximation to max function'''
    scaled_x = x/temperature
    return temperature*jnp.dot(jax.nn.softmax(scaled_x),scaled_x)


def cost(w: jnp.array, src: jnp.array, dst: jnp.array, params: hk.Params, state: hk.State,
         graph_template: jraph.GraphsTuple, tm: jnp.array) -> jnp.array:
    p = apply_for_graph(w, src, dst, params, state, graph_template)
    flat_tm = tm[src, dst]
    link_traffic = flat_tm @ p
    return cost_fn(link_traffic)

def true_cost(H: nx.DiGraph, tm: np.array, src: np.array, dst: np.array, x: jraph.GraphsTuple, w: np.array):
    '''
    Compute true cos without any soft approximation
    :param H: topology
    :param tm: traffic matrix
    :param src: sender indices
    :param dst: receiver indices
    :param x: template graph
    :param w: weights
    :return: max load
    '''
    edge_attributes = {(int(s), int(r)): float(aw) for s, r, aw in zip(x.senders, x.receivers, w)}
    nx.set_edge_attributes(H, edge_attributes, 'weight')
    all_pairs = dict(nx.all_pairs_dijkstra_path(H, weight='weight'))

    hard_routing = []
    ledges = list(H.edges)
    for sd in zip(src, dst):
        s,d = jax.tree_map(int,sd)
        row = np.zeros(w.shape[0])
        path = [ledges.index(e) for e in pairwise(all_pairs[s][d])]
        row[path] = 1
        hard_routing.append(row)

    hard_routing = np.stack(hard_routing, axis=0)
    flat_tm = tm[src, dst]
    link_traffic = flat_tm @ hard_routing
    th = float(np.max(link_traffic))

    return th

@jax.jit
def update(w: jnp.array, src: jnp.array, dst: jnp.array, params: hk.Params, state: hk.State,
         graph_template: jraph.GraphsTuple, tm: jnp.array,lr: chex.Array) -> jnp.array:
    '''Gradient descent update   constrained to positive link weights'''
    grads = jax.grad(cost)(w, src, dst, params, state, graph_template, tm)
    proposal = w - lr*grads
    proposal = jnp.where(proposal < 0, 3*lr, proposal)
    return proposal



Neural network builds differentiable surrogate for discrete Dijkstra algorith.
Let's  explore the gradient of the approximation to minimize the TE function.

In [None]:
n = len(G)

srcu, dstu = np.triu_indices(n, k=1)
srcl, dstl = np.tril_indices(n, k=-1)

src = jnp.concatenate([srcu, srcl], axis=0)
dst = jnp.concatenate([dstu, dstl], axis=0)

# This controls congestion level
tmscale = 0.6

w = jnp.ones_like(x.edges)
tm = tmscale * np.random.uniform(0.0, 1.0, size=(n, n)) / (n - 1)

H = G.copy()

initialcost = true_cost(H, tm, src, dst, x, w)
result = dict(
            initialcost=initialcost,
        )

lr = 2e-2

for i in range(3):
    w = update(w, src, dst, params, state, x, tm,lr)
    result[f'finalcost_{i}'] = true_cost(H, tm, src, dst, x, w)

w=w.block_until_ready()

result

Notice how the ***true*** maximum load is reduced.