In [1]:
#!/usr/bin/env python3
import os
import pandas as pd
from typing import List, Set

# ─── 1. Evaluation functions ────────────────────────────────────────────────────

def precision_at_k(preds: List[str], gold: Set[str], k: int) -> float:
    """Precision at K: fraction of top-k preds that are in gold."""
    if k <= 0:
        return 0.0
    topk = preds[:k]
    return sum(1 for d in topk if d in gold) / k

def recall_at_k(preds: List[str], gold: Set[str], k: int) -> float:
    """Recall at K: fraction of gold items recovered in top-k."""
    if not gold:
        return 0.0
    topk = preds[:k]
    return sum(1 for d in topk if d in gold) / len(gold)

def f1_at_k(preds: List[str], gold: Set[str], k: int) -> float:
    """F1 at K: harmonic mean of P@K and R@K."""
    p = precision_at_k(preds, gold, k)
    r = recall_at_k(preds, gold, k)
    return 2*p*r/(p+r) if (p+r) > 0 else 0.0

def reciprocal_rank(preds: List[str], gold: Set[str]) -> float:
    """MRR component: 1/rank of first correct prediction."""
    for idx, d in enumerate(preds, start=1):
        if d in gold:
            return 1.0 / idx
    return 0.0

# ─── 2. Main evaluation pipeline ────────────────────────────────────────────────

def main():
    # 2.1 Path to your CSV
    path = "/home/mhoveyda1/REASON/runs/OLMo-2-0325-32B-20250601_15-41-54/full_details.csv"
    if not os.path.exists(path):
        raise FileNotFoundError(f"CSV not found at {path}")

    # 2.2 Load data
    df = pd.read_csv(path)

    # 2.3 Define your parameters
    Ks       = [1, 3, 5, 10]
    contexts = df['context'].unique()
    methods  = df['method'].unique()

    # 2.4 Collect results
    results = []
    for context in contexts:
        df_ctx = df[df['context'] == context]
        for method in methods:
            df_m = df_ctx[df_ctx['method'] == method]

            # Build per-query predictions + gold sets
            per_query = {}
            for qid, group in df_m.groupby('query_id'):
                preds = group.sort_values('rank')['entity'].tolist()
                gold  = set(group.loc[group['is_gold'] == 1, 'entity'])
                per_query[qid] = (preds, gold)

            # Compute metrics for each K
            for K in Ks:
                _ps, _rs, _fs, _mrrs = [], [], [], []
                for preds, gold in per_query.values():
                    _ps.append(precision_at_k(preds, gold, K))
                    _rs.append(recall_at_k(preds, gold, K))
                    _fs.append(f1_at_k(preds, gold, K))
                    _mrrs.append(reciprocal_rank(preds, gold))

                results.append({
                    'context': context,
                    'method':  method,
                    'K':       K,
                    'P@K':     sum(_ps) / len(_ps),
                    'R@K':     sum(_rs) / len(_rs),
                    'F1@K':    sum(_fs) / len(_fs),
                    'MRR':     sum(_mrrs) / len(_mrrs)
                })

    # 2.5 Turn into DataFrame
    res_df = pd.DataFrame(results)

    # 2.6 (Optional) Print “flat” summary
    pd.set_option('display.precision', 4)
    print("\n=== Flat results ===")
    print(res_df)

    # ─── 3. Convert to a nested table ─────────────────────────────────────────────

    pivot_metrics = res_df.pivot_table(
        index=['context', 'method'],
        columns='K',
        values=['P@K', 'R@K', 'F1@K']
    )

    mrr_series = res_df.groupby(['context', 'method'])['MRR'].first()

    pivot_metrics[('MRR', '')] = mrr_series

    first_level  = ['P@K', 'R@K', 'F1@K', 'MRR']
    second_level = {
        'P@K':  [1, 3, 5, 10],
        'R@K':  [1, 3, 5, 10],
        'F1@K': [1, 3, 5, 10],
        'MRR':  ['']
    }
    ordered_cols = [(fl, sl) for fl in first_level for sl in second_level[fl]]
    pivot_metrics = pivot_metrics.reindex(columns=pd.MultiIndex.from_tuples(ordered_cols))

    pd.set_option('display.expand_frame_repr', False)

    # ─── 4. Collapse BM25 into a single “method-only” line on top ────────────────

    df_flat = pivot_metrics.reset_index()

    bm25_one = df_flat[df_flat['method'] == 'bm25'].iloc[[0]].copy()
    bm25_one.loc[:, 'context'] = ''   # blank out its context

    keep_others = df_flat[df_flat['method'].isin(['prob', 'rag'])].copy()

    df_collapsed = pd.concat([bm25_one, keep_others], ignore_index=True)

    df_collapsed = df_collapsed.set_index(['context', 'method'])


    print("\n=== Final (collapsed) table ===")
    print(df_collapsed)


if __name__ == "__main__":
    main()



=== Flat results ===
   context method   K     P@K     R@K    F1@K     MRR
0     atom   prob   1  0.2190  0.0863  0.1137  0.3817
1     atom   prob   3  0.1841  0.2380  0.1840  0.3817
2     atom   prob   5  0.1657  0.3714  0.2043  0.3817
3     atom   prob  10  0.1600  0.6447  0.2339  0.3817
4     atom   bm25   1  0.2190  0.0912  0.1168  0.3824
5     atom   bm25   3  0.1937  0.2294  0.1867  0.3824
6     atom   bm25   5  0.1810  0.3640  0.2158  0.3824
7     atom   bm25  10  0.1486  0.6271  0.2191  0.3824
8     atom    rag   1  0.2571  0.1135  0.1438  0.4224
9     atom    rag   3  0.2159  0.2818  0.2158  0.4224
10    atom    rag   5  0.2057  0.4369  0.2490  0.4224
11    atom    rag  10  0.1676  0.7171  0.2488  0.4224
12    wiki   prob   1  0.5048  0.2870  0.3368  0.6555
13    wiki   prob   3  0.3873  0.5829  0.4212  0.6555
14    wiki   prob   5  0.3048  0.7132  0.3869  0.6555
15    wiki   prob  10  0.2076  0.9063  0.3086  0.6555
16    wiki   bm25   1  0.2190  0.0912  0.1168  0.3824
17    

In [1]:
#!/usr/bin/env python3
import os
import pandas as pd
import matplotlib.pyplot as plt
from typing import List, Set

# ─── 1. Evaluation functions ────────────────────────────────────────────────────

def precision_at_k(preds: List[str], gold: Set[str], k: int) -> float:
    """Precision at K: fraction of top-k preds that are in gold."""
    if k <= 0:
        return 0.0
    topk = preds[:k]
    return sum(1 for d in topk if d in gold) / k

def recall_at_k(preds: List[str], gold: Set[str], k: int) -> float:
    """Recall at K: fraction of gold items recovered in top-k."""
    if not gold:
        return 0.0
    topk = preds[:k]
    return sum(1 for d in topk if d in gold) / len(gold)

