In [None]:
%load_ext autoreload
%autoreload
from retro_branching.utils import get_most_recent_checkpoint_foldername, gen_co_name
from retro_branching.networks import BipartiteGCN
from retro_branching.agents import Agent, REINFORCEAgent, PseudocostBranchingAgent, StrongBranchingAgent, RandomAgent, DQNAgent, DoubleDQNAgent
from retro_branching.environments import EcoleBranching, EcoleConfiguring
from retro_branching.validators import ReinforcementLearningValidator

import ecole
import torch
import numpy as np
import os
import shutil
import glob
import time
import gzip
import pickle
import copy

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# import ray
# import psutil
# num_cpus = psutil.cpu_count(logical=False)
# ray.init(num_cpus=int(num_cpus), ignore_reinit_error=True)

In [None]:
%autoreload

    

# @ray.remote
def run_rl_validator(path,
                     agents,
                     device,
#                      nrows=100, 
#                      ncols=100, 
                     co_class,
                     co_class_kwargs,
                     observation_function='default',
                     scip_params='default',
                     threshold_difficulty=None, 
                     max_steps=int(1e12), 
                     max_steps_agent=None,
                     overwrite=False):
    '''
    Cannot pickle ecole objects, so if agent is e.g. 'strong_branching' or 'pseudocost', need to give agent as str so
    can initialise inside this ray remote function.
    '''
    start = time.time()
    agent = agents[path]
    
    if type(agent) == str:
        if agent == 'pseudocost':
            agent = PseudocostBranchingAgent(name='pseudocost')
        elif agent == 'strong_branching':
            agent = StrongBranchingAgent(name='strong_branching')
        elif agent == 'scip_branching':
            class SCIPBranchingAgent:
                def __init__(self):
                    self.name = 'scip_branching'
            agent = SCIPBranchingAgent()
        else:
            raise Exception(f'Unrecognised agent str {agent}, cannot initialise.')
    
    if overwrite:
        # clear all old rl_validator/ folders even if would not be overwritten with current config to prevent testing inconsistencies
        paths = sorted(glob.glob(path+'rl_validator*'))
        for p in paths:
            print('Removing old {}'.format(p))
            shutil.rmtree(p)

    # instances
#     files = glob.glob(f'/scratch/datasets/retro_branching/instances/set_cover_nrows_{nrows}_ncols_{ncols}_density_005_threshold_{threshold_difficulty}/*.mps')
    instances_path = f'/scratch/datasets/retro_branching/instances/{co_class}'
    for key, val in co_class_kwargs.items():
        instances_path += f'_{key}_{val}'
#     instances_path += f'_threshold_{threshold_difficulty}'
#     files = glob.glob(f'/scratch/datasets/retro_branching/instances/{co_class}_nrows_{nrows}_ncols_{ncols}_density_005_threshold_{threshold_difficulty}/*.mps')
    files = glob.glob(instances_path+f'/scip_{scip_params}/*.mps') # CHANGE: Added scip_params to distinguish baselines and validation instances
    instances = iter([ecole.scip.Model.from_file(f) for f in files])
    print('Initialised instances.')

    # env
    if agent.name == 'scip_branching':
        env = EcoleConfiguring(observation_function=observation_function,
                                  information_function='default',
                                  scip_params=scip_params)
    else:
        env = EcoleBranching(observation_function=observation_function, # 'default' '40_var_features'
                             information_function='default',
                             reward_function='default',
                             scip_params=scip_params)
    env.seed(0)
    print('Initialised env.')

    # metrics
    # metrics = ['num_nodes', 'solving_time', 'lp_iterations', 'primal_dual_integral', 'primal_integral', 'dual_integral']
    metrics = ['num_nodes', 'solving_time', 'lp_iterations']
    print('Initialised metrics: {}'.format(metrics))

    # validator
    validator = ReinforcementLearningValidator(agents={agent.name: agent},
                                               envs={agent.name: env},
                                               instances=instances,
                                               metrics=metrics,
                                               calibration_config_path=None,
#                                                calibration_config_path='/home/zciccwf/phd_project/projects/retro_branching/scripts/',
                                               seed=0,
                                               max_steps=max_steps, # int(1e12), 10, 5, 3
                                               max_steps_agent=max_steps_agent,
                                               turn_off_heuristics=False,
                                               min_threshold_difficulty=None,
                                               max_threshold_difficulty=None, # None 250
                                               threshold_agent=None,
                                               threshold_env=None,
                                               episode_log_frequency=1,
                                               path_to_save=path,
                                               overwrite=overwrite,
                                               checkpoint_frequency=10)
    print('Initialised validator. Will save to {}'.format(validator.path_to_save))

    # run validation tests
    validator.test(len(files))
    end = time.time()
    print('Finished path {} validator in {} s'.format(path, round(end-start, 3)))

