# Evaluating branchrules

In [None]:
import torch
import numpy as np

from matplotlib import pyplot as plt

SeedSequence needs a `.spawn-one` method

In [None]:
from numpy.random import default_rng, SeedSequence


def spawn_one(ss: SeedSequence) -> SeedSequence:
    return ss.spawn(1)[0]

A function to add a dict record to a table (dict-of-lists)

In [None]:
from typing import Callable


def do_add(
    record: dict, to: dict[..., list], transform: Callable[..., dict] = None
) -> dict:
    """Add the record to a transposed dict of lists."""
    original = record
    if callable(transform):
        record = transform(**record)

    # assume no fields are missing
    for field, value in record.items():
        to.setdefault(field, []).append(value)

    return original


def collate(records: list[dict]) -> dict[..., list]:
    """Collate records assuming no fields are missing"""
    out = {}
    for record in records:
        do_add(record, out)

    return {k: np.array(v) for k, v in out.items()}

Branchrules and wrappers

In [None]:
import ecole as ec

from toybnb.scip.ecole.il.data import Observation

from toybnb.scip.ecole.il.brancher import BranchRule, BranchRuleCallable

from toybnb.scip.ecole.il.brancher import strongbranch, randombranch
from toybnb.scip.ecole.il.brancher import batched_ml_branchrule, ml_branchrule

A branchrule that communicates with a central action server, that attempts to process the requests in batch for efficiency.

In [None]:
from toybnb.scip.ecole.il.brancher import BranchingServer

A procedure to seed Ecole's PRNG

In [None]:
from ecole import RandomGenerator

from functools import partial
from toybnb.scip.ecole.il.env import ecole_seed, make_env

<br>

## The data source proper

A generator of observation-action-reward data collected from the nodes of SCIP's BnB search tree at which a branching decision was made.
- SCIP has a nasty habit of intercepting and muffling Keyboard Interrupts. A workaround is to check
if the SCIP's model's status indicates termination due to a sigint.

In [None]:
from toybnb.scip.ecole.il.rollout import pool_rollout, evaluate

<br>

## Evaluation

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# use a seed sequence with a fixed entropy pool
ss = SeedSequence(None)  # use `ss.entropy` for future reproducibility
print(f"{ss.entropy = }")

Pipe the generator that mixes several the CO problems into the continupus rollout iterator.

In [None]:
from ecole.instance import CombinatorialAuctionGenerator

# CAuc(100, 500), CAuc(50, 250)
itco = CombinatorialAuctionGenerator(100, 500, rng=ecole_seed(spawn_one(ss)))

Allow for 100k samples

In [None]:
s_project_graph = "pre"  # "post"

p_drop = 0.2
n_embed, n_heads, n_blocks = 64, 4, 1
b_edges = True
b_norm_first = False

Load the model

In [None]:
from toybnb.scip.ecole.il.nn import NeuralVariableSelector

mod = NeuralVariableSelector(
    19,
    5,
    n_embed,
    n_heads,
    n_blocks,
    p_drop,
    b_norm_first=b_norm_first,
    s_project_graph=s_project_graph,
    b_edges=b_edges,
).to(device)

ckpt = torch.load("dump/cauc-norm-first.pt")
mod.load_state_dict(ckpt["state_dict"])

List all rules we want to evaluate

In [None]:
# try the branching server
server = BranchingServer(mod, device)
server.start()

In [None]:
rules = {
    "trained": ml_branchrule(mod, device),
    "server": server.connect,
    "strongbranch": strongbranch(),
    #     "pseudocostbranch": strongbranch(True),
}

Evaluate the branchrules in parallel threads

In [None]:
from tqdm import trange


it_co = CombinatorialAuctionGenerator(100, 500)

# it = map(lambda a: a[0], zip(it_co, trange(1000, ncols=70)))
factories = [partial(make_env, fork) for fork in ss.spawn(12)]
it_eval = pool_rollout(it_co, factories, rules, maxsize=24)

# collect the evaluation results
nfos = {}
for item, _ in zip(it_eval, trange(1000, ncols=70)):
    do_add(item, nfos)

metrics = {k: collate(nfo) for k, nfo in nfos.items()}

* `n_nodes`, `n_requests`
* `n_lpiter`
* `f_soltime`

<br>