In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
DATASET_NAME = '50salads'

from src.utils import (
    prepare_df,
    group_cases_by_trace,
    get_activity_run_lengths_by_case,
    get_sequences_by_case,
    normalize_sequences_for_evaluation,
    compute_evaluation_metrics,
    compute_kari_metrics
)

import pandas as pd
import pickle as pkl
from src.evaluation import compute_tas_metrics_asformer
from pprint import pprint

In [3]:
# load your DataFrame and softmax list
result = prepare_df('50salads')
if len(result) == 2:
    df, softmax_lst = result
else:
    df, softmax_lst, _ = result

# group by trace and inspect
trace_groups = group_cases_by_trace(df)
trace_groups

Unnamed: 0,case_list,trace_length
0,"[0, 1, 2, 3]",5687
1,"[32, 33, 34, 35]",6186
2,"[36, 37, 38, 39]",5840
3,"[28, 29, 30, 31]",5261
4,"[4, 5, 6, 7]",6208
5,"[16, 17, 18, 19]",6293
6,"[24, 25, 26, 27]",6046
7,"[8, 9, 10, 11]",6584
8,"[12, 13, 14, 15]",5558
9,"[20, 21, 22, 23]",5792


In [22]:
get_activity_run_lengths_by_case(df, '1', min_runs=3, include_preceding_sequence=True)

