In [1]:
import pickle

with open("./precision_recall_data/pr-data.pkl", 'rb') as file:
    pr_data = pickle.load(file)

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.lines as mlines

def save_precision_recall_plots(data, abs_rate=0.1):
    """
    data: a list of dicts with fields:
        {
            'cascade': str,
            'benchmark': str,
            'benchmark_pretty_name': str,
            'precision': np.array,
            'recall': np.array,
            'thresholds': np.array
        }
    abs_rate: float
        The assumed abstention rate, which sets the precision of the random baseline.
    """

    ### STEP 1: identify unique cascades and benchmarks

    # get only the relevant data
    data = [ 
        record for record in data if record['abs_rate_bottom'] == abs_rate 
    ]

    all_cascades = sorted(set(d['cascade'] for d in data))
    all_benchmarks = sorted(set(d['benchmark'] for d in data))
    
    # Create a mapping from (cascade, benchmark) -> data
    cascade_benchmark_dict = {}
    for entry in data:
        key = (entry['cascade'], entry['benchmark'])
        cascade_benchmark_dict[key] = entry

    # For consistent colors across the 6 benchmarks, define a 6-color palette
    color_palette = [
        "#377eb8",  # blue
        "#e41a1c",  # red
        "#4daf4a",  # green
        "#984ea3",  # purple
        "#ff7f00",  # orange
        "#a65628"   # brown
    ]
    # Map benchmarks to colors
    benchmark_color_map = {
        bm: color_palette[i % len(color_palette)]
        for i, bm in enumerate(all_benchmarks)
    }
    
    ### STEP 2: modify matplotlib style settings

    plt.rcParams["figure.figsize"] = (5, 4)
    plt.rcParams["font.size"] = 10
    plt.rcParams["axes.linewidth"] = 1.1
    plt.rcParams["axes.labelsize"] = 11
    plt.rcParams["xtick.labelsize"] = 9
    plt.rcParams["ytick.labelsize"] = 9
    plt.rcParams["lines.linewidth"] = 2
    plt.rcParams["savefig.bbox"] = "tight"
    plt.rcParams["savefig.pad_inches"] = 0.02
    

    ### STEP 3: create one plot per cascade (no legend)

    for cascade in all_cascades:
        fig, ax = plt.subplots()

        # Plot each benchmark's precision-recall
        for bm in all_benchmarks:
            entry = cascade_benchmark_dict.get((cascade, bm), None)
            if entry is None:
                continue  # skip if missing data

            recall = entry['recall']
            precision = entry['precision']
            color = benchmark_color_map[bm]

            ax.plot(recall, precision, color=color)

        # Add a dashed horizontal line at 'abs_rate'
        ax.axhline(
            y=abs_rate,
            color='grey',
            linestyle='--',
            linewidth=1
        )
        
        # Place the caption "Random" near the dashed line
        ax.text(
            0.2,
            abs_rate - 0.01,
            "Random",
            color='grey',
            fontsize=9,
            transform=ax.get_yaxis_transform(),
            va='top',
            ha='center'
        )

        ax.set_xlabel("Recall")
        ax.set_ylabel("Precision")
        ax.grid(True, linestyle='--', linewidth=0.5, alpha=0.7)

        # Save figure
        out_filename = f"precision_recall_{cascade}_abs_rate={abs_rate}.pdf"
        plt.savefig(out_filename, bbox_inches='tight')
        plt.close(fig)

def save_legend_strip(data):
    """
    Create legend for the precision-recall plots, which color-codes the benchmarks.

    Saves the legend strip to file in PDF format.
    """
    # Identify benchmarks in sorted order
    all_benchmarks = sorted(set(d['benchmark'] for d in data))
    # Build a list of (benchmark, pretty_name) in sorted order:
    bm_pretty_pairs = []
    for bm in all_benchmarks:
        # Grab an example entry
        entry = next(e for e in data if e['benchmark'] == bm)
        bm_pretty_pairs.append((bm, entry['benchmark_pretty_name']))

    # Define consistent color palette
    color_palette = [
        "#377eb8",
        "#e41a1c",
        "#4daf4a",
        "#984ea3",
        "#ff7f00",
        "#a65628"
    ]
    bm_color_map = {
        bm: color_palette[i % len(color_palette)]
        for i, bm in enumerate(all_benchmarks)
    }

    # Create the figure for the legend:
    fig, ax = plt.subplots(figsize=(6, 0.6))
    ax.axis('off')  # no actual plot

    # Create line handles for each benchmark
    import matplotlib.lines as mlines
    handles = []
    for bm, bm_pretty in bm_pretty_pairs:
        color = bm_color_map[bm]
        line = mlines.Line2D([], [], color=color, label=bm_pretty, linewidth=2)
        handles.append(line)

    # Place a single legend in the center, spanning multiple columns
    legend = ax.legend(
        handles=handles,
        loc='center',
        ncol=len(handles),
        frameon=False,
        bbox_to_anchor=(0.5, 0.5)
    )

    plt.savefig("precision_recall_legend_strip.pdf", bbox_inches='tight')
    plt.close(fig)

In [3]:
### Generate the precision-recall curves and the benchmark strip

### WARNING: running this code overwrites any files with the following filenames:
### - "precision_recall_legend_strip.pdf"
### - "precision_recall_{x}_abs_rate={y}.pdf" for different values of x and y

for abs_rate in [0.2, 0.3]:  
    save_precision_recall_plots(pr_data, abs_rate=abs_rate)

save_legend_strip(pr_data)