How much does step run time vary across an episode? 

If it varies by a lot, this suggests that actions have an impact on the backend SCIP/ecole run times, and therefore that in order to reduce overall run time, we should somehow incorporate run time into the agent's reward. If step times are relatively constant, then regardless of what action we take we will have similar run times, and so we just want to e.g. converge on a solution with as few steps as possible, therefore we can just use e.g. dual bound, number of nodes etc. rewards to optimise our overall objective of faster solving times.

We will use 2 agents; imitation_1k and imitation_100k. Both agents have the same model architecture and so should have the same inference time. We will set imitation_100k as the 'base agent' and imitation_1k as the 'rollout agent'. At each step, we will take a step with both agents and record their step times. We will then reset the rollout agent's environment and bring it back to the same state as the base agent before taking a step with both agents. In this way, each agent's step times are directly comparable, and any difference between them should be due to different action selection (assuming hardware performance does not vary between steps).

**N.B. Whenever doing rollouts where you want to have environments resetting the env in the same way, you MUST set env.seed(seed) just before calling env.reset(instance_before_reset.copy_orig()) for ALL envs, or will get different env.reset() behaviour!!**

In [None]:
import retro_branching
from retro_branching.learners import REINFORCELearner
from retro_branching.agents import REINFORCEAgent, StrongBranchingAgent
from retro_branching.networks import BipartiteGCN
from retro_branching.environments import EcoleBranching

import ecole
import torch

import time
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
# base agent
device = 'cpu'
policy = BipartiteGCN(device=device,
                   emb_size=64,
                   num_rounds=1,
                   aggregator='add')
policy.load_state_dict(torch.load('/scratch/datasets/retro_branching/supervised_learner/gnn/gnn_21/checkpoint_275/trained_params.pkl'))
policy.eval()
base_agent = REINFORCEAgent(policy,
                            device=device,
                            temperature=1.0,
                            name='base_agent')

# rollout agent
policy = BipartiteGCN(device=device,
                   emb_size=64,
                   num_rounds=1,
                   aggregator='add')
policy.load_state_dict(torch.load('/scratch/datasets/retro_branching/supervised_learner/gnn/gnn_1/checkpoint_1/trained_params.pkl'))
policy.eval()
rollout_agent = REINFORCEAgent(policy,
                               device=device,
                               temperature=1.0,
                               name='rollout_agent')


# base env
base_env = EcoleBranching(observation_function='default',
                          information_function='default',
                          reward_function='default',
                          scip_params='default')
base_env.seed(0)

# instances
instances = ecole.instance.SetCoverGenerator(n_rows=500, n_cols=1000, density=0.05)

In [None]:
num_episodes = 10

plot_dicts = {'base': {'step_times': []},
              'rollout': {'step_times': []}}
for i in range(num_episodes):
    print(f'\nEpisode {i}')
    base_actions, base_action_sets = [], [] # store base agent history for rollouts
    
    # find an instance not pre-solved by environment
    base_obs = None
    while base_obs is None:
        base_env.seed(0)
        instance = next(instances)
        instance_before_reset = instance.copy_orig()
        base_obs, base_action_set, base_reward, base_done, base_info = base_env.reset(instance)
    
    # run episode
    with torch.no_grad():
        while not base_done:
            # get rollout env to same state as base env
            rollout_env = EcoleBranching(observation_function=base_env.str_observation_function,
                                         information_function=base_env.str_information_function,
                                         reward_function=base_env.str_reward_function,
                                         scip_params=base_env.str_scip_params)
            rollout_env.seed(0)
            rollout_obs, rollout_action_set, rollout_reward, rollout_done, rollout_info = rollout_env.reset(instance_before_reset.copy_orig())
            for action, action_set in zip(base_actions, base_action_sets):
                rollout_obs, rollout_action_set, rollout_reward, rollout_done, rollout_info = rollout_env.step(action_set[action.item()])
        
            # take action with base agent
            base_action = base_agent.action_select(base_action_set, base_obs)
            base_actions.append(base_action)
            base_action_sets.append(base_action_set)
            base_action = base_action_set[base_action.item()]
            base_start_step = time.time()
            base_obs, base_action_set, base_reward, base_done, base_info = base_env.step(base_action)
            base_end_step = time.time()
            base_time = base_end_step - base_start_step
            plot_dicts['base']['step_times'].append(base_time)
        
            # take action with rollout agent
            rollout_action = rollout_agent.action_select(rollout_action_set, rollout_obs)
            rollout_action = rollout_action_set[rollout_action.item()]
            rollout_start_step = time.time()
            rollout_obs, rollout_action_set, rollout_reward, rollout_done, rollout_info = rollout_env.step(rollout_action)
            rollout_end_step = time.time()
            rollout_time = rollout_end_step - rollout_start_step
            plot_dicts['rollout']['step_times'].append(rollout_time)
            
            print(f'Base env step time: {round(base_time, 3)} | Rollout env step time: {round(rollout_time, 3)}')

In [None]:
for name in plot_dicts.keys():
    for stat in plot_dicts[name]:
        fig = plt.figure()
        std, mean = np.std(plot_dicts[name][stat]), np.mean(plot_dicts[name][stat])
        print(f'mean: {mean}, std: {std}')
        title = f'Agent \'{name}\' episode {stat} (mean={round(mean,3)}, std={round(std,3)})'
        sns.histplot(plot_dicts[name][stat], edgecolor='k')
        plt.title(title)
        plt.xlabel(f'{stat}')
        plt.show()