{'4': ([63, 37, 45],
  [('17', '2', '3', '0'),
   ('17', '2', '3', '0', '1', '0'),
   ('17', '2', '3', '0', '1', '0', '1', '15')]),
 '5': ([63, 37, 45],
  [('17', '2', '3', '0'),
   ('17', '2', '3', '0', '1', '0'),
   ('17', '2', '3', '0', '1', '0', '1', '15')]),
 '6': ([63, 37, 45],
  [('17', '2', '3', '0'),
   ('17', '2', '3', '0', '1', '0'),
   ('17', '2', '3', '0', '1', '0', '1', '15')]),
 '7': ([63, 37, 45],
  [('17', '2', '3', '0'),
   ('17', '2', '3', '0', '1', '0'),
   ('17', '2', '3', '0', '1', '0', '1', '15')]),
 '24': ([86, 122, 143],
  [('17', '7', '8', '9', '6', '10', '0'),
   ('17', '7', '8', '9', '6', '10', '0', '1', '0'),
   ('17', '7', '8', '9', '6', '10', '0', '1', '0', '1', '4', '5')]),
 '25': ([86, 122, 143],
  [('17', '7', '8', '9', '6', '10', '0'),
   ('17', '7', '8', '9', '6', '10', '0', '1', '0'),
   ('17', '7', '8', '9', '6', '10', '0', '1', '0', '1', '4', '5')]),
 '26': ([86, 122, 143],
  [('17', '7', '8', '9', '6', '10', '0'),
   ('17', '7', '8', '9', '6', '1

## Evaluation

In [None]:
import pandas as pd
from src.evaluation import tas_metrics

## Results for 25% of the data

In [5]:
# Load recovery results
recovery_res = pd.read_csv('recovery_results_50salads_15.csv')

# Argmax
argmax_summary, argmax_per_vid = compute_tas_metrics_asformer(
    recovery_res, pred_col="argmax_activity", background=None,
    dataset_name=DATASET_NAME, return_per_video=True
)
print("Argmax:")
pprint(argmax_summary, sort_dicts=False, width=1)
print()

# SKTR
sktr_summary, sktr_per_vid = compute_tas_metrics_asformer(
    recovery_res, pred_col="sktr_activity", background=None,
    dataset_name=DATASET_NAME, return_per_video=True
)
print("SKTR:")
pprint(sktr_summary,  sort_dicts=False, width=1)


Argmax:
{'acc_micro': 82.47808913526251,
 'edit': 60.71391910996374,
 'f1@10': 69.07644654427946,
 'f1@25': 68.10491527181362,
 'f1@50': 59.971743157957576}

SKTR:
{'acc_micro': 82.6103171777789,
 'edit': 70.77625152625151,
 'f1@10': 80.19253397996623,
 'f1@25': 79.28806243525078,
 'f1@50': 68.72018411934054}


In [6]:
# Test cases
case_ids = ['20', '11', '5', '36', '14', '4', '30', '15', '3', '18']

gt_sequences = get_sequences_by_case(df, case_ids, 'concept:name')

gt_sequences

[['17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',
  '17',


In [7]:
with open("results.pkl", "rb") as f:
    loaded_results = pkl.load(f)

pred_sequences = []
for res in loaded_results:
    pred_sequences.append(res['labels'])

pred_sequences

[array([17, 17, 17, ..., 18, 18, 18], dtype=int64),
 array([17, 17, 17, ..., 18, 18, 18], dtype=int64),
 array([17, 17, 17, ..., 18, 18, 18], dtype=int64),
 array([17, 17, 17, ..., 18, 18, 18], dtype=int64),
 array([17, 17, 17, ..., 18, 18, 18], dtype=int64),
 array([17, 17, 17, ..., 18, 18, 18], dtype=int64),
 array([17, 17, 17, ..., 18, 18, 18], dtype=int64),
 array([17, 17, 17, ..., 18, 18, 18], dtype=int64),
 array([17, 17, 17, ..., 18, 18, 18], dtype=int64),
 array([17, 17, 17, ..., 18, 18, 18], dtype=int64)]

In [None]:
# Normalize
gt_norm, pred_norm = normalize_sequences_for_evaluation(gt_sequences, pred_sequences)

from src.evaluation import compute_tas_metrics_from_sequences

metrics = compute_tas_metrics_from_sequences(gt_norm, pred_norm, background=None)

print(metrics)


{'acc_micro': 76.85670209700113, 'edit': 70.19007207626133, 'f1@10': 75.16401278579981, 'f1@25': 74.0529016746887, 'f1@50': 64.3414317900242}


In [13]:
import pandas as pd

# Existing results
argmax = {
    'acc_micro': 82.47808913526251,
    'edit': 60.71391910996374,
    'f1@10': 69.07644654427946,
    'f1@25': 68.10491527181362,
    'f1@50': 59.971743157957576
}

sktr = {
    'acc_micro': 82.6103171777789,
    'edit': 70.77625152625151,
    'f1@10': 80.19253397996623,
    'f1@25': 79.28806243525078,
    'f1@50': 68.72018411934054
}

kari_new = {
    'acc_micro': 76.85670209700113,
    'edit': 70.19007207626133,
    'f1@10': 75.16401278579981,
    'f1@25': 74.0529016746887,
    'f1@50': 64.3414317900242
}

# Build DataFrame
df = pd.DataFrame([argmax, sktr, kari_new], index=["Argmax", "SKTR", "KARI"])
df = df.round(2)

print(df)


        acc_micro   edit  f1@10  f1@25  f1@50
Argmax      82.48  60.71  69.08  68.10  59.97
SKTR        82.61  70.78  80.19  79.29  68.72
KARI        76.86  70.19  75.16  74.05  64.34


## Results for complete 50 Salads

In [8]:
# Load your recovery results CSV
res_df = pd.read_csv('recovery_results_50salads_complete_15.csv')

# Compute comprehensive metrics
metrics = compute_evaluation_metrics(res_df, background=None, dataset_name=DATASET_NAME)

metrics

Computing evaluation metrics for 40 cases...
Computing SKTR metrics...
Computing argmax metrics...
Evaluation metrics computed successfully!


{'sktr': {'acc_micro': 79.52947607434194,
  'edit': 69.50820439261662,
  'f1@10': 78.99752033632461,
  'f1@25': 76.70559058468602,
  'f1@50': 68.23558107211079},
 'argmax': {'acc_micro': 82.20082415272054,
  'edit': 57.29812374180203,
  'f1@10': 67.4786017129303,
  'f1@25': 66.02423900039784,
  'f1@50': 58.52879216083951}}

In [9]:
# Compute metrics
case_order = ['30', '17', '9', '8', '20', '7', '23', '5', '28', '2', '1', '0', '13', '36', '33', '3',
 '14', '10', '31', '22', '34', '38', '37', '6', '24', '27', '21', '15', '11', '19', '16',
 '12', '32', '25', '35', '39', '26', '29', '4', '18']
 
results = compute_kari_metrics(
    pkl_file_path='kari_results_50salads_complete.pkl',
    df=df,
    case_id_order=case_order,
    method_name='kari',
    background=None
)

results

{'kari': {'acc_micro': 79.36548650239678,
  'edit': 76.21105357752282,
  'f1@10': 79.55645941148791,
  'f1@25': 78.49189068994244,
  'f1@50': 69.29259384569326}}

In [10]:
import pandas as pd

results = {
    "sktr": {
        "acc_micro": 79.52947607434194,
        "edit": 69.50820439261662,
        "f1@10": 78.99752033632461,
        "f1@25": 76.70559058468602,
        "f1@50": 68.23558107211079,
    },
    "argmax": {
        "acc_micro": 82.20082415272054,
        "edit": 57.29812374180203,
        "f1@10": 67.4786017129303,
        "f1@25": 66.02423900039784,
        "f1@50": 58.52879216083951,
    },
    "kari": {
        "acc_micro": 79.36548650239678,
        "edit": 76.21105357752282,
        "f1@10": 79.55645941148791,
        "f1@25": 78.49189068994244,
        "f1@50": 69.29259384569326,
    },
}

df = pd.DataFrame(results).T

# Reorder so 'argmax' is first
df = df.reindex(["argmax", "sktr", "kari"])

print(df.round(2))


        acc_micro   edit  f1@10  f1@25  f1@50
argmax      82.20  57.30  67.48  66.02  58.53
sktr        79.53  69.51  79.00  76.71  68.24
kari        79.37  76.21  79.56  78.49  69.29


## Analysis of results for complete 50 Salads

In [None]:
from src.utils import prepare_df, group_cases_by_trace

# Load ground-truth df and softmax list
result = prepare_df('50salads')
if len(result) == 2:
    df, softmax_lst = result
else:
    df, softmax_lst, _ = result

# group by trace and inspect
trace_groups = group_cases_by_trace(df)
trace_groups

Unnamed: 0,case_list,trace_length
0,"[0, 1, 2, 3]",5687
1,"[32, 33, 34, 35]",6186
2,"[36, 37, 38, 39]",5840
3,"[28, 29, 30, 31]",5261
4,"[4, 5, 6, 7]",6208
5,"[16, 17, 18, 19]",6293
6,"[24, 25, 26, 27]",6046
7,"[8, 9, 10, 11]",6584
8,"[12, 13, 14, 15]",5558
9,"[20, 21, 22, 23]",5792


In [18]:
# Load KARI results
with open('kari_results_50salads_complete.pkl', 'rb') as f:
    kari_results = pickle.load(f)

In [29]:
# Add KARI accuracies to trace_groups (handle existing column)
import numpy as np
import pandas as pd
import pickle

# Load KARI results
with open('kari_results_50salads_complete.pkl', 'rb') as f:
    kari_results = pickle.load(f)

# Get case order from the recovery results CSV
results_df = pd.read_csv("recovery_results_50salads_complete_15.csv")
case_order = results_df['case:concept:name'].astype(str).drop_duplicates().tolist()

print(f"Total cases in CSV: {len(case_order)}")
print(f"Total KARI results: {len(kari_results)}")

# Use positional alignment (kari_results[i] corresponds to case_order[i])
assert len(kari_results) == len(case_order), f"Mismatch: {len(kari_results)} KARI results vs {len(case_order)} cases"

# Extract KARI predictions and compute accuracies
kari_accuracies = {}
for i, (case_id, result) in enumerate(zip(case_order, kari_results)):
    # Get ground truth sequence for this case
    case_df = df[df['case:concept:name'] == case_id]
    gt_sequence = case_df['concept:name'].tolist()

    if not gt_sequence:
        kari_accuracies[case_id] = np.nan
        continue

    # Try both 'sequence' and 'labels' keys - user mentioned 'labels' might be correct
    kari_sequence = None

    # First try 'labels' key (user suggestion)
    if 'labels' in result:
        labels = result['labels']
        if isinstance(labels, (list, tuple, np.ndarray)):
            kari_sequence = [str(x) for x in labels]
        elif isinstance(labels, str):
            kari_sequence = labels.split()
        else:
            kari_sequence = [str(labels)]

    # If 'labels' didn't work or doesn't exist, try 'sequence' key
    if kari_sequence is None:
        if 'sequence' in result:
            sequence_str = result['sequence']
            if isinstance(sequence_str, str):
                kari_sequence = sequence_str.split()
            else:
                kari_sequence = [str(sequence_str)]

    if not kari_sequence:
        kari_accuracies[case_id] = np.nan
        continue

    # Compare complete sequences - KARI predictions vs ground truth
    min_len = min(len(gt_sequence), len(kari_sequence))
    if min_len > 0:
        # Convert to strings for comparison
        gt_str = [str(x) for x in gt_sequence[:min_len]]
        kari_str = [str(x) for x in kari_sequence[:min_len]]

        correct = sum(1 for gt, pred in zip(gt_str, kari_str) if gt == pred)
        accuracy = correct / min_len
        kari_accuracies[case_id] = round(accuracy, 2)
    else:
        kari_accuracies[case_id] = np.nan

# Add KARI accuracies to trace_groups (handle existing column)
def format_kari_acc_list(cases, acc_map):
    values = [acc_map.get(str(case_id), np.nan) for case_id in cases]
    return [np.nan if pd.isna(v) else float(v) for v in values]

# Handle existing column gracefully
if "kari_accuracies" in trace_groups.columns:
    # Column already exists, remove it first
    trace_groups = trace_groups.drop(columns=["kari_accuracies"])

# Insert KARI column right after SKTR
sktr_pos = trace_groups.columns.get_loc("sktr_accuracies")
trace_groups.insert(sktr_pos + 1, "kari_accuracies",
                   trace_groups["case_list"].apply(lambda cases: format_kari_acc_list(cases, kari_accuracies)))

print(f"Computed KARI accuracies for {len([v for v in kari_accuracies.values() if not pd.isna(v)])} cases")
print("Sample KARI accuracies:", list(kari_accuracies.values())[:10])

trace_groups


Total cases in CSV: 40
Total KARI results: 40
Computed KARI accuracies for 40 cases
Sample KARI accuracies: [0.64, 0.74, 0.84, 0.84, 0.88, 0.75, 0.86, 0.75, 0.69, 0.97]


Unnamed: 0,case_list,sktr_accuracies,kari_accuracies,argmax_accuracies,trace_length
0,"[0, 1, 2, 3]","[0.93, 0.94, 0.95, 0.94]","[0.97, 0.92, 0.97, 0.97]","[0.91, 0.93, 0.94, 0.94]",5687
1,"[4, 5, 6, 7]","[0.73, 0.73, 0.74, 0.74]","[0.75, 0.75, 0.75, 0.75]","[0.73, 0.73, 0.74, 0.74]",6208
2,"[8, 9, 10, 11]","[0.23, 0.76, 0.77, 0.83]","[0.84, 0.84, 0.84, 0.84]","[0.82, 0.82, 0.82, 0.83]",6584
3,"[12, 13, 14, 15]","[0.92, 0.93, 0.93, 0.92]","[0.91, 0.91, 0.91, 0.63]","[0.91, 0.92, 0.93, 0.92]",5558
4,"[16, 17, 18, 19]","[0.66, 0.69, 0.69, 0.69]","[0.65, 0.74, 0.71, 0.71]","[0.66, 0.68, 0.69, 0.69]",6293
5,"[20, 21, 22, 23]","[0.91, 0.92, 0.93, 0.94]","[0.88, 0.83, 0.93, 0.86]","[0.91, 0.91, 0.94, 0.94]",5792
6,"[24, 25, 26, 27]","[0.68, 0.68, 0.87, 0.86]","[0.85, 0.69, 0.85, 0.84]","[0.86, 0.87, 0.86, 0.86]",6046
7,"[28, 29, 30, 31]","[0.93, 0.94, 0.94, 0.94]","[0.69, 0.66, 0.64, 0.77]","[0.93, 0.94, 0.94, 0.94]",5261
8,"[32, 33, 34, 35]","[0.76, 0.75, 0.76, 0.76]","[0.88, 0.77, 0.81, 0.81]","[0.76, 0.75, 0.76, 0.76]",6186
9,"[36, 37, 38, 39]","[0.68, 0.69, 0.7, 0.71]","[0.6, 0.62, 0.8, 0.62]","[0.68, 0.68, 0.7, 0.71]",5840


In [None]:
# ['20', '11', '5', '36', '14', '4', '30', '15', '3', '18']


def compute_overall_accuracy(trace_groups, method_col):
    """Compute overall accuracy by averaging all individual case accuracies"""
    all_accuracies = []
    
    for _, row in trace_groups.iterrows():
        accuracies = row[method_col]
        if isinstance(accuracies, list):
            # Filter out NaN values
            valid_accuracies = [acc for acc in accuracies if not pd.isna(acc)]
            all_accuracies.extend(valid_accuracies)
    
    if all_accuracies:
        return round(np.mean(all_accuracies), 4)
    else:
        return np.nan

# Compute overall accuracies
overall_accuracies = {}
overall_accuracies['sktr'] = compute_overall_accuracy(trace_groups, 'sktr_accuracies')
overall_accuracies['kari'] = compute_overall_accuracy(trace_groups, 'kari_accuracies')
overall_accuracies['argmax'] = compute_overall_accuracy(trace_groups, 'argmax_accuracies')

print("Overall Accuracies:")
for method, accuracy in overall_accuracies.items():
    print(f"{method.upper()}: {accuracy}")

# Also compute per-group averages
print("\nPer-group averages:")
for _, row in trace_groups.iterrows():
    case_list = row['case_list']
    print(f"Cases {case_list}:")
    
    for method in ['sktr', 'kari', 'argmax']:
        col_name = f"{method}_accuracies"
        accuracies = row[col_name]
        if isinstance(accuracies, list):
            valid_accuracies = [acc for acc in accuracies if not pd.isna(acc)]
            if valid_accuracies:
                avg_acc = round(np.mean(valid_accuracies), 3)
                print(f"  {method.upper()}: {avg_acc}")
            else:
                print(f"  {method.upper()}: No valid data")
    print()

Overall Accuracies:
SKTR: 0.8018
KARI: 0.794
ARGMAX: 0.8262

Per-group averages:
Cases ['0', '1', '2', '3']:
  SKTR: 0.94
  KARI: 0.958
  ARGMAX: 0.93

Cases ['4', '5', '6', '7']:
  SKTR: 0.735
  KARI: 0.75
  ARGMAX: 0.735

Cases ['8', '9', '10', '11']:
  SKTR: 0.648
  KARI: 0.84
  ARGMAX: 0.822

Cases ['12', '13', '14', '15']:
  SKTR: 0.925
  KARI: 0.84
  ARGMAX: 0.92

Cases ['16', '17', '18', '19']:
  SKTR: 0.682
  KARI: 0.702
  ARGMAX: 0.68

Cases ['20', '21', '22', '23']:
  SKTR: 0.925
  KARI: 0.875
  ARGMAX: 0.925

Cases ['24', '25', '26', '27']:
  SKTR: 0.772
  KARI: 0.808
  ARGMAX: 0.862

Cases ['28', '29', '30', '31']:
  SKTR: 0.938
  KARI: 0.69
  ARGMAX: 0.938

Cases ['32', '33', '34', '35']:
  SKTR: 0.758
  KARI: 0.818
  ARGMAX: 0.758

Cases ['36', '37', '38', '39']:
  SKTR: 0.695
  KARI: 0.66
  ARGMAX: 0.692



In [37]:
with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.width', None, 'display.max_colwidth', None):
    display(results_df[results_df['case:concept:name'] == 8])

Unnamed: 0,case:concept:name,step,sktr_activity,argmax_activity,ground_truth,all_probs,all_activities,is_correct,cumulative_accuracy,sktr_move_cost
18138,8,0,17,17,17,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.01, 0.01, 0.0, 0.01, 0.01, 0.0, 0.0, 0.0, 0.0, 0.0, 0.01, 0.0, 0.91, 0.02]","['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18']",True,1.0,0.09
18139,8,1,17,17,17,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.01, 0.01, 0.0, 0.01, 0.01, 0.0, 0.0, 0.0, 0.01, 0.0, 0.01, 0.0, 0.91, 0.02]","['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18']",True,1.0,0.09
18140,8,2,17,17,17,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.01, 0.01, 0.0, 0.01, 0.01, 0.0, 0.0, 0.0, 0.01, 0.0, 0.01, 0.0, 0.91, 0.02]","['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18']",True,1.0,0.09
18141,8,3,17,17,17,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.01, 0.01, 0.0, 0.01, 0.01, 0.0, 0.0, 0.0, 0.01, 0.0, 0.01, 0.0, 0.91, 0.02]","['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18']",True,1.0,0.09
18142,8,4,17,17,17,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.01, 0.01, 0.0, 0.01, 0.01, 0.0, 0.01, 0.0, 0.01, 0.0, 0.01, 0.0, 0.91, 0.02]","['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18']",True,1.0,0.09
18143,8,5,17,17,17,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.01, 0.01, 0.0, 0.01, 0.01, 0.0, 0.01, 0.0, 0.01, 0.0, 0.01, 0.0, 0.91, 0.02]","['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18']",True,1.0,0.09
18144,8,6,17,17,17,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.01, 0.01, 0.0, 0.01, 0.01, 0.0, 0.01, 0.0, 0.01, 0.0, 0.01, 0.0, 0.91, 0.02]","['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18']",True,1.0,0.09
18145,8,7,17,17,17,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.01, 0.01, 0.0, 0.01, 0.01, 0.0, 0.01, 0.0, 0.01, 0.0, 0.01, 0.0, 0.91, 0.02]","['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18']",True,1.0,0.09
18146,8,8,17,17,17,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.01, 0.01, 0.0, 0.01, 0.01, 0.0, 0.0, 0.0, 0.01, 0.0, 0.01, 0.0, 0.91, 0.02]","['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18']",True,1.0,0.09
18147,8,9,17,17,17,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.01, 0.01, 0.0, 0.01, 0.01, 0.0, 0.0, 0.0, 0.01, 0.0, 0.01, 0.0, 0.91, 0.02]","['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18']",True,1.0,0.09


