# Branching with Imitation Learning and a GNN

In this tutorial we will reproduce the Gasse et al. (2019) paper on learning to branch with Ecole in Tensorflow 2. We collect strong branching examples on randomly generated Combinatorial Auctions instances, then train a graph neural network with bipartite state encodings to imitate the expert by classification. Finally, we will evaluate the quality of the policy.

To avoid burdening the code too much, some simplifications were made to the code:

## 1. Data collection

Our first step will be to run explore-then-strong-branch on randomly generated Combinatorial Auctions instances, and save the branching decisions to build a dataset. We will also record the state of the branch-and-bound process as a bipartite graph, which is already implemented in Ecole with the same features as Gasse et al. (2019).

In [4]:
import gzip
import pickle
import numpy as np
import ecole
from utilities import InstanceGenerator, generate_cauctions
from pathlib import Path

We will generate Combinatorial Auctions instances on-the-fly.

In [2]:
instances = InstanceGenerator(generate_cauctions, n_items=100, n_bids=100, add_item_prob=0.7)

The explore-then-strong-branch scheme described in the paper is not implemented by default in Ecole, but we can easily write this branching rule in python, which showcasees the flexibility of Ecole.

In [3]:
class ExploreThenStrongBranch:
    def __init__(self, expert_probability):
        self.expert_probability = expert_probability
        self.pseudocosts_function = ecole.observation.Pseudocosts()
        self.strong_branching_function = ecole.observation.StrongBranchingScores()
    
    def reset(self, model):
        self.pseudocosts_function.reset(model)
        self.strong_branching_function.reset(model)
    
    def obtain_observation(self, model):
        probabilities = [1-self.expert_probability, self.expert_probability]
        expert_chosen = bool(np.random.choice(np.arange(2), p=probabilities))
        if expert_chosen:
            return (self.strong_branching_function.obtain_observation(model), True)
        else:
            return (self.pseudocosts_function.obtain_observation(model), False)

We can now create the environment with the correct parameters (no restarts, 1h time limit, 5% expert sampling probability).

In [4]:
scip_parameters = {'separating/maxrounds': 0, 'presolving/maxrestarts': 0, 'limits/time': 3600, 'timing/clocktype': 2}
observation_function = ecole.observation.TupleFunction(ExploreThenStrongBranch(expert_probability=0.05), 
                                                       ecole.observation.NodeBipartite())
env = ecole.environment.Branching(observation_function=observation_function, scip_params=scip_parameters)

We can then loop over the instances, following the strong branching expert 5% of the time and saving its decision, until 10000 samples is collected.

In [22]:
episode_counter = 0
sample_counter = 0
max_samples_reached = False
Path('samples/').mkdir(exist_ok=True)
env.seed(0)

while not max_samples_reached:
    episode_counter += 1
    with next(instances) as instance:
        observation, action_set, _, done = env.reset(instance.name)
    
    while not done:
        scores, node_observation = observation
        scores, scores_are_expert = scores
        node_observation = (node_observation.row_features,
                            (node_observation.edge_features.indices, 
                             node_observation.edge_features.values),
                            node_observation.column_features)

        action = action_set[scores[action_set].argmax()]

        if scores_are_expert and not max_samples_reached:
            sample_counter += 1
            data = [node_observation, action, action_set, scores]
            filename = f'samples/sample_{sample_counter}.pkl'

            with gzip.open(filename, 'wb') as f:
                pickle.dump(data, f)
            
            if sample_counter == 100:
                max_samples_reached = True

        observation, action_set, _, done, _ = env.step(action)
    
    if episode_counter % 25 == 0:
        print(f"Episode {episode_counter}, {sample_counter} samples collected so far")

Episode 25, 17 samples collected so far
Episode 50, 39 samples collected so far
Episode 75, 52 samples collected so far
Episode 100, 76 samples collected so far
Episode 125, 95 samples collected so far


# 2. Train a GNN

Our next step is to train a GNN classifier on these collected samples to predict similar choices to strong branching.

In [12]:
import tensorflow as tf
from gnn import GCNPolicy as GNN

LEARNING_RATE = 0.001
MAX_EPOCHS = 10
PATIENCE = 10
EARLY_STOPPING = 20

We will first define a helper function that can batch a set of samples to feed to the neural net.

