In [None]:
import ecole

import torch
import torch.nn.functional as F
import torch_geometric

import networkx as nx
# from networkx.drawing.nx_pydot import graphviz_layout
import numpy as np
import ml_collections
import copy
import json
import time

## Backend `retro_branching` stuff

Putting minimal `retro_branching` implementation here so can reproduce in this notebook without needing to setup whole of `retro_branching`.

In [None]:

######################### ENV ##################################
default_scip_params = {'separating/maxrounds': 0,
                       'separating/maxroundsroot': 0,
                       'separating/maxcuts': 0,
                       'separating/maxcutsroot': 0,
                       'presolving/maxrounds': 0,
                       'presolving/maxrestarts': 0,
                       'propagating/maxrounds':0,
                       'propagating/maxroundsroot':0,
                       'lp/initalgorithm':'d',
                       'lp/resolvealgorithm':'d',
                       'limits/time': 3600}

class EcoleBranching(ecole.environment.Branching):
    def __init__(
        self,
        observation_function='default',
        information_function='default',
        reward_function='default',
        scip_params='default',
        pseudo_candidates=False,
    ):
        # save string names so easy to initialise new environments
        if type(observation_function) == str:
            self.str_observation_function = observation_function
        else:
            self.str_observation_function = None
        if type(information_function) == str:
            self.str_information_function = information_function
        else:
            self.str_information_function = None
        if type(reward_function) == str:
            self.str_reward_function = reward_function
        else:
            self.str_reward_function = None
        if type(scip_params) == str:
            self.str_scip_params = scip_params
        else:
            self.str_scip_params = None

        self.pseudo_candidates = pseudo_candidates

        # init functions from strings if needed
        if reward_function == 'default':
            reward_function = ({
                     'num_nodes': -ecole.reward.NNodes(),
                     'lp_iterations': -ecole.reward.LpIterations(),
                     'primal_integral': -ecole.reward.PrimalIntegral(),
                     'dual_integral': ecole.reward.DualIntegral(),
                     'primal_dual_integral': -ecole.reward.PrimalDualIntegral(),
                     'solving_time': -ecole.reward.SolvingTime(),
                     'normalised_lp_gain': NormalisedLPGain(use_prev_primal_bound=True)
                 })
        if information_function == 'default':
            information_function=({
                     'num_nodes': ecole.reward.NNodes().cumsum(),
                     'lp_iterations': ecole.reward.LpIterations().cumsum(),
                     'solving_time': ecole.reward.SolvingTime().cumsum(),
                 })
        if observation_function == 'default':    
            observation_function = (ecole.observation.NodeBipartite())
        elif observation_function == '43_var_features':
            observation_function = (NodeBipariteWith43VariableFeatures())
        if scip_params == 'default':
            scip_params = default_scip_params
        
        super(EcoleBranching, self).__init__(
            observation_function=observation_function,
            information_function=information_function,
            reward_function=reward_function,
            scip_params=scip_params,
            pseudo_candidates=pseudo_candidates,
        )
        
