Can we overfit to 100x100 instances and show that we can beat strong branching, or at least a network which has tried to generalise to 100x100 instances?

In [None]:
from retro_branching.agents import StrongBranchingAgent, REINFORCEAgent
from retro_branching.environments import EcoleBranching
from retro_branching.networks import BipartiteGCN 
from retro_branching.utils import plot_val_line, get_most_recent_checkpoint_foldername, sns_plot_val_line

import torch
import ecole
import pyscipopt

import glob
import os
import gzip
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from collections import defaultdict
import json

In [None]:
nrows = 100 # 100 500
ncols = 100 # 500 1000
instances_path = f'/home/zciccwf/phd_project/projects/retro_branching/scripts/instances_to_overfit/nrows_{nrows}_ncols_{ncols}/'
instance_paths = list(sorted(glob.glob(f'{instances_path}/*.mps')))
print(f'instance paths: {instance_paths}')
instance_names = [instance_path.split('/')[-1].split('.')[0] for instance_path in instance_paths]
print(f'\ninstance names: {instance_names}')

In [None]:
metrics = ['num_nodes', 'lp_iterations']

In [None]:
# solve each instance with strong branching
instance_name_to_sb_metrics = {metric: {instance_name: None for instance_name in instance_names} for metric in metrics}
for instance_path, instance_name in zip(instance_paths, instance_names):
    print(f'\n> Instance {instance_name} <')
    
    # init env
    env = EcoleBranching(observation_function='default',
                         information_function='default',
                         reward_function='default',
                         scip_params='default')
    env.seed(0)
    
    # init agent
    agent = StrongBranchingAgent()
    
    # init instance
    instance = pyscipopt.Model()
    instance.readProblem(instance_path)
    instance = ecole.scip.Model.from_pyscipopt(instance)
    
    # reset env with instance
    agent.before_reset(instance)
    obs, action_set, reward, done, info = env.reset(instance)
    
#     raise Exception() # comment if want to gen new data
    
    # solve
    t = 0
    print(f'Step {t} | Num nodes: {info["num_nodes"]}')
    while not done:
        action, action_idx = agent.action_select(action_set, env.model, done)
        obs, action_set, reward, done, info = env.step(action)
        t += 1
        print(f'Step {t} | Num nodes: {info["num_nodes"]} | LP iterations: {info["lp_iterations"]}')
    for metric in metrics:
        instance_name_to_sb_metrics[metric][instance_name] = info[metric]
    
print(f'\nStrong branching metrics for each instance:\n{instance_name_to_sb_metrics}')

In [None]:
# save
with open(f'{instances_path}/instance_name_to_sb_metrics.json', 'w') as fp:
    json.dump(instance_name_to_sb_metrics, fp)

In [None]:
# load
with open(f'{instances_path}/instance_name_to_sb_metrics.json', 'r') as fp:
    instance_name_to_sb_metrics = json.load(fp)
print(instance_name_to_sb_metrics)

In [None]:
# solve each instance with baseline rl agent
instance_name_to_baseline_metrics = {metric: {instance_name: None for instance_name in instance_names} for metric in metrics}
for instance_path, instance_name in zip(instance_paths, instance_names):
    print(f'\n> Instance {instance_name} <')
    
    # init env
    env = EcoleBranching(observation_function='default',
                         information_function='default',
                         reward_function='default',
                         scip_params='default')
    env.seed(0)
    
    # init agent
    if nrows == 100 and ncols == 100:
        emb_size, num_rounds = 128, 2
        policy_network_path = '/scratch/datasets/retro_branching/dqn_learner/dqn_gnn/dqn_gnn_161/checkpoint_11/value_network_1_params.pkl'
        name = 'dqn_gnn_161'
    elif nrows == 500 and ncols == 1000:
        emb_size, num_rounds = 64, 1
        policy_network_path = '/scratch/datasets/retro_branching/supervised_learner/gnn/gnn_21/checkpoint_275/trained_params.pkl'
        name = 'gnn_21_checkpoint_275'
    else:
        raise Exception('Not implemented')
    policy_network = BipartiteGCN(device='cpu',
                                  emb_size=emb_size, # 64 128
                                  num_rounds=num_rounds, # 1 2
                                  cons_nfeats=5,
                                  edge_nfeats=1,
                                  var_nfeats=19,
                                  aggregator='add')
    policy_network.load_state_dict(torch.load(policy_network_path, map_location='cpu'))
    
    agent = REINFORCEAgent(policy_network=policy_network,
                           device='cpu',
                           temperature=1.0,
                           name=name)
    agent.eval()
    
