Set 0 < alpha < 1. Give a set of sb_scores, for each score, if score >= (1-alpha) X max(sb_scores), label as 1, else label as 0. E.g. if alpha=0.2, scores which are within 20% of the top SB score will be labelled 1, all others labelled 0 

Idea is that we want to filter out 'poor' (as determined by dual bound reduction i.e. strong branching) actions from the agent's action space so that agent action space is small whilst still enabling it to follow promising trajectories. At alpha=0, all actions will be filtered. At action=1, all actions will be made available. Therefore, we want to have an idea of how big the agent's action space would be using different values of alpha.


In [None]:
import numpy as np
import glob
import pickle
import gzip
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
def conv_scores_to_bipartite_ranking(alpha, scores):
    max_score = np.amax(scores)
    threshold = (1-alpha) * max_score
    return np.where(scores >= threshold, 1, 0)

In [None]:
num_samples = 5000
nrows = 100
ncols = 100
branching = 'pure_strong_branch'
max_steps = 3
path = f'/scratch/datasets/retro_branching/strong_branching/{branching}/max_steps_{max_steps}/nrows_{nrows}_ncols_{ncols}/samples/aggregated_samples/'


In [None]:
files = np.array(glob.glob(path+'*.pkl'))[:num_samples]

alphas = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
# alphas = [0.2, 0.5, 0.9]
plot_dicts = {} # alpha to list of # actions for each sample

pbar = tqdm(total=len(alphas)*len(files))
pbar.set_description('Reading files')
for alpha in alphas:
    plot_dicts[alpha] = {'filtered_actions': []}
    for file in files:
        with gzip.open(file, 'rb') as f:
            sample = pickle.load(f)
            # get sb scores
            action_set, scores = sample[2], sample[3]
            scores = scores[action_set]
            # remove nans
            scores = scores[np.logical_not(np.isnan(scores))]
            # binarise sb scores
            scores = conv_scores_to_bipartite_ranking(alpha=alpha, scores=scores)
            # record number of actions
            plot_dicts[alpha]['filtered_actions'].append(np.count_nonzero(scores == 1))
            pbar.update(1)

In [None]:
for alpha in plot_dicts.keys():
    for actions in plot_dicts[alpha].keys():
        title = f'{actions} alpha={alpha}'
        fig = plt.figure()
        sns.histplot(plot_dicts[alpha][actions], edgecolor='k')
        plt.title(title)
        plt.xlabel('# Actions')
        plt.show()