In [63]:
import requests
import json
from time import sleep
from calculate_annotator_agreement import extract_score
import scipy.stats as stats
import matplotlib.pyplot as plt
import numpy as np

url = "https://api.contextual.ai/v1/lmunit"

headers = {
    "accept": "application/json",
    "content-type": "application/json",
    "authorization": "Bearer key-CLRoggUEDxqJn3DHU6hPHk3R5f6KL98IEgDBpISri1Iwp8ptg"
}

#### Set Model and File Locations

In [64]:
dials = {
    "tau_retail_eval_json": "results/20250131_152422-tau-4o-retail/tau-gpt-4o_j.json",
    "tau_air_eval_json": "results/20250131_152503-tau-4o-airline/tau-gpt-4o_j.json",
}

#### Load and Filter Dialogue Data For Evaluation

In [65]:
def load_filter_dials(tau_retail_eval_json, tau_air_eval_json, batch_json_path):
    import os
    
    for p in (tau_retail_eval_json, tau_air_eval_json, batch_json_path):
        if not os.path.exists(p):
            raise FileNotFoundError(f"Couldn't find {p!r}")

    
    with open(tau_retail_eval_json, 'r', encoding='utf-8') as f:
        tau_retail_dials = json.load(f)['dialogues']
    with open(tau_air_eval_json, 'r', encoding='utf-8') as f:
        tau_air_dials    = json.load(f)['dialogues']

    
    with open(batch_json_path, 'r', encoding='utf-8') as f:
        batch_list = json.load(f)
    batch_dials = {}
    for bid in batch_list["tau"]["retail"]:
        if bid in tau_retail_dials:
            batch_dials[bid] = tau_retail_dials[bid]
    for bid in batch_list["tau"]["airline"]:
        if bid in tau_air_dials:
            batch_dials[bid] = tau_air_dials[bid]
    return batch_dials


#### Evaluate LMUnit Score on Dialogues

In [66]:
conv_qs = [
    "Does the response directly relate to the dialogue history and the current user query?",
    "Does the response remain on-topic with the dialogue history and the user query?",
    "Does the response logically continue the progression of the dialogue?"
]
backend_qs = [
    "Does the response accurately reflect the information in the database results?",
    "Does the response stay on-topic with the database results and the dialogue context?",
    "Does response logically incorporate and progress based on the database results?"
]
policy_qs = [
    "Does the response provide suggestions only when the database results are few enough to do so?",
    "Does the response request required, relevant information from the user before offering suggestions or booking services?",
    "Does the response avoid premature actions (i.e. make a booking or suggest a service) too early in the conversation, before the necessary information is gathered?"
]

In [67]:
def eval_dials_lmunit(batch_dials):
    """Run each turn through the LM‑unit API with retry logic for rate limits."""
    import time
    import random
    
    lmunit_scores = {}
    for dial_id, turns in batch_dials.items():
        print(f"Processing dialogue {dial_id}...")
        lmunit_scores[dial_id] = []
        
        for turn_idx, turn in enumerate(turns):
            print(f"  Turn {turn_idx+1}/{len(turns)}")
            history = turn["conversation_history"] + "\nCustomer: " + turn["user"]
            response = turn.get("lex_response", turn.get("response", ""))
            db_content = json.dumps(turn.get("db", {}))

            def make_api_call(payload, max_retries=5):
                retries = 0
                while retries < max_retries:
                    try:
                        r = requests.post(url, json=payload, headers=headers)
                        data = r.json()
                        
                        if "score" in data:
                            return float(data["score"])
                        elif "detail" in data and "Too Many Requests" in str(data["detail"]):
                            retry_seconds = 1
                            if isinstance(data["detail"], str) and "Retry after" in data["detail"]:
                                try:
                                    retry_seconds = int(data["detail"].split("Retry after")[1].split("seconds")[0].strip())
                                except:
                                    pass
                            
                            sleep_time = retry_seconds + random.uniform(0.1, 1.0)
                            print(f"    Rate limited. Sleeping for {sleep_time:.2f}s (retry {retries+1}/{max_retries})")
                            time.sleep(sleep_time)
                            retries += 1
                        else:
                            print(f"    Error in API response: {data}")
                            return 0.0
                    except Exception as e:
                        print(f"    Exception during API call: {e}")
                        time.sleep(2)
                        retries += 1
                
                print(f"    Maximum retries reached for API call.")
                return 0.0
            
            #conversation consistency
            conv_scores = []
            for q in conv_qs:
                payload = {"query": history, "response": response, "unit_test": q}
                score = make_api_call(payload)
                conv_scores.append(score)
                time.sleep(0.5)  #delay 
            conv_score = sum(conv_scores) / len(conv_scores) if conv_scores else 0
            
            # backend knowledge consistency - FIX THE PAYLOAD FORMAT
            backend_scores = []
            for q in backend_qs:
                payload = {
                    "query": f"{history}\nDatabase result: {db_content}",
                    "response": response,
                    "unit_test": q
                }
                score = make_api_call(payload)
                backend_scores.append(score)
                time.sleep(0.5)
            backend_score = sum(backend_scores) / len(backend_scores) if backend_scores else 0
            
            #policy compliance
            policy_scores = []
            for q in policy_qs:
                payload = {
                    "query": f"{history}\nDatabase result: {db_content}",
                    "response": response,
                    "unit_test": q
                }
                score = make_api_call(payload)
                policy_scores.append(score)
                time.sleep(0.5)
            policy_score = sum(policy_scores) / len(policy_scores) if policy_scores else 0
            
            # compile scores
            lmunit_scores[dial_id].append({
                "conv_consistency": round(conv_score, 2),
                "backend_consistency": round(backend_score, 2),
                "policy_completeness": round(policy_score, 2)
            })
            
            # Print current scores summary
            print(f"    Scores - Conv: {round(conv_score, 2)}, Backend: {round(backend_score, 2)}, Policy: {round(policy_score, 2)}")
            # Longer pause between turns
            time.sleep(3)
        
        # Pause between dialogues
        time.sleep(5)
    
    return lmunit_scores

