# Subtrees without optimality bounds fathoming

In B&B, can fathom sub-trees in 1 of 4 ways:

1. **Optimality bounds**: Node's dual worse than global incumbent's primal -> no feasible solution better than incumbent can exist within the sub-tree 

2. **Infeasibility**: No dual solution which meets MILP's non-integrality constraints exists 

3. **Integrality**: Dual bound solution is feasible w.r.t. integrality constraints -> no better feasible solution can exist within sub-tree 

4. **Completion**: global_primal_dual_gap = 0 -> solution is optimal 
    
Optimality bounds fathoming is dependent on the global incumbent. Therefore, we do not want a local sub-tree episode to be terminated due to activity which went on somewhere else in the tree, since we don't want our agent to concern itself with this.

Can we work this into our reward and/or sub-tree episode construction procedure?

In [None]:
%load_ext autoreload
%autoreload
from retro_branching.environments import EcoleBranching, EcoleConfiguring
from retro_branching.agents import StrongBranchingAgent, PseudocostBranchingAgent, RandomAgent
from retro_branching.utils import seed_stochastic_modules_globally
# from retro_branching.rewards import NormalisedLPGain

import ecole
import numpy as np
import random
import copy
import pyscipopt

import networkx as nx
from networkx.algorithms.shortest_paths.generic import shortest_path
from networkx.algorithms.traversal.depth_first_search import dfs_tree
from networkx.drawing.nx_pydot import graphviz_layout
import matplotlib.pyplot as plt
from ordered_set import OrderedSet
import math

## Search Tree

In [None]:
%autoreload

class SearchTree:
    '''
    Tracks SCIP search tree. Call SearchTree.update_tree(ecole.Model) each
    time the ecole environments (and therefore the ecole.Model) is updated.

    N.B. SCIP does not store nodes which were pruned, infeasible, outside
    the search tree's optimality bounds, or which node was optimal, therefore these nodes will not be
    stored in the SearchTree. This is why m.getNTotalNodes() (the total number
    of nodes processed by SCIP) will likely be more than the number of nodes in
    the search tree when an instance is solved.
    '''
    def __init__(self, model):        
        self.tree = nx.DiGraph()
        
        self.tree.graph['root_node'] = None
        self.tree.graph['visited_nodes'] = []
        self.tree.graph['visited_node_ids'] = OrderedSet()
        
        self.update_tree(model)
    
    def update_tree(self, model):
        '''
        Call this method after each update to the ecole environments. Pass
        the updated ecole.Model, and the B&B tree tracker will be updated accordingly.
        '''
        m = model.as_pyscipopt()
        
        _curr_node = m.getCurrentNode()
        if _curr_node is not None:
            curr_node_id = _curr_node.getNumber()
        else:
            # branching finished, no curr node
            curr_node_id = None
        self.curr_node = {curr_node_id: _curr_node}
        if curr_node_id is not None:
            if curr_node_id not in self.tree.graph['visited_node_ids']:
                self._add_nodes(self.curr_node)
                self.tree.graph['visited_nodes'].append(self.curr_node)
                self.tree.graph['visited_node_ids'].add(curr_node_id)
        
        if curr_node_id is not None:
            _parent_node = list(self.curr_node.values())[0].getParent()
            if _parent_node is not None:
                parent_node_id = _parent_node.getNumber()
            else:
                # curr node is root node
                parent_node_id = None
            self.parent_node = {parent_node_id: _parent_node}
        else:
            self.parent_node = {None: None}
            
        open_leaves, open_children, open_siblings = m.getOpenNodes()
        self.open_leaves = {node.getNumber(): node  for node in open_leaves}
        self.open_children = {node.getNumber(): node for node in open_children}
        self.open_siblings = {node.getNumber(): node for node in open_siblings}
        
        self._add_nodes(self.open_leaves)
        self._add_nodes(self.open_children)
        self._add_nodes(self.open_siblings)
                
    def _add_nodes(self, nodes, parent_node_id=None):
        '''Adds nodes if not already in tree.'''
        for node_id, node in nodes.items():
            if node_id not in self.tree:
                # add node
                self.tree.add_node(node_id,
                                   _id=node_id,
                                   lower_bound=node.getLowerbound())

                # add edge
                _parent_node = node.getParent()
                if _parent_node is not None:
                    if parent_node_id is None:
                        parent_node_id = _parent_node.getNumber()
                    else:
                        # parent node id already given
                        pass
                    self.tree.add_edge(parent_node_id,
                                       node_id)
                else:
                    # is root node, has no parent
                    self.tree.graph['root_node'] = {node_id: node}
                
    def render(self):
        '''Renders B&B search tree.'''
        fig = plt.figure()
        
        pos = graphviz_layout(self.tree, prog='dot')
        node_labels = {node: node for node in self.tree.nodes}
        nx.draw_networkx_nodes(self.tree,
                               pos,
                               label=node_labels)
        nx.draw_networkx_edges(self.tree,
                               pos)
        
        nx.draw_networkx_labels(self.tree, pos, labels=node_labels)
        
        plt.show()