In [None]:
# Enhanced TAS Comparison Demo
import pandas as pd
from src.utils import compute_comprehensive_tas_comparison, print_tas_comparison

# Load recovery results
results_df = pd.read_csv('recovery_results_50salads_complete_15.csv')

# Compute comprehensive TAS comparison
print("🔍 Computing comprehensive TAS comparison...")
all_metrics = compute_comprehensive_tas_comparison(
    results_df=results_df,
    kari_pkl_path='kari_results_50salads_complete.pkl'
)

# Display enhanced professional table format
print("🎯 Enhanced Professional Table Format:")
print_tas_comparison(all_metrics, sort_by=None, highlight_best=True)

# Quick summary
print("\n📊 Quick Summary:")
print("-" * 50)
sktr_f1 = all_metrics['sktr']['f1@10']
argmax_f1 = all_metrics['argmax']['f1@10']
kari_f1 = all_metrics['kari']['f1@10']

print(f"SKTR F1@10:   {sktr_f1:.3f}")
print(f"Argmax F1@10: {argmax_f1:.3f}")
print(f"KARI F1@10:   {kari_f1:.3f}")

# Show improvements
if sktr_f1 > argmax_f1:
    improvement = sktr_f1 - argmax_f1
    print(f"✅ SKTR improves over Argmax by {improvement:.3f} points")