#     raise Exception() # comment if want to gen new data
    
    # init instance
    instance = pyscipopt.Model()
    instance.readProblem(instance_path)
    instance = ecole.scip.Model.from_pyscipopt(instance)
    
    # reset env with instance
    obs, action_set, reward, done, info = env.reset(instance)
    
    # solve
    t = 0
    print(f'Step {t} | Num nodes: {info["num_nodes"]}')
    while not done:
        action, action_idx = agent.action_select(action_set=action_set, obs=obs)
        obs, action_set, reward, done, info = env.step(action)
        t += 1
        print(f'Step {t} | Num nodes: {info["num_nodes"]} | LP iterations: {info["lp_iterations"]}')
    for metric in metrics:
        instance_name_to_baseline_metrics[metric][instance_name] = info[metric]
    
print(f'\nAgent {agent.name} metrics for each instance:\n{instance_name_to_baseline_metrics}')

In [None]:
# save
with open(f'{instances_path}/instance_name_to_baseline_metrics.json', 'w') as fp:
    json.dump(instance_name_to_baseline_metrics, fp)

In [None]:
# load
with open(f'{instances_path}/instance_name_to_baseline_metrics.json', 'r') as fp:
    instance_name_to_baseline_metrics = json.load(fp)
print(instance_name_to_baseline_metrics)

In [None]:
# plot learning curves relative to above metrics for each instance

learner = 'reinforce_learner' # 'reinforce_learner' 'dqn_learner'
base_name = 'rl_gnn' # 'rl_gnn' 'dqn_gnn'
net_name = 'policy_network_params' # 'policy_network_params' 'value_network_1_params'

ids = [i for i in range(1, 11)]
id_to_instance_name = {ids[idx]: instance_names[idx] for idx in range(len(ids))}

instance_name_to_agent_metrics = {metric: {instance_name: None for instance_name in id_to_instance_name.values()} for metric in metrics}
nested_dict = lambda: defaultdict(nested_dict)

logs_dict = {'overfit_agent': {}, f'{agent.name}': {}, 'strong_branching': {}}
for metric in metrics:
    for key in logs_dict.keys():
        logs_dict[key][metric] = []

for i in ids:
    plot_dict = nested_dict()
    
    instance_name = id_to_instance_name[i]
    print(f'\nInstance {instance_name}')
    
    _agent = '{}_{}'.format(base_name, i)
    agent_path = f'{instances_path}/{learner}/{base_name}/{_agent}/'
    
    foldername = get_most_recent_checkpoint_foldername(agent_path, idx=-1)
    path = '{}{}/'.format(agent_path, foldername)
        
    with gzip.open(*glob.glob(path+'episodes_log.pkl'), 'rb') as f:
        log = pickle.load(f)
        
    for metric in metrics:
        plot_dict[_agent]['y_values'] = log[metric]
        plot_dict[_agent]['x_values'] = list(range(len(log[metric])))
        
        # store mean of last 100 episodes
        instance_name_to_agent_metrics[metric][instance_name] = np.mean(log[metric][-100:])
        
        # update logs_dict
        logs_dict['overfit_agent'][metric].append(instance_name_to_agent_metrics[metric][instance_name]) 
        logs_dict[f'{agent.name}'][metric].append(instance_name_to_baseline_metrics[metric][instance_name])
        logs_dict['strong_branching'][metric].append(instance_name_to_sb_metrics[metric][instance_name])
        
        horizontal_lines = {'strong_branching': instance_name_to_sb_metrics[metric][instance_name], f'{agent.name}': instance_name_to_baseline_metrics[metric][instance_name]}
        _ = sns_plot_val_line(plot_dict,
                              moving_average_window=100, # 7 60 200 800
                              plot_unfiltered_data=True,
                              horizontal_lines=horizontal_lines,
                              unfiltered_data_alpha=0.3,
                              xlabel='Episode',
                              ylabel=metric,
                              title=f'{instance_name} {metric}',
                              show_fig=True)
        
print(f'\nAgent {_agent} metrics for each instance:\n{instance_name_to_agent_metrics}')

In [None]:
# solve each instance with final agent network
instance_name_to_agent_metrics = {metric: {instance_name: None for instance_name in instance_names} for metric in metrics}
for instance_path, instance_name in zip(instance_paths, instance_names):
    print(f'\n> Instance {instance_name} <')
    
    # init env
    env = EcoleBranching(observation_function='default',
                         information_function='default',
                         reward_function='default',
                         scip_params='default')
    env.seed(0)
    
    # init agent
    _agent = '{}_{}'.format(base_name, i)
    agent_path = f'{instances_path}/{learner}/{base_name}/{_agent}/'
    agent_path += get_most_recent_checkpoint_foldername(agent_path, idx=-1)
    
    policy_network = BipartiteGCN(device='cpu',
                                  emb_size=emb_size, # 64 128
                                  num_rounds=num_rounds, # 1 2
                                  cons_nfeats=5,
                                  edge_nfeats=1,
                                  var_nfeats=19,
                                  aggregator='add')
    policy_network.load_state_dict(torch.load(agent_path+f'/{net_name}.pkl', map_location='cpu'))
    
    agent = REINFORCEAgent(policy_network=policy_network,
                           device='cpu',
                           temperature=1.0,
                           name='overfit_agent')
    agent.eval()
    
