Strong branching works by (for a minimisation problem) calculating the reduction in the dual bound gained by branching on each variable, and then branching on the variable with the largest reduction in the dual bound. This is equivalent to a '1-step lookahead' in the dual bound reduction for each candidate variable.

If we were too lookahead n-steps, as n tends towards the number of steps needed to converge on an optimal solution (where the primal-dual gap is 0), you will converge on the globally optimum solution. However, strong branching is expensive, and even n=1 (1-step lookahead) SB does not scale.

An open question might be; how much better can we be than 1-step strong branching? We can test this by implementing n-step strong branching, and seeing if e.g. n=2 gives better results (in terms of e.g. number of nodes in search tree) than for n=1. If it makes little difference for a given problem, this means that for this particular problem, n=1 strong branching is close to the optimal branching policy. In this case, we wouldn't expect our agent to even beat SB, only tend towards imitating it. If n!=1 does result in improvement, we would expect our RL agent to be able to work out how to beat strong branching, or at least find actions which are different to SB.

In this notebook, we will implement n-step strong branching and apply it to small 100x100 set cover instances to try to answer the question; how much scope is there to improve beyond 1-step strong branching for these instance sizes?

To implement n-step strong branching, at each step in the episode, we will:

1. Save the dual bound value at the current step
2. For i in range(n): do strong branching at each step. 
3. After the n-th step, calculate the dual bound reduction from branching at each variable relative to the initial dual bound saved at step 1. The variable which resulted in the largest dual bound decrease 



In [None]:
import retro_branching
from retro_branching.environments import EcoleBranching

import ecole

from collections import defaultdict
import copy
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
class StrongBranchingAgent:
    def __init__(self, name='sb'):
        self.name = name
        self.strong_branching_function = ecole.observation.StrongBranchingScores()

    def before_reset(self, model):
        """
        This function will be called at initialization of the environment (before dynamics are reset).
        """
        self.strong_branching_function.before_reset(model)
    
    def extract(self, model, done, **kwargs):
        return self.strong_branching_function.extract(model, done)

    def action_select(self, action_set, model, done):
        scores = self.extract(model, done)[action_set]
        action = scores.argmax()
        
        # DEBUG
        print('sb action set: {}'.format(action_set))
        print('sb scores: {}'.format({action: score for action, score in zip(action_set, scores)}))
        print('sb action: idx={} (action={})'.format(action, action_set[action]))
        
        return action
    
class CustomStrongBranchingAgent:
    def __init__(self, name='sb'):
        self.name = name

    def before_reset(self, model):
        """
        This function will be called at initialization of the environment (before dynamics are reset).
        """
        pass
    
    def extract(self, model, done, action_set, action_history, instance_before_reset):
        action_to_score = {var: 0 for var in range(model.as_pyscipopt().getNVars())}
        for action in action_set:
            env = EcoleBranching(observation_function='default',
                                  information_function='default',
                                  reward_function='default',
                                  scip_params='default')
            env.seed(0)
            _ = env.reset(instance_before_reset.copy_orig())
            
            # rollout to current state
            for a in action_history:
                _ = env.step(a)
            
            m = env.model.as_pyscipopt()
            init_dual_bound = m.getDualbound()
            env.step(action)
            m = env.model.as_pyscipopt()
            final_dual_bound = m.getDualbound()
            action_to_score[action] = abs(init_dual_bound-final_dual_bound)
        return list(action_to_score.values())

    def action_select(self, action_set, model, done):
        scores = self.extract(model, done, action_set)[action_set]
        action = scores.argmax()
        return action
    
