In [None]:
import matplotlib as mpl
from matplotlib import pyplot as plt
import pickle
import numpy as np
from tqdm.notebook import tqdm
from cost_based_selection import preprocessing_utils
from glob import glob
from scipy import stats
import itertools as it
from IPython.display import display
import pandas as pd
import os

mpl.rcParams['figure.dpi'] = 144
mpl.style.use('../scrartcl.mplstyle')

# Which model to generate figures for
MODEL = os.environ.get('MODEL', 'ba')
# Methods to consider.
METHODS = (
    'JMI JMIM mRMR reliefF_l1 reliefF_rf pen_rf_importance_impurity pen_rf_importance_permutation '
    'weighted_rf_importance_impurity weighted_rf_importance_permutation random_ranking'
).split()
# The split that was used to identify features (always evaluated on the test set). This flag
# is useful for looking at different number of nodes with pilot simulations.
SPLIT = "train"

# Cost regularization.

Figures for the section on cost-based methods with a penalty parameter.

In [None]:
# Load results for all methods.
results_by_method = {}
for method in METHODS:
    with open(f'../workspace/{MODEL}/evaluation/{SPLIT}/{method}.pkl', 'rb') as fp:
        results_by_method[method] = pickle.load(fp)

In [None]:
def plot_accuracy_cost_map(result, ax=None, debug=False, cost_levels=None, colorbar=None,
                           xlabel=r'Penalty $\lambda$', ylabel=r'Number of features $k$',
                           noise_marker='o', vmin=None, clabel_fmt="%1.3f", linscale=0.25):
    """
    Plot the accuracy heat map with cost contours.
    
    Args:
        
    """
    ax = ax or plt.gca()
    cost_levels = cost_levels or [0.01, 0.05, 0.1, 0.2]
    
    # Plot the accuracies
    accuracies = result['accuracy']
    cumulative_costs = result['normalized_cumulative_costs']
    penalties = result['penalties']
    num_features = 1 + np.arange(accuracies.shape[1])
    im = ax.pcolormesh(penalties, num_features, accuracies.mean(axis=-1).T, rasterized=True,
                       vmax=1, vmin=vmin or min(.95, accuracies.mean(axis=-1).min()))

    # Plot the cumulative cost levels.
    cs = ax.contour(penalties, num_features, cumulative_costs.T, levels=cost_levels, colors='k')
    
    # Get the noise features.
    noise_features = np.char.startswith(result['features'], 'noise')
    noise_features = noise_features[result['rankings']][:, :accuracies.shape[1]]
    x, y = np.nonzero(noise_features)
    pts = ax.scatter(penalties[x], 1 + y, marker=noise_marker)
    pts.set_edgecolor('w')
    
    # ax.set_title(key)
    if xlabel:
        ax.set_xlabel(xlabel)
    if ylabel:
        ax.set_ylabel(ylabel)
    linthresh = penalties[1] / 2
    ax.set_xscale('symlog', linthresh=linthresh, linscale=linscale)
    ax.axvline(linthresh, color='k', ls=':')
    
    plt.clabel(cs, fmt=clabel_fmt)
    
    ticks = ax.get_xticks()
    ticks[1] = ticks[0]
    ax.set_xticks(ticks[1:])
    
    if colorbar:
        cb = fig.colorbar(im, ax=ax, location=colorbar)
        cb.set_label('Accuracy')
        cb.locator = mpl.ticker.MaxNLocator(5)
        cb.update_ticks()

    if debug:
        ax.axvline(linthresh)
        ax.scatter(penalties, np.ones_like(penalties), marker='x', color='C1')
        
    return im


# Show for one method as an example.
key = "JMI"
fig, ax = plt.subplots()
plot_accuracy_cost_map(results_by_method[key], colorbar='right')
ax.set_title(key)