if kari_f1 > sktr_f1:
    improvement = kari_f1 - sktr_f1
    print(f"✅ KARI improves over SKTR by {improvement:.3f} points")

print("\n✅ All TAS statistics computed successfully!")
print("="*50)


🔍 Computing comprehensive TAS comparison...
Computing comprehensive TAS comparison for 40 cases...
Computing SKTR and argmax metrics...
Computing evaluation metrics for 40 cases...
Computing SKTR metrics...
Computing argmax metrics...
Evaluation metrics computed successfully!
Computing KARI metrics...
✅ Comprehensive TAS comparison completed!
SKTR F1@10: 78.998
Argmax F1@10: 67.479
KARI F1@10: 79.556
🎯 Enhanced Professional Table Format:

TAS comparison (original order)
        acc_micro   edit  f1@10  f1@25  f1@50
argmax      82.20  57.30  67.48  66.02  58.53
sktr        79.53  69.51  79.00  76.71  68.24
kari        79.37  76.21  79.56  78.49  69.29

Best per metric:
  acc_micro: argmax (82.20)
  edit: kari (76.21)
  f1@10: kari (79.56)
  f1@25: kari (78.49)
  f1@50: kari (69.29)

📊 Quick Summary:
--------------------------------------------------
SKTR F1@10:   78.998
Argmax F1@10: 67.479
KARI F1@10:   79.556
✅ SKTR improves over Argmax by 11.519 points
✅ KARI improves over SKTR by 0.

