# Validation

This notebook is for performing validation rollouts to test trained policies etc. on some test instances. 

In [None]:
%load_ext autoreload
%autoreload
from retro_branching.environments import EcoleBranching
from retro_branching.networks import BipartiteGCN
from retro_branching.agents import REINFORCEAgent, StrongBranchingAgent
from retro_branching.validators import ReinforcementLearningValidator
from retro_branching.utils import plot_val_line

import ecole
import torch
import os
import gzip
import pickle
from collections import defaultdict
import numpy as np

In [None]:
def get_most_recent_agent_checkpoint_foldername(agent_path):
    '''Given a path to a folders named <name>_<number>, will return foldername with highest <number>.'''
    foldernames = [name.split('_') for name in os.listdir(agent_path)]
    idx_to_num = {idx: int(num) for idx, num in zip(range(len(foldernames)), [name[-1] for name in foldernames])}
    latest_idx = max(idx_to_num, key=idx_to_num.get)
    foldername = [name for name in os.listdir(agent_path)][latest_idx]
    return foldername

## Initialise Agents

Load any saved NN parameters and initialise any agents you want to test.

In [None]:
%autoreload
DEVICE = 'cuda:1'
agents = {}

# # LOAD RL AGENTS
# ids = [57, 58, 60, 65]
# for i in ids:
#     # load agent params
#     agent = 'rl_gnn_{}'.format(i)
#     agent_path = '/scratch/datasets/retro_branching/reinforce_learner/rl_gnn/{}/'.format(agent)
#     foldername = get_most_recent_agent_checkpoint_foldername(agent_path)
# #     foldername='checkpoint_5'
#     path = '{}{}/'.format(agent_path, foldername)
#     policy_network = BipartiteGCN(DEVICE)
#     policy_network.load_state_dict(torch.load(path+'trained_params.pkl'))

#     # init agent
#     agent = REINFORCEAgent(policy_network=policy_network, device=DEVICE, name=agent_name)
#     agent.eval() # turn on evaluation mode
#     agents[agent_name] = agent

# LOAD GNN AGENT (NO RL)
policy_network = BipartiteGCN(DEVICE)
agent = 'gnn_21'
agent_path = '/scratch/datasets/retro_branching/supervised_learner/gnn/{}/'.format(agent)
foldername = get_most_recent_agent_checkpoint_foldername(agent_path)
# foldername='checkpoint_5'
path = '{}{}/'.format(agent_path, foldername)
print(path)
policy_network.load_state_dict(torch.load(path+'trained_params.pkl', map_location=DEVICE))
agent = REINFORCEAgent(policy_network=policy_network, device=DEVICE, name='gnn_100k')
agent.eval() # turn on evaluation mode
agents['gnn_100k'] = agent

DEVICE = 'cuda:0'
policy_network = BipartiteGCN(DEVICE)
path = '../scripts/'
print(path)
policy_network.load_state_dict(torch.load(path+'trained_params_dict_disabler.pkl', map_location=DEVICE))
agent = REINFORCEAgent(policy_network=policy_network, device=DEVICE, name='gnn_1k')
agent.eval() # turn on evaluation mode
agents['gnn_1k'] = agent

# LOAD STRONG BRANCHING AGENT
agent = StrongBranchingAgent(name='sb')
agents['sb'] = agent
print('Initialised agents {}'.format(agents.keys()))

print(agents)

## Run Validation Rollouts

In [None]:
# instances = ecole.instance.SetCoverGenerator(n_rows=100, n_cols=100, density=0.05)
instances = ecole.instance.SetCoverGenerator(n_rows=500, n_cols=1000, density=0.05)
print('Initialised instances.')


# init envs
envs = {}
for agent_name in agents.keys():
    envs[agent_name] = EcoleBranching()
print('Initialised agent envs.')

# metrics
metrics = ['num_nodes', 'solving_time', 'lp_iterations']
print('Initialised metrics.')


validator = ReinforcementLearningValidator(agents=agents,
                                           envs=envs,
                                           instances=instances,
                                           metrics=metrics,
                                           seed=0,
                                           turn_off_heuristics=False,
                                           threshold_difficulty=None,
                                           epoch_log_frequency=1)
validator.test(100)

## Plot Validation Rollout Results

In [None]:
%autoreload
nested_dict = lambda: defaultdict(nested_dict)
plot_dicts = {metric: nested_dict() for metric in validator.metrics}

for agent in validator.epochs_log.keys():
    for metric in validator.metrics:
        plot_dicts[metric][agent]['y_values'] = validator.epochs_log[agent][metric]
        plot_dicts[metric][agent]['x_values'] = list(range(len(validator.epochs_log[agent][metric])))
        print('Agent {} mean {}: {}'.format(agent, metric, np.mean(validator.epochs_log[agent][metric])))
        
res = 1
for metric in plot_dicts.keys():
    _ = plot_val_line(plot_dicts[metric],
                      xlabel='Epoch',
                      ylabel='{}'.format(metric),
                      show_fig=True)
    _ = plot_val_line(plot_dicts[metric],
                      xlabel='Epoch',
                      ylabel='Mean {}'.format(metric),
                      smooth_data_res=res,
                      show_fig=True)