class NodeBipariteWith43VariableFeatures(ecole.observation.NodeBipartite):
    '''
    Adds (mostly global) features to variable node features.

    Adds 24 extra variable features to each variable on top of standard ecole
    NodeBipartite obs variable features (19), so each variable will have
    43 features in total.

    '''
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def before_reset(self, model):
        super().before_reset(model)
        
        self.init_dual_bound = None
        self.init_primal_bound = None
        
        
    def extract(self, model, done):
        # get the NodeBipartite obs
        obs = super().extract(model, done)
        
        m = model.as_pyscipopt()
        
        if self.init_dual_bound is None:
            self.init_dual_bound = m.getDualbound()
            self.init_primal_bound = m.getPrimalbound()
            
        # dual/primal bound features
        dual_bound_frac_change = abs(self.init_dual_bound - m.getDualbound()) / self.init_dual_bound
        primal_bound_frac_change = abs(self.init_primal_bound - m.getPrimalbound()) / self.init_primal_bound

        primal_dual_gap = abs(m.getPrimalbound() - m.getDualbound())
        max_dual_bound_frac_change = primal_dual_gap / self.init_dual_bound
        max_primal_bound_frac_change = primal_dual_gap / self.init_primal_bound

        curr_primal_dual_bound_gap_frac = m.getGap()
        
        # global tree features
        num_leaves_frac = m.getNLeaves() / m.getNNodes()
        num_feasible_leaves_frac = m.getNFeasibleLeaves() / m.getNNodes()
        num_infeasible_leaves_frac = m.getNInfeasibleLeaves() / m.getNNodes()
        num_lp_iterations_frac = m.getNNodes() / m.getNLPIterations()
        
        # focus node features
        num_siblings_frac = m.getNSiblings() / m.getNNodes()
        curr_node = m.getCurrentNode()
        best_node = m.getBestNode()
        if best_node is not None:
            if curr_node.getNumber() == best_node.getNumber():
                is_curr_node_best = 1
            else:
                is_curr_node_best = 0
        else:
            # no best node found yet
            is_curr_node_best = 0
        parent_node = curr_node.getParent()
        if parent_node is not None and best_node is not None:
            if parent_node.getNumber() == best_node.getNumber():
                is_curr_node_parent_best = 1
            else:
                is_curr_node_parent_best = 0
        else:
            # node has no parent node or no best node found yet
            is_curr_node_parent_best = 0
        curr_node_depth = m.getDepth() / m.getNNodes()
        curr_node_lower_bound_relative_to_init_dual_bound = self.init_dual_bound / curr_node.getLowerbound()
        curr_node_lower_bound_relative_to_curr_dual_bound =  m.getDualbound() / curr_node.getLowerbound()
        num_branching_changes, num_constraint_prop_changes, num_prop_changes = curr_node.getNDomchg()
        total_num_changes = num_branching_changes + num_constraint_prop_changes + num_prop_changes
        try:
            branching_changes_frac = num_branching_changes / total_num_changes
        except ZeroDivisionError:
            branching_changes_frac = 0
        try:
            constraint_prop_changes_frac = num_constraint_prop_changes / total_num_changes
        except ZeroDivisionError:
            constraint_prop_changes_frac = 0
        try:
            prop_changes_frac = num_prop_changes / total_num_changes
        except ZeroDivisionError:
            prop_changes_frac = 0
        parent_branching_changes_frac = curr_node.getNParentBranchings() / m.getNNodes()
        best_sibling = m.getBestSibling()
        if best_sibling is None:
            is_best_sibling_none = 1
            is_best_sibling_best_node = 0
        else:
            is_best_sibling_none = 0
            if best_node is not None:
                if best_sibling.getNumber() == best_node.getNumber():
                    is_best_sibling_best_node = 1
                else:
                    is_best_sibling_best_node = 0
            else:
                is_best_sibling_best_node = 0
        if best_sibling is not None:
            best_sibling_lower_bound_relative_to_init_dual_bound = self.init_dual_bound / best_sibling.getLowerbound()
            best_sibling_lower_bound_relative_to_curr_dual_bound = m.getDualbound() / best_sibling.getLowerbound()
            best_sibling_lower_bound_relative_to_curr_node_lower_bound = best_sibling.getLowerbound() / curr_node.getLowerbound()
        else:
            best_sibling_lower_bound_relative_to_init_dual_bound = 0
            best_sibling_lower_bound_relative_to_curr_dual_bound = 0
            best_sibling_lower_bound_relative_to_curr_node_lower_bound = 0
        
        # add feats to each variable
        feats_to_add = np.array([[dual_bound_frac_change,
                                 primal_bound_frac_change,
                                 max_primal_bound_frac_change,
                                 max_dual_bound_frac_change,
                                 curr_primal_dual_bound_gap_frac,
                                 num_leaves_frac,
                                 num_feasible_leaves_frac,
                                 num_infeasible_leaves_frac,
                                 num_lp_iterations_frac,
                                 num_siblings_frac,
                                 is_curr_node_best,
                                 is_curr_node_parent_best,
                                 curr_node_depth,
                                 curr_node_lower_bound_relative_to_init_dual_bound,
                                 curr_node_lower_bound_relative_to_curr_dual_bound,
                                 branching_changes_frac,
                                 constraint_prop_changes_frac,
                                 prop_changes_frac,
                                 parent_branching_changes_frac,
                                 is_best_sibling_none,
                                 is_best_sibling_best_node,
                                 best_sibling_lower_bound_relative_to_init_dual_bound,
                                 best_sibling_lower_bound_relative_to_curr_dual_bound,
                                 best_sibling_lower_bound_relative_to_curr_node_lower_bound] for _ in range(obs.column_features.shape[0])])
        
        obs.column_features = np.column_stack((obs.column_features, feats_to_add))

                
        return obs

        
        
        
        
        
        

        