In [None]:
%autoreload
learner = 'dqn_learner' # 'reinforce_learner' 'dqn_learner' 'supervised_learner'
base_name = 'dqn_gnn' # 'rl_gnn' 'dqn_gnn' 'gnn'
AgentClass = Agent # Agent REINFORCEAgent DQNAgent DoubleDQNAgent

# CHANGE: Added scip_params to distinguish baselines and validation instances
# scip_params = 'default'
# scip_params = 'gasse_2019'
scip_params = 'dfs'
# scip_params = 'bfs'
# scip_params = 'uct'

# # SC
co_class = 'set_covering'
# co_class_kwargs = {'n_rows': 100, 'n_cols': 100}
# co_class_kwargs = {'n_rows': 165, 'n_cols': 230}
# co_class_kwargs = {'n_rows': 250, 'n_cols': 500}
# co_class_kwargs = {'n_rows': 300, 'n_cols': 500}
co_class_kwargs = {'n_rows': 500, 'n_cols': 1000}
# co_class_kwargs = {'n_rows': 1000, 'n_cols': 1000}

# CA
# co_class = 'combinatorial_auction'
# co_class_kwargs = {'n_items': 10, 'n_bids': 50}
# co_class_kwargs = {'n_items': 23, 'n_bids': 67}
# co_class_kwargs = {'n_items': 37, 'n_bids': 83}

# # CFL
# co_class = 'capacitated_facility_location'
# co_class_kwargs = {'n_customers': 5, 'n_facilities': 5}
# co_class_kwargs = {'n_customers': 5, 'n_facilities': 8}
# co_class_kwargs = {'n_customers': 5, 'n_facilities': 12}

# # MIS
# co_class = 'maximum_independent_set'
# co_class_kwargs = {'n_nodes': 25}
# co_class_kwargs = {'n_nodes': 42}
# co_class_kwargs = {'n_nodes': 58}


threshold_difficulty = None
last_checkpoint_idx = None
max_steps = int(1e12) # int(1e12) 3
overwrite = True

In [None]:
%autoreload
# agent
agents = {}




# ####### NEW (using config.json file(s) to initialise) ####
i = 1481 # 1236 343 1094 341
checkpoint = 166 # 457 233 108 120
observation_function = '43_var_features' # 'default' '43_var_features'
agent_name = f'{base_name}_{i}'
device = 'cuda:3' # 'cpu' 'cuda:0'
agents, agent_paths = {}, []
path = f'/scratch/datasets/retro_branching/{learner}/{base_name}/{agent_name}/checkpoint_{checkpoint}/'
config = path + '/config.json'
agent = AgentClass(device=device, config=config)
agent.name = f'{agent_name}_checkpoint_{checkpoint}'
for network_name, network in agent.get_networks().items():
#     if network_name == 'networks':
#         # TEMPORARY: Fix typo
#         network_name = 'network'
    if network is not None:
        try:
            # see if network saved under same var as 'network_name'
            agent.__dict__[network_name].load_state_dict(torch.load(path+f'/{network_name}_params.pkl', map_location=device))
        except KeyError:
            # network saved under generic 'network' var (as in Agent class)
            agent.__dict__['network'].load_state_dict(torch.load(path+f'/{network_name}_params.pkl', map_location=device))
    else:
        print(f'{network_name} is None.')
        
