In [1]:
%cd /home/jeroen/repos/traffic-scheduling/network/
from generate_network import generate_simple_instance
from automaton import Automaton

/home/jeroen/repos/traffic-scheduling/network


In [2]:
instance = generate_simple_instance()

### Obtain expert demonstration

In [3]:
from exact import solve
y, obj = solve(instance)

We solve the instance to optimality with an exact method. Next, we use the resulting crossing time schedule to compute actions for the automaton that lead to the same schedule. However, this sequence of actions is not unique: the order in which intersections are considered does not matter for the final schedule. Therefore, we sample some intersection order and replay the sequence of actions on the automaton to generate the corresponding sequence of state-action pairs. At this point, we **copy the whole disjunctive graph** for each state.

In [4]:
from random import choice
from util import vehicle_indices

def collect_state_action_pairs(instance, schedule):
    """Collect states and actions leading to the given schedule.
    `schedule` is a dict mapping (route, order, node) tuples to crossing times.
    """
    automaton = Automaton(instance)

    # Compute the order in which vehicles pass each intersection
    # based on the order of the route to which the vehicles belong.
    route_order = {}
    indices = schedule.keys() # (r, k, v) tuples
    for v in instance['G'].intersections:
        # note the minus sign: reverse sorted order, so last route first
        route_order[v] = sorted(filter(lambda x: x[2] == v, indices), key=lambda i: -schedule[i])
        route_order[v] = list(map(lambda x: x[0], route_order[v])) # take route index
    
    actions = []
    states = [automaton.D.copy()] # initial state is empty disjunctive graph

    # keep track of which intersections still have unscheduled vehicles
    pending_intersections = list(route_order.keys())
    while len(pending_intersections) > 0:
        v = choice(pending_intersections)
        # we can now pop from back because route_order[v] was reverse sorted
        r = route_order[v].pop() 
        automaton.step(r, v)
    
        # record state-action pair
        actions.append((r, v))
        states.append(automaton.D.copy())
    
        # remove intersection if done
        if len(route_order[v]) == 0:
            pending_intersections.remove(v)

    return states, actions, automaton

Verify the reconstruction by replaying the given actions and checking whether we arrive at the same schedule again.

In [5]:
from networkx import get_node_attributes
import numpy as np

states, actions, automaton = collect_state_action_pairs(instance, y)
automaton = Automaton(instance)
for action in actions:
    automaton.step(*action)
LB = get_node_attributes(automaton.D, 'LB')
np.testing.assert_allclose(np.array(list(LB.values())), np.array(list(y.values())))

### Imitation learning with GNN policy

In [36]:
from torch_geometric.utils.convert import from_networkx

# force double during conversion
torch.set_default_dtype(torch.float64)

graphs = []
for state in states:
    graphs.append(from_networkx(state, group_node_attrs=['LB', 'done']))

num_actions = len(set(actions))

We now have the following graph classification task: map state graph to action, which is a (route, intersection) pair.

In [85]:
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import GCNConv

class GNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        node_embedding_dim = 16
        self.conv1 = GCNConv(2, 16)
        self.conv2 = GCNConv(16, node_embedding_dim)
        
        self.lin1 = Linear(node_embedding_dim, 16)
        self.lin2 = Linear(16, 1)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)

        y = self.lin1(x)
        y = F.relu(y)
        y = self.lin2(y).squeeze()
        
        return F.softmax(y, dim=0)

In [86]:
model = GNN()
model(graphs[0])

tensor([0.0307, 0.0294, 0.0329, 0.0298, 0.0305, 0.0365, 0.0297, 0.0336, 0.0382,
        0.0300, 0.0294, 0.0338, 0.0287, 0.0310, 0.0369, 0.0299, 0.0337, 0.0383,
        0.0298, 0.0295, 0.0339, 0.0379, 0.0288, 0.0314, 0.0373, 0.0410, 0.0304,
        0.0349, 0.0393, 0.0428], grad_fn=<SoftmaxBackward0>)