# Testing Strong Branching Agent

In [None]:
%load_ext autoreload
%autoreload
from retro_branching.environments import EcoleBranching
from retro_branching.agents import StrongBranchingAgent

import ecole
import pyscipopt

import matplotlib.pyplot as plt

In [None]:
agent = StrongBranchingAgent()

In [None]:
env = EcoleBranching()
env.seed(0)
env2 = EcoleBranching()
env2.seed(0)

instances = ecole.instance.SetCoverGenerator(n_rows=100, n_cols=100, density=0.05)

# find an instance
obs = None
while obs is None:
    instance = next(instances)
    instance_before_reset = instance.copy_orig()
    obs, action_set, reward, done, info = env.reset(instance)
    
_ = env2.reset(instance_before_reset.copy_orig())

# Testing Rollouts

In [None]:
import ecole
import pyscipopt
from retro_branching.environments import EcoleBranching

# ecole branching environment
default_scip_params = {'separating/maxrounds': 0,
                       'separating/maxroundsroot': 0,
                       'separating/maxcuts': 0,
                       'separating/maxcutsroot': 0,
                       'presolving/maxrounds': 0,
                       'presolving/maxrestarts': 0,
                       'propagating/maxrounds':0,
                       'propagating/maxroundsroot':0,
                       'lp/initalgorithm':'d',
                       'lp/resolvealgorithm':'d',
                       'limits/time': 3600}

# class EcoleBranching(ecole.environment.Branching):
#     def __init__(
#         self,
#         observation_function="default",
#         information_function="default",
#         reward_function="default",
#         scip_params="default",
#     ):
#         if reward_function == "default":
#             reward_function = reward_function=({
#                      'num_nodes': -ecole.reward.NNodes(),
#                      'lp_iterations': -ecole.reward.LpIterations(),
#                      'primal_integral': -ecole.reward.PrimalIntegral(),
#                      'dual_integral': -ecole.reward.DualIntegral(),
#                      'primal_dual_integral': -ecole.reward.PrimalDualIntegral(),
#                  })
#         if information_function == 'default':
#             information_function=({
#                      'lp_iterations': ecole.reward.LpIterations().cumsum(),
#                      'num_nodes': ecole.reward.NNodes().cumsum(),
#                      'solving_time': ecole.reward.SolvingTime().cumsum(),
#                      'dual_integral': ecole.reward.DualIntegral(),
#                      'primal_dual_integral': ecole.reward.PrimalDualIntegral(),
#                  }),
#         if observation_function == 'default':    
#             observation_function=(
#                      ecole.observation.NodeBipartite()
#                  ),
#         if scip_params == 'default':
#             scip_params=default_scip_params
        
#         super(EcoleBranching, self).__init__(
#             observation_function=observation_function,
#             information_function=information_function,
#             reward_function=reward_function,
#             scip_params=scip_params,
#         )

# branching agent
class PseudocostBranchingAgent:
    def __init__(self, name='pc'):
        self.name = name
        self.pc_branching_function = ecole.observation.Pseudocosts()

    def extract(self, model, done):
        return self.pc_branching_function.extract(model, done)

    def action_select(self, action_set, model, done):
        scores = self.extract(model, done)
        action = scores[action_set].argmax()
        return action


instances = ecole.instance.SetCoverGenerator(n_rows=500, n_cols=1000, density=0.05)
base_env = EcoleBranching()
base_env.seed(0)
agent = PseudocostBranchingAgent()
num_instances = 0
while True:
    # find an instance not pre-solved by environment
    print('\n >>> New Instance <<<')
    obs, counter = None, 0
    while obs is None:
        counter += 1
        print(counter)
        instance = next(instances)
        instance_before_reset = instance.copy_orig()
        obs, action_set, reward, done, info = base_env.reset(instance)
    dual_integrals = []
    primal_integrals = []
        
    # use an agent to solve the instance in base_env
    while not done:

#         # rollout agent from current step to episode termination
#         print('>> New Rollout <<')
#         curr_state = base_env.model.as_pyscipopt()
#         rollout_m = pyscipopt.Model(sourceModel=curr_state, globalcopy=True)
#         rollout_env = EcoleBranching()
#         rollout_env.seed(0)
#         rollout_m = ecole.scip.Model.from_pyscipopt(rollout_m)
#         rollout_obs, rollout_action_set, rollout_reward, rollout_done, rollout_info = rollout_env.reset(rollout_m)
#         while not rollout_done:
#             rollout_action = agent.action_select(rollout_action_set, rollout_env.model, rollout_done)
#             rollout_action = rollout_action_set[rollout_action]
#             rollout_obs, rollout_action_set, rollout_reward, rollout_done, rollout_info = rollout_env.step(rollout_action)
#             rollout_m = rollout_env.model.as_pyscipopt()
            