# # TEMPORARY
# agent.default_epsilon = 0
# agent.name = agent.name + f'_eps_{agent.default_epsilon}'
        
agent.eval() # put in test mode
path_to_save_baseline = f'/scratch/datasets/retro_branching/instances/{gen_co_name(co_class, co_class_kwargs)}/scip_{scip_params}/baselines/{agent.name}/'
agents[path_to_save_baseline] = agent










# # # ######### gnn
# gnn = 'gnn_302' # 'gnn_1' 'gnn_21' 'gnn_265'
# checkpoint = 'checkpoint_36' # 'checkpoint_1' 'checkpoint_275' 'checkpoint_305'
# policy_network = BipartiteGCN(device=device,
#                            config=None,
#                            emb_size=128,
#                            num_rounds=2,
#                            cons_nfeats=5,
#                            edge_nfeats=1,
#                            var_nfeats=19,
#                            aggregator='add',
#                            name=gnn+'_'+checkpoint)
# try:
#     policy_network.load_state_dict(torch.load(f'/scratch/datasets/retro_branching/supervised_learner/gnn/{gnn}/{checkpoint}/trained_params.pkl'))
# except FileNotFoundError:
#     policy_network.load_state_dict(torch.load(f'/scratch/datasets/retro_branching/supervised_learner/gnn/{gnn}/{checkpoint}/network_params.pkl'))
# # agent = REINFORCEAgent(policy_network=policy_network,
# #                        device=device,
# #                        name=policy_network.name)
# # agent = DQNAgent(value_network=policy_network,
# #                        device=device,
# #                        name=policy_network.name)
# agent = DoubleDQNAgent(value_network_1=policy_network,
#                        value_network_2=copy.deepcopy(policy_network),
#                        device=device,
#                        name=policy_network.name)
# agent.eval()
# agent_path = f'/scratch/datasets/retro_branching/instances/set_cover_nrows_{nrows}_ncols_{ncols}_density_005_threshold_{threshold_difficulty}/baselines/{agent.name}/'
# agents[agent_path] = agent





############ rand
# device = 'cpu'
# num_rand_agents = 20
# policy_networks = {f'random_{i}': {'filter_network': None, # None 'gnn_235
#                                    'filter_method': 'method_2',
#                                    'checkpoints': [1],
#                                    'emb_size': 64,
#                                    'num_rounds': 1,
#                                    'cons_nfeats': 5,
#                                    'edge_nfeats': 1,
#                                    'var_nfeats': 19,
#                                    'aggregator': 'add'} for i in range(1, num_rand_agents+1)}

# # load and initialise agents
# agents, envs, agent_paths = {}, {}, []


# for agent_name in policy_networks.keys():
#     agent_path = f'/scratch/datasets/retro_branching/instances/set_cover_nrows_{nrows}_ncols_{ncols}_density_005_threshold_{threshold_difficulty}/baselines/{agent_name}/'
#     agent_paths.append(agent_path) # useful for overwriting later in script

#     # collect agent NN training checkpoint parameters
#     policy_network = BipartiteGCN(device,
#                                 emb_size=policy_networks[agent_name]['emb_size'],
#                                 num_rounds=policy_networks[agent_name]['num_rounds'],
#                                 cons_nfeats=policy_networks[agent_name]['cons_nfeats'],
#                                 edge_nfeats=policy_networks[agent_name]['edge_nfeats'],
#                                 var_nfeats=policy_networks[agent_name]['var_nfeats'],
#                                 aggregator=policy_networks[agent_name]['aggregator'])