In [2]:
def load_batch_gcnn(sample_files):
    """
    Loads and concatenates a bunch of samples into one mini-batch.
    """
    c_features, e_indices, e_features, v_features = [], [], [], []
    candss, cand_choices, cand_scoress = [], [], []

    # load samples
    for filename in sample_files:
        with gzip.open(filename.numpy(), 'rb') as f:
            sample = pickle.load(f)

        sample_observation, sample_action, sample_action_set, sample_scores = sample

        sample_action_set = np.array(sample_action_set)
        cand_choice = np.where(sample_action_set == sample_action)[0][0]  # action index relative to candidates
        cand_scores = sample_scores[sample_action_set]

        c, (ei, ev), v = sample_observation
        c_features.append(c)
        e_indices.append(ei)
        e_features.append(tf.expand_dims(ev, -1))
        v_features.append(v)
        candss.append(sample_action_set)
        cand_choices.append(cand_choice)
        cand_scoress.append(cand_scores)

    n_cs_per_sample = [c.shape[0] for c in c_features]
    n_vs_per_sample = [v.shape[0] for v in v_features]
    n_cands_per_sample = [cds.shape[0] for cds in candss]

    # concatenate samples in one big graph
    c_features = np.concatenate(c_features, axis=0)
    v_features = np.concatenate(v_features, axis=0)
    e_features = np.concatenate(e_features, axis=0)
    # edge indices have to be adjusted accordingly
    cv_shift = np.cumsum([[0] + n_cs_per_sample[:-1],
                          [0] + n_vs_per_sample[:-1]], axis=1)
    e_indices = np.concatenate([e_ind + cv_shift[:, j:(j+1)]
        for j, e_ind in enumerate(e_indices)], axis=1)
    # candidate indices as well
    candss = np.concatenate([cands + shift
        for cands, shift in zip(candss, cv_shift[1])])
    cand_choices = np.array(cand_choices)
    cand_scoress = np.concatenate(cand_scoress, axis=0)

    # convert to tensors
    c_features = tf.convert_to_tensor(c_features, dtype=tf.float32)
    e_indices = tf.convert_to_tensor(e_indices, dtype=tf.int32)
    e_features = tf.convert_to_tensor(e_features, dtype=tf.float32)
    v_features = tf.convert_to_tensor(v_features, dtype=tf.float32)
    n_cs_per_sample = tf.convert_to_tensor(n_cs_per_sample, dtype=tf.int32)
    n_vs_per_sample = tf.convert_to_tensor(n_vs_per_sample, dtype=tf.int32)
    candss = tf.convert_to_tensor(candss, dtype=tf.int32)
    cand_choices = tf.convert_to_tensor(cand_choices, dtype=tf.int32)
    cand_scoress = tf.convert_to_tensor(cand_scoress, dtype=tf.float32)
    n_cands_per_sample = tf.convert_to_tensor(n_cands_per_sample, dtype=tf.int32)

    return c_features, e_indices, e_features, v_features, n_cs_per_sample, n_vs_per_sample, \
            n_cands_per_sample, candss, cand_choices, cand_scoress


def load_batch_tf(x):
    return tf.py_function(load_batch_gcnn, [x], [tf.float32, tf.int32, tf.float32, tf.float32, 
                                tf.int32, tf.int32, tf.int32, tf.int32, tf.int32, tf.float32])

We can then prepare the data loaders.

In [5]:
sample_files = [str(path) for path in Path('samples/').glob('sample_*.pkl')]
train_files = sample_files[:int(0.7*len(sample_files))]
valid_files = sample_files[int(0.7*len(sample_files)):int(0.85*len(sample_files))]
test_files = sample_files[int(0.85*len(sample_files)):]

train_data = tf.data.Dataset.from_tensor_slices(train_files)
train_data = train_data.batch(32)
train_data = train_data.map(load_batch_tf)
train_data = train_data.prefetch(1)

valid_data = tf.data.Dataset.from_tensor_slices(valid_files)
valid_data = valid_data.batch(128)
valid_data = valid_data.map(load_batch_tf)
valid_data = valid_data.prefetch(1)

Next, we will define a helper function to train or evaluate the model on a whole epoch, and compute metrics for monitoring.