################################## CUSTOM REWARD ##################################
class NormalisedLPGain:
    def __init__(self, use_prev_primal_bound=True):
        '''
        Args:
            use_prev_primal_bound (bool): If True, will normalise the reward w.r.t.
                the previous step's primal bound (i.e. the primal bound of the
                instance at the step when the agent took the branching action).
                Otherwise, will normalise w.r.t. the primal bound of the new step.
        '''
        self.use_prev_primal_bound = use_prev_primal_bound

    def before_reset(self, model):
        self.prev_node = None
        self.prev_node_id = None
        self.prev_primal_bound = None

    def extract(self, model, done):
        m = model.as_pyscipopt()

        if self.prev_node_id is None:
            # not yet started, update prev node for next step
            self.prev_node = m.getCurrentNode()
            if self.prev_node is not None:
                self.tree = SearchTree(model)
                self.prev_node_id = copy.deepcopy(self.prev_node.getNumber())
                self.prev_primal_bound = m.getPrimalbound()
            return 0

        # update search tree with current model state
        self.tree.update_tree(model)
        
        # collect node stats from children introduced from previous branching decision
        prev_node_lb = self.tree.tree.nodes[self.prev_node_id]['lower_bound']
        prev_node_child_ids = [child for child in self.tree.tree.successors(self.prev_node_id)]
        prev_node_child_lbs = [self.tree.tree.nodes[child]['lower_bound'] for child in prev_node_child_ids]

        # calc reward for previous branching decision
        if len(prev_node_child_lbs) > 0:
            # use child lp gains to retrospectively calculate a score for the previous branching decision
            score = -1
            for child_node_lb in prev_node_child_lbs:
                if self.use_prev_primal_bound:
                    # use primal bound of step branching action was taken
                    score *= (self.prev_primal_bound - child_node_lb) / (self.prev_primal_bound - prev_node_lb)
                else:
                    # use primal bound of new step
                    score *= (m.getPrimalbound() - child_node_lb) / (m.getPrimalbound() - prev_node_lb)
        else:
            # previous branching decision led to all child nodes being pruned, infeasible, or outside bounds -> don't punish brancher
            score = 0

        if m.getCurrentNode() is not None:
            # update stats for next step
            self.prev_node = m.getCurrentNode()
            self.prev_node_id = copy.deepcopy(self.prev_node.getNumber())
            self.prev_primal_bound = m.getPrimalbound()
        else:
            # instance completed, no current focus node
            pass
        
        return score

class SearchTree:
    '''
    Tracks SCIP search tree. Call SearchTree.update_tree(ecole.Model) each
    time the ecole environment (and therefore the ecole.Model) is updated.

    N.B. SCIP does not store nodes which were pruned, infeasible, or outside
    the search tree's optimality bounds, therefore these nodes will not be
    stored in the SearchTree. This is why m.getNTotalNodes() (the total number
    of nodes processed by SCIP) will likely be more than the number of nodes in
    the search tree when an instance is solved.
    '''
    def __init__(self, model):
        self.tree = nx.DiGraph()
        self.update_tree(model)
    
    def update_tree(self, model):
        '''
        Call this method after each update to the ecole environment. Pass
        the updated ecole.Model, and the B&B tree tracker will be updated accordingly.
        '''
        m = model.as_pyscipopt()
        
        _curr_node = m.getCurrentNode()
        if _curr_node is not None:
            curr_node_id = _curr_node.getNumber()
        else:
            # branching finished, no curr node
            curr_node_id = None
        self.curr_node = {curr_node_id: _curr_node}
        if curr_node_id is not None:
            self._add_nodes(self.curr_node)
        
        if curr_node_id is not None:
            _parent_node = list(self.curr_node.values())[0].getParent()
            if _parent_node is not None:
                parent_node_id = _parent_node.getNumber()
            else:
                # curr node is root node
                parent_node_id = None
            self.parent_node = {parent_node_id: _parent_node}
        else:
            self.parent_node = {None: None}
            
        open_leaves, open_children, open_siblings = m.getOpenNodes()
        self.open_leaves = {node.getNumber(): node  for node in open_leaves}
        self.open_children = {node.getNumber(): node for node in open_children}
        self.open_siblings = {node.getNumber(): node for node in open_siblings}
        
        self._add_nodes(self.open_leaves)
        self._add_nodes(self.open_children)
        self._add_nodes(self.open_siblings)
        
    def _add_nodes(self, nodes):
        '''Adds nodes if not already in tree.'''
        for node_id, node in nodes.items():
            if node_id not in self.tree:
                # add node
                self.tree.add_node(node_id,
                                   _id=node_id,
                                   lower_bound=node.getLowerbound())

                # add edge
                _parent_node = node.getParent()
                if _parent_node is not None:
                    parent_node_id = _parent_node.getNumber()
                    self.tree.add_edge(parent_node_id,
                                       node_id)
                else:
                    # is root node, has no parent
                    pass
                
    def render(self):
        '''Renders B&B search tree.'''
        fig = plt.figure()
        
        pos = graphviz_layout(self.tree, prog='dot')
        node_labels = {node: node for node in self.tree.nodes}
        nx.draw_networkx_nodes(self.tree,
                               pos,
                               label=node_labels)
        nx.draw_networkx_edges(self.tree,
                               pos)
        
        nx.draw_networkx_labels(self.tree, pos, labels=node_labels)
        
        plt.show()