#     if policy_networks[agent_name]['filter_network'] is not None:
#         filter_name = policy_networks[agent_name]['filter_network']
#         filter_network = BipartiteGCN(device,
#                                 emb_size=filter_networks[filter_name]['emb_size'],
#                                 num_rounds=filter_networks[filter_name]['num_rounds'],
#                                 cons_nfeats=filter_networks[filter_name]['cons_nfeats'],
#                                 edge_nfeats=filter_networks[filter_name]['edge_nfeats'],
#                                 var_nfeats=filter_networks[filter_name]['var_nfeats'],
#                                 aggregator=filter_networks[filter_name]['aggregator'])
# #             filter_network.load_state_dict(torch.load('/scratch/datasets/retro_branching/supervised_learner/gnn/{}/checkpoint_{}/trained_params.pkl'.format(filter_name, filter_networks[filter_name]['checkpoint']), map_location=device))
#     else:
#         filter_network = None

#     path = '{}/'.format(agent_path)
# #     policy_network.load_state_dict(torch.load(path+'trained_params.pkl', map_location=device))
# #     print('Loaded params from {}'.format(path))
#     agent = REINFORCEAgent(policy_network=policy_network, filter_network=filter_network, device=device, name=agent_name, filter_method=policy_networks[agent_name]['filter_method'])
#     agent.eval() # turn on evaluation mode
#     agents[path] = agent

    
    
    
############### random agent
# device = 'cpu'
# for i in range(1, num_rand_agents+1):
#     agents[f'/scratch/datasets/retro_branching/instances/set_cover_nrows_{nrows}_ncols_{ncols}_density_005_threshold_{threshold_difficulty}/baselines/random_agent_{i}/'] = RandomAgent(name=f'random_agent_{i}')





# ############# pseudocost agent
# device = 'cpu'
# agent = 'pseudocost'
# observation_function = 'default'
# agent_path = f'/scratch/datasets/retro_branching/instances/{gen_co_name(co_class, co_class_kwargs)}/scip_{scip_params}/baselines/pseudocost/'
# agents[agent_path] = agent

# ############ strong branching agent
# device = 'cpu'
# agent = 'strong_branching'
# observation_function = 'default'
# agent_path = f'/scratch/datasets/retro_branching/instances/{gen_co_name(co_class, co_class_kwargs)}/scip_{scip_params}/baselines/strong_branching/'
# agents[agent_path] = agent

# ############ scip branching agent
# device = 'cpu'
# agent = 'scip_branching'
# observation_function = 'default'
# agent_path = f'/scratch/datasets/retro_branching/instances/{gen_co_name(co_class, co_class_kwargs)}/scip_{scip_params}/baselines/{agent}/'
# agents[agent_path] = agent



# path = '/scratch/datasets/retro_branching/reinforce_learner/rl_gnn/rl_gnn_634/checkpoint_1'
# device = 'cpu'
# config = f'{path}/config.json'
# # config = f'/scratch/datasets/retro_branching/reinforce_learner/rl_gnn/rl_gnn_634/checkpoint_1/config.json'
# print(config)
# agent = REINFORCEAgent(device=device, config=config)
# agent.policy_network.load_state_dict(torch.load(f'{path}/trained_params.pkl', map_location=device))
# if agent.filter_network is not None:
#     agent.filter_network.load_state_dict(torch.load(f'{path}/filter_params.pkl', map_location=device))
# agent.eval()


max_steps_agent = None



print(agents)

In [None]:
%autoreload

# # RAY
# result_ids = []
# for path in agents.keys():
#     result_ids.append(run_rl_validator.remote(path, 
#                                               agents, 
#                                               device,
#                                               nrows, 
#                                               ncols, 
#                                               observation_function, 
#                                               threshold_difficulty, 
#                                               max_steps, 
#                                               max_steps_agent, 
#                                               overwrite))
    

# start = time.time()
# _ = ray.get(result_ids)
# end = time.time()


