In [None]:
'''
All results processing functions are put together in this script, including the metric processing for multiple runs, prediction results plot, training loss plot, etc.
'''
%matplotlib inline

import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import matthews_corrcoef as mcc_fn
from sklearn.metrics import accuracy_score as acc_fn

exp_dir = os.path.abspath("")
print(exp_dir)

In [None]:
# Process the metric from multiple runs and save the mean+std results

def compute_metrics(preds, labels, hr_threshold, num_classes=7):
    # Compute runtime (rt) and halt rate (hr) based on different halt rate thresholds.
    if labels.shape[1] > 7:
        runtime = labels[:, 2:-1]
    else:
        runtime = labels

    if num_classes == 5:   # Remove kissat3 and bulky
        runtime = torch.concat([runtime[:, 1:5], runtime[:, 6:]], dim=-1)

    base_solver_idx = runtime.mean(dim=0).argmin()
    # print(f'Base solver idx: {base_solver_idx}')

    base_time = runtime[:, base_solver_idx]
    min_time, _ = runtime.min(dim=1) 

    len_data = preds.shape[0]

    base_hr_cnt = (base_time > hr_threshold).sum().float()
    min_hr_cnt = (min_time > hr_threshold).sum().float()
    hr_min = 100. * (base_hr_cnt - min_hr_cnt) / base_hr_cnt 

    pred_top1 = preds.argmax(dim=1) # Predicted top1 solver
    pred_time_top1 = runtime[torch.arange(len_data), pred_top1] 

    # Compute the better time among predicted top2 solvers
    _, pred_idx_top2 = torch.topk(preds, 2, dim=1)
    tmp = torch.zeros((len_data, 2))
    tmp[:, 0] = runtime[torch.arange(len_data), pred_idx_top2[:, 0]]
    tmp[:, 1] = runtime[torch.arange(len_data), pred_idx_top2[:, 1]]
    pred_time_top2, _ = tmp.min(dim=1)

    hr_cnt_top1 = (pred_time_top1 > hr_threshold).sum().float()
    hr_top1 = 100. * (base_hr_cnt - hr_cnt_top1) / base_hr_cnt
    hr_cnt_top2 = (pred_time_top2 > hr_threshold).sum().float()
    hr_top2 = 100. * (base_hr_cnt - hr_cnt_top2) / base_hr_cnt

    rt_top1 = pred_time_top1.mean()
    rt_top2 = pred_time_top2.mean()
    rt_min = min_time.mean()
    rt_base = base_time.mean()

    # Compute mcc and acc
    tgt_idx = runtime.argmin(dim=1)
    mcc = mcc_fn(tgt_idx, pred_top1)
    acc = acc_fn(tgt_idx, pred_top1)

    # Compute the ratio of proved SAT instance for the given cutoff threshold
    ratio_proved_top1 = (pred_time_top1 <= hr_threshold).sum().float() / len_data
    ratio_proved_top2 = (pred_time_top2 <= hr_threshold).sum().float() / len_data

    return {'hr_min': hr_min, 'hr_top1': hr_top1, 'hr_top2': hr_top2, 'base_solver_idx': base_solver_idx,
            'rt_base': rt_base, 'rt_min': rt_min, 'rt_top1': rt_top1, 'rt_top2': rt_top2, 
            'ratio_proved_top1': ratio_proved_top1, 'ratio_proved_top2': ratio_proved_top2, 'mcc': mcc, 'acc': acc}