#             print('rollout reward: {} | total num nodes: {} | reward primal-dual integral: {} | reward integral: {}'.format(rollout_reward['num_nodes'], rollout_m.getNTotalNodes(), rollout_reward['primal_dual_integral'], rollout_reward['dual_integral']))
            
            
        # get agent action at current step
        action = agent.action_select(action_set, base_env.model, done)
        action = action_set[action]
            
        # take step in agent environment
        obs, action_set, reward, done, info = base_env.step(action)
        print('reward: {}'.format(reward))
        print('info: {}'.format(info))
        base_m = base_env.model.as_pyscipopt()
        print('base reward: {} | total num nodes: {}'.format(reward['num_nodes'], base_m.getNTotalNodes()))
        
        dual_integrals.append(reward['dual_integral'])
        primal_integrals.append(abs(reward['primal_integral']))
        
        if reward['num_nodes'] < -1e10:
            raise Exception('Agent reward {} invalid'.format(reward['num_nodes']))
    print(dual_integrals)
    dual_integrals.append(0)
    primal_integrals.append(0)
    dual_integrals_vs_time = [sum(dual_integrals[i:]) for i in range(len(dual_integrals))]
    primal_integrals_vs_time = [sum(primal_integrals[i:]) for i in range(len(primal_integrals))]
    fig = plt.figure()
    plt.plot(dual_integrals_vs_time, label='Dual Integral')
    plt.plot(primal_integrals_vs_time, label='Primal Integral')
    plt.legend()
    plt.show()
    num_instances += 1
    if num_instances == 10:
        raise Exception(f'{num_instances} instances reached')

In [None]:
import numpy as np

np.random.seed
x = np.random.randint(1, 10, )

In [None]:
import ecole
import pyscipopt

import copy
import pickle
import dill

# ecole branching environment
default_scip_params = {'separating/maxrounds': 0,
                       'separating/maxroundsroot': 0,
                       'separating/maxcuts': 0,
                       'separating/maxcutsroot': 0,
                       'presolving/maxrounds': 0,
                       'presolving/maxrestarts': 0,
                       'propagating/maxrounds':0,
                       'propagating/maxroundsroot':0,
                       'lp/initalgorithm':'d',
                       'lp/resolvealgorithm':'d',
                       'limits/time': 3600}

class EcoleBranching(ecole.environment.Branching):
    def __init__(self,
                 observation_function=(
                     ecole.observation.NodeBipartite()
                 ),
                 information_function=({
                     'lp_iterations': ecole.reward.LpIterations().cumsum(),
                     'num_nodes': ecole.reward.NNodes().cumsum(),
                     'solving_time': ecole.reward.SolvingTime().cumsum(),
                     'dual_integral': ecole.reward.DualIntegral(),
                     'primal_dual_integral': ecole.reward.PrimalDualIntegral(),
                 }),
                 reward_function=({
                     'num_nodes': -ecole.reward.NNodes(),
                     'lp_iterations': -ecole.reward.LpIterations(),
                     'primal_integral': -ecole.reward.PrimalIntegral(),
                     'dual_integral': -ecole.reward.DualIntegral(),
                     'primal_dual_integral': -ecole.reward.PrimalDualIntegral(),
                 }),
                 scip_params=default_scip_params):
        super(EcoleBranching, self).__init__(observation_function=observation_function,
                                             information_function=information_function,
                                             reward_function=reward_function,
                                             scip_params=scip_params)
#         m = self.model.as_pyscipopt()
#         self.prev_num_nodes = m.getNTotalNodes()
    
#     def scip_total_num_nodes(self):
#         m = self.model.as_pyscipopt()
#         return m.getNTotalNodes()
    
#     def scip_change_num_nodes(self):
#         m = self.model.as_pyscipopt()
#         change_num_nodes = m.getNTotalNodes() - self.prev_num_nodes
#         self.prev_num_nodes = mgetNTotalNodes()
#         return change_num_nodes
    
# branching agent
class PseudocostBranchingAgent:
    def __init__(self, name='pc'):
        self.name = name
        self.pc_branching_function = ecole.observation.Pseudocosts()

    def extract(self, model, done):
        return self.pc_branching_function.extract(model, done)

    def action_select(self, action_set, model, done):
        scores = self.extract(model, done)
        action = scores[action_set].argmax()
        return action


