# A dqn agent

In [None]:
import torch
import numpy as np

from matplotlib import pyplot as plt

The pp-curve drawing procedure. We compare distibutions using the pp-curve, which is analogous to the ROC curve: pp compares two independent distributions, while ROC compares the true-, false- positive distributions

In [None]:
from numpy import ndarray

from toybnb.scip.ecole.il.plotting import pp_curve

A simple viz for tracking loss and other runtime series

In [None]:
from toybnb.scip.ecole.il.plotting import plot_series

SeedSequence needs a `.spawn-one` method

In [None]:
import numpy as np
from numpy.random import default_rng, SeedSequence


def spawn_one(ss: SeedSequence) -> SeedSequence:
    return ss.spawn(1)[0]

A function to add a dict record to a table (dict-of-lists)

In [None]:
from typing import Callable


def do_add(
    record: dict, to: dict[..., list], transform: Callable[..., dict] = None
) -> dict:
    """Add the record to a transposed dict of lists."""
    original = record
    if callable(transform):
        record = transform(**record)

    # assume no fields are missing
    for field, value in record.items():
        to.setdefault(field, []).append(value)

    return original

<br>

## Composing the datastreams

Reservoir sampling for data from infinite data streams 

In [None]:
from typing import Iterable


def shuffle(
    it: Iterable,
    n_size: int = 1024,
    seed: int = None,
) -> Iterable:
    """Shuffle the values from the iterable"""

    rng, reservoir = default_rng(seed), []
    for sample in it:
        # stack elements until the reservior is full
        if len(reservoir) < n_size:
            reservoir.append(sample)
            continue

        # replace a random element with the new sample
        ix = rng.choice(n_size)
        yield reservoir[ix]
        reservoir[ix] = sample

    # re-shuffle the remaining samples (in the case
    #  the sequence was too short for proper mixing)
    rng.shuffle(reservoir)
    yield from reservoir

Batch the data in a sequence

In [None]:
def batch(it: Iterable, n_size: int = 16) -> Iterable[list]:
    """Batch the values from the iterable"""
    batch = []
    for sample in it:
        batch.append(sample)

        # produce the batch when it's full and then clear
        if len(batch) >= n_size:
            yield batch
            batch.clear()

    # don't forget the residual batch
    if batch:
        yield batch

a handy sequence limiter

In [None]:
from typing import Union


def limit(it: Iterable, limiter: Union[int, Iterable]) -> Iterable:
    """Limit the length of the sequence to at most `n_total` values"""
    limiter = range(limiter) if isinstance(limiter, int) else limiter
    for sample, _ in zip(it, limiter):
        yield sample

finally, a generator that mixes data from multiple iterables

In [None]:
def mixer(*its: Iterable, seed: int = None) -> Iterable:
    """Yield a value from an iterable picked at random each time"""
    iters = list(map(iter, its))

    rng = default_rng(seed)
    while iters:
        it = rng.choice(iters, shuffle=False)
        try:
            yield next(it)

        except StopIteration:
            iters.remove(it)

<br>

## CO and Branching data sources

The following SCIP settings were inherited from [Gasse et al. 2019](), [Parsonson et al. 2022](), and
[Scavuzzo et al. 2022]().

In [None]:
import pyscipopt
import ecole as ec

# from toybnb.scip.ecole.il.env import default_scip_params

A derived branching env that disables SCIP's presolver
- without presolve training becomes much slower

In [None]:
from ecole.environment import Branching

# from toybnb.scip.ecole.il.env import BranchingWithoutPresolve

We want to train a problem-dependent branching heuristic. Strong branching and pseudocost have the best problem dependency -- they have access to its geometry, but are slow. Our goal is to train a neural net that extract geometric info from the parametric representation of a problem at a node.

In [None]:
from ecole.observation import NodeBipartite
from ecole.reward import Constant


from toybnb.scip.ecole.il.env import make_env

We use a special representation of a batch of observations

In [None]:
from typing import NamedTuple
from ecole.observation import NodeBipartiteObs
from torch import Tensor
from numpy import ndarray


from toybnb.scip.ecole.il.data import Observation, BatchObservation