def f1_at_k(preds: List[str], gold: Set[str], k: int) -> float:
    """F1 at K: harmonic mean of P@K and R@K."""
    p = precision_at_k(preds, gold, k)
    r = recall_at_k(preds, gold, k)
    return 2 * p * r / (p + r) if (p + r) > 0 else 0.0

def reciprocal_rank(preds: List[str], gold: Set[str]) -> float:
    """MRR component: 1 / (rank of first correct prediction)."""
    for idx, d in enumerate(preds, start=1):
        if d in gold:
            return 1.0 / idx
    return 0.0

# ─── 2. Main evaluation pipeline ────────────────────────────────────────────────

def main():
    # 2.1 Path to your CSV
    path = "/home/mhoveyda1/REASON/runs/OLMo-2-0325-32B-20250601_15-41-54/full_details.csv"
    if not os.path.exists(path):
        raise FileNotFoundError(f"CSV not found at {path}")

    # 2.2 Load data
    df = pd.read_csv(path)

    # 2.3 Define your parameters
    Ks       = [1, 3, 5, 10]
    contexts = df['context'].unique()
    methods  = df['method'].unique()

    # 2.4 Collect results
    results = []
    for context in contexts:
        df_ctx = df[df['context'] == context]
        for method in methods:
            df_m = df_ctx[df_ctx['method'] == method]

            # Build per-query predictions + gold sets
            per_query = {}
            for qid, group in df_m.groupby('query_id'):
                preds = group.sort_values('rank')['entity'].tolist()
                gold  = set(group.loc[group['is_gold'] == 1, 'entity'])
                per_query[qid] = (preds, gold)

            # Compute metrics for each K
            for K in Ks:
                _ps, _rs, _fs, _mrrs = [], [], [], []
                for preds, gold in per_query.values():
                    _ps.append(precision_at_k(preds, gold, K))
                    _rs.append(recall_at_k(preds, gold, K))
                    _fs.append(f1_at_k(preds, gold, K))
                    _mrrs.append(reciprocal_rank(preds, gold))

                results.append({
                    'context': context,
                    'method':  method,
                    'K':       K,
                    'P@K':     sum(_ps) / len(_ps),
                    'R@K':     sum(_rs) / len(_rs),
                    'F1@K':    sum(_fs) / len(_fs),
                    'MRR':     sum(_mrrs) / len(_mrrs)
                })

    # 2.5 Turn into DataFrame
    res_df = pd.DataFrame(results)

    # 2.6 (Optional) Print “flat” summary
    pd.set_option('display.precision', 4)
    print("\n=== Flat results ===")
    print(res_df)

    # ─── 3. Convert to a nested table ─────────────────────────────────────────────

    pivot_metrics = res_df.pivot_table(
        index=['context', 'method'],
        columns='K',
        values=['P@K', 'R@K', 'F1@K']
    )

    mrr_series = res_df.groupby(['context', 'method'])['MRR'].first()
    pivot_metrics[('MRR', '')] = mrr_series

    first_level  = ['P@K', 'R@K', 'F1@K', 'MRR']
    second_level = {
        'P@K':  [1, 3, 5, 10],
        'R@K':  [1, 3, 5, 10],
        'F1@K': [1, 3, 5, 10],
        'MRR':  ['']
    }
    ordered_cols = [(fl, sl) for fl in first_level for sl in second_level[fl]]
    pivot_metrics = pivot_metrics.reindex(columns=pd.MultiIndex.from_tuples(ordered_cols))

    pd.set_option('display.expand_frame_repr', False)

    # ─── 4. Collapse BM25 into a single “method-only” line on top ────────────────

    df_flat = pivot_metrics.reset_index()

    bm25_one = df_flat[df_flat['method'] == 'bm25'].iloc[[0]].copy()
    bm25_one.loc[:, 'context'] = ''   # blank out its context

    keep_others = df_flat[df_flat['method'].isin(['prob', 'rag'])].copy()

    df_collapsed = pd.concat([bm25_one, keep_others], ignore_index=True)
    df_collapsed = df_collapsed.set_index(['context', 'method'])

    # ─── 5. Print the final, “real” nested table ─────────────────────────────────

    # Round numeric values to 3 decimal places for console display
    df_to_print = df_collapsed.copy()
    df_to_print = df_to_print.round(3)

    print("\n=== Final (collapsed) table with 3-decimal precision ===")
    print(df_to_print)

    # ─── 6. Create and save a figure from the final table ─────────────────────────

    # 6.1 Flatten columns for easier plotting (same as before)
    flat_cols = []
    for metric, k_val in df_collapsed.columns:
        if metric == 'MRR':
            flat_cols.append('MRR')
        else:
            flat_cols.append(f"{metric}_{k_val}")
    df_collapsed.columns = flat_cols

    df_print = df_collapsed.reset_index()

    numeric_cols = [c for c in df_print.columns if c not in ['context', 'method']]
    df_print[numeric_cols] = df_print[numeric_cols].round(3)

    base, _ = os.path.splitext(path)
    fig_path = base + "_table.png"

    num_rows, num_cols = df_print.shape
    fig_height = max(2, 0.5 * num_rows)  # height proportional to number of rows
    fig_width  = max(6, 1.0 * num_cols)   # width proportional to number of columns

    fig, ax = plt.subplots(figsize=(fig_width, fig_height))
    ax.axis('off')

    table = ax.table(
        cellText=df_print.values,
        colLabels=df_print.columns,
        cellLoc='center',
        loc='center'
    )
    table.auto_set_font_size(False)
    table.set_fontsize(8)
    table.scale(1, 1.2)

    plt.tight_layout()
    fig.savefig(fig_path, dpi=300, bbox_inches='tight')
    plt.close(fig)

    print(f"\nSaved table figure to: {fig_path}")


if __name__ == "__main__":
    main()


=== Flat results ===
   context method   K     P@K     R@K    F1@K     MRR
0     atom   prob   1  0.2190  0.0863  0.1137  0.3817
1     atom   prob   3  0.1841  0.2380  0.1840  0.3817
2     atom   prob   5  0.1657  0.3714  0.2043  0.3817
3     atom   prob  10  0.1600  0.6447  0.2339  0.3817
4     atom   bm25   1  0.2190  0.0912  0.1168  0.3824
5     atom   bm25   3  0.1937  0.2294  0.1867  0.3824
6     atom   bm25   5  0.1810  0.3640  0.2158  0.3824
7     atom   bm25  10  0.1486  0.6271  0.2191  0.3824
8     atom    rag   1  0.2571  0.1135  0.1438  0.4224
9     atom    rag   3  0.2159  0.2818  0.2158  0.4224
10    atom    rag   5  0.2057  0.4369  0.2490  0.4224
11    atom    rag  10  0.1676  0.7171  0.2488  0.4224
12    wiki   prob   1  0.5048  0.2870  0.3368  0.6555
13    wiki   prob   3  0.3873  0.5829  0.4212  0.6555
14    wiki   prob   5  0.3048  0.7132  0.3869  0.6555
15    wiki   prob  10  0.2076  0.9063  0.3086  0.6555
16    wiki   bm25   1  0.2190  0.0912  0.1168  0.3824
17    

In [1]:
paths = [
    "/home/mhoveyda1/REASON/runs/Meta-Llama-3-8B-Instruct-20250601_15-41-54/full_details.csv"

    
]

