# A dqn agent

In [None]:
import torch
import numpy as np

from matplotlib import pyplot as plt

The pp-curve drawing procedure. We compare distibutions using the pp-curve, which is analogous to the ROC curve: pp compares two independent distributions, while ROC compares the true-, false- positive distributions

In [None]:
from numpy import ndarray


def pp_curve(*, x: ndarray, y: ndarray, num: int = None) -> tuple[ndarray, ndarray]:
    """Build threshold-parameterized pipi curve."""
    # sort each sample for fast O(\log n) eCDF queries by `searchsorted`
    x, y = np.sort(x), np.sort(y)

    # pool sorted samples to get thresholds
    xy = np.concatenate((x, y))
    if num is None:
        # finest detail thresholds: sort the pooled samples (sorted
        #  arrays can be merged in O(n), but it turns out numpy does
        #  not have the procedure)
        xy.sort()

    else:
        # coarsen by finding threshold grid in the pooled sample, that
        #  is equispaced after being transformed by the empirical cdf.
        xy = np.quantile(xy, np.linspace(0, 1, num=num), method="linear")

    # add +ve/-ve inf end points to the parameter value sequence
    xy = np.r_[-np.inf, xy, +np.inf]

    # we build the pp-curve the same way as we build the ROC curve:
    #  by parameterizing with the a monotonic threshold sequence
    #    pp: v \mapsto (\hat{F}_x(v), \hat{F}_y(v))
    #  where \hat{F}_S(v) = \frac1{n_S} \sum_j 1_{S_j \leq v}
    p = np.searchsorted(x, xy) / len(x)
    q = np.searchsorted(y, xy) / len(y)

    return p, q

A simple viz for tracking loss and other runtime series

In [None]:
from matplotlib.axes import Axes


def plot_stats(ax: Axes = None, n_last: int = 25, **series) -> None:
    ax = plt.gca() if ax is None else ax

    els = {}
    for name, x in series.items():
        (el,) = ax.plot(x, label=name)
        # add the average estimate tick to the right-hand side
        # XXX throws a warning on all-nan slices
        avg, col = np.nanmean(x[-n_last:]), el.get_color()
        ax.axhline(
            avg,
            0.975,
            c=col,
            alpha=0.25,
            zorder=-10,
        )
        ax.annotate(
            f"{avg:.2g}",
            c=col,
            fontsize="xx-small",
            xy=(1.005, avg),
            xytext=(0.0, -2.0),
            xycoords=("axes fraction", "data"),
            textcoords="offset points",
            zorder=-10,
        )
        els[name] = el

    ax.legend(els.values(), series, loc="best", fontsize="x-small", ncol=3)

    return els

SeedSequence needs a `.spawn-one` method

In [None]:
import numpy as np
from numpy.random import default_rng, SeedSequence


def spawn_one(ss: SeedSequence) -> SeedSequence:
    return ss.spawn(1)[0]

A function to add a dict record to a table (dict-of-lists)

In [None]:
from typing import Callable


def do_add(
    record: dict, to: dict[..., list], transform: Callable[..., dict] = None
) -> dict:
    """Add the record to a transposed dict of lists."""
    original = record
    if callable(transform):
        record = transform(**record)

    # assume no fields are missing
    for field, value in record.items():
        to.setdefault(field, []).append(value)

    return original

<br>

## Composing the datastreams

Reservoir sampling for data from infinite data streams 

In [None]:
from typing import Iterable


def shuffle(
    it: Iterable,
    n_size: int = 1024,
    seed: int = None,
) -> Iterable:
    """Shuffle the values from the iterable"""

    rng, reservoir = default_rng(seed), []
    for sample in it:
        # stack elements until the reservior is full
        if len(reservoir) < n_size:
            reservoir.append(sample)
            continue

        # replace a random element with the new sample
        ix = rng.choice(n_size)
        yield reservoir[ix]
        reservoir[ix] = sample

    # re-shuffle the remaining samples (in the case
    #  the sequence was too short for proper mixing)
    rng.shuffle(reservoir)
    yield from reservoir

Batch the data in a sequence

In [None]:
def batch(it: Iterable, n_size: int = 16) -> Iterable[list]:
    """Batch the values from the iterable"""
    batch = []
    for sample in it:
        batch.append(sample)

        # produce the batch when it's full and then clear
        if len(batch) >= n_size:
            yield batch
            batch.clear()

    # don't forget the residual batch
    if batch:
        yield batch

a handy sequence limiter

In [None]:
from typing import Union


def limit(it: Iterable, limiter: Union[int, Iterable]) -> Iterable:
    """Limit the length of the sequence to at most `n_total` values"""
    limiter = range(limiter) if isinstance(limiter, int) else limiter
    for sample, _ in zip(it, limiter):
        yield sample

finally, a generator that mixes data from multiple iterables

In [None]:
def mixer(*its: Iterable, seed: int = None) -> Iterable:
    """Yield a value from an iterable picked at random each time"""
    iters = list(map(iter, its))

    rng = default_rng(seed)
    while iters:
        it = rng.choice(iters, shuffle=False)
        try:
            yield next(it)

        except StopIteration:
            iters.remove(it)

<br>

## CO and Branching data sources

The following SCIP settings were inherited from [Gasse et al. 2019](), [Parsonson et al. 2022](), and
[Scavuzzo et al. 2022]().

In [None]:
import pyscipopt
import ecole as ec


def default_scip_params() -> dict:
    #
    return {
        # although we use ecole's `.disable_presolve()`, we still keep these params
        "separating/maxrounds": 0,  # separate (cut) only at root node
        "presolving/maxrestarts": 0,  # disable solver restarts
        # determines scip's inner clock and affects the time limit
        "timing/clocktype": 1,  # 1: CPU user seconds, 2: wall clock time
        "limits/time": 60 * 60,  # solver time limit
    }