# NON-RAY
start = time.time()
for path in agents.keys():
    run_rl_validator(path, 
                      agents, 
                      device,
#                       nrows, 
#                       ncols, 
                      co_class,
                      co_class_kwargs,
                      observation_function, 
                      scip_params,
                      threshold_difficulty, 
                      max_steps, 
                      max_steps_agent, 
                      overwrite)
end = time.time()

print('Finished in {}'.format(end-start))

# Baseline Visualisation

In [None]:
x_tick_freq = 10 # frequency with which to draw x-axis labels on graph

## Metrics Mean

In [None]:
%autoreload
if agent in ['scip_branching', 'pseudocost', 'strong_branching']:
    class SCIPBranchingAgent:
        def __init__(self):
            self.name = agent
    agent = SCIPBranchingAgent()

baselines = sorted(glob.glob(f'/scratch/datasets/retro_branching/instances/{gen_co_name(co_class, co_class_kwargs)}/scip_{scip_params}/baselines/*'))
print(f'Saved baselines available: {baselines}')
if co_class == 'set_covering':
    if co_class_kwargs['n_rows'] == 100:
    #     baselines_to_show = ['strong_branching', 'pseudocost', 'rl_gnn_641_checkpoint_1997', 'gnn_341_checkpoint_120', 'gnn_339_checkpoint_90', 'dqn_gnn_1094_checkpoint_108', 'scip_branching']
        baselines_to_show = ['strong_branching', 'pseudocost', 'gnn_341_checkpoint_120', 'gnn_339_checkpoint_90', 'dqn_gnn_1318_checkpoint_58']
    elif co_class_kwargs['n_rows'] == 500:
    #     baselines_to_show = ['gnn_21_checkpoint_275', 'dqn_gnn_1094_checkpoint_108', 'gnn_341_checkpoint_120', 'gnn_343_checkpoint_233']
    #     baselines_to_show = ['gnn_343_checkpoint_233', 'dqn_gnn_1147_checkpoint_64', 'pseudocost', 'strong_branching', 'dqn_gnn_1226_checkpoint_200', 'scip_branching']
        baselines_to_show = ['gnn_343_checkpoint_233', 'dqn_gnn_1147_checkpoint_64', 'pseudocost', 'strong_branching', 'dqn_gnn_1226_checkpoint_200', 'dqn_gnn_1484_checkpoint_79', 'gnn_361_checkpoint_139']
    elif co_class_kwargs['n_rows'] == 1000:
        baselines_to_show = ['gnn_343_checkpoint_233']
    elif co_class_kwargs['n_rows'] == 165:
        baselines_to_show = ['gnn_356_checkpoint_104', 'pseudocost', 'strong_branching']
    elif co_class_kwargs['n_rows'] == 250:
        baselines_to_show = ['pseudocost', 'strong_branching', 'gnn_358_checkpoint_268']
    elif co_class_kwargs['n_rows'] == 300:
        baselines_to_show = ['gnn_357_checkpoint_173']
    else:
        raise Exception(f'Unrecognised n_rows {co_class_kwargs["n_rows"]}')
elif co_class == 'combinatorial_auction':
    if co_class_kwargs['n_items'] == 10:
        baselines_to_show = ['pseudocost', 'strong_branching', 'gnn_347_checkpoint_124']
    elif co_class_kwargs['n_items'] == 23:
        baselines_to_show = ['pseudocost', 'strong_branching', 'gnn_348_checkpoint_128']
    elif co_class_kwargs['n_items'] == 37:
        baselines_to_show = ['pseudocost', 'strong_branching', 'gnn_349_checkpoint_98']
    else:
        raise Exception(f'Unrecognised n_items {co_class_kwargs["n_items"]}')
