In [1]:
%cd /home/jeroen/repos/traffic-scheduling/network/
from generate_network import generate_simple_instance
from automaton import Automaton

/home/jeroen/repos/traffic-scheduling/network


### Obtain expert demonstration

In [2]:
from exact import solve
instance = generate_simple_instance()
y, obj = solve(instance)

We solve the instance to optimality with an exact method. Next, we use the resulting crossing time schedule to compute actions for the automaton that lead to the same schedule. However, this sequence of actions is not unique: the order in which intersections are considered does not matter for the final schedule. Therefore, we sample some intersection order and replay the sequence of actions on the automaton to generate the corresponding sequence of state-action pairs. At this point, we **copy the whole disjunctive graph** for each state. Alternatively, we could use some sort of masking for non-final states.

In [3]:
from random import choice
from util import vehicle_indices

def collect_state_action_pairs(instance, schedule):
    """Collect states and actions leading to the given schedule.
    `schedule` is a dict mapping (route, order, node) tuples to crossing times.
    """
    automaton = Automaton(instance)

    # Compute the order in which vehicles pass each intersection
    # based on the order of the route to which the vehicles belong.
    route_order = {}
    indices = schedule.keys() # (r, k, v) tuples
    for v in instance['G'].intersections:
        # note the minus sign: reverse sorted order, so last route first
        route_order[v] = sorted(filter(lambda x: x[2] == v, indices), key=lambda i: -schedule[i])
        route_order[v] = list(map(lambda x: x[0], route_order[v])) # take route index
    
    actions = []
    states = [automaton.D.copy()] # initial state is empty disjunctive graph

    # keep track of which intersections still have unscheduled vehicles
    pending_intersections = list(route_order.keys())
    while len(pending_intersections) > 0:
        v = choice(pending_intersections)
        # we can now pop from back because route_order[v] was reverse sorted
        r = route_order[v].pop() 
        automaton.step(r, v)
    
        # record action:
        # Instead of (r,v) pairs, which are used in the automaton, we
        # use the full (r,k,v) tuple, like in Zhang et al., where they
        # use the full operation a_t = O_{ij} as action.
        k = automaton.last_scheduled[r, v]
        actions.append((r, k, v))
        # record state by copying current disjunctive graph
        states.append(automaton.D.copy())
    
        # remove intersection if done
        if len(route_order[v]) == 0:
            pending_intersections.remove(v)

    return states, actions, automaton

Verify the reconstruction by replaying the given actions and checking whether we arrive at the same schedule again.

In [6]:
from networkx import get_node_attributes
import numpy as np

states, actions, automaton = collect_state_action_pairs(instance, y)
automaton = Automaton(instance)
for action in actions:
    r, k, v = action
    automaton.step(r, v)
LB = get_node_attributes(automaton.D, 'LB')
np.testing.assert_allclose(np.array(list(LB.values())), np.array(list(y.values())))

### Create training data

We solve a couple of instances and collect all the state-action pairs in a single dataset to support mini-batching via the `DataLoader` class.

In [7]:
import torch
from torch_geometric.loader import DataLoader
from torch_geometric.utils.convert import from_networkx
from util import vehicle_indices, route_indices

# force double during conversion
torch.set_default_dtype(torch.float64)

# number of valid actions is number of intersection-route pairs
instance = generate_simple_instance()
automaton = Automaton(instance)
# actions are all route-intersection pairs
valid_actions = list(automaton.D.nodes)
num_actions = len(valid_actions)

# TODO: generate single network
# TODO: generate different vehicle arrivals for this same network

def generate_data(N):
    """Generate set graphs based on solving N problem instances to optimality."""
    graphs = []
    for _ in range(N):
        instance = generate_simple_instance()
        y, obj = solve(instance)
    
        states, actions, _ = collect_state_action_pairs(instance, y)
        for state, action in zip(states, actions):        
            graph = from_networkx(state, group_node_attrs=['LB', 'done'])
    
            
            graph.action = valid_actions.index(action) # map to integers
            
            graphs.append(graph)
    return graphs

Obtain and inspect a single batch:

In [8]:
data = DataLoader(generate_data(10), batch_size=2, shuffle=True)
next(iter(data))

DataBatch(edge_index=[2, 93], label=[2], action_mask=[60], weight=[93], x=[60, 2], action=[2], batch=[60], ptr=[3])

What are the following attributes:

- label
- weight
- batch: index of graph to which this node belongs
- ptr

**Assumption**: node order (see `batch.label`) is the same among all state graphs.

### Imitation learning with GNN policy

We now have the following classification task: map disjunctive **graph** to an **action** (route-intersection pair). We use a GIN to compute an embedding for each node, which is fed through an MLP and softmax to produce a probability over nodes. In Zhang et al., each action corresponds to a unique node, encoding the operations that is dispatched next. However, we only really need to provide a route-intersection pair, but **how to exploit this in the policy model**?

In [47]:
import torch.nn.functional as F
from torch.nn import Sequential, Linear
from torch_geometric.nn import GCNConv, GINConv

class GNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # we need a separate layer for the first iteration, because the initial
        # feature dimension in different from the node embedding dimension
        lin0 = Sequential(Linear(2, 32))
        self.gin0 = GINConv(lin0, train_eps=True)

        lin = Sequential(Linear(32, 32))
        self.gin = GINConv(lin, train_eps=True)
        
        self.lin1 = Linear(32, 32)
        self.lin2 = Linear(32, 1)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.gin0(x, edge_index)
        x = F.relu(x)
        for _ in range(1): # the rest of the K-1 iterations
            x = self.gin(x, edge_index)

        x = self.lin1(x)
        x = F.relu(x)
        x = self.lin2(x)

        return x.squeeze()

model = GNN()

The GNN computes node embeddings, which are mapped to a score for each node. We compute the softmax over the scores of the nodes and then compute the negative log likelihood loss for backpropagation.

In [54]:
from torch.nn import CrossEntropyLoss
from torch_geometric.utils import softmax
import torch.nn.functional as F
import torch.optim as optim

def stacked_batch(batch, indices):
    """Transform flat batch of dimension (N*B), where N is number of graphs
    in batch and B is number of nodes in graph, to a stacked batch of
    dimension (N, B), based on batch indices `batch.batch`."""
    unique = torch.unique(indices)
    return torch.vstack([batch[indices == i] for i in unique])

N = 500 # number of instances
data_train = DataLoader(generate_data(N), batch_size=10, shuffle=True)
learning_rate = 1e-3
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
epochs = 15
    
print("\ntraining model\n")
model.train()
for i in range(epochs):
    loss_total = 0
    print(f'epoch: {i}')
    for batch in data_train:
        optimizer.zero_grad()

        # compute node scores
        y = model(batch)
        # softmax over node scores per graph, using batch indices
        y = softmax(y, batch.batch)
        # restack for loss calculation
        pred = stacked_batch(y, batch.batch)

        target = batch.action
        loss = F.cross_entropy(pred, target)
        loss.backward()
        loss_total += loss
        optimizer.step()
    print(f"loss: {loss_total.item()}")


training model

epoch: 0
loss: 1873.8948934434998
epoch: 1
loss: 1874.3125676702848
epoch: 2
loss: 1869.5252868922591
epoch: 3
loss: 1863.432925770807
epoch: 4
loss: 1859.0543117467428
epoch: 5
loss: 1862.4085723285118
epoch: 6
loss: 1855.2669235564033
epoch: 7
loss: 1854.8674736717348
epoch: 8
loss: 1854.5387522919525
epoch: 9
loss: 1853.0164628663333
epoch: 10
loss: 1859.0068659607434
epoch: 11
loss: 1853.9021937203777
epoch: 12
loss: 1853.012929535769
epoch: 13
loss: 1853.1915699958354
epoch: 14
loss: 1852.6229600303213


### Evaluate imitation

In [11]:
from torch import argmax

def evaluate_imitation(model, N=500):
    """Measure accuracy based on unseen expert demonstration state-action pairs."""
    print("\nevaluating imitation accuracy\n")
    model.eval()
    data_test = DataLoader(generate_data(N))
    total_correct = 0
    for batch in data_test:
        # compute node scores
        y = model(batch)
        # softmax over node scores per graph, using batch indices
        y = softmax(y, batch.batch)
        # restack for loss calculation
        pred = stacked_batch(y, batch.batch)
        pred = argmax(pred, dim=1)
        target = batch.action
        total_correct += pred == target
    print(f"accuracy: {total_correct.item() / len(data_test)}")

In [49]:
evaluate_imitation(model)


evaluating imitation accuracy

accuracy: 0.3215


### Evaluate scheduling

Current definition of objective in `exact.py` is total sum of crossing times, including at exit points.

In [66]:
from torch import masked_select

def evaluate_scheduling(model, N=100):
    """Evaluate average objective when executing the policy over full
    unseen problem instances compared to average optimal objective."""
    print("\nevaluating policy\n")
    model.eval()
    obj_opt = 0
    obj_model = 0
    for _ in range(N):
        instance = generate_simple_instance()

        # solve optimally
        y, obj = solve(instance)
        obj_opt += obj

        # TEST objective definition
        # _, actions, _ = collect_state_action_pairs(instance, y)

        # execute learned heuristic
        automaton = Automaton(instance)
        while not automaton.done:
            state = automaton.D
            graph = from_networkx(state, group_node_attrs=['LB', 'done'])
            # compute node scores
            y = model(graph)
            # mask valid actions (set to -inf)
            y = y.masked_fill(~graph.action_mask.bool(), -torch.inf)
            y = argmax(y)
            action = valid_actions[y]
            r, k, v = action

            # transform to valid action
            automaton.step(r, v)

        # TEST objective definition
        # for action in actions:
        #     r, k, v = action
        #     automaton.step(r, v)

        # compute obj from automaton
        obj_model += automaton.get_obj()

    print(f"obj_opt={obj_opt / N} vs obj_model={obj_model / N}")

In [67]:
evaluate_scheduling(model)


evaluating policy

obj_opt=312.57050607140025 vs obj_model=322.8857129350818
