In [None]:
from repepo.steering.utils.helpers import load_eval_result, SteeringConfig

config = SteeringConfig(
    train_split="0%:40%",
    test_split="40%:50%",
)
eval_result = load_eval_result(config.eval_hash)

print(len(eval_result.predictions))

In [None]:
# Plot the mean, std logit diff. 

import numpy as np
import matplotlib.pyplot as plt

print(eval_result.metrics['mean_logit_diff'])
print(eval_result.metrics['std_logit_diff'])

In [None]:
# Print the spearman correlation between logit diff, pos prob. 

from scipy.stats import spearmanr

logit_diffs = [p.metrics['logit_diff'] for p in eval_result.predictions]
pos_probs = [p.metrics['pos_prob'] for p in eval_result.predictions]

print(spearmanr(logit_diffs, pos_probs))

In [None]:
from repepo.core.types import Completion
import pandas as pd
from IPython.display import display, HTML

pd.set_option('display.max_colwidth', 0)

def pretty_print(df):
    # Set text to left-align
    # df = df.style.set_properties(**{'text-align': 'left'})
    html = df.to_html()
    # Show line breaks in the dataframe
    # Reference: https://stackoverflow.com/questions/34322448/pretty-printing-newlines-inside-a-string-in-a-pandas-dataframe
    html = html.replace("\\n","<br>")
    return display(HTML(html))

# Construct a dataframe to visualize the above. 

df = pd.DataFrame([{
    'positive_str: ': pred.positive_output_prob.text,
    'negative_str': pred.negative_output_prob.text,
    'logit_diff': pred.metrics['logit_diff'],
    'pos_prob': pred.metrics['pos_prob'],
} for pred in eval_result.predictions])

pretty_print(df.head())

In [None]:
# Visualize the top-5 and bottom-5 answers by logit diff. 

df = df.sort_values(by='logit_diff', ascending=False)
print("Top-5 by logit diff")
pretty_print(df.head())
print("Bottom-5 by logit diff")
pretty_print(df.tail())