## Define the network

In [None]:
class BipartiteGCN(torch.nn.Module):
    def __init__(self, 
                 device, 
                 config=None,
                 emb_size=64,
                 num_rounds=1,
                 aggregator='add',
                 activation=None,
                 cons_nfeats=5,
                 edge_nfeats=1,
                 var_nfeats=19,
                 num_heads=1,
                 linear_weight_init=None,
                 linear_bias_init=None,
                 layernorm_weight_init=None,
                 layernorm_bias_init=None,
                 head_aggregator=None,
                 name='gnn'):
        '''
        Args:
            config (str, ml_collections.ConfigDict()): If not None, will initialise 
                from config dict. Can be either string (path to config.json) or
                ml_collections.ConfigDict object.
            activation (None, 'sigmoid', 'relu', 'leaky_relu', 'inverse_leaky_relu', 'elu', 'hard_swish',
                'softplus', 'mish', 'softsign')
            num_heads (int): Number of heads (final layers) to use. Will use
                head_aggregator to reduce all heads.
            linear_weight_init (None, 'uniform', 'normal', 
                'xavier_uniform', 'xavier_normal', 'kaiming_uniform', 'kaiming_normal')
            linear_bias_init (None, 'zeros', 'normal')
            layernorm_weight_init (None, 'normal')
            layernorm_bias_init (None, 'zeros', 'normal')
            head_aggregator (None, 'add', 'mean'): Reduce operation to use to aggregate outputs
                of heads. If None, forward() will return output of each head.
        '''
        super().__init__()
        self.device = device

        if config is not None:
            self.init_from_config(config)
        else:
            self.name = name
            self.init_nn_modules(emb_size=emb_size, 
                                 num_rounds=num_rounds, 
                                 cons_nfeats=cons_nfeats, 
                                 edge_nfeats=edge_nfeats, 
                                 var_nfeats=var_nfeats, 
                                 aggregator=aggregator, 
                                 activation=activation, 
                                 num_heads=num_heads,
                                 linear_weight_init=linear_weight_init,
                                 linear_bias_init=linear_bias_init,
                                 layernorm_weight_init=layernorm_weight_init,
                                 layernorm_bias_init=layernorm_bias_init,
                                 head_aggregator=head_aggregator)

        self.printed_warning = False
        self.to(self.device)

    def init_from_config(self, config):
        if type(config) == str:
            # load from json
            with open(config, 'r') as f:
                json_config = json.load(f)
                config = ml_collections.ConfigDict(json.loads(json_config))
        self.name = config.name
        if 'activation' not in config.keys():
            config.activation = None
        if 'num_heads' not in config.keys():
            config.num_heads = 1
        if 'linear_weight_init' not in config.keys():
            config.linear_weight_init = None
        if 'linear_bias_init' not in config.keys():
            config.linear_bias_init = None
        if 'layernorm_weight_init' not in config.keys():
            config.layernorm_weight_init = None
        if 'layernorm_bias_init' not in config.keys():
            config.layernorm_bias_init = None

            config.linear_bias_init = None

        if 'head_aggregator' not in config:
            config.head_aggregator = None

        self.init_nn_modules(emb_size=config.emb_size, 
                             num_rounds=config.num_rounds, 
                             cons_nfeats=config.cons_nfeats, 
                             edge_nfeats=config.edge_nfeats, 
                             var_nfeats=config.var_nfeats, 
                             aggregator=config.aggregator, 
                             activation=config.activation,
                             num_heads=config.num_heads,
                             linear_weight_init=config.linear_weight_init,
                             linear_bias_init=config.linear_bias_init,
                             layernorm_weight_init=config.layernorm_weight_init,
                             layernorm_bias_init=config.layernorm_bias_init,
                             head_aggregator=config.head_aggregator)

    def get_networks(self):
        return {'network': self}

    def init_model_parameters(self):

        def init_params(m):
            if isinstance(m, torch.nn.Linear):
                # weights
                if self.linear_weight_init is None:
                    pass
                elif self.linear_weight_init == 'uniform':
                    torch.nn.init.uniform_(m.weight, a=0.0, b=1.0)
                elif self.linear_weight_init == 'normal':
                    torch.nn.init.normal_(m.weight, mean=0.0, std=0.01)
                elif self.linear_weight_init == 'xavier_uniform':
                    torch.nn.init.xavier_uniform_(m.weight, gain=torch.nn.init.calculate_gain(self.activation))
                elif self.linear_weight_init == 'xavier_normal':
                    torch.nn.init.xavier_normal_(m.weight, gain=torch.nn.init.calculate_gain(self.activation))
                elif self.linear_weight_init == 'kaiming_uniform':
                    torch.nn.init.kaiming_uniform_(m.weight, nonlinearity=self.activation)
                elif self.linear_weight_init == 'kaiming_normal':
                    torch.nn.init.kaiming_normal_(m.weight, nonlinearity=self.activation)
                    # torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                else:
                    raise Exception(f'Unrecognised linear_weight_init {self.linear_weight_init}')

                # biases
                if m.bias is not None:
                    if self.linear_bias_init is None:
                        pass
                    elif self.linear_bias_init == 'zeros':
                        torch.nn.init.zeros_(m.bias)
                    elif self.linear_bias_init == 'uniform':
                        torch.nn.init.uniform_(m.bias)
                    elif self.linear_bias_init == 'normal':
                        torch.nn.init.normal_(m.bias)
                    else:
                        raise Exception(f'Unrecognised bias initialisation {self.linear_bias_init}')

            elif isinstance(m, torch.nn.LayerNorm):
                # weights
                if self.layernorm_weight_init is None:
                    pass
                elif self.layernorm_weight_init == 'normal':
                    torch.nn.init.normal_(m.weight, mean=0.0, std=0.01)
                else:
                    raise Exception(f'Unrecognised layernorm_weight_init {self.layernorm_weight_init}')

                # biases
                if self.layernorm_bias_init is None:
                    pass
                elif self.layernorm_bias_init == 'zeros':
                    torch.nn.init.zeros_(m.bias)
                elif self.layernorm_bias_init == 'normal':
                    torch.nn.init.normal_(m.bias)
                else:
                    raise Exception(f'Unrecognised layernorm_bias_init {self.layernorm_bias_init}')

        # init base GNN params
        self.apply(init_params)

        # init head output params
        for h in self.heads_module:
            h.apply(init_params)

        

    def init_nn_modules(self, 
                        emb_size=64, 
                        num_rounds=1, 
                        cons_nfeats=5, 
                        edge_nfeats=1, 
                        var_nfeats=19, 
                        aggregator='add', 
                        activation=None,
                        num_heads=1,
                        linear_weight_init=None,
                        linear_bias_init=None,
                        layernorm_weight_init=None,
                        layernorm_bias_init=None,
                        head_aggregator='add'):
        self.emb_size = emb_size
        self.num_rounds = num_rounds
        self.cons_nfeats = cons_nfeats
        self.edge_nfeats = edge_nfeats
        self.var_nfeats = var_nfeats
        self.aggregator = aggregator
        self.activation = activation
        self.num_heads = num_heads
        self.linear_weight_init = linear_weight_init
        self.linear_bias_init = linear_bias_init
        self.layernorm_weight_init = layernorm_weight_init
        self.layernorm_bias_init = layernorm_bias_init
        self.head_aggregator = head_aggregator

        # CONSTRAINT EMBEDDING
        self.cons_embedding = torch.nn.Sequential(
            torch.nn.LayerNorm(cons_nfeats),
            torch.nn.Linear(cons_nfeats, emb_size),
            # torch.nn.LayerNorm(emb_size, emb_size), # added
            torch.nn.LeakyReLU(),
            torch.nn.Linear(emb_size, emb_size),
            # torch.nn.LayerNorm(emb_size, emb_size), # added
            torch.nn.LeakyReLU(),
        )