paths = {
    "Llama-3-8B-Instruct": "/home/mhoveyda1/REASON/runs/Meta-Llama-3-8B-Instruct-20250601_15-41-54/full_details.csv",
    "Llama-3.3-70B-Instruct": "/home/mhoveyda1/REASON/runs/Llama-3.3-70B-Instruct-20250601_15-41-54/full_details.csv",
    "Mistral-v1-7B-Instruct": "/home/mhoveyda1/REASON/runs/Mistral-7B-Instruct-v0.1-20250601_15-41-54/full_details.csv",
    "Mistral-v1-8x7B-Instruct": "/home/mhoveyda1/REASON/runs/Mixtral-8x7B-Instruct-v0.1-20250601_15-41-54/full_details.csv",
    "Olmo-2-32B": "/home/mhoveyda1/REASON/runs/OLMo-2-0325-32B-20250601_15-41-54",
    "Olmo-2-7B-Instruct": "/home/mhoveyda1/REASON/runs/OLMo-2-1124-7B-Instruct-20250601_15-41-54/full_details.csv",
}

SyntaxError: expression expected after dictionary key and ':' (370925904.py, line 10)

In [3]:
#!/usr/bin/env python3
import os
import pandas as pd
import matplotlib.pyplot as plt
from typing import List, Set

# ─── 1. Evaluation functions ────────────────────────────────────────────────────

def precision_at_k(preds: List[str], gold: Set[str], k: int) -> float:
    """Precision at K: fraction of top-k preds that are in gold."""
    if k <= 0:
        return 0.0
    topk = preds[:k]
    return sum(1 for d in topk if d in gold) / k

def recall_at_k(preds: List[str], gold: Set[str], k: int) -> float:
    """Recall at K: fraction of gold items recovered in top-k."""
    if not gold:
        return 0.0
    topk = preds[:k]
    return sum(1 for d in topk if d in gold) / len(gold)

def f1_at_k(preds: List[str], gold: Set[str], k: int) -> float:
    """F1 at K: harmonic mean of P@K and R@K."""
    p = precision_at_k(preds, gold, k)
    r = recall_at_k(preds, gold, k)
    return 2 * p * r / (p + r) if (p + r) > 0 else 0.0

def reciprocal_rank(preds: List[str], gold: Set[str]) -> float:
    """MRR component: 1 / (rank of first correct prediction)."""
    for idx, d in enumerate(preds, start=1):
        if d in gold:
            return 1.0 / idx
    return 0.0

# ─── 2. Main evaluation pipeline ────────────────────────────────────────────────

def main():
    # 2.1 Path to your CSV
    path = "/home/mhoveyda1/REASON/runs/OLMo-2-0325-32B-20250601_15-41-54/full_details.csv"
    if not os.path.exists(path):
        raise FileNotFoundError(f"CSV not found at {path}")

    # 2.2 Load data
    df = pd.read_csv(path)

    # 2.3 Define your parameters
    Ks       = [1, 3, 5, 10]
    contexts = df['context'].unique()
    methods  = df['method'].unique()

    # 2.4 Collect results
    results = []
    for context in contexts:
        df_ctx = df[df['context'] == context]
        for method in methods:
            df_m = df_ctx[df_ctx['method'] == method]

            # Build per-query predictions + gold sets
            per_query = {}
            for qid, group in df_m.groupby('query_id'):
                preds = group.sort_values('rank')['entity'].tolist()
                gold  = set(group.loc[group['is_gold'] == 1, 'entity'])
                per_query[qid] = (preds, gold)

            # Compute metrics for each K
            for K in Ks:
                _ps, _rs, _fs, _mrrs = [], [], [], []
                for preds, gold in per_query.values():
                    _ps.append(precision_at_k(preds, gold, K))
                    _rs.append(recall_at_k(preds, gold, K))
                    _fs.append(f1_at_k(preds, gold, K))
                    _mrrs.append(reciprocal_rank(preds, gold))

                results.append({
                    'context': context,
                    'method':  method,
                    'K':       K,
                    'P@K':     sum(_ps) / len(_ps),
                    'R@K':     sum(_rs) / len(_rs),
                    'F1@K':    sum(_fs) / len(_fs),
                    'MRR':     sum(_mrrs) / len(_mrrs)
                })

    # 2.5 Turn into DataFrame
    res_df = pd.DataFrame(results)

    # 2.6 (Optional) Print “flat” summary
    pd.set_option('display.precision', 4)
    print("\n=== Flat results ===")
    print(res_df)

    # ─── 3. Convert to a nested table ─────────────────────────────────────────────

    pivot_metrics = res_df.pivot_table(
        index=['context', 'method'],
        columns='K',
        values=['P@K', 'R@K', 'F1@K']
    )

    mrr_series = res_df.groupby(['context', 'method'])['MRR'].first()
    pivot_metrics[('MRR', '')] = mrr_series

    first_level  = ['P@K', 'R@K', 'F1@K', 'MRR']
    second_level = {
        'P@K':  [1, 3, 5, 10],
        'R@K':  [1, 3, 5, 10],
        'F1@K': [1, 3, 5, 10],
        'MRR':  ['']
    }
    ordered_cols = [(fl, sl) for fl in first_level for sl in second_level[fl]]
    pivot_metrics = pivot_metrics.reindex(columns=pd.MultiIndex.from_tuples(ordered_cols))

    pd.set_option('display.expand_frame_repr', False)

    # ─── 4. Collapse BM25 into a single “method-only” line on top ────────────────

    df_flat = pivot_metrics.reset_index()

    bm25_one = df_flat[df_flat['method'] == 'bm25'].iloc[[0]].copy()
    bm25_one.loc[:, 'context'] = ''   # blank out its context

    keep_others = df_flat[df_flat['method'].isin(['prob', 'rag'])].copy()

    df_collapsed = pd.concat([bm25_one, keep_others], ignore_index=True)
    df_collapsed = df_collapsed.set_index(['context', 'method'])

    # ─── 5. Print the final, “real” nested table ─────────────────────────────────

    # Round numeric values to 3 decimal places for console display
    df_to_print = df_collapsed.copy().round(3)

    print("\n=== Final (collapsed) table with 3-decimal precision ===")
    print(df_to_print)

    # ─── 6. Create and save a figure from the final table with green shading ─────

    # 6.1 Flatten columns for easier plotting
    flat_cols = []
    for metric, k_val in df_collapsed.columns:
        if metric == 'MRR':
            flat_cols.append('MRR')
        else:
            flat_cols.append(f"{metric}_{k_val}")
    df_collapsed.columns = flat_cols

    # 6.2 Reset index so “context” and “method” become ordinary columns
    df_print = df_collapsed.reset_index()

    # 6.3 Round numeric values to 3 decimal places for the figure
    numeric_cols = [c for c in df_print.columns if c not in ['context', 'method']]
    df_print[numeric_cols] = df_print[numeric_cols].round(3)

    # 6.4 Determine output filename (same directory, suffix "_table.png")
    base, _ = os.path.splitext(path)
    fig_path = base + "_table.png"

    # 6.5 Create a matplotlib figure containing the table
    num_rows, num_cols = df_print.shape
    fig_height = max(2, 0.5 * num_rows)     # height proportional to number of rows
    fig_width  = max(6, 1.0 * num_cols)     # width proportional to number of columns

    fig, ax = plt.subplots(figsize=(fig_width, fig_height))
    ax.axis('off')

    table = ax.table(
        cellText=df_print.values,
        colLabels=df_print.columns,
        cellLoc='center',
        loc='center'
    )
    table.auto_set_font_size(False)
    table.set_fontsize(8)
    table.scale(1, 1.2)

    # 6.6 Apply green shading per numeric column, adjust text color for readability
    for j, col_name in enumerate(df_print.columns):
        if col_name in ['context', 'method']:
            continue
        col_values = df_print[col_name].astype(float)
        col_min, col_max = col_values.min(), col_values.max()
        diff = col_max - col_min if col_max > col_min else 1.0

        for i, val in enumerate(col_values):
            norm = (val - col_min) / diff
            # Use a lighter portion of the Greens colormap to avoid too-dark backgrounds:
            color = plt.cm.Greens(0.3 + 0.7 * norm)  # shift so minimum is not too pale, maximum not too dark
            cell = table[(i + 1, j)]  # +1 to skip header row
            cell.set_facecolor(color)

            # Choose black or white text depending on background brightness:
            # Compute perceived luminance: 
            r, g, b, _ = color
            luminance = 0.299*r + 0.587*g + 0.114*b
            text_color = 'white' if luminance < 0.5 else 'black'
            cell.get_text().set_color(text_color)

    plt.tight_layout()
    fig.savefig(fig_path, dpi=300, bbox_inches='tight')
    plt.close(fig)

    print(f"\nSaved shaded table figure to: {fig_path}")