## Normalised LP Gain Reward

In [None]:
%autoreload

class NormalisedLPGain:
    def __init__(self, 
                 normaliser='init_primal_bound', 
                 transform_with_log=False,
                 epsilon=None):
        '''
        Args:
            normaliser ('init_primal_bound', 'curr_primal_bound'): What to normalise
                with respect to in the numerator and denominator to calculate
                the per-step normalsed LP gain reward.
            transform_with_log: If True, will transform the reward by doing
                reward = sign(reward) * log(1 + |reward|) -> helps reduce variance
                of reward as in https://arxiv.org/pdf/1704.03732.pdf
            epsilon (None, float): If not None, will set score of nodes which
                were pruned, infeasible, or outside of bounds to epsilon (rather than
                0 if added or rather than not considering at all if never added to tree).
                N.B. epsilon should be a small number e.g. 1e-6. N.B.2. Current
                implementation assumes each branching decision results in 2 child
                nodes (regardless of whether or not SCIP stores these in memory).
        '''
        self.normaliser = normaliser
        self.transform_with_log = transform_with_log
        self.epsilon = epsilon

    def before_reset(self, model):
        self.prev_node = None
        self.prev_node_id = None
        self.prev_primal_bound = None
        self.init_primal_bound = None

    def extract(self, model, done):
        m = model.as_pyscipopt()

        if self.prev_node_id is None:
            # not yet started, update prev node for next step
            self.prev_node = m.getCurrentNode()
            self.tree = SearchTree(model)
            if self.prev_node is not None:
                self.prev_node_id = copy.deepcopy(self.prev_node.getNumber())
                self.prev_primal_bound = m.getPrimalbound()
                self.init_primal_bound = m.getPrimalbound()
            return 0

        # update search tree with current model state
        self.tree.update_tree(model)
        
        # collect node stats from children introduced from previous branching decision
        prev_node_lb = self.tree.tree.nodes[self.prev_node_id]['lower_bound']
        prev_node_child_ids = [child for child in self.tree.tree.successors(self.prev_node_id)]
        prev_node_child_lbs = [self.tree.tree.nodes[child]['lower_bound'] for child in prev_node_child_ids]

        # calc reward for previous branching decision
        if len(prev_node_child_lbs) > 0:
            # use child lp gains to retrospectively calculate a score for the previous branching decision
            score = -1
            for child_node_lb in prev_node_child_lbs:
                if self.normaliser == 'curr_primal_bound':
                    # use primal bound of step branching action was taken
                    score *= (self.prev_primal_bound - child_node_lb) / (self.prev_primal_bound - prev_node_lb)
                elif self.normaliser == 'init_primal_bound':
                    # use init primal bound
                    score *= (self.init_primal_bound - child_node_lb) / (self.init_primal_bound - prev_node_lb)
                else:
                    raise Exception(f'Unrecognised normaliser {self.normaliser}')
            if self.epsilon is not None:
                # consider child nodes which were never added to search tree by SCIP
                for _ in range(int(2-len(prev_node_child_lbs))):
                    score *= self.epsilon
        else:
            # previous branching decision led to all child nodes being pruned, infeasible, or outside bounds -> don't punish brancher
            if self.epsilon is not None:
                score = -1 * (self.epsilon**2)
            else:
                score = 0
        self.tree.tree.nodes[self.prev_node_id]['score'] = score

        if m.getCurrentNode() is not None:
            # update stats for next step
            self.prev_node = m.getCurrentNode()
            self.prev_node_id = copy.deepcopy(self.prev_node.getNumber())
            self.prev_primal_bound = m.getPrimalbound()
        else:
            # instance completed, no current focus node
            pass

        if self.transform_with_log:
            sign = math.copysign(1, score)
            score = sign * math.log(1 + abs(score), 10)

        if score < -1:
            print('Score < -1 found.')
            for node in self.tree.tree.nodes():
                print(f'Node {node} lb: {self.tree.tree.nodes[node]["lower_bound"]}')
            raise Exception()
        
        return score

