In [2]:
import pandas as pd
import sys
sys.path.append("..")
from src.utils import html_highlight_diffs
from IPython.core.display import display, HTML
import numpy as np

In [3]:
from src.utils import load_predictor, get_ints_to_labels

In [4]:
TASK = "imdb"
STAGE2EXP = "mice_binary"
EDIT_PATH = f"../results/{TASK}/edits/{STAGE2EXP}/edits.csv"

In [5]:
def read_edits(path):
    edits = pd.read_csv(EDIT_PATH, sep="\t", lineterminator="\n", error_bad_lines=False, warn_bad_lines=True)

    if edits['new_pred'].dtype == pd.np.dtype('float64'):
        edits['new_pred'] = edits.apply(lambda row: str(int(row['new_pred']) if not np.isnan(row['new_pred']) else ""), axis=1)
        edits['orig_pred'] = edits.apply(lambda row: str(int(row['orig_pred']) if not np.isnan(row['orig_pred']) else ""), axis=1)
        edits['contrast_pred'] = edits.apply(lambda row: str(int(row['contrast_pred']) if not np.isnan(row['contrast_pred']) else ""), axis=1)
    else:
        edits['new_pred'].fillna(value="", inplace=True)
        edits['orig_pred'].fillna(value="", inplace=True)
        edits['contrast_pred'].fillna(value="", inplace=True)
    return edits

In [6]:
def get_best_edits(edits):
    """ MiCE writes all edits that are found in Stage 2, 
    but we only want to evaluate the smallest per input. 
    Calling get_sorted_e() """
    return edits[edits['sorted_idx'] == 0]
    
def evaluate_edits(edits):
    temp = edits[edits['sorted_idx'] == 0]
    minim = temp['minimality'].mean()
    flipped = temp[temp['new_pred'].astype(str)==temp['contrast_pred'].astype(str)]
    nunique = temp['data_idx'].nunique()
    flip_rate = len(flipped)/nunique
    duration=temp['duration'].mean()
    metrics = {
        "num_total": nunique,
        "num_flipped": len(flipped),
        "flip_rate": flip_rate,
        "minimality": minim,
        "duration": duration,
    }
    for k, v in metrics.items():
        print(f"{k}: \t{round(v, 3)}")
    return metrics

In [7]:
def display_edits(row):
    html_original, html_edited = html_highlight_diffs(row['orig_editable_seg'], row['edited_editable_seg'])
    minim = round(row['minimality'], 3)
    print(f"MINIMALITY: \t{minim}")
    print("")
    display(HTML(html_original))
    display(HTML(html_edited))

def display_classif_results(rows):
    for _, row in rows.iterrows():
        orig_contrast_prob_pred = round(row['orig_contrast_prob_pred'], 3)
        new_contrast_prob_pred = round(row['new_contrast_prob_pred'], 3)
        print("-----------------------")
        print(f"ORIG LABEL: \t{row['orig_pred']}")
        print(f"CONTR LABEL: \t{row['contrast_pred']} (Orig Pred Prob: {orig_contrast_prob_pred})")
        print(f"NEW LABEL: \t{row['new_pred']} (New Pred Prob: {new_contrast_prob_pred})")
        print("")
        display_edits(row)

def display_race_results(rows):
    for _, row in rows.iterrows():
        orig_contrast_prob_pred = round(row['orig_contrast_prob_pred'], 3)
        new_contrast_prob_pred = round(row['new_contrast_prob_pred'], 3)
        orig_input = eval(row['orig_input'])
        options = orig_input['options']
        print("-----------------------")
        print(f"QUESTION: {orig_input['question']}")
        print("\nOPTIONS:")
        for opt_idx, opt in enumerate(options):
            print(f"  ({opt_idx}) {opt}")
        print(f"\nORIG LABEL: \t{row['orig_pred']}")
        print(f"CONTR LABEL: \t{row['contrast_pred']} (Orig Pred Prob: {orig_contrast_prob_pred})")
        print(f"NEW LABEL: \t{row['new_pred']} (New Pred Prob: {new_contrast_prob_pred})")
        print("")
        display_edits(row)

In [16]:
edits = read_edits(EDIT_PATH)
edits = get_best_edits(edits)
metrics = evaluate_edits(edits)

num_total: 	235
num_flipped: 	170
flip_rate: 	0.723
minimality: 	0.195
duration: 	9.175


  if edits['new_pred'].dtype == pd.np.dtype('float64'):


In [19]:
random_rows = edits.sample(1)
display_classif_results(random_rows)
# display_race_results(random_rows)

-----------------------
ORIG LABEL: 	1
CONTR LABEL: 	0 (Orig Pred Prob: 0.001)
NEW LABEL: 	0 (New Pred Prob: 0.831)

MINIMALITY: 	0.106