A special collation logic for `Observation`

In [None]:
from torch import as_tensor
from toybnb.scip.ecole.il.data import collate

We source supervised data from the strong branching heuristic

In [None]:
from ecole.observation import StrongBranchingScores, Pseudocosts
from ecole.core.scip import Model, Stage

from toybnb.scip.ecole.il.brancher import BranchRule, BranchRuleCallable
from toybnb.scip.ecole.il.brancher import strongbranch

We also compare to a randombranching expert, although its utility is vague.

In [None]:
from toybnb.scip.ecole.il.brancher import randombranch

And, finally, a function to use the trained machine learning model to pick branching vars

In [None]:
from torch.nn import Module

from toybnb.scip.ecole.il.brancher import batched_ml_branchrule, ml_branchrule

A branchrule that communicates with a central action server, that attempts to process the requests in batches for efficiency.

In [None]:
from threading import Thread, Event
from queue import Queue, Empty as QueueEmpty

from toybnb.scip.ecole.il.threads import BatchProcessor


class BranchingServer(BatchProcessor):
    """Branching variable Server"""

    def connect(self, env: Branching) -> BranchRuleCallable:
        """Spawn a new branchrule"""
        co_yield = super().connect()

        def _branchrule(obs: Observation, **ignored: dict) -> int:
            if env.model.stage != Stage.Solving:
                return None

            return int(co_yield(obs))

        return _branchrule

A procedure to seed Ecole's PRNG

In [None]:
from ecole import RandomGenerator

from toybnb.scip.ecole.il.env import ecole_seed

<br>

## The data source proper

During training we need a generator of observation, action, and reward data from the decision nodes of the BnB tree. During evaluation we only care about the final `nfo` data from the branching environment, as it contains the post-search tree stats. `rollout` returns both kinds of data, `evaluate` is a handy wrapper around `rollout`.

In [None]:
from toybnb.scip.ecole.il.rollout import rollout, evaluate

We use infinite problem generators, for which we implement a continuous rollout wrapper

In [None]:
def continuous_rollout(
    it: Iterable[Model], env: Branching, branchrule: BranchRule, kwargs: dict = None
) -> Iterable:
    # we use for-loop in case an infinite generator is actually finite
    for p in it:
        yield from rollout(p, env, branchrule, kwargs)

<br>

### A multithreaded version for especially slow problem instances

The inner logic of the parallel job feeder

In [None]:
from queue import Empty, Queue, Full
from threading import Event


def t_feed(
    it: Iterable,
    to: Queue,
    *,
    signal: Event,
    err: Queue,
    timeout: float = 1.0,
) -> None:
    """Keep putting items from the iterable into the queue, until stopped"""
    try:
        item = next(it)
        while not signal.is_set():
            try:
                to.put(item, True, timeout)
            except Full:
                continue

            item = next(it)

    except StopIteration:
        pass

    except BaseException as e:
        err.put_nowait(e)

The body of a parallel rollout worker

In [None]:
from typing import Callable


def t_rollout(
    feed: Queue,
    factory: Callable,
    branchrule: Callable,
    into: Queue,
    *,
    signal: Event,
    err: Queue,
    timeout: float = 1.0,
) -> None:
    """Rollout the `branchrule` on instances from `feed` solved in `factory`,
    saving the observations in `into`
    """
    try:
        # get the env
        env = factory()
        while not signal.is_set():
            # poll the feed queue for a new job
            try:
                p = feed.get(True, timeout)
            except Empty:
                continue

            # do a rollout, sending the results into the buffer
            for item in rollout(p, env, branchrule, {}, stop=signal.is_set):
                while not signal.is_set():
                    try:
                        into.put(item, True, timeout)
                    except Full:
                        continue

                    break

    except BaseException as e:
        err.put_nowait(e)

The procedure itself

In [None]:
from threading import Thread
from functools import partial


def maybe_raise(err: Queue) -> None:
    """Raise if the error queue has an exception"""
    with err.mutex:
        if err.queue:
            raise err.queue.popleft()