#     raise Exception() # comment if want to gen new data
    
    # init instance
    instance = pyscipopt.Model()
    instance.readProblem(instance_path)
    instance = ecole.scip.Model.from_pyscipopt(instance)
    
    # reset env with instance
    obs, action_set, reward, done, info = env.reset(instance)
    
    # solve
    t = 0
    print(f'Step {t} | Num nodes: {info["num_nodes"]}')
    while not done:
        action, action_idx = agent.action_select(action_set=action_set, obs=obs)
        obs, action_set, reward, done, info = env.step(action)
        t += 1
        print(f'Step {t} | Num nodes: {info["num_nodes"]} | LP iterations: {info["lp_iterations"]}')
    for metric in metrics:
        instance_name_to_agent_metrics[metric][instance_name] = info[metric]
    
print(f'\nAgent {agent.name} metrics for each instance:\n{instance_name_to_agent_metrics}')

In [None]:
print(logs_dict.keys())

In [None]:
# plot hist

x_tick_freq = 1
for metric in metrics:
    data = {'Instance': [], f'{metric}': [], 'Agent': []}
    fig = plt.figure()
    
    for baseline_name in logs_dict.keys():
        data['Instance'] += [i for i in range(len(logs_dict[baseline_name][metric]))]
        data[f'{metric}'] += logs_dict[baseline_name][metric]
        data['Agent'] += [baseline_name for _ in range(len(logs_dict[baseline_name][metric]))]
        
    data = pd.DataFrame(data)
    g = sns.catplot(data=data, kind='bar', x='Instance', y=f'{metric}', hue='Agent', palette='hls')
    
    colours = iter(sns.color_palette(palette='hls', n_colors=len(list(logs_dict.keys())), desat=None))
    for baseline_name in logs_dict.keys():
        plt.axhline(y=np.mean(logs_dict[baseline_name][metric]), color=next(colours), linestyle='--', alpha=1)
        
    for counter, label in enumerate(g.ax.xaxis.get_ticklabels()):
        if counter % x_tick_freq == 0:
            label.set_visible(True)
        else:
            label.set_visible(False)  
    plt.show()

In [None]:
# horizontal bar of agent score normalsied w.r.t. some other baseline (e.g. strong branching)
agent_name = 'overfit_agent'
baseline_agent_name = 'strong_branching'
for metric in metrics:
    fig = plt.figure()
    
    # gather data
    agent_data = {'Instance': [i for i in range(len(logs_dict[agent_name][metric]))], 
                  f'{metric}': np.array(logs_dict[agent_name][metric]),
                  'Agent': [agent_name for i in range(len(logs_dict[agent_name][metric]))]}
    baseline_data = {'Instance': [i for i in range(len(logs_dict[baseline_agent_name]))], 
                     f'{metric}': np.array(logs_dict[baseline_agent_name][metric])}
    
    # normalise agent metric data w.r.t. baseline metric data
    agent_data[f'{metric}'] /= baseline_data[f'{metric}']
    
    # count % of instances agent metric was lesser/equal/greater than baseline agent metric
    percent_lesser = 100 * np.count_nonzero(agent_data[f'{metric}'] < 1) / len(agent_data[f'{metric}'])
    percent_equal = 100 * np.count_nonzero(agent_data[f'{metric}'] == 1) / len(agent_data[f'{metric}'])
    percent_greater = 100 * np.count_nonzero(agent_data[f'{metric}'] > 1) / len(agent_data[f'{metric}'])
    
    # plot
    agent_data = pd.DataFrame(agent_data)
    sns.barplot(data=agent_data, x=f'{metric}', y='Instance', hue='Agent', palette='pastel', orient='h')
    plt.axvline(x=1, linestyle='--', label=baseline_agent_name, alpha=1, color='g')
    
    # legend, informative axes titles, title
    ax = plt.gca()
    ax.legend(ncol=1, frameon=True)
    ax.set(xlim=None, ylabel='Instance', xlabel=f'{baseline_agent_name}-normalised {metric}')
    plt.title(f'{agent_name} <, ==, > {baseline_agent_name}: {percent_lesser:.1f}%, {percent_equal:.1f}%, {percent_greater:.1f}%')
    
    plt.show()
    
    

# Save Results

In [None]:
print(logs_dict)
with open(f'{instances_path}/logs_dict.json', 'w') as fp:
    json.dump(logs_dict, fp)

In [None]:
# can now re-load as desired
with open(f'{instances_path}/logs_dict.json', 'r') as fp:
    loaded_logs_dict = json.load(fp)
print(loaded_logs_dict)