if __name__ == "__main__":
    main()


=== Flat results ===
   context method   K     P@K     R@K    F1@K     MRR
0     atom   prob   1  0.2190  0.0863  0.1137  0.3817
1     atom   prob   3  0.1841  0.2380  0.1840  0.3817
2     atom   prob   5  0.1657  0.3714  0.2043  0.3817
3     atom   prob  10  0.1600  0.6447  0.2339  0.3817
4     atom   bm25   1  0.2190  0.0912  0.1168  0.3824
5     atom   bm25   3  0.1937  0.2294  0.1867  0.3824
6     atom   bm25   5  0.1810  0.3640  0.2158  0.3824
7     atom   bm25  10  0.1486  0.6271  0.2191  0.3824
8     atom    rag   1  0.2571  0.1135  0.1438  0.4224
9     atom    rag   3  0.2159  0.2818  0.2158  0.4224
10    atom    rag   5  0.2057  0.4369  0.2490  0.4224
11    atom    rag  10  0.1676  0.7171  0.2488  0.4224
12    wiki   prob   1  0.5048  0.2870  0.3368  0.6555
13    wiki   prob   3  0.3873  0.5829  0.4212  0.6555
14    wiki   prob   5  0.3048  0.7132  0.3869  0.6555
15    wiki   prob  10  0.2076  0.9063  0.3086  0.6555
16    wiki   bm25   1  0.2190  0.0912  0.1168  0.3824
17    

In [1]:
#!/usr/bin/env python3
import os
import pandas as pd
import matplotlib.pyplot as plt
from typing import List, Set, Dict

# ─── 1. Evaluation functions ────────────────────────────────────────────────────

def precision_at_k(preds: List[str], gold: Set[str], k: int) -> float:
    """Precision at K: fraction of top-k preds that are in gold."""
    if k <= 0:
        return 0.0
    topk = preds[:k]
    return sum(1 for d in topk if d in gold) / k

def recall_at_k(preds: List[str], gold: Set[str], k: int) -> float:
    """Recall at K: fraction of gold items recovered in top-k."""
    if not gold:
        return 0.0
    topk = preds[:k]
    return sum(1 for d in topk if d in gold) / len(gold)

def f1_at_k(preds: List[str], gold: Set[str], k: int) -> float:
    """F1 at K: harmonic mean of P@K and R@K."""
    p = precision_at_k(preds, gold, k)
    r = recall_at_k(preds, gold, k)
    return 2 * p * r / (p + r) if (p + r) > 0 else 0.0

def reciprocal_rank(preds: List[str], gold: Set[str]) -> float:
    """MRR component: 1 / (rank of first correct prediction)."""
    for idx, d in enumerate(preds, start=1):
        if d in gold:
            return 1.0 / idx
    return 0.0

# ─── 2. Processing function ────────────────────────────────────────────────────