A derived branching env that disables SCIP's presolver
- without presolve training becomes much slower

In [None]:
from ecole.environment import Branching


class BranchingWithoutPresolve:
    def reset(self, instance, *dynamics_args, **dynamics_kwargs):
        # disable presolve through ecole
        # XXX [.disable_presolve](./libecole/src/scip/model.cpp#L195)
        #  calls [SCIPsetPresolving](./src/scip/scip_params.c#L913-937)
        instance = instance.copy_orig()
        instance.disable_presolve()

        return super().reset(instance, *dynamics_args, **dynamics_kwargs)

SCIP intercepts the keyboard interrput, making it impossible to abort a seemingly stuck loop. We intercept the signal and delay it until the endo fo the `with` scope.

* for `CAuc(100, 500)` the bipartite obs takes up about `260kb`, which means that precomputin 100k samples is out of question

In [None]:
import signal


class DelaySIGINT:
    def __init__(self, disable: bool = False) -> None:
        self.disable = disable

    def __bool__(self):
        return self.signal is not None

    def __enter__(self):
        self.signal = None
        if self.disable:
            return self

        self.old_handler = signal.getsignal(signal.SIGINT)
        signal.signal(signal.SIGINT, self.handler)
        return self

    def handler(self, sig, frame):
        self.signal = sig, frame

    def __exit__(self, type, value, traceback):
        if not self.disable:
            signal.signal(signal.SIGINT, self.old_handler)
            if self.signal is not None:
                self.old_handler(*self.signal)

We want to train a problem-dependent branching heuristic. Strong branching and pseudocost have the best problem dependency -- they have access to its geometry, but are slow. Our goal is to train a neural net that extract geometric info from the parametric representation of a problem at a node.

In [None]:
from ecole.observation import NodeBipartite
from ecole.reward import Constant


def make_env(
    entropy: int = None, presolve: bool = True, scip_params: dict = np._NoValue
) -> Branching:
    # fork the seed sequence from the given entropy
    ss = entropy if isinstance(entropy, SeedSequence) else SeedSequence(entropy)

    # allow for `true None` scip-params
    if scip_params is np._NoValue:
        scip_params = default_scip_params()

    # choose the env
    cls = Branching if presolve else BranchingWithoutPresolve

    # the branching env
    env = cls(
        # We use bipartite graph repr of the node's LP
        observation_function=NodeBipartite(),
        # No reward function at, since we imitate an expert
        # reward_function=ec.reward.PrimalDualIntegral(),
        reward_function=Constant(float("nan")),
        # we track the aggregate tree stats
        # XXX not sure if nnodes is a `clean` metric
        information_function={
            "n_nodes": ec.reward.NNodes().cumsum(),
            "n_lpiter": ec.reward.LpIterations().cumsum(),
            "f_soltime": ec.reward.SolvingTime().cumsum(),
            # 'primal_integral': ec.reward.PrimalIntegral().cumsum(),
            # 'dual_integral': ec.reward.DualIntegral().cumsum(),
            # 'primal_dual_integral': ec.reward.PrimalDualIntegral(),
        },
        scip_params=scip_params,
    )

    # `RandomGenerator.max_seed` reports 2^{32}-1
    (seed,) = ss.generate_state(1, dtype=np.uint32)
    env.seed(int(seed))

    return env

We use a special representation of a batch of observations

In [None]:
from typing import NamedTuple
from ecole.observation import NodeBipartiteObs
from torch import Tensor
from numpy import ndarray


class Observation(NamedTuple):
    obs: NodeBipartiteObs
    actset: ndarray


class BatchObservation(NamedTuple):
    # variable, constraint, and cons-to-vars link features
    vars: Tensor
    cons: Tensor

    ctov_v: Tensor
    ctov_ij: Tensor

    actset: Tensor

    # the first index in the collated batch of the span of data
    #  originating from each uncolated batch element
    ptr_vars: Tensor
    ptr_cons: Tensor
    ptr_ctov: Tensor
    ptr_actset: Tensor

    # batch affinity
    inx_vars: Tensor
    inx_cons: Tensor
    inx_ctov: Tensor
    inx_actset: Tensor

A special collation logic for `Observation`

In [None]:
from torch import as_tensor