instances = ecole.instance.SetCoverGenerator(n_rows=100, n_cols=100, density=0.05)
base_env = EcoleBranching()
base_env.seed(0)
agent = PseudocostBranchingAgent()
while True:
    # find an instance not pre-solved by environment
    print('\n >>> New Instance <<<')
    obs, counter = None, 0
    while obs is None:
        counter += 1
        print(counter)
        instance = next(instances)
        instance_before_reset = instance.copy_orig()
        obs, action_set, reward, done, info = base_env.reset(instance)
        
    # use an agent to solve the instance in base_env
    while not done:
        # check base env state before rollout
        base_m = base_env.model.as_pyscipopt()
        print(f'base_env before rollout num nodes: {base_m.getNTotalNodes()} | best primal bound: {base_m.getPrimalbound()} | best dual bound: {base_m.getDualbound()} | primal-dual gap: {base_m.getGap()}')
        
        # rollout agent from current step to episode termination
        print('>> New Rollout <<')
        curr_state = base_env.model.as_pyscipopt()
        rollout_m = pyscipopt.Model(sourceModel=curr_state, globalcopy=True)
        
#         rollout_env = copy.deepcopy(base_env)
#         print('copied!')
#         pickle_test = pickle.dumps(base_env)
#         pickle_test = pickle.lodas(pickle_test)
#         print('pickled!')
        dill_test = dill.dumps(base_env)
        dill_test = dill.dill(dill_test)
        print('dilld!')
        
        rollout_env = EcoleBranching()
        rollout_env.seed(0)
        rollout_m = ecole.scip.Model.from_pyscipopt(rollout_m)
        rollout_obs, rollout_action_set, rollout_reward, rollout_done, rollout_info = rollout_env.reset(rollout_m)
        rollout_m = rollout_env.model.as_pyscipopt()
        prev_num_nodes = rollout_m.getNTotalNodes()
        while not rollout_done:
            rollout_action = agent.action_select(rollout_action_set, rollout_env.model, rollout_done)
            rollout_action = rollout_action_set[rollout_action]
            rollout_obs, rollout_action_set, rollout_reward, rollout_done, rollout_info = rollout_env.step(rollout_action)
            print('ecole rollout reward (num nodes): {}'.format(rollout_reward['num_nodes']))
            rollout_m = rollout_env.model.as_pyscipopt()
            print('scip rollout reward (num nodes): {}'.format(-(rollout_m.getNTotalNodes()-prev_num_nodes)))
            prev_num_nodes = rollout_m.getNTotalNodes()
        rollout_m = rollout_env.model.as_pyscipopt()
        print(f'rollout_env final total nodes: {rollout_m.getNTotalNodes()} | best primal bound: {rollout_m.getPrimalbound()} | best dual bound: {rollout_m.getDualbound()} | primal-dual gap: {rollout_m.getGap()}')
            
        # check base env state after rollout
        base_m = base_env.model.as_pyscipopt()
        print(f'base_env after rollout num nodes: {base_m.getNTotalNodes()} | best primal bound: {base_m.getPrimalbound()} | best dual bound: {base_m.getDualbound()} | primal-dual gap: {base_m.getGap()}')
            
        # get agent action at current step
        action = agent.action_select(action_set, base_env.model, done)
        action = action_set[action]
            
        # take step in agent environment
        obs, action_set, reward, done, info = base_env.step(action)
        print('ecole agent reward (num nodes): {}'.format(reward['num_nodes']))
        base_m = base_env.model.as_pyscipopt()
        print('scip agent reward (num nodes): {}'.format(-base_m.getNNodes()))
        print('scip agent total nodes: {}'.format(-base_m.getNTotalNodes()))
        
#         # check if agent reward is valid
        if reward['num_nodes'] < -1e10:
            raise Exception('Agent reward {} invalid'.format(reward['num_nodes']))

# Testing Multiple Envs using Same Instance

In [None]:
from retro_branching.environments import EcoleBranching
import ecole

In [None]:
env1 = EcoleBranching()
# env1.seed(0)

env2 = EcoleBranching()
# env2.seed(0)

In [None]:
instances = ecole.instance.SetCoverGenerator(n_rows=100, n_cols=200, density=0.05)

# find an instance
obs, counter = None, 1
while obs is None:
    print(counter)
    instance = next(instances)
    instance_before_reset = instance.copy_orig()
    obs, action_set, reward, done, info = env1.reset(instance)
    counter += 1

env2.reset(instance_before_reset)

m = env1.model.as_pyscipopt()
m2 = env2.model.as_pyscipopt()

print('\nafter reset')
print('vars: {} {}'.format(len(m.getVars()), m.getVars()))
print('vars2: {} {}'.format(len(m2.getVars()), m2.getVars()))

print('\nconss: {} {}'.format(len(m.getConss()), m.getConss()))
print('conss2: {} {}'.format(len(m2.getConss()), m2.getConss()))