def process_run(name: str, csv_path: str) -> None:
    """
    Given a run name and the path to its full_details.csv (or directory containing it),
    compute metrics, collapse BM25, print a 3-decimal table, and save a shaded PNG.
    """
    # 2.0 Ensure we have the full_details.csv
    if os.path.isdir(csv_path):
        csv_path = os.path.join(csv_path, "full_details.csv")
    if not os.path.exists(csv_path):
        raise FileNotFoundError(f"For '{name}', CSV not found at {csv_path}")

    # 2.1 Load data
    df = pd.read_csv(csv_path)

    # 2.2 Define parameters
    Ks       = [1, 3, 5, 10]
    contexts = df['context'].unique()
    methods  = df['method'].unique()

    # 2.3 Collect results
    results = []
    for context in contexts:
        df_ctx = df[df['context'] == context]
        for method in methods:
            df_m = df_ctx[df_ctx['method'] == method]

            # Build per-query predictions + gold sets
            per_query = {}
            for qid, group in df_m.groupby('query_id'):
                preds = group.sort_values('rank')['entity'].tolist()
                gold  = set(group.loc[group['is_gold'] == 1, 'entity'])
                per_query[qid] = (preds, gold)

            # Compute metrics for each K
            for K in Ks:
                _ps, _rs, _fs, _mrrs = [], [], [], []
                for preds, gold in per_query.values():
                    _ps.append(precision_at_k(preds, gold, K))
                    _rs.append(recall_at_k(preds, gold, K))
                    _fs.append(f1_at_k(preds, gold, K))
                    _mrrs.append(reciprocal_rank(preds, gold))

                results.append({
                    'context': context,
                    'method':  method,
                    'K':       K,
                    'P@K':     sum(_ps) / len(_ps),
                    'R@K':     sum(_rs) / len(_rs),
                    'F1@K':    sum(_fs) / len(_fs),
                    'MRR':     sum(_mrrs) / len(_mrrs)
                })

    # 2.4 Turn into DataFrame
    res_df = pd.DataFrame(results)

    # 2.5 Pivot into nested table
    pivot_metrics = res_df.pivot_table(
        index=['context', 'method'],
        columns='K',
        values=['P@K', 'R@K', 'F1@K']
    )
    mrr_series = res_df.groupby(['context', 'method'])['MRR'].first()
    pivot_metrics[('MRR', '')] = mrr_series

    first_level  = ['P@K', 'R@K', 'F1@K', 'MRR']
    second_level = {
        'P@K':  [1, 3, 5, 10],
        'R@K':  [1, 3, 5, 10],
        'F1@K': [1, 3, 5, 10],
        'MRR':  ['']
    }
    ordered_cols = [(fl, sl) for fl in first_level for sl in second_level[fl]]
    pivot_metrics = pivot_metrics.reindex(columns=pd.MultiIndex.from_tuples(ordered_cols))

    # 2.6 Collapse BM25 to one line (blank context)
    df_flat = pivot_metrics.reset_index()
    bm25_one = df_flat[df_flat['method'] == 'bm25'].iloc[[0]].copy()
    bm25_one.loc[:, 'context'] = ''
    keep_others = df_flat[df_flat['method'].isin(['prob', 'rag'])].copy()
    df_collapsed = pd.concat([bm25_one, keep_others], ignore_index=True)
    df_collapsed = df_collapsed.set_index(['context', 'method'])

    # 2.7 Print table to console (3-decimal)
    df_to_print = df_collapsed.copy().round(3)
    print(f"\n=== {name}: Final (collapsed) table with 3-decimal precision ===")
    print(df_to_print)

    # 2.8 Prepare for figure: flatten columns, reset index, round
    flat_cols = []
    for metric, k_val in df_collapsed.columns:
        flat_cols.append('MRR' if metric == 'MRR' else f"{metric}_{k_val}")
    df_collapsed.columns = flat_cols
    df_print = df_collapsed.reset_index()
    numeric_cols = [c for c in df_print.columns if c not in ['context', 'method']]
    df_print[numeric_cols] = df_print[numeric_cols].round(3)

    # 2.9 Create and save shaded table figure
    fig_path = os.path.join(os.path.dirname(csv_path), f"{name}_table.png")
    num_rows, num_cols = df_print.shape
    fig_height = max(2, 0.5 * num_rows)
    fig_width  = max(6, 1.0 * num_cols)

    fig, ax = plt.subplots(figsize=(fig_width, fig_height))
    ax.axis('off')

    table = ax.table(
        cellText=df_print.values,
        colLabels=df_print.columns,
        cellLoc='center',
        loc='center'
    )
    table.auto_set_font_size(False)
    table.set_fontsize(8)
    table.scale(1, 1.2)

    # Apply green shading per numeric column with readable text
    for j, col_name in enumerate(df_print.columns):
        if col_name in ['context', 'method']:
            continue
        col_values = df_print[col_name].astype(float)
        col_min, col_max = col_values.min(), col_values.max()
        diff = col_max - col_min if col_max > col_min else 1.0

        for i, val in enumerate(col_values):
            norm = (val - col_min) / diff
            color = plt.cm.Greens(0.3 + 0.7 * norm)
            cell = table[(i + 1, j)]
            cell.set_facecolor(color)

            r, g, b, _ = color
            luminance = 0.299*r + 0.587*g + 0.114*b
            text_color = 'white' if luminance < 0.5 else 'black'
            cell.get_text().set_color(text_color)

    plt.tight_layout()
    fig.savefig(fig_path, dpi=300, bbox_inches='tight')
    plt.close(fig)

    print(f"Saved shaded table figure for '{name}' to: {fig_path}")


# ─── 3. Main: iterate over all runs ─────────────────────────────────────────────

def main():
    paths: Dict[str, str] = {
        "Llama-3-8B-Instruct":    "/home/mhoveyda1/REASON/runs/Meta-Llama-3-8B-Instruct-20250601_15-41-54/full_details.csv",
        "Llama-3.3-70B-Instruct":  "/home/mhoveyda1/REASON/runs/Llama-3.3-70B-Instruct-20250601_15-41-54/full_details.csv",
        "Mistral-v1-7B-Instruct":  "/home/mhoveyda1/REASON/runs/Mistral-7B-Instruct-v0.1-20250601_15-41-54/full_details.csv",
        "Mistral-v1-8x7B-Instruct":"/home/mhoveyda1/REASON/runs/Mixtral-8x7B-Instruct-v0.1-20250601_15-41-54/full_details.csv",
        "Olmo-2-32B":              "/home/mhoveyda1/REASON/runs/OLMo-2-0325-32B-20250601_15-41-54",  # directory; will append full_details.csv
        "Olmo-2-7B-Instruct":      "/home/mhoveyda1/REASON/runs/OLMo-2-1124-7B-Instruct-20250601_15-41-54/full_details.csv",
    }

    for name, path in paths.items():
        try:
            process_run(name, path)
        except FileNotFoundError as e:
            print(f"Error for '{name}': {e}")


if __name__ == "__main__":
    main()


=== Llama-3-8B-Instruct: Final (collapsed) table with 3-decimal precision ===
                  P@K                         R@K                        F1@K  \
                    1      3      5     10      1      3      5     10      1   
context method                                                                  
        bm25    0.219  0.194  0.181  0.149  0.091  0.229  0.364  0.627  0.117   
atom    prob    0.276  0.244  0.219  0.167  0.133  0.362  0.513  0.718  0.163   
        rag     0.257  0.251  0.206  0.165  0.113  0.312  0.433  0.708  0.144   
wiki    prob    0.514  0.422  0.331  0.209  0.291  0.627  0.749  0.898  0.340   
        rag     0.400  0.337  0.270  0.185  0.211  0.511  0.647  0.817  0.253   

                                       MRR  
                    3      5     10         
context method                              
        bm25    0.187  0.216  0.219  0.382  
atom    prob    0.259  0.273  0.245  0.456  
        rag     0.246  0.249  0.245  0.425  
wi

In [1]:
#!/usr/bin/env python3
import os
import pandas as pd
import matplotlib.pyplot as plt
from typing import List, Set, Dict

# ─── 1. Evaluation functions ────────────────────────────────────────────────────

def precision_at_k(preds: List[str], gold: Set[str], k: int) -> float:
    """Precision at K: fraction of top-k preds that are in gold."""
    if k <= 0:
        return 0.0
    topk = preds[:k]
    return sum(1 for d in topk if d in gold) / k

def recall_at_k(preds: List[str], gold: Set[str], k: int) -> float:
    """Recall at K: fraction of gold items recovered in top-k."""
    if not gold:
        return 0.0
    topk = preds[:k]
    return sum(1 for d in topk if d in gold) / len(gold)

def f1_at_k(preds: List[str], gold: Set[str], k: int) -> float:
    """F1 at K: harmonic mean of P@K and R@K."""
    p = precision_at_k(preds, gold, k)
    r = recall_at_k(preds, gold, k)
    return 2 * p * r / (p + r) if (p + r) > 0 else 0.0

def reciprocal_rank(preds: List[str], gold: Set[str]) -> float:
    """MRR component: 1 / (rank of first correct prediction)."""
    for idx, d in enumerate(preds, start=1):
        if d in gold:
            return 1.0 / idx
    return 0.0

# ─── 2. Processing function ────────────────────────────────────────────────────