## Sub-trees episode construction

In [None]:
%autoreload

class RetroBranching:
    def __init__(self, 
                 normaliser='init_primal_bound', 
                 min_subtree_depth=1, 
                 retro_trajectory_construction='deepest',
                 remove_nonoptimal_fathomed_leaves=False,
                 debug_mode=False):
        '''
        Waits until end of episode to calculate rewards for each step, then retrospectively
        goes back through each step in the episode and calculates reward for that step.
        I.e. reward returned will be None until the end of the episode, at which
        point a dict mapping episode_step_idx for optimal path nodes to reward will be returned.

        The terminal sub-tree will first retrospectively construct an episode from the root node
        to the opimal node (i.e. the 'optimal path') and make this as one episode. Then, it will
        iteratively go through all other nodes in the B&B tree not already included in a retrospective
        sub-tree episode path and construct sub-trees randomly untill all nodes experiences by the agent
        are included in an episode. Will then return a list, where each element in the list is a dict
        mapping the step index in the original episode and the corresponding reward received by the agent.
        
        Args:
            normaliser ('init_primal_bound', 'curr_primal_bound'): What to normalise
                with respect to in the numerator and denominator to calculate
                the per-step normalsed LP gain reward.
            min_subtree_depth (int): Minimum depth of sub-tree (i.e. minimum length of sub-tree episode).
            retro_trajectory_construction ('random', 'deepest'): Which policy to use when choosing a leaf node as the
                final node to construct a sub-tree.
            remove_nonoptimal_fathomed_leaves (bool): If True, at end of episode, will remove
                all leaves in tree which were fathomed (had score == 0) except optimal path leaf 
                so that no non-optimal sub-tree will have been fathomed and contain a 
                node/experience with score == 0.
        '''
        self.min_subtree_depth = min_subtree_depth
        self.retro_trajectory_construction = retro_trajectory_construction
        self.normalised_lp_gain = NormalisedLPGain(normaliser=normaliser) # normalised lp gain reward tracker
        self.remove_nonoptimal_fathomed_leaves = remove_nonoptimal_fathomed_leaves
        self.debug_mode = debug_mode

    def before_reset(self, model):
        self.started = False
        self.normalised_lp_gain.before_reset(model)
        
    def get_path_node_scores(self, tree, path):
        return [tree.nodes[node]['score'] for node in path]
        
    def conv_root_final_pair_to_step_idx_reward_map(self, root_node, final_node, check_depth=True):
        path = shortest_path(self.normalised_lp_gain.tree.tree, source=root_node, target=final_node)
        
        # register which nodes have been directly included in the sub-tree
        for node in path:
            self.nodes_added.add(node)
            
        if check_depth:
            if len(path) < self.min_subtree_depth:
                # subtree not deep enough, do not use episode (but count all nodes as having been added)
                return None
        
        # get rewards at each step in sub-tree episode
        path_node_rewards = self.get_path_node_scores(self.normalised_lp_gain.tree.tree, path)

        # get episode step indices at which each node in sub-tree was visited
        path_to_step_idx = {node: self.visited_nodes_to_step_idx[node] for node in path}

        # map each path node episode step idx to its corresponding reward
        step_idx_to_reward = {step_idx: r for step_idx, r in zip(list(path_to_step_idx.values()), path_node_rewards)}
        
        return step_idx_to_reward

    def extract(self, model, done):
        # update normalised LP gain tracker
        _ = self.normalised_lp_gain.extract(model, done)

        # m = model.as_pyscipopt()
        # curr_node = m.getCurrentNode()

        # if not self.started:
            # if curr_node is not None:
                # self.started = True
            # return None
        
        # if curr_node is not None:
            # # instance not yet finished
            # return None

        if not done:
            return None
        else:
            if self.normalised_lp_gain.tree.tree.graph['root_node'] is None:
                # instance was pre-solved
                return [{0: 0}]

            # instance finished, retrospectively create subtree episode paths
            subtrees_step_idx_to_reward = []

            # keep track of which nodes have been added to a sub-tree
            self.nodes_added = set()
            
            if self.debug_mode:
                print('\nB&B tree:')
                print(f'All nodes saved: {self.normalised_lp_gain.tree.tree.nodes()}')
                print(f'Visited nodes: {self.normalised_lp_gain.tree.tree.graph["visited_node_ids"]}')
                self.normalised_lp_gain.tree.render()

            # remove nodes which were never visited by the brancher and therefore do not have a score or next state
            nodes = [node for node in self.normalised_lp_gain.tree.tree.nodes]
            for node in nodes:
                if 'score' not in self.normalised_lp_gain.tree.tree.nodes[node]:
                    # node never visited by brancher -> do not consider
                    self.normalised_lp_gain.tree.tree.remove_node(node)
                    if node in self.normalised_lp_gain.tree.tree.graph['visited_node_ids']:
                        # hack: SCIP sometimes returns large int rather than None node_id when episode finished
                        # since never visited this node (since no score assigned), do not count this node as having been visited when calculating paths below
                        self.normalised_lp_gain.tree.tree.graph['visited_node_ids'].remove(node)

            # map which nodes were visited at which step in episode
            # visited_nodes = [list(node.keys())[0] for node in self.normalised_lp_gain.tree.tree.graph['visited_nodes']]
            self.visited_nodes_to_step_idx = {node: idx for idx, node in enumerate(self.normalised_lp_gain.tree.tree.graph['visited_node_ids'])}

            # get optimal path
            root_node = list(self.normalised_lp_gain.tree.tree.graph['root_node'].keys())[0]
            # final_node = list(self.normalised_lp_gain.tree.tree.graph['visited_nodes'][-1].keys())[0]
            final_node = self.normalised_lp_gain.tree.tree.graph['visited_node_ids'][-1]
            subtrees_step_idx_to_reward.append(self.conv_root_final_pair_to_step_idx_reward_map(root_node, final_node, check_depth=False))

            if self.remove_nonoptimal_fathomed_leaves:
                for node in nodes:
                    if node in self.normalised_lp_gain.tree.tree.nodes.keys():
                        if self.normalised_lp_gain.tree.tree.nodes[node]['score'] == 0 and node != final_node:
                            # node fathomed and not in optimal path, remove
                            self.normalised_lp_gain.tree.tree.remove_node(node)
                            if self.debug_mode:
                                print(f'Removed non-optimal fathomed leaf node {node}')
            
            # create sub-tree episodes from remaining B&B nodes visited by agent
            while True:
                # create depth first search sub-trees from nodes still leftover
                nx_subtrees = []
                
                # construct sub-trees containing prospective sub-tree episode(s) from remaining nodes
                for node in self.nodes_added:
                    children = [child for child in self.normalised_lp_gain.tree.tree.successors(node)]
                    for child in children:
                        if child not in self.nodes_added:
                            nx_subtrees.append(dfs_tree(self.normalised_lp_gain.tree.tree, child))
                            
                for i, subtree in enumerate(nx_subtrees):
                    # init node scores for nodes in subtree (since these are not transferred into new subtree)
                    for node in subtree.nodes:
                        subtree.nodes[node]['score'] = self.normalised_lp_gain.tree.tree.nodes[node]['score']

                    # get root of sub-tree
                    for root_node in subtree.nodes:
                        if subtree.in_degree(root_node) == 0:
                            # node is root
                            break

                    # get a path by choosing a leaf node as the final node in the path
                    leaf_nodes = [node for node in subtree.nodes() if subtree.out_degree(node) == 0]
                    if self.retro_trajectory_construction == 'random':
                        # randomly choose leaf node as final node
                        final_node = leaf_nodes[random.choice(range(len(leaf_nodes)))]
                    elif self.retro_trajectory_construction == 'deepest':
                        # choose leaf node which would lead to deepest subtree as final node
                        depths = [len(shortest_path(subtree, source=root_node, target=leaf_node)) for leaf_node in leaf_nodes]
                        final_node = leaf_nodes[depths.index(max(depths))]
                    else:
                        raise Exception(f'Unrecognised retro_trajectory_construction {self.retro_trajectory_construction}')
                        
                    subtree_step_idx_to_reward = self.conv_root_final_pair_to_step_idx_reward_map(root_node, final_node, check_depth=True)
                    if subtree_step_idx_to_reward is not None:
                        subtrees_step_idx_to_reward.append(subtree_step_idx_to_reward)
                    else:
                        # subtree was not deep enough to be added
                        pass

                if len(nx_subtrees) == 0:
                    # all sub-trees added
                    break
                    
            if self.debug_mode:
                print(f'visited_nodes_to_step_idx: {self.visited_nodes_to_step_idx}')
                step_idx_to_visited_nodes = {val: key for key, val in self.visited_nodes_to_step_idx.items()}
                for i, ep in enumerate(subtrees_step_idx_to_reward):
                    print(f'>>> sub-tree episode {i+1}: {ep}')
                    ep_path = [step_idx_to_visited_nodes[idx] for idx in ep.keys()]
                    print(f'path: {ep_path}')
            
            return subtrees_step_idx_to_reward


