In [None]:
import pandas as pd
import os

mnli_baseline_files = [
    "20251111_162712-mnli_baseline-1/eval_predictions.jsonl",
    "20251111_193408-mnli_baseline-2/eval_predictions.jsonl",
    "20251111_223025-mnli_baseline-3/eval_predictions.jsonl"
]

def combine_eval_predicitons(files):
    dfs = []
    for f in files:
        parsed_path = parse_eval_path(f)
        df = pd.read_json(f, lines=True)
        df['date'] = parsed_path['date']
        df['time'] = parsed_path['time']
        df['model'] = parsed_path['model']
        df['seed'] = parsed_path['seed']
        dfs.append(df)
    all_df = pd.concat(dfs, ignore_index=True)
    return all_df

def parse_eval_path(path):
    dir_name = os.path.dirname(path)
    date_time, model_name, seed = dir_name.split("-")
    date_str, time_str = date_time.split("_")
    return {
        "date": date_str,
        "time": time_str,
        "model": model_name,
        "seed": seed
    }

def parse_eval_predictions(all_df):
# 1) (Optional) ensure each pairID has all seeds
    expected_seeds = sorted(all_df['seed'].unique())
    pairs_with_all = (
        all_df.groupby('pairID')['seed'].nunique()
        .eq(len(expected_seeds))
    )
    all_df = all_df[all_df['pairID'].isin(pairs_with_all[pairs_with_all].index)]

    # 2) pivot so each seed’s predictions share one row
    base_cols = ['promptID','pairID','premise','hypothesis','genre','label']
    wide = (
        all_df
        .pivot_table(index=base_cols,
                    columns='seed',
                    values=['predicted_label','predicted_scores'],
                    aggfunc='first')
        .reset_index()
    )

    # flatten multiindex columns like ('predicted_label', '42') -> 'predicted_label_seed42'
    wide.columns = [
        f"{c0}_seed{c1}" if isinstance(c0, str) and c0.startswith('predicted_') else c0
        for c0, c1 in (wide.columns if isinstance(wide.columns, pd.MultiIndex)
                    else [(c, '') for c in wide.columns])
    ]

    # 3) count how many models were wrong on each row
    pl_cols = [c for c in wide.columns if str(c).startswith('predicted_label_seed')]
    wide['wrong_count'] = sum((wide[c] != wide['label']).astype(int) for c in pl_cols)

    # 4) keep only rows all models got wrong
    all_wrong = wide[wide['wrong_count'] == len(pl_cols)].copy()

    # 5) (Optional) select only the columns you want to see
    # Base columns + every seed’s predicted_label and predicted_scores + wrong_count
    ps_cols = [c for c in wide.columns if str(c).startswith('predicted_scores_seed')]
    final_cols = base_cols + pl_cols + ps_cols + ['wrong_count']
    result = all_wrong[final_cols] 
    return result   

def display_all_results(df):
    with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.max_colwidth', None):
        display(df)
        
all_df = combine_eval_predicitons(mnli_baseline_files)
hard_prompts = parse_eval_predictions(all_df)

In [65]:
genre_counts = (
    hard_prompts["genre"]
    .value_counts()
    .rename_axis("genre")
    .reset_index(name="num_wrong")
    .sort_values("num_wrong", ascending=False)
)

print(genre_counts)

genre_rate = (
    hard_prompts.groupby("genre").size() /
    all_df.groupby("genre").size()
).sort_values(ascending=False)
print(genre_rate)

        genre  num_wrong
0       slate        315
1   telephone        252
2     fiction        242
3      travel        208
4  government        171
genre
slate         0.053708
telephone     0.042726
fiction       0.040885
travel        0.035088
government    0.029306
dtype: float64


In [76]:
all_df.count()

promptID                   29445
pairID                     29445
premise                    29445
premise_binary_parse       29445
premise_parse              29445
hypothesis                 29445
hypothesis_binary_parse    29445
hypothesis_parse           29445
genre                      29445
label                      29445
predicted_scores           29445
predicted_label            29445
date                       29445
time                       29445
model                      29445
seed                       29445
dtype: int64

In [75]:
hard_prompts.count()

promptID                  1188
pairID                    1188
premise                   1188
hypothesis                1188
genre                     1188
label                     1188
predicted_label_seed1     1188
predicted_label_seed2     1188
predicted_label_seed3     1188
predicted_scores_seed1    1188
predicted_scores_seed2    1188
predicted_scores_seed3    1188
wrong_count               1188
dtype: int64

In [None]:
)