def process_run(name: str, csv_path: str) -> None:
    """
    Given a run name and the path to its full_details.csv (or directory containing it),
    compute metrics, collapse BM25, print a 3-decimal table, and save a shaded PNG
    with the model name as a caption/title.
    """
    # 2.0 Ensure we have the full_details.csv
    if os.path.isdir(csv_path):
        csv_path = os.path.join(csv_path, "full_details.csv")
    if not os.path.exists(csv_path):
        raise FileNotFoundError(f"For '{name}', CSV not found at {csv_path}")

    # 2.1 Load data
    df = pd.read_csv(csv_path)

    # 2.2 Define parameters
    Ks       = [1, 3, 5, 10]
    contexts = df['context'].unique()
    methods  = df['method'].unique()

    # 2.3 Collect results
    results = []
    for context in contexts:
        df_ctx = df[df['context'] == context]
        for method in methods:
            df_m = df_ctx[df_ctx['method'] == method]

            # Build per-query predictions + gold sets
            per_query = {}
            for qid, group in df_m.groupby('query_id'):
                preds = group.sort_values('rank')['entity'].tolist()
                gold  = set(group.loc[group['is_gold'] == 1, 'entity'])
                per_query[qid] = (preds, gold)

            # Compute metrics for each K
            for K in Ks:
                _ps, _rs, _fs, _mrrs = [], [], [], []
                for preds, gold in per_query.values():
                    _ps.append(precision_at_k(preds, gold, K))
                    _rs.append(recall_at_k(preds, gold, K))
                    _fs.append(f1_at_k(preds, gold, K))
                    _mrrs.append(reciprocal_rank(preds, gold))

                results.append({
                    'context': context,
                    'method':  method,
                    'K':       K,
                    'P@K':     sum(_ps) / len(_ps),
                    'R@K':     sum(_rs) / len(_rs),
                    'F1@K':    sum(_fs) / len(_fs),
                    'MRR':     sum(_mrrs) / len(_mrrs)
                })

    # 2.4 Turn into DataFrame
    res_df = pd.DataFrame(results)

    # 2.5 Pivot into nested table
    pivot_metrics = res_df.pivot_table(
        index=['context', 'method'],
        columns='K',
        values=['P@K', 'R@K', 'F1@K']
    )
    mrr_series = res_df.groupby(['context', 'method'])['MRR'].first()
    pivot_metrics[('MRR', '')] = mrr_series

    first_level  = ['P@K', 'R@K', 'F1@K', 'MRR']
    second_level = {
        'P@K':  [1, 3, 5, 10],
        'R@K':  [1, 3, 5, 10],
        'F1@K': [1, 3, 5, 10],
        'MRR':  ['']
    }
    ordered_cols = [(fl, sl) for fl in first_level for sl in second_level[fl]]
    pivot_metrics = pivot_metrics.reindex(columns=pd.MultiIndex.from_tuples(ordered_cols))

    # 2.6 Collapse BM25 to one line (blank context)
    df_flat = pivot_metrics.reset_index()
    bm25_one = df_flat[df_flat['method'] == 'bm25'].iloc[[0]].copy()
    bm25_one.loc[:, 'context'] = ''
    keep_others = df_flat[df_flat['method'].isin(['prob', 'rag'])].copy()
    df_collapsed = pd.concat([bm25_one, keep_others], ignore_index=True)
    df_collapsed = df_collapsed.set_index(['context', 'method'])

    # 2.7 Print table to console (3-decimal)
    df_to_print = df_collapsed.copy().round(3)
    print(f"\n=== {name}: Final (collapsed) table with 3-decimal precision ===")
    print(df_to_print)

    # 2.8 Prepare for figure: flatten columns, reset index, round
    flat_cols = []
    for metric, k_val in df_collapsed.columns:
        flat_cols.append('MRR' if metric == 'MRR' else f"{metric}_{k_val}")
    df_collapsed.columns = flat_cols
    df_print = df_collapsed.reset_index()
    numeric_cols = [c for c in df_print.columns if c not in ['context', 'method']]
    df_print[numeric_cols] = df_print[numeric_cols].round(3)

    # 2.9 Create and save shaded table figure with model name as title
    fig_path = os.path.join(os.path.dirname(csv_path), f"{name}_table.png")
    num_rows, num_cols = df_print.shape
    fig_height = max(2, 0.5 * num_rows)
    fig_width  = max(6, 0.7 * num_cols)

    fig, ax = plt.subplots(figsize=(fig_width, fig_height))
    ax.axis('off')

    # Add title (model name) above the table
    ax.set_title(name, fontweight='bold', fontsize=12, pad=20)

    table = ax.table(
        cellText=df_print.values,
        colLabels=df_print.columns,
        cellLoc='center',
        loc='center'
    )
    table.auto_set_font_size(False)
    table.set_fontsize(8)
    table.scale(1, 1.2)

    # Apply green shading per numeric column with readable text
    for j, col_name in enumerate(df_print.columns):
        if col_name in ['context', 'method']:
            continue
        col_values = df_print[col_name].astype(float)
        col_min, col_max = col_values.min(), col_values.max()
        diff = col_max - col_min if col_max > col_min else 1.0

        for i, val in enumerate(col_values):
            norm = (val - col_min) / diff
            # Use a lighter portion of the Greens colormap to avoid too-dark backgrounds
            color = plt.cm.Greens(0.3 + 0.7 * norm)
            cell = table[(i + 1, j)]
            cell.set_facecolor(color)

            r, g, b, _ = color
            luminance = 0.299*r + 0.587*g + 0.114*b
            text_color = 'white' if luminance < 0.5 else 'black'
            cell.get_text().set_color(text_color)

    plt.tight_layout()
    fig.savefig(fig_path, dpi=300, bbox_inches='tight')
    plt.close(fig)

    print(f"Saved shaded table figure for '{name}' to: {fig_path}")


# ─── 3. Main: iterate over all runs ─────────────────────────────────────────────

def main():
    # paths: Dict[str, str] = {
    #     "Llama-3-8B-Instruct":     "/home/mhoveyda1/REASON/runs/Meta-Llama-3-8B-Instruct-20250601_15-41-54/full_details.csv",
    #     "Llama-3.3-70B-Instruct":   "/home/mhoveyda1/REASON/runs/Llama-3.3-70B-Instruct-20250601_15-41-54/full_details.csv",
    #     "Mistral-v1-7B-Instruct":   "/home/mhoveyda1/REASON/runs/Mistral-7B-Instruct-v0.1-20250601_15-41-54/full_details.csv",
    #     "Mistral-v1-8x7B-Instruct": "/home/mhoveyda1/REASON/runs/Mixtral-8x7B-Instruct-v0.1-20250601_15-41-54/full_details.csv",
    #     "Olmo-2-32B":               "/home/mhoveyda1/REASON/runs/OLMo-2-0325-32B-20250601_15-41-54",
    #     "Olmo-2-7B-Instruct":       "/home/mhoveyda1/REASON/runs/OLMo-2-1124-7B-Instruct-20250601_15-41-54/full_details.csv",
    # }
    paths: Dict[str, str] = {
        "Llama-3-8B-Instruct":     "/home/mhoveyda1/REASON/runs/correct_prevs_IMPORTANT/llama-3-8B-instr-20250601_10-20-16/full_details.csv",
        "Llama-3.3-70B-Instruct":   "/home/mhoveyda1/REASON/runs/correct_prevs_IMPORTANT/llama-3-70B-instr-20250601_10-20-16/full_details.csv",
    
    }

    for name, path in paths.items():
        try:
            process_run(name, path)
        except FileNotFoundError as e:
            print(f"Error for '{name}': {e}")