def collate(batch: tuple[Observation], device: torch.device = None) -> BatchObservation:
    """Collate `NodeBipartiteObs` into torch tensors"""
    obs, actset = zip(*batch)

    # prepare vars
    vars = [x.variable_features for x in obs]
    n_vars = sum(map(len, vars))
    x_vars = torch.empty(
        (n_vars, *vars[0].shape[1:]),
        dtype=torch.float32,
        device=device,
    )

    ptr_vars = x_vars.new_zeros(1 + len(vars), dtype=torch.long)
    inx_vars = x_vars.new_empty(n_vars, dtype=torch.long)

    # prepare cons
    cons = [x.row_features for x in obs]
    n_cons = sum(map(len, cons))
    x_cons = x_vars.new_empty((n_cons,) + cons[0].shape[1:])

    ptr_cons = x_cons.new_zeros(1 + len(cons), dtype=torch.long)
    inx_cons = x_cons.new_empty(n_cons, dtype=torch.long)

    # prepare edges (coo ijv, cons-to-vars)
    ctov = [x.edge_features for x in obs]
    n_ctov = sum(len(e.values) for e in ctov)
    x_ctov_v = x_vars.new_empty(n_ctov)
    x_ctov_ij = x_vars.new_empty((2, n_ctov), dtype=torch.long)

    ptr_ctov = x_ctov_v.new_zeros(1 + len(ctov), dtype=torch.long)
    inx_ctov = x_ctov_v.new_empty(n_ctov, dtype=torch.long)

    # prepare the collated action set
    n_actset = sum(map(len, actset))
    x_actset = x_vars.new_empty(n_actset, dtype=torch.long)

    ptr_actset = x_actset.new_zeros(1 + len(actset), dtype=torch.long)
    inx_actset = x_actset.new_empty(n_actset, dtype=torch.long)

    # copy numpy data into the allocated tensors
    v1 = c1 = e1 = j1 = 0
    for b, (x, act_set) in enumerate(batch):
        v0, v1 = v1, v1 + len(x.variable_features)
        c0, c1 = c1, c1 + len(x.row_features)
        e0, e1 = e1, e1 + len(x.edge_features.values)
        j0, j1 = j1, j1 + len(act_set)

        # the vars, cons, and cons-to-vars (edges)
        x_vars[v0:v1].copy_(as_tensor(x.variable_features))
        x_cons[c0:c1].copy_(as_tensor(x.row_features))
        x_ctov_v[e0:e1].copy_(as_tensor(x.edge_features.values))
        x_ctov_ij[:, e0:e1].copy_(as_tensor(x.edge_features.indices.astype(int)))
        x_actset[j0:j1].copy_(as_tensor(act_set.astype(int)))

        # fixup the ij-link indices and action set
        x_ctov_ij[0, e0:e1] += c0
        x_ctov_ij[1, e0:e1] += v0
        x_actset[j0:j1] += v0

        # the batch assignment
        inx_vars[v0:v1] = b
        inx_cons[c0:c1] = b
        inx_ctov[e0:e1] = b
        inx_actset[j0:j1] = b

        # record the batch index pointer
        ptr_vars[1 + b] = v1
        ptr_cons[1 + b] = c1
        ptr_ctov[1 + b] = e1
        ptr_actset[1 + b] = j1

    return BatchObservation(
        x_vars,
        x_cons,
        x_ctov_v,
        x_ctov_ij,
        x_actset,
        ptr_vars,
        ptr_cons,
        ptr_ctov,
        ptr_actset,
        inx_vars,
        inx_cons,
        inx_ctov,
        inx_actset,
    )

We source supervised data from the strong branching heuristic

In [None]:
from ecole.observation import StrongBranchingScores, Pseudocosts
from ecole.core.scip import Model, Stage

BranchRule = Callable[[Branching], int]
BranchRuleCallable = Callable[[Observation], int]


def strongbranch(pseudocost: bool = False) -> BranchRule:
    if not pseudocost:
        scorer = StrongBranchingScores(pseudo_candidates=False)

    else:
        scorer = Pseudocosts()

    def _spawn(env: Branching) -> BranchRuleCallable:
        def _branchrule(obs: Observation, **ignored) -> int:
            if env.model.stage != Stage.Solving:
                return None

            scores = scorer.extract(env.model, False)
            return obs.actset[scores[obs.actset].argmax()]  # SCIPvarGetProbindex

        return _branchrule

    return _spawn

We also compare to a randombranching expert, although its utility is vague.

In [None]:
def randombranch(seed: int = None) -> BranchRule:
    rng = default_rng(seed)

    def _spawn(env: Branching) -> BranchRuleCallable:
        def _branchrule(obs: Observation, **ignored) -> int:
            if env.model.stage == Stage.Solving:
                return int(rng.choice(obs.actset))
            return None

        return _branchrule

    return _spawn

And, finally, a function to use the trained machine learning model to pick branching vars

In [None]:
from torch.nn import Module


def batched_ml_branchrule(module: Module) -> BranchRuleCallable:
    def _branchrule(batch: tuple[Observation], **ignored) -> tuple[int]:
        module.eval()
        out = module.predict(collate(batch)).cpu()
        return np.asarray(out, dtype=int).tolist()

    return torch.inference_mode(True)(_branchrule)


def ml_branchrule(module: Module) -> BranchRule:
    do_batch = batched_ml_branchrule(module)

    def _spawn(env: Branching) -> BranchRuleCallable:
        def _branchrule(obs: Observation, **ignored) -> int:
            if env.model.stage != Stage.Solving:
                return None

            # apply the model to a single-item batch
            return int(do_batch([obs])[0])

        return _branchrule

    return _spawn

A server that attempts to batch-process the requests

In [None]:
from threading import Thread, Event
from queue import Queue, Empty as QueueEmpty


class BatchProcessor(Thread):
    """Collect requests and batch process them with target"""

    timeout: float = 3.0

    def __init__(
        self,
        target: Callable,
        name: str = None,
        daemon: bool = None,
    ) -> None:
        super().__init__(name=name, daemon=daemon)
        self.exception_, self.target = None, target
        self.is_finished, self.requests = Event(), Queue()

    def stop(self) -> None:
        self.is_finished.set()
        self.join()

    def run(self) -> None:
        batch = []
        while not self.is_finished.is_set():
            # collect the first element by a short-lived blocking call
            try:
                batch.append(self.requests.get(True, timeout=self.timeout))

            except QueueEmpty:
                continue

            # fetch all __immediately__ available items
            try:
                while True:
                    batch.append(self.requests.get(False, timeout=None))

            except QueueEmpty:
                pass

            # separate the data from the tx queue
            coms, inputs = zip(*batch)
            batch.clear()

            # process and send each result back to its origin,
            # but auto-shutdown in case of emergency
            try:
                for com, out in zip(coms, self.target(inputs)):
                    com.put(out)

            except Exception as e:
                self.exception_ = e
                break

        self.is_finished.set()

    def connect(self) -> Callable:
        """create a communications closure"""
        com = Queue()

        def co_yield(input: ...) -> ...:
            # send the request, unless the servser has been terminated
            if self.is_finished.is_set():
                raise RuntimeError

            self.requests.put((com, input))

            # wait on the exclusive queue, making sure not to block
            #  for too long
            while not self.is_finished.is_set():
                try:
                    return com.get(True, timeout=self.timeout)

                except QueueEmpty:
                    continue

            raise self.exception_ or RuntimeError

        return co_yield