class TwoStepStrongBranchingAgent:
    def __init__(self, name='2_step_sb'):
        '''Have not yet implemented N>2 since have many trajectories to track. For now use N=2.'''
        self.N = 2
        self.name = name
    
    def before_reset(self, model):
        # initialises and resets sb functions when self.extract() called so no need to do anything here
        pass
        
    def extract(self, instance_before_reset, env, _action_history, action_set, done):
        action_history = copy.deepcopy(_action_history)
            
        # init tracking of step -> action -> scores trajectories
        step_scores = {step: defaultdict(lambda: 0) for step in range(1, self.N+1)}
            
        # init env for each candidate branching variable
        obs_func, info_func, reward_func, scip_params = env.str_observation_function, env.str_information_function, env.str_reward_function, env.str_scip_params
        envs = {action: EcoleBranching(observation_function=obs_func,
                                        information_function=info_func,
                                        reward_function=reward_func,
                                        scip_params=scip_params)
               for action in action_set}
        
        # init env to <step_return_params> maps
        env_to_action_set = {e: None for e in envs.keys()}
        env_to_done = {e: None for e in envs.keys()}
        env_to_action_history = {e: copy.deepcopy(action_history) for e in envs.keys()}

        # init corresponding sb agents to use for each env
        sb_agents = {e: CustomStrongBranchingAgent() # CustomStrongBranchingAgent() StrongBranchingAgent()
               for e in envs.keys()}

        # reset envs and sb agents and rollout to current state
        for key in envs.keys():
            envs[key].seed(0)
            sb_agents[key].before_reset(instance_before_reset.copy_orig())
            _, env_to_action_set[key], _, env_to_done[key], _ = envs[key].reset(instance_before_reset.copy_orig())
            for a in action_history:
                _, env_to_action_set[key], _, env_to_done[key], _ = envs[key].step(a)

        # get 1-step sb scores for each action, and branch on this action
        scores = sb_agents[key].extract(model=envs[key].model, done=done, action_set=env_to_action_set[key], action_history=env_to_action_history[key], instance_before_reset=instance_before_reset.copy_orig())
        scores = np.nan_to_num(scores)[env_to_action_set[key]]
        keep_running = True
        for idx, action in enumerate(env_to_action_set[key]):
            step_scores[1][action] += scores[idx]
            _, env_to_action_set[action], _, env_to_done[action], _ = envs[action].step(action)
            env_to_action_history[action].append(action)
            if env_to_done[action]:
                # this was first agent to solve problem, no need to do 2-step lookahead after
                keep_running = False
        
        # get range(2, N+1)-step scores for each action
        for n in range(2, self.N+1):
            if keep_running:
                for key in envs.keys():

                    if not env_to_done[key]:
                        # get sb scores
                        s = sb_agents[key].extract(model=envs[key].model, done=env_to_done[key], action_set=env_to_action_set[key], action_history=env_to_action_history[key], instance_before_reset=instance_before_reset.copy_orig())
                        # set nans to 0
                        s = np.nan_to_num(s)
                        # filter out invalid actions from scores
                        s = s[env_to_action_set[key]]
                        # get sb action index
                        a_idx = np.argmax(s)
                        env_to_action_history[key].append(env_to_action_set[key][a_idx])
                        # store action score
                        step_scores[n][key] += s[a_idx]
                        # step env with SB score
                        _, env_to_action_set[key], _, env_to_done[key], _ = envs[key].step(env_to_action_set[key][a_idx])
                        if env_to_done[key]:
                            # this was first agent to solve problem, no need to keep running other agents
                            keep_running = False
                            break
        
        # calc total dual bound reduction for each action's trajectory
        action_to_score = {action: 0 for action in action_set}
        for action in action_set:
            for n in step_scores.keys():
                if action in step_scores[n]:
                    action_to_score[action] += step_scores[n][action]

        return action_to_score, step_scores
        
    def action_select(self, instance_before_reset, env, action_history, action_set, done):
        scores, step_scores = self.extract(instance_before_reset=instance_before_reset, env=env, _action_history=action_history, action_set=action_set, done=done)
        action = np.where(max(scores, key=scores.get)==action_set)
        
        # DEBUG
        print('2-step sb action set: {}'.format(action_set))
        print('2-step sb step -> action -> scores: {}'.format(step_scores))
        print('2-step sb action: idx={} (action={})'.format(action, max(scores, key=scores.get)))
        
        return action
            
        
        
    

In [None]:
# agents
one_agent = StrongBranchingAgent()
two_agent = TwoStepStrongBranchingAgent()

# envs
one_env = EcoleBranching(observation_function='default',
                          information_function='default',
                          reward_function='default',
                          scip_params='default')
one_env.seed(0)
two_env = EcoleBranching(observation_function='default',
                          information_function='default',
                          reward_function='default',
                          scip_params='default')
two_env.seed(0)

# instances
instances = ecole.instance.SetCoverGenerator(n_rows=100, n_cols=100, density=0.05)