#         # EDGE EMBEDDING
#         self.edge_embedding = torch.nn.Sequential(
#             torch.nn.LayerNorm(edge_nfeats),
#         )

        # VARIABLE EMBEDDING
        self.var_embedding = torch.nn.Sequential(
            torch.nn.LayerNorm(var_nfeats),
            torch.nn.Linear(var_nfeats, emb_size),
            # torch.nn.LayerNorm(emb_size, emb_size), # added
            torch.nn.LeakyReLU(),
            torch.nn.Linear(emb_size, emb_size),
            # torch.nn.LayerNorm(emb_size, emb_size), # added
            torch.nn.LeakyReLU(),
        )

        self.conv_v_to_c = BipartiteGraphConvolution(emb_size=emb_size, aggregator=aggregator)
        self.conv_c_to_v = BipartiteGraphConvolution(emb_size=emb_size, aggregator=aggregator)

        # heads
        self.heads_module = torch.nn.ModuleList([
            torch.nn.Sequential(
                torch.nn.Linear(emb_size, emb_size),
                torch.nn.LeakyReLU(),
                torch.nn.Linear(emb_size, 1, bias=True)
                )
            for _ in range(self.num_heads)
            ])

        if self.activation is None:
            self.activation_module = None
        elif self.activation == 'sigmoid':
            self.activation_module = torch.nn.Sigmoid()
        elif self.activation == 'relu':
            self.activation_module = torch.nn.ReLU()
        elif self.activation == 'leaky_relu' or self.activation == 'inverse_leaky_relu':
            self.activation_module = torch.nn.LeakyReLU()
        elif self.activation == 'elu':
            self.activation_module = torch.nn.ELU()
        elif self.activation == 'hard_swish':
            self.activation_module = torch.nn.Hardswish()
        elif self.activation == 'softplus':
            self.activation_module = torch.nn.Softplus()
        elif self.activation == 'mish':
            self.activation_module = torch.nn.Mish()
        elif self.activation == 'softsign':
            self.activation_module = torch.nn.Softsign()
        else:
            raise Exception(f'Unrecognised activation {self.activation}')
    
        self.init_model_parameters()





    def forward(self, *_obs):
        '''Returns output of each head.'''
        if len(_obs) > 1:
            # no need to pre-process observation features