Mock data for testing


In [68]:
def mock_eval_dials_lmunit(batch_dials):
    """Generate mock scores for testing display without API calls."""
    import random
    
    lmunit_scores = {}
    for dial_id, turns in batch_dials.items():
        lmunit_scores[dial_id] = []
        for _ in turns:
            # Generate random scores between 0.6 and 1.0
            lmunit_scores[dial_id].append({
                "conv_consistency": round(random.uniform(0.6, 1.0), 2),
                "backend_consistency": round(random.uniform(0.6, 1.0), 2),
                "policy_completeness": round(random.uniform(0.6, 1.0), 2)
            })
    return lmunit_scores

Print scores

In [69]:
from IPython.display import display, HTML
import pandas as pd

def display_scores(scores_dict, show_averages=True):
    """Display evaluation scores in a nice format."""

    rows = []
    for dial_id, turns in scores_dict.items():
        for turn_idx, scores in enumerate(turns):
            row = {
                "Dialogue ID": dial_id,
                "Turn": turn_idx + 1,
                "Conversation Consistency": scores["conv_consistency"],
                "Backend Consistency": scores["backend_consistency"],
                "Policy Completeness": scores["policy_completeness"],
                "Average Score": round((scores["conv_consistency"] + 
                                      scores["backend_consistency"] + 
                                      scores["policy_completeness"])/3, 2)
            }
            rows.append(row)
    
  
    df = pd.DataFrame(rows)
    
    # Overall statistics
    if show_averages and len(df) > 0:
        display(HTML("<h3>LMUnit Evaluation Results</h3>"))
        display(df)
        
        display(HTML("<h3>Summary Statistics</h3>"))
        summary_df = df.groupby("Dialogue ID")[
            ["Conversation Consistency", "Backend Consistency", 
             "Policy Completeness", "Average Score"]
        ].mean().round(2)
        
        # Add overall average row
        overall_avg = summary_df.mean().round(2)
        summary_df.loc["OVERALL"] = overall_avg
        
        display(summary_df)
    else:
        display(df)
    
    return df

# Display the results
batch_path = "datasets/main_human_eval/batch.json"
filtered_dials = load_filter_dials(
    dials["tau_retail_eval_json"],
    dials["tau_air_eval_json"],
    batch_path
)


real_results = eval_dials_lmunit(filtered_dials)


# save results to file
with open("lmunit_scores.json", "w") as f:
    json.dump(real_results, f, indent=2)

Processing dialogue 6...
  Turn 1/8
    Scores - Conv: 4.38, Backend: 2.99, Policy: 3.65
  Turn 2/8
    Scores - Conv: 4.69, Backend: 4.75, Policy: 4.51
  Turn 3/8
    Scores - Conv: 4.85, Backend: 3.72, Policy: 4.55
  Turn 4/8
    Scores - Conv: 4.22, Backend: 4.19, Policy: 2.93
  Turn 5/8
    Scores - Conv: 4.76, Backend: 2.7, Policy: 3.28
  Turn 6/8
    Scores - Conv: 4.86, Backend: 3.86, Policy: 3.71
  Turn 7/8
    Scores - Conv: 4.6, Backend: 4.41, Policy: 3.48
  Turn 8/8
    Scores - Conv: 4.92, Backend: 4.06, Policy: 3.64
Processing dialogue 10...
  Turn 1/5
    Scores - Conv: 4.33, Backend: 2.92, Policy: 3.85
  Turn 2/5
    Rate limited. Sleeping for 1.61s (retry 1/5)
    Rate limited. Sleeping for 1.38s (retry 1/5)
    Scores - Conv: 4.81, Backend: 4.27, Policy: 3.79
  Turn 3/5
    Rate limited. Sleeping for 1.38s (retry 1/5)
    Rate limited. Sleeping for 1.57s (retry 1/5)
    Rate limited. Sleeping for 1.39s (retry 1/5)
    Scores - Conv: 4.67, Backend: 4.77, Policy: 4.63
  