print(f'\nbest primal bound: {m.getPrimalbound()} | best dual bound: {m.getDualbound()} | primal-dual gap: {m.getGap()}')
print(f'best primal bound: {m2.getPrimalbound()} | best dual bound: {m2.getDualbound()} | primal-dual gap: {m2.getGap()}')

In [None]:
instances1 = ecole.instance.SetCoverGenerator(n_rows=100, n_cols=100, density=0.05)
instances2 = ecole.instance.SetCoverGenerator(n_rows=100, n_cols=100, density=0.05)

instance1 = next(instances1)
instance2 = next(instances2)

env1 = EcoleBranching()
env2 = EcoleBranching()

env1.reset(instance1)
env2.reset(instance2)

m = env1.model.as_pyscipopt()
m2 = env2.model.as_pyscipopt()

print(f'\nbest primal bound: {m.getPrimalbound()} | best dual bound: {m.getDualbound()} | primal-dual gap: {m.getGap()}')
print(f'best primal bound: {m2.getPrimalbound()} | best dual bound: {m2.getDualbound()} | primal-dual gap: {m2.getGap()}')

# Validator

In [None]:
%autoreload
from retro_branching.environments import EcoleBranching
from retro_branching.networks import BipartiteGCN
from retro_branching.agents import REINFORCEAgent, StrongBranchingAgent
from retro_branching.validators import ReinforcementLearningValidator

import ecole
import torch
import os
import gzip
import pickle

# # load RL agents
# ids = [57, 58, 60, 65]
# DEVICE = 'cuda:1'
# agents = {}
# for i in ids:
#     # load agent params
#     agent_name = 'rl_gnn_{}'.format(i)
#     agent_path = '../scripts/reinforce_learner/rl_gnn/{}/'.format(agent_name)
#     path = '{}{}/'.format(agent_path, [name for name in os.listdir(agent_path)][-1])
#     policy_network = BipartiteGCN(DEVICE)
#     policy_network.load_state_dict(torch.load(path+'trained_params.pkl'))

#     # init agent
#     agent = REINFORCEAgent(policy_network=policy_network, device=DEVICE, name=agent_name)
#     agent.eval() # turn on evaluation mode
#     agents[agent_name] = agent
# load gnn agent (no RL)
policy_network = BipartiteGCN(DEVICE)
policy_network.load_state_dict(torch.load('../scripts/trained_params_dict_disabler.pkl'))
agent = REINFORCEAgent(policy_network=policy_network, device=DEVICE, name='gnn')
agent.eval() # turn on evaluation mode
agents['gnn'] = agent
print(agents)
# load strong branching agent
agent = StrongBranchingAgent(name='sb')
agents['sb'] = agent
print('Initialised agents {}'.format(agents.keys()))



# init instances generator
# instances = ecole.instance.SetCoverGenerator(n_rows=100, n_cols=100, density=0.05)
instances = ecole.instance.SetCoverGenerator(n_rows=500, n_cols=1000, density=0.05)
print('Initialised instances.')


# init envs
envs = {}
for agent_name in agents.keys():
    envs[agent_name] = EcoleBranching()
print('Initialised agent envs.')

# metrics
metrics = ['num_nodes', 'solving_time', 'lp_iterations']


validator = ReinforcementLearningValidator(agents=agents,
                                           envs=envs,
                                           instances=instances,
                                           metrics=metrics,
                                           seed=0,
                                           turn_off_heuristics=False,
                                           threshold_difficulty=None,
                                           epoch_log_frequency=1,
                                           path_to_save='.')
validator.test(10)

In [None]:
%autoreload
from retro_branching.utils import plot_val_line

from collections import defaultdict
import numpy as np


nested_dict = lambda: defaultdict(nested_dict)
plot_dicts = {metric: nested_dict() for metric in validator.metrics}

for agent in validator.epochs_log.keys():
    for metric in validator.metrics:
        plot_dicts[metric][agent]['y_values'] = validator.epochs_log[agent][metric]
        plot_dicts[metric][agent]['x_values'] = list(range(len(validator.epochs_log[agent][metric])))
        print('Agent {} mean {}: {}'.format(agent, metric, np.mean(validator.epochs_log[agent][metric])))
        
res = 1
for metric in plot_dicts.keys():
    _ = plot_val_line(plot_dicts[metric],
                      xlabel='Epoch',
                      ylabel='{}'.format(metric),
                      show_fig=True)
    _ = plot_val_line(plot_dicts[metric],
                      xlabel='Epoch',
                      ylabel='Mean {}'.format(metric),
                      smooth_data_res=res,
                      show_fig=True)