def process_metrics(num_classes=7):
    hr_options = [100, 200, 300, 400, 500]
    data = []
    for split_idx in range(5):
        run_dir = os.path.join(exp_dir, 'seed_604_split_'+str(split_idx))
        test_label_file = os.path.join(run_dir, 'test_labels.csv')
        test_pred_file = os.path.join(run_dir, 'test_pred_probs.csv')

        if not os.path.isfile(test_pred_file):
            continue

        labels = pd.read_csv(test_label_file).to_numpy()
        preds = pd.read_csv(test_pred_file).to_numpy()

        labels = torch.tensor(labels)
        preds = torch.tensor(preds)

        test_metrics_list = []
        for hr_threshold in hr_options:
            metrics = compute_metrics(preds, labels, hr_threshold, num_classes)
            test_metrics_list.append(metrics)

        test_metrics = pd.DataFrame.from_records(test_metrics_list, index=hr_options).astype(float)
        save_path = os.path.join(run_dir, 'test_metrics_v2.csv')
        test_metrics.to_csv(save_path)

        data.append(test_metrics)

    data_arr = np.dstack([d.to_numpy() for d in data])

    avg = data_arr.mean(axis=2).round(3)
    std = data_arr.std(axis=2).round(3)

    new_df = pd.DataFrame(columns=data[0].columns)
    
    for j, col in enumerate(new_df.columns):
        dat = []
        for i in range(avg.shape[0]):
            avg_std = str(avg[i, j]) + str(u"\u00B1") + str(std[i, j])
            dat.append(avg_std)
        new_df[col] = dat
    
    exp_name = exp_dir.split('/')[-1]
    save_path = os.path.join(exp_dir, f'{exp_name}_all_runs.csv')
    new_df.to_csv(save_path, encoding='utf-8-sig')
    print(new_df)

num_classes=7
process_metrics(num_classes)

In [None]:
# Plot prediction vs labels

for split_idx in range(5):
    run_dir = os.path.join(exp_dir, 'seed_604_split_'+str(split_idx))
    pred_filepath = os.path.join(run_dir, 'test_pred_probs.csv')
    label_filepath = os.path.join(run_dir, 'test_labels.csv')

    if not os.path.isfile(pred_filepath):
        continue

    preds = pd.read_csv(pred_filepath, index_col=False).to_numpy()
    # Note that the labels saved in sat_selection_light has variables other than runtime
    labels = pd.read_csv(label_filepath, index_col=False).to_numpy()
    if labels.shape[1] == 10:
        labels = labels[:, 2:-1]

    if num_classes == 5:   # Remove kissat3 and bulky
        labels = np.concatenate([labels[:, 1:5], labels[:, 6:]], axis=-1)

    pred_idx =  preds.argmax(axis=1)
    tgt_idx = labels.argmin(axis=1)

    plt.figure()

    precision = []
    recall = []
    for i in range(7):
        _idx = (tgt_idx==i)
        TP = (pred_idx[_idx]==i).sum() 
        pred_cnt = (pred_idx==i).sum()
        tgt_cnt = _idx.sum()
        if pred_cnt == 0:
            p = 0
        else:
            p = TP/pred_cnt
        precision.append(p)
        recall.append(TP/tgt_cnt)

    tgts, tgts_cnt = np.unique(tgt_idx, return_counts=True)
    preds, preds_cnt = np.unique(pred_idx, return_counts=True)
    preds_cnt_padded = np.zeros_like(tgts_cnt)   # Pad the zero predicted solver with zeros
    preds_cnt_padded[preds] = preds_cnt


    fig, ax = plt.subplots()
    bar_preds = ax.bar(tgts-0.15, preds_cnt_padded, width=0.30, color='tab:grey', edgecolor='black', alpha=0.6, label='predictions')
    bar_tgts = ax.bar(tgts+0.15, tgts_cnt, width=0.30, color='tab:blue', edgecolor='black', alpha=0.6, label='targets')
    # Add precision and recall values above the bar
    for i, rect in enumerate(bar_preds):
        x = rect.get_x() + rect.get_width()/2
        y = rect.get_height() + 5
        p_r = f"p:{precision[i]:.2f}\nr:{recall[i]:.2f}"
        ax.text(x, y, p_r, ha='center', va='bottom')
    # ax.bar_label(bar_preds, fmt='%.2g', padding=2)
    # ax.bar_label(bar_tgts, fmt='%.2g', padding=2)

    plt.title(f"Top1 Prediction vs Targets - split_{str(split_idx)}")
    plt.legend(loc='best')
    plt.xlabel('Solver Index')
    plt.ylabel('Count')
    plot_savepath = pred_filepath.split('.')[0]+'.png' 
    plt.savefig(plot_savepath)
    plt.show()
    # plt.close()


In [None]:
# Plot training loss