In [None]:
df = pd.read_csv('results/sktr_kari_argmax_50salads_results.csv')
df.head()

Unnamed: 0,case:concept:name,sktr_activity,kari_activity,argmax_activity,ground_truth
0,30,17,17,17,17
1,30,17,17,17,17
2,30,17,17,17,17
3,30,17,17,17,17
4,30,17,17,17,17


In [None]:
kari_sequences = pd.read_csv('results/kari_50salads_sequences_predictions.csv')
kari_sequences.head()

Unnamed: 0,case:concept:name,kari_activity
0,30,17
1,30,11
2,30,13
3,30,4
4,30,5


In [None]:
gt_sequences = pd.read_csv('data/ground_truth_50salads_sequences.csv')
gt_sequences.head()

Unnamed: 0,case:concept:name,ground_truth
0,30,17
1,30,0
2,30,1
3,30,11
4,30,12


In [None]:
import pandas as pd

# Load cleaned, per-step sequences
gt = pd.read_csv("data/ground_truth_50salads_sequences.csv", usecols=["case:concept:name", "ground_truth"])
kp = pd.read_csv("results/kari_50salads_sequences_predictions.csv", usecols=["case:concept:name", "kari_activity"])

# Build per-case sequences (order preserved within each case)
gt_seq = gt.groupby("case:concept:name", sort=False)["ground_truth"].agg(list)
kp_seq = kp.groupby("case:concept:name", sort=False)["kari_activity"].agg(list)