if __name__ == "__main__":
    main()


=== Llama-3-8B-Instruct: Final (collapsed) table with 3-decimal precision ===
                  P@K                         R@K                        F1@K  \
                    1      3      5     10      1      3      5     10      1   
context method                                                                  
        bm25    0.312  0.257  0.223  0.179  0.125  0.294  0.405  0.652  0.161   
atom    prob    0.343  0.277  0.247  0.194  0.158  0.359  0.512  0.747  0.194   
        rag     0.346  0.292  0.243  0.194  0.141  0.346  0.461  0.729  0.181   
wiki    prob    0.568  0.449  0.367  0.239  0.288  0.591  0.748  0.912  0.347   
        rag     0.484  0.380  0.311  0.213  0.232  0.501  0.645  0.814  0.283   

                                       MRR  
                    3      5     10         
context method                              
        bm25    0.242  0.256  0.255  0.469  
atom    prob    0.275  0.295  0.280  0.510  
        rag     0.280  0.283  0.279  0.507  
wi

In [14]:
path_to_data = "/home/mhoveyda1/RSN_Z/test_top20_sample0_2025-05-22_12-23_filtered_with_wikidata_and_wikipedia_metadata_filtered_based_on_pred_maps_sampled_equi_110.jsonl"

import json 
import os

with open(path_to_data, 'r', encoding='utf-8') as f:
    data = [json.loads(line) for line in f]

In [21]:
type(data[0]['id'])

int

In [17]:
data[0]['metadata']['template']

'_'

In [4]:
#!/usr/bin/env python3
import os
import json
import pandas as pd
import matplotlib.pyplot as plt
from typing import List, Set, Dict

# ─── 1. Evaluation functions ────────────────────────────────────────────────────

def precision_at_k(preds: List[str], gold: Set[str], k: int) -> float:
    """Precision at K: fraction of top-k preds that are in gold."""
    if k <= 0:
        return 0.0
    topk = preds[:k]
    return sum(1 for d in topk if d in gold) / k

def recall_at_k(preds: List[str], gold: Set[str], k: int) -> float:
    """Recall at K: fraction of gold items recovered in top-k."""
    if not gold:
        return 0.0
    topk = preds[:k]
    return sum(1 for d in topk if d in gold) / len(gold)

def f1_at_k(preds: List[str], gold: Set[str], k: int) -> float:
    """F1 at K: harmonic mean of P@K and R@K."""
    p = precision_at_k(preds, gold, k)
    r = recall_at_k(preds, gold, k)
    return 2 * p * r / (p + r) if (p + r) > 0 else 0.0

def reciprocal_rank(preds: List[str], gold: Set[str]) -> float:
    """MRR component: 1 / (rank of first correct prediction)."""
    for idx, d in enumerate(preds, start=1):
        if d in gold:
            return 1.0 / idx
    return 0.0

# ─── 2. Core evaluation + plotting ──────────────────────────────────────────────

def evaluate_and_plot(
    run_name: str,
    df_raw: pd.DataFrame,
    output_dir: str,
    caption: str = None
) -> None:
    """
    Given a run name, a DataFrame of raw results (with columns including
    context, method, query_id, entity, rank, is_gold), compute P@K, R@K, F1@K, MRR,
    collapse BM25 to one line, print the final table, and save a shaded PNG.
    If 'caption' is provided, it will be used as the figure title (e.g., template).
    """
    # 2.1 Define Ks
    Ks = [1, 3, 5, 10]

    # 2.2 Extract unique contexts and methods
    contexts = df_raw['context'].unique()
    methods  = df_raw['method'].unique()

    # 2.3 Collect per-(context,method,K) metrics
    results = []
    for context in contexts:
        df_ctx = df_raw[df_raw['context'] == context]
        for method in methods:
            df_m = df_ctx[df_ctx['method'] == method]
            # Build per-query preds + gold sets
            per_query: Dict[int, (List[str], Set[str])] = {}
            for qid, group in df_m.groupby('query_id'):
                preds = group.sort_values('rank')['entity'].tolist()
                gold  = set(group.loc[group['is_gold'] == 1, 'entity'])
                per_query[qid] = (preds, gold)
            # Compute metrics for each K
            for K in Ks:
                _ps, _rs, _fs, _mrrs = [], [], [], []
                for preds, gold in per_query.values():
                    _ps.append(precision_at_k(preds, gold, K))
                    _rs.append(recall_at_k(preds, gold, K))
                    _fs.append(f1_at_k(preds, gold, K))
                    _mrrs.append(reciprocal_rank(preds, gold))
                results.append({
                    'context': context,
                    'method':  method,
                    'K':       K,
                    'P@K':     sum(_ps) / len(_ps) if _ps else 0.0,
                    'R@K':     sum(_rs) / len(_rs) if _rs else 0.0,
                    'F1@K':    sum(_fs) / len(_fs) if _fs else 0.0,
                    'MRR':     sum(_mrrs) / len(_mrrs) if _mrrs else 0.0
                })

    # 2.4 Build DataFrame of aggregated metrics
    res_df = pd.DataFrame(results)

    # 2.5 Pivot into nested table
    pivot_metrics = res_df.pivot_table(
        index=['context', 'method'],
        columns='K',
        values=['P@K', 'R@K', 'F1@K']
    )
    mrr_series = res_df.groupby(['context', 'method'])['MRR'].first()
    pivot_metrics[('MRR', '')] = mrr_series

    first_level  = ['P@K', 'R@K', 'F1@K', 'MRR']
    second_level = {
        'P@K':  [1, 3, 5, 10],
        'R@K':  [1, 3, 5, 10],
        'F1@K': [1, 3, 5, 10],
        'MRR':  ['']
    }
    ordered_cols = [(fl, sl) for fl in first_level for sl in second_level[fl]]
    pivot_metrics = pivot_metrics.reindex(columns=pd.MultiIndex.from_tuples(ordered_cols))

    # 2.6 Collapse BM25 into single “method-only” line on top
    df_flat = pivot_metrics.reset_index()
    if 'bm25' in df_flat['method'].values:
        bm25_one = df_flat[df_flat['method'] == 'bm25'].iloc[[0]].copy()
        bm25_one.loc[:, 'context'] = ''   # blank out its context
        keep_others = df_flat[df_flat['method'].isin(['prob', 'rag'])].copy()
        df_collapsed = pd.concat([bm25_one, keep_others], ignore_index=True)
    else:
        df_collapsed = df_flat.copy()
    df_collapsed = df_collapsed.set_index(['context', 'method'])

    # 2.7 Print final table to console (rounded to 3 decimals)
    df_to_print = df_collapsed.copy().round(3)
    header = f"{run_name}"
    if caption:
        header += f" ({caption})"
    print(f"\n=== {header}: Collapsed table with 3-decimal precision ===")
    print(df_to_print)

    # 2.8 Prepare DataFrame for plotting: flatten columns, reset index, round
    flat_cols = []
    for metric, k_val in df_collapsed.columns:
        flat_cols.append('MRR' if metric == 'MRR' else f"{metric}_{k_val}")
    df_collapsed.columns = flat_cols
    df_print = df_collapsed.reset_index()
    numeric_cols = [c for c in df_print.columns if c not in ['context', 'method']]
    df_print[numeric_cols] = df_print[numeric_cols].round(3)

    # 2.9 Plot and save shaded table
    fig_name = run_name if not caption else f"{run_name}_{caption}"
    fig_path = os.path.join(output_dir, f"{fig_name}_table.png")

    num_rows, num_cols = df_print.shape
    fig_height = max(2, 0.5 * num_rows)     # reduce multiplier if needed
    fig_width  = max(5, 0.7 * num_cols)     # narrower figure per column

    fig, ax = plt.subplots(figsize=(fig_width, fig_height))
    ax.axis('off')

    # Add title with run name and caption (if provided)
    title_text = run_name
    if caption:
        title_text += f" – {caption}"
    ax.set_title(title_text, fontweight='bold', fontsize=10, pad=12)

    table = ax.table(
        cellText=df_print.values,
        colLabels=df_print.columns,
        cellLoc='center',
        loc='center'
    )
    table.auto_set_font_size(False)
    table.set_fontsize(6)                 # smaller font
    table.scale(0.6, 1.1)                 # shrink horizontally, slight vertical padding

    # Apply green shading per numeric column with readable text
    for j, col_name in enumerate(df_print.columns):
        if col_name in ['context', 'method']:
            continue
        col_values = df_print[col_name].astype(float)
        col_min, col_max = col_values.min(), col_values.max()
        diff = col_max - col_min if col_max > col_min else 1.0

        for i, val in enumerate(col_values):
            norm = (val - col_min) / diff
            color = plt.cm.Greens(0.3 + 0.7 * norm)
            cell = table[(i + 1, j)]
            cell.set_facecolor(color)

            r, g, b, _ = color
            luminance = 0.299*r + 0.587*g + 0.114*b
            text_color = 'white' if luminance < 0.5 else 'black'
            cell.get_text().set_color(text_color)

    plt.tight_layout()
    fig.savefig(fig_path, dpi=300, bbox_inches='tight')
    plt.close(fig)
    print(f"Saved shaded table to: {fig_path}")