A batching branchrule server

In [None]:
class BranchingServer(BatchProcessor):
    """Branching variable Server"""

    def connect(self, env: Branching) -> BranchRuleCallable:
        """Spawn a new branchrule"""
        co_yield = super().connect()

        def _branchrule(obs: Observation, **ignored: dict) -> int:
            if env.model.stage != Stage.Solving:
                return None

            return int(co_yield(obs))

        return _branchrule

A procedure to seed Ecole's PRNG

In [None]:
from ecole import RandomGenerator


def ecole_seed(ss: SeedSequence) -> RandomGenerator:
    # `RandomGenerator.max_seed` reports 2^{32}-1
    (seed,) = ss.generate_state(1, dtype=np.uint32)
    return RandomGenerator(seed)

<br>

## The data source proper

A generator of observation-action-reward data collected from the nodes of SCIP's BnB search tree at which a branching decision was made.

In [None]:
from typing import Iterable
from ecole.scip import Model

# from pyscipopt.scip import Model as SCIPModel


def maybe_raise_sigint(m: Model) -> None:
    """Manually check if SCIP encountered a sigint"""
    if m.as_pyscipopt().getStatus() == "userinterrupt":
        raise KeyboardInterrupt from None


def rollout(
    p: Model,
    env: Branching,
    branchrule: BranchRule,
    *,
    delay: bool = True,
    kwargs: dict = None,
) -> Iterable:
    kwargs = {} if kwargs is None else kwargs
    do_branch = branchrule(env)

    obs, act_set, rew, fin, nfo = env.reset(p)
    maybe_raise_sigint(env.model)
    while not fin:
        # the action set should be treated as a part of the observation
        obs_ = Observation(obs, act_set)

        # query the expert and branch
        with DelaySIGINT(disable=not delay):
            act_ = do_branch(obs_, **kwargs)
            obs, act_set, rew, fin, nfo = env.step(act_)  # t -->> t+1
            if not delay:
                maybe_raise_sigint(env.model)

            # send out the `x_{t-1}, a_{t-1}, r_t`
            yield obs_, act_, rew  # XXX no underscore in `rew`!
            # XXX SCIP has a complex node selection strategy, which even
            #  when set to prioritize DFS, still may switch no arbitrary
            #  node after branching. For the purpose of this experiment
            #  we make the worst-case assumption about the transition
            #  funciton that the next focus node is not at all related
            #  to the a prior branching decision. The only assumption is
            #  that the returned reward reflects the quality of the branching
            #  decision

    # no need to yield anything on fin=True, since ecole's terminal
    #  observation is None
    return nfo

During evaluation we only care about the final `nfo` data from the branching env, as it contains the post-search tree stats.

In [None]:
def evaluate(
    p: Model, env: Branching, branchrule: BranchRule, *, delay: bool = True
) -> dict[str, float]:
    try:
        # use while-loop to capture the return value from the generator
        #  (which we care about now).
        it, n_steps = rollout(p, env, branchrule, delay=delay), 0
        while True:
            next(it)
            n_steps += 1

    except StopIteration as e:
        return dict(n_requests=n_steps, **e.value)

We use infinite problem generators, for which we implement a continuous rollout wrapper

In [None]:
def continuous_rollout(
    it: Iterable[Model], env: Branching, branchrule: BranchRuleCallable
) -> Iterable:
    # we use for-loop in case an infinite generator is actually finite
    for p in it:
        yield from rollout(p, env, branchrule)

A problem instance server for mutlithreaded version

In [None]:
class Job(NamedTuple):
    p: Model
    parameters: dict

A multithreaded version for especially slow branchrules

In [None]:
from queue import Full as QueueFull


class RolloutPool:
    def __init__(
        self,
        jobs: Queue,
        factories: tuple[Callable, BranchRule],
        *,
        maxsize: int = 64,
        timeout: float = 0.5,
    ) -> None:
        self.jobs, self.timeout = jobs, timeout

        # spawn branching workers
        self.workers = []
        for factory, branchrule in factories:
            t = Thread(target=self.worker, args=(factory, branchrule), daemon=True)
            self.workers.append(t)

        self.errors, self.output = Queue(), Queue(maxsize)
        self.is_finished = Event()

    def start(self) -> None:
        for t in self.workers:
            t.start()

    def stop(self) -> None:
        self.is_finished.set()
        for t in self.workers:
            t.join()

        self.workers.clear()

    def worker(self, factory: Callable, branchrule: BranchRule) -> None:
        env = factory()
        try:
            while not self.is_finished.is_set():
                # busy-check the termination flag until we receive a job
                try:
                    input = self.jobs.get(True, timeout=self.timeout)

                except QueueEmpty:
                    continue

                assert isinstance(input, Job)

                # do a rollout, sending the results into the buffer
                for output in rollout(
                    input.p,
                    env,
                    branchrule,  # XXX get a new do_branch on eeach new rollout
                    delay=False,
                    kwargs=input.parameters,
                ):
                    while not self.is_finished.is_set():
                        try:
                            self.output.put(output, True, timeout=self.timeout)

                        except QueueFull:
                            continue

                        break

        except BaseException as e:
            self.is_finished.set()
            self.errors.put(e)

    def __iter__(self) -> None:
        # keep regularly checking the termination flag until
        #  we receive a result
        while not self.is_finished.is_set():
            try:
                yield self.output.get(True, timeout=self.timeout)

            except QueueEmpty:
                continue

        self.stop()
        if not self.errors.empty():
            raise self.errors.get()

<br>

## the Architecture

In [None]:
from torch import nn

from torch.nn import functional as F