elif co_class == 'capacitated_facility_location':
    if co_class_kwargs['n_facilities'] == 5:
        baselines_to_show = ['pseudocost', 'strong_branching', 'gnn_350_checkpoint_104']
    elif co_class_kwargs['n_facilities'] == 8:
        baselines_to_show = ['pseudocost', 'strong_branching', 'gnn_351_checkpoint_69']
    elif co_class_kwargs['n_facilities'] == 12:
        baselines_to_show = ['pseudocost', 'strong_branching', 'gnn_352_checkpoint_131']
    else:
        raise Exception(f'Unrecognised n_facilities {co_class_kwargs["n_facilities"]}')
elif co_class == 'maximum_independent_set':
    if co_class_kwargs['n_nodes'] == 25:
        baselines_to_show = ['pseudocost', 'strong_branching', 'gnn_353_checkpoint_209']
    elif co_class_kwargs['n_nodes'] == 42:
        baselines_to_show = ['pseudocost', 'strong_branching', 'gnn_354_checkpoint_193']
    elif co_class_kwargs['n_nodes'] == 58:
        baselines_to_show = ['pseudocost', 'strong_branching', 'gnn_355_checkpoint_158']
    else:
        raise Exception(f'Unrecognised n_nodes {co_class_kwargs["n_nodes"]}')
else:
    print(f'Not yet configured which baselines to show for co_class {co_class}')
baselines_to_show += [agent.name]

baselines_logs_dict = {}
baseline_to_mean = {}
for baseline in baselines:
    baseline_name = baseline.split('/')[-1]
    if baseline_name in baselines_to_show:
        print('')
        baselines_logs_dict[baseline_name] = {}
        path = baseline + '/rl_validator/rl_validator_1/'
        path += get_most_recent_checkpoint_foldername(path)
        with gzip.open(*glob.glob(path+'/*log.pkl'), 'rb') as f:
            log = pickle.load(f)
            for metric in log['metrics']:
                baselines_logs_dict[baseline_name][metric] = [abs(np.sum(rewards)) for rewards in log[baseline_name][metric]]
                metric_mean = np.mean(baselines_logs_dict[baseline_name][metric])
                print('{} mean {}: {}'.format(baseline_name, metric, metric_mean))
            

for metric in log['metrics']:
    fig = plt.figure()
    class_colours = iter(sns.color_palette(palette='hls', n_colors=len(list(baselines_logs_dict.keys())), desat=None))
    for baseline_name in sorted(baselines_logs_dict.keys()):
        metric_mean = np.mean(baselines_logs_dict[baseline_name][metric])
        plt.axhline(y=metric_mean, color=next(class_colours), linestyle='--', label=baseline_name)
    frame = plt.gca()
    frame.axes.get_xaxis().set_visible(False)
    plt.ylabel(f'mean {metric}')
    plt.legend()
    plt.show()

## Metrics Hist

In [None]:
if co_class == 'set_covering':
    if co_class_kwargs['n_rows'] == 100:
    #     baseline_agent_name = 'strong_branching' # agent to normalise performance with respect to
    #     baseline_agent_name = 'gnn_339_checkpoint_90'
        baseline_agent_name = 'gnn_341_checkpoint_120'
    elif co_class_kwargs['n_rows'] == 500 or co_class_kwargs['n_rows'] == 1000:
    #     baseline_agent_name = 'gnn_21_checkpoint_275'
    #     baseline_agent_name = 'gnn_341_checkpoint_120'
    #     baseline_agent_name = 'dqn_gnn_1094_checkpoint_108'
        baseline_agent_name = 'gnn_343_checkpoint_233'
    else:
        raise Exception(f'Unrecognised nrows {co_class_kwargs["n_rows"]}')
else:
    print(f'Not yet configured which baselines to show for co_class {co_class}')
print(baselines_logs_dict.keys())

for metric in log['metrics']:
    data = {'Instance': [], f'{metric}': [], 'Agent': []}
    fig = plt.figure()