# ─── 3. Main: load mappings, iterate per run & template ─────────────────────────

def main():
    # 3.1 Paths to runs
    # paths: Dict[str, str] = {
    #     "Llama-3-8B-Instruct":     "/home/mhoveyda1/REASON/runs/Meta-Llama-3-8B-Instruct-20250601_15-41-54/full_details.csv",
    #     "Llama-3.3-70B-Instruct":   "/home/mhoveyda1/REASON/runs/Llama-3.3-70B-Instruct-20250601_15-41-54/full_details.csv",
    #     "Mistral-v1-7B-Instruct":   "/home/mhoveyda1/REASON/runs/Mistral-7B-Instruct-v0.1-20250601_15-41-54/full_details.csv",
    #     "Mistral-v1-8x7B-Instruct": "/home/mhoveyda1/REASON/runs/Mixtral-8x7B-Instruct-v0.1-20250601_15-41-54/full_details.csv",
    #     "Olmo-2-32B":               "/home/mhoveyda1/REASON/runs/OLMo-2-0325-32B-20250601_15-41-54",
    #     "Olmo-2-7B-Instruct":       "/home/mhoveyda1/REASON/runs/OLMo-2-1124-7B-Instruct-20250601_15-41-54/full_details.csv",
    # }
    paths: Dict[str, str] = {
        "Llama-3-8B-Instruct":     "/home/mhoveyda1/REASON/runs/correct_prevs_IMPORTANT/llama-3-8B-instr-20250601_10-20-16/full_details.csv",
        "Llama-3.3-70B-Instruct":   "/home/mhoveyda1/REASON/runs/correct_prevs_IMPORTANT/llama-3-70B-instr-20250601_10-20-16/full_details.csv",
    
    }

    # 3.2 Path to JSONL with query metadata
    path_to_data = "/home/mhoveyda1/RSN_Z/test_top20_sample0_2025-05-22_12-23_filtered_with_wikidata_and_wikipedia_metadata_filtered_based_on_pred_maps_sampled_equi_110.jsonl"
    if not os.path.exists(path_to_data):
        raise FileNotFoundError(f"Metadata JSONL not found at {path_to_data}")

    # 3.3 Load JSONL and build mapping: query_id -> template
    with open(path_to_data, 'r', encoding='utf-8') as f:
        data = [json.loads(line) for line in f]
    id_to_template: Dict[int, str] = {
        entry['id']: entry['metadata']['template']
        for entry in data
        if 'id' in entry and 'metadata' in entry and 'template' in entry['metadata']
    }

    # 3.4 For each run, load CSV, annotate with template, then evaluate per-template
    for run_name, run_path in paths.items():
        # 3.4.1 Resolve full_details.csv if a directory is given
        if os.path.isdir(run_path):
            run_path = os.path.join(run_path, "full_details.csv")
        if not os.path.exists(run_path):
            print(f"Skipping '{run_name}': CSV not found at {run_path}")
            continue

        # 3.4.2 Load run’s raw CSV
        df_raw = pd.read_csv(run_path)

        # 3.4.3 Map each row’s query_id to its template (or None if missing)
        df_raw['template'] = df_raw['query_id'].map(id_to_template)

        # 3.4.4 Drop any rows whose query_id has no template mapping
        df_raw = df_raw.dropna(subset=['template'])

        # 3.4.5 Identify unique templates in this run
        templates = df_raw['template'].unique()

        # 3.4.6 Create output directory next to CSV
        output_dir = os.path.dirname(run_path)

        # 3.4.7 Evaluate & plot for each template
        for template in templates:
            df_subset = df_raw[df_raw['template'] == template]
            evaluate_and_plot(
                run_name=run_name,
                df_raw=df_subset,
                output_dir=output_dir,
                caption=template.replace(" ", "_")  # use safe filename
            )


if __name__ == "__main__":
    main()


=== Llama-3-8B-Instruct (_): Collapsed table with 3-decimal precision ===
                  P@K                         R@K                        F1@K  \
                    1      3      5     10      1      3      5     10      1   
context method                                                                  
        bm25    0.357  0.214  0.200  0.200  0.154  0.255  0.374  0.747  0.194   
atom    prob    0.500  0.357  0.286  0.229  0.190  0.404  0.499  0.729  0.250   
        rag     0.429  0.333  0.257  0.229  0.190  0.362  0.469  0.878  0.241   
wiki    prob    0.643  0.524  0.386  0.236  0.363  0.684  0.778  0.927  0.430   
        rag     0.643  0.500  0.329  0.229  0.321  0.743  0.766  0.902  0.396   

                                       MRR  
                    3      5     10         
context method                              
        bm25    0.207  0.235  0.282  0.492  
atom    prob    0.338  0.326  0.312  0.615  
        rag     0.315  0.295  0.326  0.576  
wiki  