A tracer hook to help with debugging live models

```python
hook = model.register_forward_pre_hook(tracer)

bt, by = next(feed)
model(bt, by)

...

hook.remove()
```
- _judicious placement_ of breakpoints `b` is advised so that continuation `c` would work
    - try `n`, `b 1208`
- otherwise use `n` for next, `s` for step inside, `u/d` to move between stack frames

In [None]:
def tracer(module, input) -> None:
    import pdb

    pdb.set_trace()

A good old trusted MLP

In [None]:
def mlp(activation: type = nn.ReLU, /, *n_dims: int) -> nn.Sequential:
    """multi-layer perceptrons are magic"""
    layers = []
    for d0, d1 in zip(n_dims, n_dims[1:]):
        layers.append(nn.Linear(d0, d1))
        layers.append(activation())

    return nn.Sequential(*layers[:-1])

The original message passing architecture is
$$
x^k_i
    = \gamma\bigl(
        x^{k-1}_i,
        \operatorname{\diamond} \bigl(
            \bigl\{
                \phi(x^{k-1}_i, x^{k-1}_j, e_{ij})
                \colon j \in G_i
            \bigr\}
        \bigr)
    \bigr)
    \,, $$

where $\diamond$ is a permutation-invariant set-to-real _aggregation_ operator,
$\phi$ is the _message_ function, and $\gamma$ is the _output_ function, and $G_u$
is the graph neighborhood of the vertex $u$.

Below we implement a variant of the massage passing where the _aggregation_ and _message_ operations are based on multi-headed cross attention. This means, that $\diamond$ and $\phi$ are _fused_:
$$
    (\diamond \circ \phi)
    \colon \mathbb{R}^d \times \bigl(\mathbb{R}^d\bigr)^* \to \mathbb{R}^d
    \colon (x_i, \{x_j\colon j \in G_i\}) \mapsto
        \sum_{j \in G_i} p_j \phi^v(x_j)
    \,. $$
with $\log p_j \propto \phi^q(x_i)^\top \phi^k(x_j)$ for $j \in G_i$.

Implement a special layer that computes message passing with cross-attention between the parts of a bipartite graph. We make heavy use of `torch-scatter` operations

In [None]:
from math import sqrt

from einops import rearrange
from torch_scatter import scatter_softmax, scatter_sum


class BipartiteMHXA(Module):
    def __init__(self, n_embed: int = 32, n_heads: int = 1) -> None:
        super().__init__()
        self.n_heads = n_heads

        # query projection for `input`
        self.p_q = nn.Linear(n_embed, n_embed, bias=False)

        # key-value projection for `other`
        self.p_kv = nn.Linear(n_embed, 2 * n_embed, bias=False)

        # the final head output mixer
        self.out = nn.Linear(n_embed, n_embed)

    def forward(
        self,
        input: Tensor,
        other: Tensor,
        coupling: tuple[Tensor, Tensor],
        weights: Tensor = None,
    ) -> Tensor:
        """Cross attention form `input` (query) to `other` (keys and values)"""
        if weights is not None:
            raise NotImplementedError

        # `coupling` specifies which `input` attends to which `other`
        #  XXX (u, v) represents a directed edge `u to v` and means
        #  that `input[u]` attends to `other[v]` and sources data from it.
        t, s = coupling  # XXX t (in input) <<-- s (in other)

        # get the qkv vectors, properly reshaped for multi-headed attention
        q = rearrange(self.p_q(input), "N (h f) -> N h () f", h=self.n_heads)
        k, v = self.p_kv(other).chunk(2, -1)
        k = rearrange(k, "N (h f) -> N h f ()", h=self.n_heads)
        v = rearrange(v, "N (h f) -> N h f", h=self.n_heads)

        # q is N x H x 1 x F, k is N x H x F x 1
        # XXX indexing by `t` and `s` materializes potentially large tensors!
        score = torch.matmul(q[t], k[s]).div(sqrt(q.shape[-1])).squeeze(-1)

        # softmax over `a` in `(q_b k_{i_a})_{a \colon j_a = b}`
        # XXX `scatter_softmax(x, j)` computes softmax over `x` grouped by j
        #    `y[j==b] = softmax(x[j==b])` for all b in j's range
        # XXX `q[b].matmul(k[i[j==b]])`
        alpha = scatter_softmax(score, t, dim=0, dim_size=len(q))

        # `scatter_sum(x, j)` computes `y_b = \sum_{a: j_a = b} x_a`
        # XXX `y[j[a]] += x[a]`
        out = scatter_sum(alpha * v[s], t, dim=0, dim_size=len(q))
        return self.out(rearrange(out, "... h f -> ... (h f)"))