#             constraint_features, edge_indices, edge_features, variable_features = _obs
            constraint_features, edge_indices, variable_features = _obs
        else:
            # need to pre-process observation features
            obs = _obs[0] # unpack
            constraint_features = torch.from_numpy(obs.row_features.astype(np.float32)).to(self.device)
            edge_indices = torch.from_numpy(obs.edge_features.indices.astype(np.int64)).to(self.device)
#             edge_features = torch.from_numpy(obs.edge_features.values.astype(np.float32)).view(-1, 1).to(self.device)
#             print(f'edge features in net forward: {edge_features.shape} {edge_features}')
            variable_features = torch.from_numpy(obs.column_features.astype(np.float32)).to(self.device)

        reversed_edge_indices = torch.stack([edge_indices[1], edge_indices[0]], dim=0)
        
        # First step: linear embedding layers to a common dimension (64)
        constraint_features = self.cons_embedding(constraint_features)
#         edge_features = self.edge_embedding(edge_features)
        if variable_features.shape[1] != self.var_nfeats:
            if not self.printed_warning:
                ans = None
                while ans not in {'y', 'n'}:
                    ans = input(f'WARNING: variable_features is shape {variable_features.shape} but var_nfeats is {self.var_nfeats}. Will index out extra features. Continue? (y/n): ')
                if ans == 'y':
                    pass
                else:
                    raise Exception('User stopped programme.')
                self.printed_warning = True
            variable_features = variable_features[:, 0:self.var_nfeats]
        variable_features = self.var_embedding(variable_features)

        # Two half convolutions (message passing round)
        for _ in range(self.num_rounds):