# Compare only cases present in both
common = gt_seq.index.intersection(kp_seq.index)

matches = kp_seq.loc[common].reset_index(drop=True) == gt_seq.loc[common].reset_index(drop=True)
accuracy_percent = float(matches.mean() * 100)

print(f"Cases compared: {len(common)}")
print(f"Exact sequence matches: {int(matches.sum())}")
print(f"Accuracy: {accuracy_percent:.2f}%")

# Optional: see which cases mismatched
mismatched_cases = list(common[~matches.to_numpy()])
print('Mismatched case IDs:', mismatched_cases[:10])


Cases compared: 40
Exact sequence matches: 3
Accuracy: 7.50%
Mismatched case IDs: [30, 17, 9, 8, 20, 7, 23, 5, 28, 1]


In [None]:
import pandas as pd

gt = pd.read_csv("data/ground_truth_50salads_sequences.csv", usecols=["case:concept:name", "ground_truth"])
kp = pd.read_csv("results/kari_50salads_sequences_predictions.csv", usecols=["case:concept:name", "kari_activity"])

# Build per-case sequences
gt_seq = gt.groupby("case:concept:name", sort=False)["ground_truth"].agg(list)
kp_seq = kp.groupby("case:concept:name", sort=False)["kari_activity"].agg(list)

# Compare only cases present in both
common = gt_seq.index.intersection(kp_seq.index)

def is_subsequence(short, long):
    i = 0
    for x in long:
        if i < len(short) and x == short[i]:
            i += 1
            if i == len(short):
                return True
    return i == len(short)

# Exact equality
exact = (kp_seq.loc[common] == gt_seq.loc[common])

# Ground truth is subsequence of Kari (order preserved) → correct or correct+extras
superset_or_equal = pd.Series(
    [is_subsequence(gt_seq.loc[c], kp_seq.loc[c]) for c in common],
    index=common
)

print(f"Cases compared: {len(common)}")
print(f"Exact matches: {int(exact.sum())} ({100*exact.mean():.2f}%)")
print(f"Correct-or-with-additions: {int(superset_or_equal.sum())} ({100*superset_or_equal.mean():.2f}%)")
print(f"Additional-only (correct but extra): {int((superset_or_equal & ~exact).sum())}")

Cases compared: 40
Exact matches: 3 (7.50%)
Correct-or-with-additions: 4 (10.00%)
Additional-only (correct but extra): 1