In [None]:
num_episodes = 100
metrics = ['num_nodes', 'solving_time', 'lp_iterations']

plot_dict = {'1_step_sb': {metric: [] for metric in metrics},
             '2_step_sb': {metric: [] for metric in metrics}}
for ep in range(num_episodes):
    print('\n\n>>> Episode {} <<<'.format(ep))
    
    # find an instance not pre-solved by environment
    one_obs = None
    while one_obs is None:
        one_env.seed(0)
        instance = next(instances)
        instance_before_reset = instance.copy_orig()
        one_agent.before_reset(instance_before_reset.copy_orig())
        one_obs, one_action_set, one_reward, one_done, one_info = one_env.reset(instance)
    two_env.seed(0)
    two_obs, two_action_set, two_reward, two_done, two_info = two_env.reset(instance_before_reset.copy_orig())
    
    # 1-step SB agent
    # DEBUG
    m = one_env.model.as_pyscipopt()
    print('\ninit 1-step sb dual/primal/gap: {}/{}/{}'.format(m.getDualbound(), m.getPrimalbound(), m.getGap()))
    t = 1
    while not one_done:
        print(f'> t={t}')
        prev_dual = m.getDualbound()
        one_action = one_agent.action_select(one_action_set, one_env.model, one_done)
        one_action = one_action_set[one_action]
        one_obs, one_action_set, one_reward, one_done, one_info = one_env.step(one_action)
        # DEBUG
        m = one_env.model.as_pyscipopt()
        curr_dual = m.getDualbound()
        print('1-step sb dual/primal/gap: {}/{}/{} (dual change = {})'.format(m.getDualbound(), m.getPrimalbound(), m.getGap(), curr_dual-prev_dual))
        t += 1
    for metric in metrics:
        plot_dict['1_step_sb'][metric].append(one_info[metric])
    print('>> 1-step SB num nodes: {}'.format(one_info['num_nodes']))
        
    # 2-step SB agent
    # DEBUG
    m = two_env.model.as_pyscipopt()
    print('\ninit 2-step sb dual/primal/gap: {}/{}/{}'.format(m.getDualbound(), m.getPrimalbound(), m.getGap()))
    action_history = [] # store history of 2-step sb actions taken so can rollout envs to current state
    t = 1
    while not two_done:
        print(f'> t={t}')
        prev_dual = m.getDualbound()
        two_action = two_agent.action_select(instance_before_reset=instance_before_reset.copy_orig(), env=two_env, action_history=action_history, action_set=two_action_set, done=two_done)
        two_action = two_action_set[two_action]
        action_history.append(two_action)
        two_obs, two_action_set, two_reward, two_done, two_info = two_env.step(two_action)
        # DEBUG
        m = two_env.model.as_pyscipopt()
        curr_dual = m.getDualbound()
        print('2-step sb dual/primal/gap: {}/{}/{} (dual change = {})'.format(m.getDualbound(), m.getPrimalbound(), m.getGap(), curr_dual-prev_dual))
        t += 1
    for metric in metrics:
        plot_dict['2_step_sb'][metric].append(two_info[metric])
    print('>> 2-step SB num nodes: {}'.format(two_info['num_nodes']))

In [None]:
for agent in plot_dict.keys():
    for metric in plot_dict[agent].keys():
        fig = plt.figure()
        mean, std = np.mean(plot_dict[agent][metric]), np.std(plot_dict[agent][metric])
        title = f'{agent} {metric} (mean={round(mean, 3)}, std={round(std, 3)})'
        sns.histplot(plot_dict[agent][metric], edgecolor='k')
        plt.title(title)
        plt.xlabel(metric)
        plt.show()

# 2 key observations

From above code, there are 2 key observations:

1. SCIP strong branching scores are not the same as the total change in the dual bound - e.g. variables which have a 0 change in dual bound might still have highest strong branching score. As such, suspect that there is something under the hood which is doing more than just a 1-step lookahead to predict which variables will be better long term.

2. On an instance-by-instance basis, due to SCIP backend controlling node selection, pruning, etc., 2-step SB can end up with more nodes than 1-step SB

## Observation 1: SCIP SB scores != change in dual bound

## Observation 2: 1-step SB can have fewer nodes than 2-step SB