A bipartite transformer block
* we opt for post-norm architecture
* [Zhang et al. (2022)](https://arxiv.org/abs/2206.11925.pdf) metnion that norm-first layer norm may impact expressivty, but do not analyze when the norm is at the other end
* while [](https://arxiv.org/abs/2002.04745.pdf) claim that pre-LN stabilizes the grad, whereas post-LN expected grads' magnitude grows with the depth
  - strange, but this appears to contradict their own theoreical results

In [None]:
class BipartiteBlock(Module):
    def __init__(
        self,
        n_embed: int = 32,
        n_heads: int = 1,
        p_drop: float = 0.2,
        b_norm_first: bool = False,
    ) -> None:
        super().__init__()
        self.b_norm_first = b_norm_first

        self.attn = BipartiteMHXA(n_embed, n_heads)
        self.do_1 = nn.Dropout(p_drop)
        self.ln_1 = nn.LayerNorm(n_embed)

        self.pwff = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.LeakyReLU(),
            nn.Dropout(p_drop),
            nn.Linear(4 * n_embed, n_embed),
        )
        self.do_2 = nn.Dropout(p_drop)
        self.ln_2 = nn.LayerNorm(n_embed)

    def forward(
        self,
        input: Tensor,
        other: Tensor,
        coupling: tuple[Tensor, Tensor],
        weights: Tensor = None,
    ) -> Tensor:
        x = input
        if self.b_norm_first:
            x = x + self.do_1(self.attn(self.ln_1(x), other, coupling, weights))
            return x + self.do_2(self.pwff(self.ln_2(x)))

        x = self.ln_1(x + self.do_1(self.attn(x, other, coupling, weights)))
        return self.ln_2(x + self.do_2(self.pwff(x)))

The full model

In [None]:
from torch_scatter import scatter_log_softmax, scatter_logsumexp
from torch_scatter import scatter_max, scatter_mean


class NeuralVariableSelector(Module):
    def __init__(
        self,
        n_dim_vars: int = 19,
        n_dim_cons: int = 5,
        n_embed: int = 32,
        n_heads: int = 1,
        n_blocks: int = 1,
        p_drop: float = 0.2,
        b_norm_first: bool = False,
    ) -> None:
        super().__init__()
        self.encoder = nn.ModuleDict(
            dict(
                vars=mlp(nn.LeakyReLU, n_dim_vars, 4 * n_embed, n_embed),
                cons=mlp(nn.LeakyReLU, n_dim_cons, 4 * n_embed, n_embed),
                # edge=mlp(nn.LeakyReLU, 1, 4 * n_embed, n_embed)
                edge=None,
            )
        )

        blk = [
            BipartiteBlock(n_embed, n_heads, p_drop, b_norm_first)
            for _ in range(n_blocks)
        ]
        self.block_cv = nn.ModuleList(blk)
        blk = [
            BipartiteBlock(n_embed, n_heads, p_drop, b_norm_first)
            for _ in range(n_blocks)
        ]
        self.block_vc = nn.ModuleList(blk)

        self.head = mlp(nn.LeakyReLU, 2 * n_embed, 4 * n_embed, 1)

    def forward(
        self, input: BatchObservation, target: Tensor = None
    ) -> tuple[Tensor, dict[str, Tensor]]:
        jc, jv = input.ctov_ij

        # encode the vars and cons features
        cons = self.encoder.cons(input.cons)
        vars = self.encoder.vars(input.vars)
        edge = None
        if self.encoder.edge is not None:
            edge = self.encoder.edge(input.ctov_v)

        # bipartite ladder
        for m_cv, m_vc in zip(self.block_cv, self.block_vc):
            cons = m_cv(cons, vars, (jc, jv), edge)
            vars = m_vc(vars, cons, (jv, jc), edge)

        # compute the graph-level variable embedding
        graph = scatter_mean(vars, input.inx_vars, 0)

        # get the raw-logit scores of each variable
        x = torch.cat((vars, graph[input.inx_vars]), -1)
        raw = self.head(x).squeeze(-1)

        # optionally compute the loss
        if target is None:
            return raw, {}  # raw.new_full((), float("nan"))

        # get the log-probas and compute the loss terms
        logp_vars = scatter_log_softmax(raw, input.inx_vars)

        # get the log-likelihood of the target variables
        # XXX `ptr_vars` corrects the variable indices in the target batch
        j = input.ptr_vars[:-1] + target.to(logp_vars.device)
        loglik_target = logp_vars[j]

        # get the log likelihood of the forbidden variable set
        mask = torch.ones_like(logp_vars, dtype=bool)
        mask[input.actset] = False
        logp_forbidden = scatter_logsumexp(
            logp_vars[mask],
            input.inx_vars[mask],  # XXX group by batch assignment
            dim=0,
            # It is possible that the forbidden mask to be
            #  empty in certain nodes of certain instances
            # XXX ptr_* is one longer than the batch size
            dim_size=len(input.ptr_vars) - 1,
        )

        # XXX no need to compute either the act-set sizes
        #  or the vars sizes with `input.ptr_vars.diff()`
        # n_forbidden_size = scatter_sum(mask.float(), input.inx_vars, 0)
        loglik_actset = logp_forbidden  # .div(n_forbidden_size)

        # compute the discrete entropy $- \sum_j \pi_j \log \pi_j$
        # XXX For input x and target y, `F.kl_div(x, y, "none", log_target=True)`
        #  returns $e^y (y - x)$, correctly handling infinite `-ve` logits.
        #  `log_target` relates to `y` being logits (True) or probas (False).
        # XXX `.new_zeros(())` creates a scalar zero (yes, an EMPTY tuple)
        # XXX from nle-toolbox
        zero = logp_vars.new_zeros(())
        p_log_p = F.kl_div(zero, logp_vars, reduction="none", log_target=True)
        entropy = scatter_sum(p_log_p, input.inx_vars, 0).neg()

        return raw, dict(
            neg_target=loglik_target.neg(),
            neg_actset=loglik_actset.neg(),
            entropy=entropy,
        )

    @torch.inference_mode()
    def predict(self, input: BatchObservation) -> ndarray:
        scores, _ = self.forward(input)

        # mask forbidden variables
        mask = torch.ones_like(scores, dtype=bool)
        mask[input.actset] = False
        scores.masked_fill_(mask, float("-inf"))

        # pick the maximally probable action in each batch item
        _, j = scatter_max(scores, input.inx_vars, 0)
        if mask[j].any():
            raise RuntimeError("Empty actset!")  # sanity check

        # subtract the base index of each batch item
        j -= input.ptr_vars[:-1]
        return j

<br>

Some procs for handling dicts

In [None]:
def transpose_dict(dict: dict[..., dict]) -> dict[..., dict]:
    outer = {}
    for k_out, inner in dict.items():
        for k_in, value in inner.items():
            new = outer.setdefault(k_in, {})
            assert k_out not in new
            new[k_out] = value

    return outer


def collate_dict(records: list[dict]) -> dict[..., list]:
    """Collate records assuming no fields are missing"""
    out = {}
    for record in records:
        for field, value in record.items():
            out.setdefault(field, []).append(value)

    return {k: np.array(v) for k, v in out.items()}

<br>

## Trainnig

Allow for 100k samples

In [None]:
n_total, n_batch_size = 100_000, 16  # XXX 100k is too long
n_reservoir = 512  # XXX 128 reservoir was ok for CAuc

C_neg_actset = 0.0  # 1e-3  # XXX the orginal CAuc used to have 0.0
C_entropy = 0.0  # 1e-2  # XXX was set to zero in the first CAuc

# use a seed sequence with a fixed entropy pool
ss = SeedSequence(83278314352113072500167414370310027453)
# ss = SeedSequence(None)  # use `ss.entropy` for future reproducibility

Pipe the generator that mixes several the CO problems into the continupus rollout iterator.

In [None]:
from toybnb.scip.ecole.benchmarks import gasse2019
from ecole.instance import CombinatorialAuctionGenerator


# init the branching env
env = make_env(spawn_one(ss))

# CAuc(100, 500), CAuc(50, 250)
gens = [
    CombinatorialAuctionGenerator(100, 500, rng=ecole_seed(spawn_one(ss))),
    CombinatorialAuctionGenerator(50, 250, rng=ecole_seed(spawn_one(ss))),
]

# Use co problems from Gasse 2019
# gens = transpose_dict(gasse2019(spawn_one(ss)))
# gens = gens["train"].values()
itco = mixer(*gens, seed=spawn_one(ss))

set up the rollout observation feed

In [None]:
from functools import partial


if False:
    feed = continuous_rollout(itco, env, strongbranch())

else:

    def t_job_generator(it: Iterable, jobs: Queue) -> None:
        for sample in it:
            jobs.put(Job(sample, {}))

    n_jobs = 6
    jobs = Queue(max(16, 2 * n_jobs))
    source = Thread(
        target=t_job_generator,
        args=(itco, jobs),
        daemon=True,
    )
    source.start()

    branchrule = strongbranch()
    pool = RolloutPool(
        jobs,
        [(partial(make_env, fork), branchrule) for fork in ss.spawn(n_jobs)],
        maxsize=64,
        timeout=0.5,
    )
    pool.start()

    feed = iter(pool)

Then feed the branching observation data into a shuffler, limiter and then batcher.

In [None]:
from tqdm import trange

# set up the data stream
# XXX if we put a limiter on the source iter, then shuffle's reservoir
#  would consume `n_reservoir`, which are never going to be yield,
#  if the source is infinite. In this case, instead, we should put
#  a limiter on the `shuffle` itself.
it = feed
it = limit(it, trange(n_total, ncols=70))
it = shuffle(it, n_reservoir, spawn_one(ss))
it = batch(it, n_batch_size)

Set up the model and optimizers

In [None]:
mod = NeuralVariableSelector(19, 5, 32, 2, 1, 0.2)
# hook = mod.register_forward_pre_hook(tracer)

optim = torch.optim.AdamW(mod.parameters(), lr=3e-4)

train an IL model

In [None]:
from IPython.display import clear_output


device = torch.device("cpu")
log = dict()
for bt in it:
    # collate the batch
    obs, act, rew = zip(*bt)
    bx = collate(obs, device=device)
    act = as_tensor(np.array(act, dtype=int), device=device)
    rew = np.array(rew, dtype=np.float32)

    # forward pass
    mod.train()
    _, terms = mod(bx, target=act)
    terms = {k: v.mean() for k, v in terms.items()}

    # log likelihoods and entropy have the same uints (and thus scale)
    loss = (
        # (min) -ve log-likelihood
        terms["neg_target"]
        # (max) entropy
        - C_entropy * terms["entropy"]
        # (max) -ve log-likelihood of forbidden actions
        - C_neg_actset * terms["neg_actset"]
    )

    # backprop
    mod.zero_grad(True)
    loss.backward()
    optim.step()

    mod.eval()
    do_add({k: float(v) for k, v in terms.items()}, log)

    clear_output(True)
    fig, ax0 = plt.subplots(1, 1, figsize=(5, 2), dpi=200, sharex=True)
    plot_stats(ax0, **log)
    ax0.set_title("terms")
    ax0.legend(loc="lower left", fontsize="xx-small")

    plt.show()

<br>

In [None]:
from time import strftime

torch.save(
    dict(
        __dttm__=strftime("%Y%m%d-%H%M%S"),
        state_dict=mod.state_dict(),
    ),
    #     "CAuc100500-50250-good.pt"
    #     "second-good.pt",
    #     "second-good--zero-actset--logp.pt",
    #     "second-CAuc100500-50250-good-no-reg.pt",
    #     "second-CAuc50250-good-no-reg.pt",
    #     "second-CAuc50250-good-no-reg-sum.pt",
    #     "third-CAuc-dropout.pt",
    #     "third-CAuc-dropout__large.pt",
    "fouth-All-dropout.pt",
)

In [None]:
mod = NeuralVariableSelector(19, 5, 32, 2, 1, 0.2)

ckpt = torch.load("fouth-All-dropout.pt")
mod.load_state_dict(ckpt["state_dict"])

List all the rules

In [None]:
rules = {
    "trained": ml_branchrule(mod),
    #     "strongbranch": strongbranch(),
    #     "pseudocostbranch": strongbranch(True),
}

Evaluate the branchrules in parallel threads

In [None]:
def t_evaluate(rk: int, ws: int, ss: SeedSequence, rx: Queue, tx: Queue) -> None:
    env = make_env(ss, presolve=True)
    while True:
        p = rx.get()
        if not isinstance(p, Model):
            break

        out = {}
        for nom, rule in rules.items():
            out[nom] = evaluate(p.copy_orig(), env, rule, delay=False)

        tx.put(out)

Let's do this

In [None]:
from tqdm import trange

n_jobs = 8

# spawn thread workers consuming from `rx` and producing into `tx`
threads, rx, tx = [], Queue(2 * n_jobs), Queue()
for rk, sk in enumerate(ss.spawn(n_jobs)):
    args = rk, n_jobs, sk, rx, tx
    t = Thread(target=t_evaluate, args=args, daemon=True)
    threads.append(t)
    t.start()

# start filling the rx queue
print("Populating the jobs queue")
it_co = CombinatorialAuctionGenerator(100, 500)
for p, _ in zip(it_co, trange(1000, ncols=70)):
    rx.put(p)

# submit termination signals through the queue and wait for shutdown
print("Shutting workers down")
for t in threads:
    rx.put(StopIteration)

for t in threads:
    t.join()

# collect the evaluation results
nfos = {}
while not tx.empty():
    do_add(tx.get(), nfos)

* `n_nodes`, `n_requests`
* `n_lpiter`
* `f_soltime`

<br>

In [None]:
from time import strftime

# raise RuntimeError

__dttm__ = strftime("%Y%m%d-%H%M%S")
metrics = {k: collate_dict(nfo) for k, nfo in nfos.items()}
torch.save(
    dict(__dttm__=__dttm__, metrics=metrics),
    #     "cauc__il_xattn__zero_logactst_reg.pt",
    #     "cauc__il_xattn__nonzero_logactst_reg.pt",
    #     "cauc__il_xattn__no-reg.pt",
    #     "cauc50250__il_xattn__no-reg.pt",
    #     "cauc50250__il_xattn__no-reg_sum.pt",
    #     "cauc__il_xattn__dropout.pt",
    #     "cauc__il_xattn__dropout__large.pt",
    "all__il_xattn__dropout.pt",
)

<br>

In [None]:
m_zer = torch.load("cauc__il_xattn__zero_logactst_reg.pt")["metrics"]
m_nnz = torch.load("cauc__il_xattn__nonzero_logactst_reg.pt")["metrics"]
m_org = torch.load("cauc__il_xattn__no-reg.pt")["metrics"]
m_smol = torch.load("cauc50250__il_xattn__no-reg.pt")["metrics"]
m_smol_sum = torch.load("cauc50250__il_xattn__no-reg_sum.pt")["metrics"]
m_dropout = torch.load("cauc__il_xattn__dropout.pt")["metrics"]
m_large = torch.load("cauc__il_xattn__dropout__large.pt")["metrics"]
m_all = torch.load("all__il_xattn__dropout.pt")["metrics"]

In [None]:
# series = "n_nodes", "n_requests", "f_soltime"
series = "n_nodes"

metric = {
    "nonzero": m_nnz["trained"][series],
    "zero": m_zer["trained"][series],
    "original": m_org["trained"][series],
    "small": m_smol["trained"][series],
    "small-sum": m_smol_sum["trained"][series],
    "dropout": m_dropout["trained"][series],
    "large": m_large["trained"][series],
    "all": m_all["trained"][series],
    "strongbranch": np.r_[
        m_nnz["strongbranch"][series],
        m_zer["strongbranch"][series],
        m_org["strongbranch"][series],
        m_smol["strongbranch"][series],
        m_smol_sum["strongbranch"][series],
    ],
    "pseudocostbranch": np.r_[
        m_nnz["pseudocostbranch"][series],
        m_zer["pseudocostbranch"][series],
        m_org["pseudocostbranch"][series],
        m_smol["pseudocostbranch"][series],
        m_smol_sum["pseudocostbranch"][series],
    ],
}

colors = {
    "strongbranch": "black",
    "pseudocostbranch": "C1",
    "nonzero": "C0",
    "zero": "C2",
    "original": "C3",
    "small": "C4",
    "small-sum": "C5",
    "dropout": "C6",
    "large": "C7",
    "all": "C8",
}
xlat = {
    "strongbranch": "SB",
    "pseudocostbranch": "PC",
    "zero": "IL-XAttn-ent",
    "original": "IL-XAttn-original",
    "nonzero": "IL-XAttn-act-set-ent",
    "small": "IL-XAttn-50250",
    "small-sum": "IL-XAttn-50250-sum",
    "dropout": "IL-XAttn-dropout",
    "large": "IL-XAttn-dropout-64",
    "all": "all-IL-XAttn-dropout",
}

In [None]:
base = "strongbranch"

fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200)
for name, data in metric.items():
    if base == name:
        continue
    p, q = pp_curve(x=metric[base], y=data, num=None)
    ax.plot(p, q, label=xlat[name], c=colors[name])

ax.plot((0, 1), (0, 1), c=colors[base], zorder=10, alpha=0.25, label=xlat[base])
ax.set_xlim(-0.025, 1.025)
ax.set_ylim(-0.025, 1.025)
ax.set_aspect(1.0)
ax.legend(loc="best", fontsize="xx-small")

fig.savefig(f"cauc__{__dttm__}.pdf")

Test it

In [None]:
mod.eval()
scores = mod(bx)

fig, ax = plt.subplots(1, 1, figsize=(7, 3), dpi=300)

probas = scatter_softmax(scores, x.inx_vars, 0)
ax.plot(probas.detach().numpy())
ax.set_xlim(-1.5, len(scores) + 1.5)

n_spans = int(x.inx_vars.max()) + 1
f, colors = 0, ("red", "blue")
for s in range(n_spans):
    jx = (x.inx_vars == s).nonzero()
    a, b = int(jx.min()), int(jx.max())

    ax.axvspan(a, b, color=colors[f], alpha=0.05)
    f = 1 - f

<br>