def multirollout(
    feed: Iterable,
    ss: SeedSequence,
    branchrule: BranchRule,
    n_jobs: int = 1,
) -> Iterable:
    if n_jobs < 2:
        (fork,) = ss.spawn(1)
        yield from continuous_rollout(feed, make_env(fork), branchrule, {})
        return

    ctx = dict(signal=Event(), err=Queue(), timeout=1.0)

    # spawn feed thread and the workers
    feed_q, rollout_q = Queue(32), Queue()
    threads = [Thread(target=t_feed, args=(feed, feed_q), kwargs=ctx, daemon=True)]
    for fork in ss.spawn(n_jobs):
        args = feed_q, partial(make_env, fork), branchrule, rollout_q
        threads.append(Thread(target=t_rollout, args=args, kwargs=ctx, daemon=True))

    for t in threads:
        t.start()

    try:
        # the main thread yields results from the rollout output queue
        while True:
            maybe_raise(ctx["err"])
            try:
                yield rollout_q.get(True, timeout=5.0)

            except Empty:
                continue

    finally:
        ctx["signal"].set()
        for t in threads:
            t.join()

<br>

## the Architecture

In [None]:
from torch import nn

from torch.nn import functional as F

A tracer hook to help with debugging live models

```python
hook = model.register_forward_pre_hook(tracer)

bt, by = next(feed)
model(bt, by)

...

hook.remove()
```
- _judicious placement_ of breakpoints `b` is advised so that continuation `c` would work
    - try `n`, `b 1208`
- otherwise use `n` for next, `s` for step inside, `u/d` to move between stack frames

In [None]:
def tracer(module, input) -> None:
    import pdb

    pdb.set_trace()

A good old trusted MLP

In [None]:
from toybnb.scip.ecole.il.nn import mlp

The original message passing architecture is
$$
x^k_i
    = \gamma\bigl(
        x^{k-1}_i,
        \operatorname{\diamond} \bigl(
            \bigl\{
                \phi(x^{k-1}_i, x^{k-1}_j, e_{ij})
                \colon j \in G_i
            \bigr\}
        \bigr)
    \bigr)
    \,, $$

where $\diamond$ is a permutation-invariant set-to-real _aggregation_ operator,
$\phi$ is the _message_ function, and $\gamma$ is the _output_ function, and $G_u$
is the graph neighborhood of the vertex $u$.

Below we implement a variant of the massage passing where the _aggregation_ and _message_ operations are based on multi-headed cross attention. This means, that $\diamond$ and $\phi$ are _fused_:
$$
    (\diamond \circ \phi)
    \colon \mathbb{R}^d \times \bigl(\mathbb{R}^d\bigr)^* \to \mathbb{R}^d
    \colon (x_i, \{x_j\colon j \in G_i\}) \mapsto
        \sum_{j \in G_i} p_j \phi^v(x_j)
    \,. $$
with $\log p_j \propto \phi^q(x_i)^\top \phi^k(x_j)$ for $j \in G_i$.

Implement a special layer that computes message passing with cross-attention between the parts of a bipartite graph. We make heavy use of `torch-scatter` operations

In [None]:
from math import sqrt

from einops import rearrange
from torch_scatter import scatter_softmax, scatter_sum

# from toybnb.scip.ecole.il.nn import BipartiteMHXA