#             constraint_features = self.conv_v_to_c(variable_features, reversed_edge_indices, edge_features, constraint_features)
#             variable_features = self.conv_c_to_v(constraint_features, edge_indices, edge_features, variable_features)
            constraint_features = self.conv_v_to_c(variable_features, reversed_edge_indices, constraint_features)
            variable_features = self.conv_c_to_v(constraint_features, edge_indices, variable_features)

        # get output for each head
        head_output = [self.heads_module[head](variable_features).squeeze(-1) for head in range(self.num_heads)]
        # print(f'head outputs: {head_output}')

        # check if should aggregate head outputs
        if self.head_aggregator is None:
            # do not aggregate heads
            pass
        else:
            # aggregate head outputs
            # head_output = np.array([head_output[head].detach().cpu().numpy() for head in range(len(head_output))])
            if self.head_aggregator == 'add':
                head_output = [torch.stack(head_output, dim=0).sum(dim=0)]
            elif self.head_aggregator == 'mean':
                head_output = [torch.stack(head_output, dim=0).mean(dim=0)]
            else:
                raise Exception(f'Unrecognised head_aggregator {self.head_aggregator}')
            # print(f'head outputs after aggregation: {head_output}')

        # activation
        if self.activation_module is not None:
            head_output = [self.activation_module(head) for head in head_output]
            if self.activation == 'inverse_leaky_relu':
                # invert
                head_output = [-1 * head for head in head_output]
        # print(f'head outputs after activation: {head_output}')



        # # activation
        # if self.activation_module is not None:
            # head_output = self.activation_module(head_output)

        return head_output

    def create_config(self):
        '''Returns config dict so that can re-initialise easily.'''
        # create network dict of self.<attribute> key-value pairs
        network_dict = copy.deepcopy(self.__dict__)

        # remove module references to avoid circular references
        del network_dict['_modules']

        # create config dict
        config = ml_collections.ConfigDict(network_dict)

        return config
    
    
class BipartiteGraphConvolution(torch_geometric.nn.MessagePassing):
    """
    The bipartite graph convolution is already provided by pytorch geometric and we merely need 
    to provide the exact form of the messages being passed.
    """
    def __init__(self,
                 aggregator='add',
                 emb_size=64):
        super().__init__(aggregator)
        
        self.feature_module_left = torch.nn.Sequential(
            torch.nn.Linear(emb_size, emb_size)
        )
#         self.feature_module_edge = torch.nn.Sequential(
#             torch.nn.Linear(1, emb_size, bias=False)
#         )
        self.feature_module_right = torch.nn.Sequential(
            torch.nn.Linear(emb_size, emb_size, bias=False)
        )
        self.feature_module_final = torch.nn.Sequential(
            torch.nn.LayerNorm(emb_size),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(emb_size, emb_size)
        )
        
        self.post_conv_module = torch.nn.Sequential(
            torch.nn.LayerNorm(emb_size)
        )

        # output_layers
        self.output_module = torch.nn.Sequential(
            torch.nn.Linear(2*emb_size, emb_size),
            # torch.nn.LayerNorm(emb_size, emb_size), # added
            torch.nn.LeakyReLU(),
            torch.nn.Linear(emb_size, emb_size),
            # torch.nn.LayerNorm(emb_size, emb_size), # added
        )

#     def forward(self, left_features, edge_indices, edge_features, right_features):
    def forward(self, left_features, edge_indices, right_features):
        """
        This method sends the messages, computed in the message method.
        """
#         print(f'edge features in GNN forward: {edge_features.shape} {edge_features}')
#         output = self.propagate(edge_indices, size=(left_features.shape[0], right_features.shape[0]), 
#                                 node_features=(left_features, right_features), edge_features=edge_features)
        output = self.propagate(edge_indices, size=(left_features.shape[0], right_features.shape[0]), 
                                node_features=(self.feature_module_left(left_features), self.feature_module_right(right_features)))
        return self.output_module(torch.cat([self.post_conv_module(output), right_features], dim=-1))

#     def message(self, node_features_i, node_features_j, edge_features):
    def message(self, node_features_i, node_features_j):
        output = self.feature_module_final(node_features_i + node_features_j)
        return output