#     for baseline_name in sorted(baselines_logs_dict.keys()):
#         data['Instance'] += [i for i in range(len(baselines_logs_dict[baseline_name][metric]))]
#         data[f'{metric}'] += baselines_logs_dict[baseline_name][metric]
#         data['Agent'] += [baseline_name for _ in range(len(baselines_logs_dict[baseline_name][metric]))]

    data['Instance'] += [i for i in range(len(baselines_logs_dict[baseline_agent_name][metric]))]
    data[f'{metric}'] += baselines_logs_dict[baseline_agent_name][metric]
    data['Agent'] += [baseline_agent_name for _ in range(len(baselines_logs_dict[baseline_agent_name][metric]))]
    
    data['Instance'] += [i for i in range(len(baselines_logs_dict[agent.name][metric]))]
    data[f'{metric}'] += baselines_logs_dict[agent.name][metric]
    data['Agent'] += [agent.name for _ in range(len(baselines_logs_dict[agent.name][metric]))]
    
    data = pd.DataFrame(data)
    g = sns.catplot(data=data, kind='bar', x='Instance', y=f'{metric}', hue='Agent', palette='hls')
    for counter, label in enumerate(g.ax.xaxis.get_ticklabels()):
        if counter % x_tick_freq == 0:
            label.set_visible(True)
        else:
            label.set_visible(False)  
    plt.show()

# Metrics Horizontal Bar Chart

In [None]:

# if nrows == 100:
# #     baseline_agent_name = 'strong_branching' # agent to normalise performance with respect to
# #     baseline_agent_name = 'gnn_339_checkpoint_90'
#     baseline_agent_name = 'gnn_341_checkpoint_120'
# elif nrows == 500:
# #     baseline_agent_name = 'gnn_21_checkpoint_275'
#     baseline_agent_name = 'gnn_341_checkpoint_120'
# else:
#     raise Exception(f'Unrecognised nrows {nrows}')
# print(baselines_logs_dict.keys())

for metric in log['metrics']:
    
    for agent_name in sorted(baselines_logs_dict.keys()):
        if agent_name != baseline_agent_name:
            fig = plt.figure()
            
            # gather data
            agent_data = {'Instance': [i for i in range(len(baselines_logs_dict[baseline_agent_name][metric]))], 
                          f'{metric}': np.array(baselines_logs_dict[agent_name][metric]), 
                          'Agent': [agent_name for _ in range(len(baselines_logs_dict[agent_name][metric]))]}
            
            # normalise agent metric data w.r.t. baseline metric data
            agent_data[f'{metric}'] /= baselines_logs_dict[baseline_agent_name][f'{metric}']

            # count % of instances agent metric was lesser/equal/greater than baseline agent metric
            percent_lesser = 100 * np.count_nonzero(agent_data[f'{metric}'] < 1) / len(agent_data[f'{metric}'])
            percent_equal = 100 * np.count_nonzero(agent_data[f'{metric}'] == 1) / len(agent_data[f'{metric}'])
            percent_greater = 100 * np.count_nonzero(agent_data[f'{metric}'] > 1) / len(agent_data[f'{metric}'])
            
            # plot
            agent_data = pd.DataFrame(agent_data)
            sns.barplot(data=agent_data, x=f'{metric}', y='Instance', hue='Agent', palette='pastel', orient='h')
            plt.axvline(x=1, linestyle='--', label=baseline_agent_name, alpha=1, color='g')

            # legend, informative axes titles, title
            ax = plt.gca()
            ax.legend(ncol=1, frameon=True)
            ax.set(xlim=None, ylabel='Instance', xlabel=f'{baseline_agent_name}-normalised {metric}')
            plt.title(f'{agent_name} <, ==, > {baseline_agent_name}: {percent_lesser:.1f}%, {percent_equal:.1f}%, {percent_greater:.1f}%')

            for counter, label in enumerate(ax.yaxis.get_ticklabels()):
                if counter % x_tick_freq == 0:
                    label.set_visible(True)
                else:
                    label.set_visible(False) 
            
            plt.show()