A bipartite transformer block
* we opt for post-norm architecture
* [Zhang et al. (2022)](https://arxiv.org/abs/2206.11925.pdf) metnion that norm-first layer norm may impact expressivty, but do not analyze when the norm is at the other end
* while [](https://arxiv.org/abs/2002.04745.pdf) claim that pre-LN stabilizes the grad, whereas post-LN expected grads' magnitude grows with the depth
  - strange, but this appears to contradict their own theoreical results

In [None]:
# from toybnb.scip.ecole.il.nn import BipartiteBlock

The full model

In [None]:
from torch_scatter import scatter_log_softmax, scatter_logsumexp
from torch_scatter import scatter_max, scatter_mean

from toybnb.scip.ecole.il.nn import NeuralVariableSelector

The network from Gasse et al. 2019 with PreNorm layers replaced by batchnorms

In [None]:
from toybnb.scip.ecole.il.nn import BipartiteGConv, NeuralClassifierBranchruleMixin


class Gasse2019(Module, NeuralClassifierBranchruleMixin):
    def __init__(
        self,
        n_dim_vars: int = 19,
        n_dim_cons: int = 5,
        n_embed: int = 64,
    ) -> None:
        super().__init__()
        self.encoder = nn.ModuleDict(
            dict(
                vars=nn.Sequential(
                    # PreNormLayer,
                    # nn.BatchNorm1d(n_dim_vars, affine=False),
                    nn.Linear(n_dim_vars, n_embed),
                    nn.ReLU(),
                    nn.Linear(n_embed, n_embed),
                    nn.ReLU(),
                ),
                cons=nn.Sequential(
                    # PreNormLayer,
                    # nn.BatchNorm1d(n_dim_cons, affine=False),
                    nn.Linear(n_dim_cons, n_embed),
                    nn.ReLU(),
                    nn.Linear(n_embed, n_embed),
                    nn.ReLU(),
                ),
                edge=nn.Sequential(
                    # PreNormLayer,
                    nn.Unflatten(-1, (-1, 1)),
                    # nn.BatchNorm1d(1, affine=False),
                ),
            )
        )

        self.v2c = BipartiteGConv(n_embed, True)
        self.c2v = BipartiteGConv(n_embed, True)
        self.head = nn.Sequential(
            nn.Linear(n_embed, n_embed),
            nn.ReLU(),
            nn.Linear(n_embed, 1),
        )

    def forward(self, input: BatchObservation) -> Tensor:
        jc, jv = input.ctov_ij

        # encode the vars and cons features
        cons = self.encoder.cons(input.cons)
        edge = self.encoder.edge(input.ctov_v)
        vars = self.encoder.vars(input.vars)

        # bipartite vcv ladder
        # (v_{t-1}, c_{t-1}) -->> (v_{t-1}, c_t) -->> (v_t, c_t)
        cons = self.c2v(cons, vars, (jc, jv), edge)
        vars = self.v2c(vars, cons, (jv, jc), edge)

        # get the raw-logit scores of each variable
        return self.head(vars).squeeze(-1)

Some procs for handling dicts

In [None]:
def transpose_dict(dict: dict[..., dict]) -> dict[..., dict]:
    outer = {}
    for k_out, inner in dict.items():
        for k_in, value in inner.items():
            new = outer.setdefault(k_in, {})
            assert k_out not in new
            new[k_out] = value

    return outer


def collate_dict(records: list[dict]) -> dict[..., list]:
    """Collate records assuming no fields are missing"""
    out = {}
    for record in records:
        for field, value in record.items():
            out.setdefault(field, []).append(value)

    return {k: np.array(v) for k, v in out.items()}

<br>

## Trainnig

Allow for 100k samples

In [None]:
n_total, n_batch_size = 100_000, 16  # XXX 100k is too long
n_reservoir = 512  # XXX 128 reservoir was ok for CAuc

C_neg_actset = 0.0  # 1e-3  # XXX the orginal CAuc used to have 0.0
C_entropy = 0.0  # 1e-2  # XXX was set to zero in the first CAuc

# use a seed sequence with a fixed entropy pool
ss = SeedSequence(83278314352113072500167414370310027453)
# ss = SeedSequence(None)  # use `ss.entropy` for future reproducibility

s_project_graph = "pre"  # "post"

p_drop = 0.2
n_embed, n_heads, n_blocks = 32, 1, 1
b_edges = True
b_norm_first = True  # False

Pipe the generator that mixes several the CO problems into the continupus rollout iterator.

In [None]:
from toybnb.scip.ecole.benchmarks import gasse2019
from ecole.instance import CombinatorialAuctionGenerator


# init the branching env
env = make_env(spawn_one(ss))

# CAuc(100, 500), CAuc(50, 250)
gens = [
    CombinatorialAuctionGenerator(100, 500, rng=ecole_seed(spawn_one(ss))),
    CombinatorialAuctionGenerator(50, 250, rng=ecole_seed(spawn_one(ss))),
]

# Use co problems from Gasse 2019
# gens = transpose_dict(gasse2019(spawn_one(ss)))
# gens = gens["train"].values()
itco = mixer(*gens, seed=spawn_one(ss))

set up the rollout observation feed

In [None]:
feed = multirollout(
    itco,
    ss,
    strongbranch(False),
    n_jobs=6,
)

Then feed the branching observation data into a shuffler, limiter and then batcher.

In [None]:
from tqdm import trange

# set up the data stream
# XXX if we put a limiter on the source iter, then shuffle's reservoir
#  would consume `n_reservoir`, which are never going to be yield,
#  if the source is infinite. In this case, instead, we should put
#  a limiter on the `shuffle` itself.
it = feed
it = limit(it, trange(n_total, ncols=70))
it = shuffle(it, n_reservoir, spawn_one(ss))
it = batch(it, n_batch_size)

Set up the model and optimizers

In [None]:
from torch.optim.lr_scheduler import SequentialLR
from torch.optim.lr_scheduler import LinearLR, ConstantLR
from torch.optim.lr_scheduler import CosineAnnealingLR


mod = NeuralVariableSelector(
    19,
    5,
    n_embed,
    n_heads,
    n_blocks,
    p_drop,
    b_norm_first=b_norm_first,
    s_project_graph=s_project_graph,
    b_edges=b_edges,
)

# mod = Gasse2019()

# optim = torch.optim.AdamW(mod.parameters(), lr=3e-4)
optim = torch.optim.AdamW(mod.parameters(), lr=1e-3)

sched = None
sched = SequentialLR(
    optim,
    [
        ConstantLR(optim, factor=0.5, total_iters=50),
        LinearLR(optim, start_factor=0.5, total_iters=250),
        CosineAnnealingLR(optim, T_max=50, eta_min=1e-6),
    ],
    [50, 300],
)

train an IL model

In [None]:
from IPython.display import clear_output


device = torch.device("cpu")
log = dict()
for bt in it:
    # collate the batch
    obs, act, rew = zip(*bt)
    bx = collate(obs, device=device)
    act = as_tensor(np.array(act, dtype=int), device=device)
    rew = np.array(rew, dtype=np.float32)

    # forward pass
    mod.train()
    _, terms = mod.compute(bx, target=act)
    terms = {k: v.mean() for k, v in terms.items()}

    # log likelihoods and entropy have the same uints (and thus scale)
    loss = (
        # (min) -ve log-likelihood
        terms["neg_target"]
        # (max) entropy
        - C_entropy * terms["entropy"]
        # (max) -ve log-likelihood of forbidden actions
        - C_neg_actset * terms["neg_actset"]
    )

    # backprop
    mod.zero_grad(True)
    loss.backward()
    optim.step()

    if sched is not None:
        sched.step()

    mod.eval()
    do_add({k: float(v) for k, v in terms.items()}, log)

    clear_output(True)
    fig, ax0 = plt.subplots(1, 1, figsize=(5, 2), dpi=200, sharex=True)
    plot_series(ax0, **log)
    ax0.set_title("terms")
    ax0.legend(loc="lower left", fontsize="xx-small")

    plt.show()

<br>

In [None]:
from time import strftime

torch.save(
    dict(
        __dttm__=strftime("%Y%m%d-%H%M%S"),
        state_dict=mod.state_dict(),
    ),
    "ninth-cauc-norm-first.pt",
)

<br>

$$
(n+m) \mu_{n+m} = m \nu_m + n \mu_n
    \,. $$
$$
\begin{align}
    (n+m) S^2_{n+m}
        &= \sum_{j < m} (x_{n+j} - \mu_{n+m})^2 + \sum_{j < n} (x_j - \mu_{n+m})^2
        \\
        &= (n+m) (\mu_n - \mu_{n+m})^2
        + \sum_{j < m} (x_{n+j} - \mu_n)^2 + \sum_{j < n} (x_j - \mu_n)^2
        + 2 (\mu_n - \mu_{n+m}) \sum_{j < m} (x_{n+j} - \mu_n)
        \\
        &= (n+m) S^2_n
        + (n+m) (\mu_n - \mu_{n+m})^2
        - m S^2_n
        + 2 m (\mu_n - \mu_{n+m}) (\nu_m - \mu_n)
        + \sum_{j < m} (x_{n+j} - \mu_n)^2
        \,. \\
    (n+m) S^2_{n+m} - n S^2_n
        &= n (\mu_n - \mu_{n+m})^2
        + m (\nu_m - \mu_{n+m})^2
        + m T^2_m
        \\
        &= n (\frac{m}{n+m} \nu_m - \frac{m}{n+m} \mu_n)^2
        + m (\frac{-n}{n+m} \nu_m + \frac{n}{n+m} \mu_n)^2
        + m T^2_m
        \\
        &= \frac{m n}{n+m} (\nu_m - \mu_n)^2 + m T^2_m
        \,. \\
    \mu_{n+m} - \mu_n
        &= \frac{m}{n+m} (\nu_m - \mu_n)
        = \frac{m}{n+m} \delta
        \\
    S^2_{n+m} - S^2_n
        &= \frac{m n}{(n+m)^2} (\nu_m - \mu_n)^2 + \frac{m}{n+m} (T^2_m - S^2_n)
        \\
        &= \frac{m}{n+m} \Bigl(
            \Bigl(1 - \frac{m}{n+m}\Bigr) \delta^2 + (T^2_m - S^2_n)
        \Bigr)
        \\
\end{align}
$$

In [None]:
from torch_scatter import scatter_log_softmax, scatter_logsumexp
from torch_scatter import scatter_max, scatter_mean


class PreNormLayer(Module):
    """Not quite batch norm"""

    def __init__(
        self,
        dim: int = 0,
        location: bool = False,
        scale: bool = False,
    ) -> None:
        super().__init__()
        self.dim, self.location, self.scale = dim, location, scale
        self.register_buffer("n_samples", torch.tensor(0, dtype=int))

        self.register_buffer("weight", torch.ones(0))
        self.register_buffer("bias", torch.zeros(0))

    @torch.inference_mode()
    def _update(self, input: Tensor) -> Tensor:
        self.n_samples += len(input)

        # online average
        nu = input.mean(self.dim)
        if self.bias.numel() < 1:
            self.bias.resize_as_(nu).zero_()

        ratio = len(input) / int(self.n_samples)
        delta = nu - self.bias
        self.bias.add_(delta, alpha=ratio)

        # biased online variance
        T2 = input.var(self.dim, unbiased=False)
        if self.weight.numel() < 1:
            self.weight.resize_as_(T2).zero_()

        sigma = (1 - ratio) * delta * delta + (T2 - self.weight)
        self.weight.add_(sigma, alpha=ratio)

    def forward(self, input: Tensor) -> Tensor:
        if self.location:
            input = input.sub(self.bias)

        if self.scale:
            input = input.div(self.weight.sqrt())

        return input

In [None]:
class BaseVQHelper:
    """A context object, which tracks the VQ layers in the module."""

    def __init__(
        self,
        module: nn.Module,
        cls: type = PreNormLayer,
    ) -> None:
        # enumerate the layers and attach our forward hook to them
        self._hooks = {}
        self._names = {}
        self._register(module, cls)
        self._enabled = False

    def __enter__(self) -> object:
        self._enabled = True
        return self

    def __exit__(self, exc_type, exc_value, traceback) -> None:
        self._enabled = False

    def _register(self, module: nn.Module, cls: type) -> None:
        for nom, mod in module.named_modules():
            if isinstance(mod, cls) and mod not in self._hooks:
                self._hooks[mod] = mod.register_forward_pre_hook(self._hook)
                self._names[mod] = nom

    def _hook(
        self,
        module: Module,
        inputs: tuple[Tensor, ...],
        output: Tensor,
    ) -> None:
        # don't do anything OUTSIDE the with-scope
        if self._enabled:
            if module not in self._hooks:
                raise RuntimeError

            return self.on_forward(module, inputs, output)

In [None]:
pnl = PreNormLayer(0, True, True)
input = torch.randn(1024, 8) + 10

for x in input.split(8, 0):
    pnl._update(x)

pnl(x)

input.mean(0) - pnl.bias, input.var(0, False) - pnl.weight

<br>