# Evaluating branchrules

In [None]:
import torch
import numpy as np

from matplotlib import pyplot as plt

SeedSequence needs a `.spawn-one` method

In [None]:
from numpy.random import default_rng, SeedSequence


def spawn_one(ss: SeedSequence) -> SeedSequence:
    return ss.spawn(1)[0]

A function to add a dict record to a table (dict-of-lists)

In [None]:
from typing import Callable


def do_add(
    record: dict, to: dict[..., list], transform: Callable[..., dict] = None
) -> dict:
    """Add the record to a transposed dict of lists."""
    original = record
    if callable(transform):
        record = transform(**record)

    # assume no fields are missing
    for field, value in record.items():
        to.setdefault(field, []).append(value)

    return original


def collate(records: list[dict]) -> dict[..., list]:
    """Collate records assuming no fields are missing"""
    out = {}
    for record in records:
        do_add(record, out)

    return {k: np.array(v) for k, v in out.items()}

Branchrules and wrappers

In [None]:
import ecole as ec

from toybnb.scip.ecole.il.data import Observation

from toybnb.scip.ecole.il.brancher import BranchRule, BranchRuleCallable

from toybnb.scip.ecole.il.brancher import strongbranch, randombranch
from toybnb.scip.ecole.il.brancher import batched_ml_branchrule, ml_branchrule

A branchrule that communicates with a central action server, that attempts to process the requests in batch for efficiency.

In [None]:
from toybnb.scip.ecole.il.threads import BatchProcessor

from ecole.core.scip import Stage
from ecole.environment import Branching


class BranchingServer(BatchProcessor):
    """Branching variable Server"""

    def connect(self, env: Branching, config: dict = None) -> BranchRuleCallable:
        """Spawn a new branchrule"""
        co_yield = super().connect()

        def _branchrule(obs: Observation) -> int:
            if env.model.stage != Stage.Solving:
                return None

            return int(co_yield(obs))

        return _branchrule

A procedure to seed Ecole's PRNG

In [None]:
from ecole import RandomGenerator

from toybnb.scip.ecole.il.env import ecole_seed

<br>

## The data source proper

A generator of observation-action-reward data collected from the nodes of SCIP's BnB search tree at which a branching decision was made.
- SCIP has a nasty habit of intercepting and muffling Keyboard Interrupts. A workaround is to check
if the SCIP's model's status indicates termination due to a sigint.

In [None]:
from toybnb.scip.ecole.il.rollout import rollout

The inner logic of the parallel job feeder

In [None]:
from typing import Iterable
from queue import Empty, Queue, Full
from threading import Event


def t_feed(
    it: Iterable,
    to: Queue,
    *,
    signal: Event,
    err: Queue,
    timeout: float = 1.0,
) -> None:
    """Keep putting items from the iterable into the queue, until stopped"""
    try:
        item = next(it)
        while not signal.is_set():
            try:
                to.put(item, True, timeout)
            except Full:
                continue

            item = next(it)

    except StopIteration:
        pass

    except BaseException as e:
        err.put_nowait(e)

The body of a parallel rollout worker

In [None]:
from typing import Callable


def t_evaluate(
    feed: Queue,
    factory: Callable,
    branchrules: tuple[BranchRule],
    into: Queue,
    *,
    signal: Event,
    err: Queue,
    timeout: float = 1.0,
) -> None:
    """Evaluate the `branchrule` on instances from `feed` solved in `factory`,
    saving the observations in `into`
    """
    try:
        # get the env
        env = factory()
        while not signal.is_set():
            # poll the feed queue for a new job
            try:
                p = feed.get(True, timeout)
            except Empty:
                continue

            out = {}
            # do a rollout on this instance with each branchrule
            for name, rule in branchrules.items():
                try:
                    it = rollout(p, env, rule, {}, stop=signal.is_set)
                    while True:
                        next(it)

                except StopIteration as e:
                    # save the final `nfo` data from the branching env,
                    #  as it contains the post-search tree stats
                    out[name] = e.value

            # send the evaluation results of all branchrules
            while not signal.is_set():
                try:
                    into.put(out, True, timeout)
                except Full:
                    continue

                break

    except BaseException as e:
        err.put_nowait(e)

The procedure itself

In [None]:
from threading import Thread
from functools import partial
from toybnb.scip.ecole.il.env import make_env


def maybe_raise(err: Queue) -> None:
    """Raise if the error queue has an exception"""
    with err.mutex:
        if err.queue:
            raise err.queue.popleft()


def multievaluate(
    feed: Iterable,
    ss: SeedSequence,
    branchrules: dict[str, BranchRule],
    n_jobs: int = 8,
) -> Iterable:
    ctx = dict(signal=Event(), err=Queue(), timeout=1.0)

    # spawn feed thread and the workers
    feed_q, rollout_q = Queue(128), Queue(128)
    threads = [Thread(target=t_feed, args=(feed, feed_q), kwargs=ctx, daemon=True)]
    for fork in ss.spawn(n_jobs):
        args = feed_q, partial(make_env, fork), branchrules, rollout_q
        threads.append(Thread(target=t_evaluate, args=args, kwargs=ctx, daemon=True))

    for t in threads:
        t.start()

    try:
        # the main thread yields results from the rollout output queue
        while True:
            maybe_raise(ctx["err"])
            try:
                yield rollout_q.get(True, timeout=5.0)

            except Empty:
                continue

    finally:
        ctx["signal"].set()
        for t in threads:
            t.join()

<br>

## Evaluation

In [None]:
# use a seed sequence with a fixed entropy pool
ss = SeedSequence(None)  # use `ss.entropy` for future reproducibility
print(f"{ss.entropy = }")

Pipe the generator that mixes several the CO problems into the continupus rollout iterator.

In [None]:
from ecole.instance import CombinatorialAuctionGenerator

# CAuc(100, 500), CAuc(50, 250)
itco = CombinatorialAuctionGenerator(100, 500, rng=ecole_seed(spawn_one(ss)))

Allow for 100k samples

In [None]:
s_project_graph = "pre"  # "post"

p_drop = 0.2
n_embed, n_heads, n_blocks = 32, 1, 1
b_edges = True
b_norm_first = False

Load the model

In [None]:
from toybnb.scip.ecole.il.nn import NeuralVariableSelector

mod = NeuralVariableSelector(
    19,
    5,
    n_embed,
    n_heads,
    n_blocks,
    p_drop,
    b_norm_first=b_norm_first,
    s_project_graph=s_project_graph,
    b_edges=b_edges,
)

ckpt = torch.load("dump/eighth-cauc-pseudocost.pt")
mod.load_state_dict(ckpt["state_dict"])

List all rules we want to evaluate

In [None]:
rules = {
    "trained": ml_branchrule(mod),
    "strongbranch": strongbranch(),
    #     "pseudocostbranch": strongbranch(True),
}

Evaluate the branchrules in parallel threads

In [None]:
from tqdm import trange

it_co = CombinatorialAuctionGenerator(100, 500)

# it = map(lambda a: a[0], zip(it_co, trange(1000, ncols=70)))
it_eval = multievaluate(it_co, ss, rules, n_jobs=4)

# collect the evaluation results
nfos = {}
for item, _ in zip(it_eval, trange(1000, ncols=70)):
    do_add(item, nfos)

metrics = {k: collate(nfo) for k, nfo in nfos.items()}

* `n_nodes`, `n_requests`
* `n_lpiter`
* `f_soltime`

<br>