In [13]:
def process(model, dataloader, optimizer=None):
    mean_loss = 0
    mean_kacc = np.zeros(len([1, 3, 5, 10]))
    
    @tf.function(input_signature=model.input_signature)
    def forward(inputs, training):
        return model.call(inputs, training)

    n_samples_processed = 0
    for batch in dataloader:
        c, ei, ev, v, n_cs, n_vs, n_cands, cands, best_cands, cand_scores = batch
        batched_states = (c, ei, ev, v, tf.reduce_sum(n_cs, keepdims=True), tf.reduce_sum(n_vs, keepdims=True))  # prevent padding
        batch_size = len(n_cs.numpy())

        if optimizer:
            with tf.GradientTape() as tape:
                logits = forward(batched_states, tf.convert_to_tensor(True)) # training mode
                logits = tf.expand_dims(tf.gather(tf.squeeze(logits, 0), cands), 0)  # filter candidate variables
                logits = model.pad_output(logits, n_cands.numpy())  # apply padding now
                loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=best_cands, logits=logits)
                loss = tf.reduce_mean(loss)
            grads = tape.gradient(target=loss, sources=model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
        else:
            logits = forward(batched_states, tf.convert_to_tensor(False))  # eval mode
            logits = tf.expand_dims(tf.gather(tf.squeeze(logits, 0), cands), 0)  # filter candidate variables
            logits = model.pad_output(logits, n_cands.numpy())  # apply padding now
            loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=best_cands, logits=logits)
            loss = tf.reduce_mean(loss)

        true_scores = model.pad_output(tf.reshape(cand_scores, (1, -1)), n_cands)
        true_bestscore = tf.reduce_max(true_scores, axis=-1, keepdims=True)
        true_scores = true_scores.numpy()
        true_bestscore = true_bestscore.numpy()

        kacc = []
        for k in [1, 3, 5, 10]:
            pred_top_k = tf.nn.top_k(logits, k=k)[1].numpy()
            pred_top_k_true_scores = np.take_along_axis(true_scores, pred_top_k, axis=1)
            kacc.append(np.mean(np.any(pred_top_k_true_scores == true_bestscore, axis=1)))
        kacc = np.asarray(kacc)

        mean_loss += loss.numpy() * batch_size
        mean_kacc += kacc * batch_size
        n_samples_processed += batch_size

    mean_loss /= n_samples_processed
    mean_kacc /= n_samples_processed

    return mean_loss, mean_kacc

Finally, we can actually create the model and train it.

In [14]:
model = GNN()

lr = LEARNING_RATE
optimizer = tf.keras.optimizers.Adam(learning_rate=lambda: lr)
best_loss = np.inf
for epoch in range(MAX_EPOCHS + 1):
    train_loss, train_kacc = process(model, train_data, optimizer)
    print(f"TRAIN LOSS: {train_loss:0.3f} " + "".join([f" acc@{k}: {acc:0.3f}" for k, acc in zip([1, 3, 5, 10], train_kacc)]))

    valid_loss, valid_kacc = process(model, valid_data, None)
    print(f"VALID LOSS: {valid_loss:0.3f} " + "".join([f" acc@{k}: {acc:0.3f}" for k, acc in zip([1, 3, 5, 10], valid_kacc)]))

    if valid_loss < best_loss:
        plateau_count = 0
        best_loss = valid_loss
        model.save_state('trained_params.pkl')
        print(f"  best model so far")
    else:
        plateau_count += 1
        if plateau_count % EARLY_STOPPING == 0:
            print(f"  {plateau_count} epochs without improvement, early stopping")
            break
        if plateau_count % PATIENCE == 0:
            lr *= 0.2
            print(f"  {plateau_count} epochs without improvement, decreasing learning rate to {lr}")

model.restore_state('trained_params.pkl')
valid_loss, valid_kacc = process(model, valid_data, None)
print(f"BEST VALID LOSS: {valid_loss:0.3f} " + "".join([f" acc@{k}: {acc:0.3f}" for k, acc in zip([1, 3, 5, 10], valid_kacc)]))

TRAIN LOSS: 3.613  acc@1: 0.243 acc@3: 0.471 acc@5: 0.557 acc@10: 0.714
VALID LOSS: 3.491  acc@1: 0.133 acc@3: 0.333 acc@5: 0.467 acc@10: 0.667
  best model so far
TRAIN LOSS: 3.532  acc@1: 0.229 acc@3: 0.357 acc@5: 0.486 acc@10: 0.743
VALID LOSS: 3.453  acc@1: 0.133 acc@3: 0.333 acc@5: 0.400 acc@10: 0.667
  best model so far
TRAIN LOSS: 3.442  acc@1: 0.214 acc@3: 0.329 acc@5: 0.514 acc@10: 0.743
VALID LOSS: 3.403  acc@1: 0.133 acc@3: 0.333 acc@5: 0.400 acc@10: 0.667
  best model so far
TRAIN LOSS: 3.329  acc@1: 0.214 acc@3: 0.357 acc@5: 0.529 acc@10: 0.743
VALID LOSS: 3.352  acc@1: 0.200 acc@3: 0.333 acc@5: 0.400 acc@10: 0.667
  best model so far
TRAIN LOSS: 3.202  acc@1: 0.243 acc@3: 0.386 acc@5: 0.571 acc@10: 0.757
VALID LOSS: 3.266  acc@1: 0.267 acc@3: 0.400 acc@5: 0.467 acc@10: 0.667
  best model so far
TRAIN LOSS: 3.082  acc@1: 0.271 acc@3: 0.471 acc@5: 0.586 acc@10: 0.800
VALID LOSS: 3.192  acc@1: 0.400 acc@3: 0.600 acc@5: 0.600 acc@10: 0.733
  best model so far
TRAIN LOSS: 2.99