rolling_window_size = 100
for split_idx in range(5):
    run_dir = os.path.join(exp_dir, 'seed_604_split_'+str(split_idx))
    step_metric_file = os.path.join(run_dir, 'step_metrics.csv')
    epoch_metric_file = os.path.join(run_dir, 'epoch_metrics.csv')
    if not os.path.isfile(step_metric_file):
        continue
    
    step_metric = pd.read_csv(step_metric_file, index_col=0)
    epoch_metric = pd.read_csv(epoch_metric_file, index_col=0)

    total_steps = step_metric.shape[0]
    total_epoch = epoch_metric.shape[0]

    val_loss_step = np.arange(total_epoch) * int(total_steps/total_epoch)
    val_loss = epoch_metric['val_loss'].values

    fig = plt.figure()
    plt.plot(step_metric.rolling(rolling_window_size).mean()['train_loss'], label='train_loss')
    plt.plot(val_loss_step, val_loss, color='red', label='val_loss')
    plt.legend(loc='best')
    plt.title(f'Train / val loss in training - split_{split_idx}')
    plot_savepath = os.path.join(run_dir, 'step_loss.png')
    plt.savefig(plot_savepath)
    plt.show()

In [None]:
# Plot the cost of not predicting the best solver

for split_idx in range(5):
    run_dir = os.path.join(exp_dir, 'seed_604_split_'+str(split_idx))
    pred_filepath = os.path.join(run_dir, 'test_pred_probs.csv')
    label_filepath = os.path.join(run_dir, 'test_labels.csv')

    if not os.path.isfile(pred_filepath):
        continue

    preds = pd.read_csv(pred_filepath, index_col=False).to_numpy()
    # Note that the labels saved in sat_selection_light has variables other than runtime
    labels = pd.read_csv(label_filepath, index_col=False).to_numpy()
    if labels.shape[1] == 10:
        runtime = labels[:, 2:-1]

    if num_classes == 5:   # Remove kissat3 and bulky
        runtime = np.concatenate([labels[:, 1:5], labels[:, 6:]], axis=-1)

    pred_idx =  preds.argmax(axis=1)
    tgt_idx = runtime.argmin(axis=1)

    len_data = preds.shape[0]
    pred_rt = runtime[torch.arange(len_data), pred_idx] 
    tgt_rt = runtime.min(axis=1)

    rt_delta = pred_rt - tgt_rt

    rt_delta_cost = rt_delta[rt_delta>0]
    cost_mean = rt_delta_cost.mean()
    cost_std = rt_delta_cost.std()
    print(rt_delta_cost)

    plt.figure()

    # fig, axes = plt.subplots()
    ax = sns.histplot(rt_delta_cost)
    ax.text(1000, 1000, f"mean={cost_mean:.3f}\nstd={cost_std:.3f}")

    plt.title(f"Cost of predicting wrong - split_{str(split_idx)}")
    plt.xlabel('Diff to min runtime')
    plt.ylabel('Count')
    # plot_savepath = pred_filepath.split('.')[0]+'.png' 
    # plt.savefig(plot_savepath)
    plt.show()
    plt.close()

In [None]:
# Plot runtime distribution of solvers

for split_idx in range(5):
    run_dir = os.path.join(exp_dir, 'seed_604_split_'+str(split_idx))
    pred_filepath = os.path.join(run_dir, 'test_pred_probs.csv')
    label_filepath = os.path.join(run_dir, 'test_labels.csv')

    if not os.path.isfile(pred_filepath):
        continue

    # preds = pd.read_csv(pred_filepath, index_col=False).to_numpy()
    # Note that the labels saved in sat_selection_light has variables other than runtime
    labels = pd.read_csv(label_filepath, index_col=False)
    
    if labels.shape[1] == 10:
        col_names = ['#var', '#clause', 'base', 'HyWalk', 'MOSS', 'mabgb', 'ESA', 'bulky', 'UCB', 'MIN']
        labels.columns = col_names
        runtime = labels.iloc[:, 2:-1]
    
    ax = sns.boxplot(runtime)

    means = runtime.mean(axis=0) 
    stds = runtime.std(axis=0) 
    for xtick in ax.get_xticks():
        ax.text(xtick, means[xtick],f"{means[xtick]:.1f}({stds[xtick]:.1f})", 
                horizontalalignment='center',size='x-small',color='k')

    plt.title(f"Runtime by each solver")
    plt.ylabel('Runtime (s)')
    # plot_savepath = pred_filepath.split('.')[0]+'.png' 
    # plt.savefig(plot_savepath)
    plt.show()
    plt.close()