In [None]:
# Generate the four-panel figures for different methods.
configs_list = [
    {
        "JMI": {
            "label": {
                "s": "(a)",
                "x": 0.05,
                "y": 0.95,
                "va": "top",
            },
            "levels": [0.01, 0.05, 0.2]
        },
        "mRMR": {
            "label": {
                "s": "(b)",
                "x": 0.05,
                "y": 0.95,
                "va": "top",
            },
            "levels": [0.001, 0.01, 0.05]
        }, 
        "reliefF_l1": {
            "label": {
                "s": "(c)",
                "x": 0.95 if MODEL == "ba" else 0.05,
                "y": 0.05 if MODEL == "ba" else 0.95,
                "ha": "right" if MODEL == "ba" else "left",
                "va": "bottom" if MODEL == "ba" else "top",
            }
        }, 
        "pen_rf_importance_permutation": {
            "label": {
                "s": "(d)",
                "x": 0.05,
                "y": 0.95,
                "va": "top",
            },
            "levels": [0.01, 0.05, 0.2]
        },
    },
    {
        "JMIM": {
            "label": {
                "s": "(a)",
                "x": 0.05,
                "y": 0.95,
                "va": "top",
            },
        },
        "reliefF_rf": {
            "label": {
                "s": "(b)",
                "x": 0.05,
                "y": 0.95,
                "va": "top",
            },
        },
        "weighted_rf_importance_permutation": {
            "label": {
                "s": "(c)",
                "x": 0.05,
                "y": 0.95,
                "va": "top",
            },
        },
        "pen_rf_importance_impurity": {
            "label": {
                "s": "(d)",
                "x": 0.05,
                "y": 0.95,
                "va": "top",
            },
        },
    },
]


for offset, configs in enumerate(configs_list):
    fig = plt.figure()
    gs = fig.add_gridspec(2, 3, width_ratios=[1, 1, .05])
    ax = None
    axes = []
    for i in range(2):
        for j in range(2):
            ax = fig.add_subplot(gs[i, j], sharex=ax, sharey=ax)
            axes.append(ax)
            if i == 0:
                plt.setp(ax.get_xticklabels(), visible=False)
            else:
                ax.set_xlabel(r"Cost regularization $\lambda$")

            if j == 1:
                plt.setp(ax.get_yticklabels(), visible=False)
            else:
                ax.set_ylabel("Number of features $k$")

    cax = fig.add_subplot(gs[:, 2])

    locator = mpl.ticker.MaxNLocator(integer=True)
    for ax, (key, config) in zip(axes, configs.items()):
        im = plot_accuracy_cost_map(
            results_by_method[key], ax=ax, colorbar=False, clabel_fmt=None,
            xlabel=None, ylabel=None, vmin=0.5, cost_levels=config.get("levels"),
            linscale=0.5,
        )
        ax.yaxis.set_major_locator(locator)
        ax.text(**config['label'], transform=ax.transAxes)
        print(key)

    fig.colorbar(im, cax=cax).set_label("Accuracy")
    fig.tight_layout()
    filename = f"{MODEL}-accuracy-matrix-{offset}.pdf"
    fig.savefig(filename)
    print(filename)

# Pilot simulations

Figures for the cost of pilot simulations.

In [None]:
sizes = 100 * (1 + np.arange(10))
seeds = np.arange(5)
configs = [
    (method, f'../workspace/{MODEL}/evaluation/pilot/num_nodes-{num_nodes}/seed-{seed}/{method}.pkl')
    for method, num_nodes, seed in it.product(METHODS, sizes, seeds)
]
results_by_method = {}
accuracies_by_method = {}
for method, filename in configs:
    with open(filename, 'rb') as fp:
        result = pickle.load(fp)
    results_by_method.setdefault(method, []).append(result)
        
    # Verify shape; we have one cost regularization (lambda=0), up to fifteen different features,
    # and ten-fold cross validation.
    accuracy = result["accuracy"]
    assert accuracy.shape == (1, 15, 10)
    
    # We average over cross validation dimensions and pick the single seed.
    accuracy = accuracy[0].mean(axis=-1)
    accuracies_by_method.setdefault(method, []).append(accuracy)

In [None]:
# Verify accuracy didn't decrease independent of size.
fig, ax = plt.subplots()
num_features = 5

for method, accuracy in accuracies_by_method.items():
    accuracy = np.asarray(accuracy).reshape((len(sizes), len(seeds), -1))
    ys = accuracy[..., num_features - 2]
    y = ys.mean(axis=-1)
    line, = ax.plot(sizes, y, label=method)
    l, u = np.quantile(ys, [0, 1], axis=-1)
    ax.fill_between(sizes, l, u, color=line.get_color(), alpha=0.2)
    
ax.legend(loc='best', fontsize='small', ncol=2)
ax.set_xlabel('Number of nodes in pilot simulation')
ax.set_ylabel('Classification accuracy')
ax.set_ylim(0.95)
fig.tight_layout()