## Init

In [None]:
%autoreload

seed = 0
seed_stochastic_modules_globally(default_seed=seed)

agent = PseudocostBranchingAgent()

env = EcoleBranching(observation_function='default',
                      information_function='default',
                      reward_function='default',
                      scip_params='default')
env.seed(seed)

instances = ecole.instance.SetCoverGenerator(n_rows=300, n_cols=300, density=0.05)

# Run

In [None]:
obs = None
custom_reward = RetroBranching(normaliser='init_primal_bound', 
                                         min_subtree_depth=1, 
                                         retro_trajectory_construction='deepest',
                                         remove_nonoptimal_fathomed_leaves=True,
                                         debug_mode=True)
while obs is None:
    env.seed(seed)
    instance = next(instances)
    custom_reward.before_reset(instance)
    agent.before_reset(instance)
    obs, action_set, reward, done, info = env.reset(instance)
    _custom_reward = custom_reward.extract(env.model, done)
    
t = 1
instance_transitions = []
prev_obs = copy.deepcopy(obs)
tree = custom_reward.normalised_lp_gain.tree
while not done:
    # select branching action
    action, action_idx = agent.action_select(action_set, model=env.model, done=done)
    obs, action_set, reward, done, info = env.step(action)
    _custom_reward = custom_reward.extract(env.model, done)
    
    if done:
        obs = copy.deepcopy(prev_obs)
    
    # store transition
    instance_transitions.append({'obs': prev_obs,
                               'action': action,
#                                'reward': reward['normalised_lp_gain'],
                               'reward': reward['num_nodes'],
                               'done': done,
                               'next_obs': obs})
    
    # update prev obs
    prev_obs = copy.deepcopy(obs)
    
    m = env.model.as_pyscipopt()
    print(f'Step {t} | Reward: {reward["num_nodes"]:.3f} | primal bound: {m.getPrimalbound()} | dual bound: {m.getDualbound()}')
    print(f'Custom reward: {_custom_reward}')
    
    # update search tree and analyse branching action
    tree.update_tree(env.model)
#     tree.render()
    
    print('')
    
    t += 1
    
m = env.model.as_pyscipopt()
print(f'\nFinished | primal bound: {m.getPrimalbound()} | dual bound: {m.getDualbound()} | # nodes: {m.getNTotalNodes()} | Final node: {tree.tree.graph["visited_nodes"][-1]}')
print(f'Custom reward: {_custom_reward}')
tree.render()