## Define the Batch object for input state

In [None]:
class BipartiteNodeData(torch_geometric.data.Data):
    """
    This class encode a node bipartite graph observation as returned by the `ecole.observation.NodeBipartite` 
    observation function in a format understood by the pytorch geometric data handlers.
    """
    def __init__(self, obs, candidates):
        super().__init__()
        self.obs = obs
        self.constraint_features = torch.FloatTensor(obs.row_features)
        self.edge_index = torch.LongTensor(obs.edge_features.indices.astype(np.int64))
#         self.edge_attr = torch.FloatTensor(obs.edge_features.values).unsqueeze(1)
#         print(f'edge features in BipartiteNodeData: {self.edge_attr.shape} {self.edge_attr}')
        self.variable_features = torch.FloatTensor(obs.column_features)
        self.candidates = torch.from_numpy(candidates.astype(np.int64)).long()
        self.raw_candidates = torch.from_numpy(candidates.astype(np.int64)).long()
        
        self.num_candidates = len(candidates)
        self.num_variables = self.variable_features.size(0)
        self.num_nodes = self.constraint_features.size(0) + self.variable_features.size(0)

    def __inc__(self, key, value):
        """
        We overload the pytorch geometric method that tells how to increment indices when concatenating graphs 
        for those entries (edge index, candidates) for which this is not obvious. This
        enables batching.
        """
        if key == 'edge_index':
            # constraint nodes connected via edge to variable nodes
            return torch.tensor([[self.constraint_features.size(0)], [self.variable_features.size(0)]])
        elif key == 'candidates':
            # actions are variable nodes
            return self.variable_features.size(0)
        else:
            return super().__inc__(key, value)


## Generate states

In [None]:
env = EcoleBranching(observation_function='43_var_features',
                      information_function='default',
                      reward_function='default',
                      scip_params='default')

instances = ecole.instance.SetCoverGenerator(n_rows=500, n_cols=1000, density=0.05)

def gen_obs(env, instances):
    obs = None
    while obs is None:
        instance = next(instances)
        obs, action_set, _, _, _ = env.reset(instance)
    return (obs, action_set)

## Pass through network

In [None]:
device = 'cuda:1'
net = BipartiteGCN(device=device,
                  emb_size=128,
                  num_rounds=2,
                  cons_nfeats=5,
                  edge_nfeats=1,
                  var_nfeats=43, # 19 20 28 45 40
                  aggregator='add',
                  activation='inverse_leaky_relu',
                  num_heads=1,
                  linear_weight_init='normal',
                  linear_bias_init='zeros',
                  layernorm_weight_init=None,
                  layernorm_bias_init=None,
                  head_aggregator=None) # None 'add'

In [None]:
obs = gen_obs(env, instances)
state = BipartiteNodeData(*obs) 

state = state.to(device)
# state = (state.constraint_features, state.edge_index, state.edge_attr, state.variable_features)
state = (state.constraint_features, state.edge_index, state.variable_features)

In [None]:
%timeit -n10 net(*state)

In [None]:
obs = gen_obs(env, instances)
state = BipartiteNodeData(*obs)

state = state.to(device)
# state = (state.constraint_features, state.edge_index, state.edge_attr, state.variable_features)
state = (state.constraint_features, state.edge_index, state.variable_features)


with torch.autograd.profiler.profile(use_cuda=True) as prof:
    output = net(*state)
print(prof.table(sort_by='cuda_time_total'))

## Create batch of states

In [None]:
batch_size = 128

observations = [gen_obs(env, instances) for _ in range(batch_size)]
states = [BipartiteNodeData(*obs) for obs in observations]
batched_state = torch_geometric.data.Batch.from_data_list(states)
print(batched_state)

In [None]:
state = batched_state.to(device)
# state = (state.constraint_features, state.edge_index, state.edge_attr, state.variable_features)
state = (state.constraint_features, state.edge_index, state.variable_features)
for el in state:
    print(el.shape)

output_start = time.time()
output = net(*state)
print(output[0][0])
print(output[0].shape)
output_t = time.time() - output_start
print(f'output_t: {output_t*1e3:.3f} ms')

In [None]:
backward_start = time.time()
output[0].sum().backward()
print(output[0][0])
backward_t = time.time() - backward_start
print(f'backward_t: {backward_t*1e3:.3f} ms')