# Homework 5: Multimodal

In [None]:
import polars as pl
from polars import DataFrame
from pathlib import Path

## Load Labeled Traces

In [None]:
# Read the JSONL file into a Polars DataFrame
traces_path = Path("reference_files/labeled_traces.jsonl")
df: DataFrame = pl.read_ndjson(traces_path)

# Display basic info
print(f"Shape: {df.shape}")
print(f"\nColumns: {df.columns}")
df.head()

In [None]:
# Display schema to see nested structure
df.schema

## Pipeline State Taxonomy

| # | State | Description |
|---|--------------------|-------------------------------------------|
| 1 | `ParseRequest`     | LLM interprets the user's message         |
| 2 | `PlanToolCalls`    | LLM decides which tools to invoke         |
| 3 | `GenCustomerArgs`  | LLM constructs arguments for customer DB  |
| 4 | `GetCustomerProfile` | Executes customer-profile tool         |
| 5 | `GenRecipeArgs`    | LLM constructs arguments for recipe DB    |
| 6 | `GetRecipes`       | Executes recipe-search tool               |
| 7 | `GenWebArgs`       | LLM constructs arguments for web search   |
| 8 | `GetWebInfo`       | Executes web-search tool                  |
| 9 | `ComposeResponse`  | LLM drafts the final answer               |
|10 | `DeliverResponse`  | Agent sends the answer                    |

In [None]:
# Define Pipeline State Taxonomy order (1-10)
state_order = [
    'ParseRequest',
    'PlanToolCalls',
    'GenCustomerArgs',
    'GetCustomerProfile',
    'GenRecipeArgs',
    'GetRecipes',
    'GenWebArgs',
    'GetWebInfo',
    'ComposeResponse',
    'DeliverResponse'
]

state_numbers = {b:a+1 for a,b in enumerate(state_order)}
state_numbers

## Extract State Sequences

In [None]:
import re

def extract_state_sequence(messages):
    """
    Extract numbered state sequence with USER_N and ASSISTANT_N labels.
    
    Args:
        messages: List of message dicts with 'role' and 'content' keys
    
    Returns:
        List of state names/turns in order:
        - "USER_1", "USER_2", etc. for user messages
        - "ASSISTANT_1", "ASSISTANT_2", etc. for assistant messages without TOOL_CALL
        - State name (e.g., "ParseRequest") for TOOL_CALL messages
    """
    states = []
    pattern = r'TOOL_CALL\[([^\]]+)\]'
    user_count = 0
    assistant_count = 0
    
    for msg in messages:
        role = msg['role']
        content = msg['content']
        
        # Check if this is a TOOL_CALL message
        match = re.search(pattern, content)
        if match:
            state_name = match.group(1)
            states.append(state_name)
        else:
            # Add numbered role-based labels for non-tool-call messages
            if role.lower() == 'user':
                user_count += 1
                states.append(f'USER_{user_count}')
            elif role.lower() in ['assistant', 'agent']:
                assistant_count += 1
                states.append(f'ASSISTANT_{assistant_count}')
    
    return states


# Add state sequence column to dataframe
df_with_states: DataFrame = df.with_columns(
    pl.col('messages').map_elements(
        extract_state_sequence, 
        return_dtype=pl.List(pl.String)
    ).alias('state_sequence')
)

def truncate_at_first_failure(state_sequence, first_failure_state):
    """
    Truncate state sequence after the first occurrence of the failure state.
    
    Args:
        state_sequence: List of state strings
        first_failure_state: The failure state to look for
    
    Returns:
        Truncated list including states up to and including the first failure
    """
    try:
        failure_idx = state_sequence.index(first_failure_state)
        return state_sequence[:failure_idx + 1]  # Include the failure state
    except ValueError:
        # Failure state not in sequence, return full sequence
        return state_sequence


# Apply truncation to create truncated_sequence column
df_with_states: DataFrame = df_with_states.with_columns(
    pl.struct(['state_sequence', 'first_failure_state'])
    .map_elements(
        lambda x: truncate_at_first_failure(x['state_sequence'], x['first_failure_state']),
        return_dtype=pl.List(pl.String)
    ).alias('truncated_sequence')
)

# Analyze sequence lengths
df_with_states: DataFrame = df_with_states.with_columns(
    pl.col('state_sequence').list.len().alias('seq_len_all'),
    pl.col('truncated_sequence').list.len().alias('seq_len_working')
)

# Display examples
print("Sample state sequences with numbered USER/ASSISTANT turns:")
df_with_states.select(['conversation_id', 'truncated_sequence', 'last_success_state', 'first_failure_state', 'seq_len_all', 'seq_len_working']).head(5)

# State Pairs

These are interesting, the top pairs are going to be the ones we mostly care about, but there is a long tail as well.

The tail includes some transitions you might not expect:
* PlanToolCalls -> DeliverResponse (giving up immediately eh!)
* GenCustomerArgs -> ComposeResponse

... why is it ever allowed to call gen args and then not call the tool?

In [None]:
# Count occurrences of each (last_success_state, first_failure_state) pairing
state_pairs: DataFrame = (
    df_with_states.group_by(
        [
            "last_success_state",
            "first_failure_state",
            #"seq_len_working"
        ]
    )
    .agg(pl.len().alias("count"))
    .sort("count", descending=True)
).with_columns([
    (
        pl.col("first_failure_state").replace_strict(state_numbers, default=None) -
        pl.col("last_success_state").replace_strict(state_numbers, default=None)
    ).alias("state_jump")
])

print(f"Total unique pairings: {len(state_pairs)}")
state_pairs

## State Transition Heatmap

Looking at this heatmap, the failures are around generating the arguments for the search tool, and then the handling of getting no results. Past that, to the right, failures are less common.

In [None]:
import altair as alt
# Create heatmap showing prevalence of state pairings


heatmap = alt.Chart(state_pairs).mark_rect().encode(
    x=alt.X('first_failure_state:N',
            title='First Failure State',
            axis=alt.Axis(labelAngle=-45),
            sort=state_order),
    y=alt.Y('last_success_state:N',
            title='Last Success State',
            sort=state_order),
    color=alt.Color('count:Q',
                    title='Count',
                    scale=alt.Scale(scheme='viridis')),
    tooltip=[
        alt.Tooltip('first_failure_state:N', title='First Failure'),
        alt.Tooltip('last_success_state:N', title='Last Success'),
        alt.Tooltip('count:Q', title='Count')
    ]
).properties(
    width=500,
    height=400,
    title='State Transition Prevalence: Last Success → First Failure'
)

heatmap

In [None]:
from polars import DataFrame

# TODO - make this highlight the error step
# TODO - make this put out the row number in the dataframe - the viewer page doesnt take the id,

def get_examples_for_top_pairings(
    df: DataFrame,
    state_pairs: DataFrame,
    top_n: int = 5,
    examples_per_pairing: int = 2,
):
    """
    Get example conversations for the most prevalent state pairings.

    Args:
        df: Original dataframe with conversation data
        state_pairs: Dataframe with state pairing counts
        top_n: Number of top pairings to show examples for
        examples_per_pairing: Number of example conversations per pairing

    Returns:
        Dictionary mapping (last_success, first_failure) to list of example conversations
    """
    examples: dict[Unknown, Unknown] = {}

    # Get top N pairings
    top_pairings = state_pairs.head(top_n)

    for row in top_pairings.iter_rows(named=True):
        last_success = row["last_success_state"]
        first_failure = row["first_failure_state"]
        count = row["count"]

        # Filter conversations matching this pairing
        matching = df.filter(
            (pl.col("last_success_state") == last_success)
            & (pl.col("first_failure_state") == first_failure)
        ).head(examples_per_pairing)

        # Store examples
        key = (last_success, first_failure, count)
        examples[key] = matching.to_dicts()

    return examples

# TODO - make this render something nicer than plain text - can we generate markdown in jupyter from python and have it render?
def print_example_conversations(examples):
    """
    Pretty print example conversations for state pairings.
    """
    for (last_success, first_failure, count), conversations in examples.items():
        print(f"\n{'=' * 80}")
        print(f"Pairing: {last_success} → {first_failure} (Count: {count})")
        print("=" * 80)

        for i, conv in enumerate(conversations, 1):
            print(f"\n--- Example {i} ---")
            print(f"Conversation ID: {conv['conversation_id']}")
            print(f"\nMessages:")
            for msg in conv["messages"]:
                role = msg["role"].upper()
                content = msg["content"]
                print(f"  [{role}]: {content}")
            print()

# "PlanToolCalls" -> "GenRecipeArgs"	10

Looking at a few of these perhaps the issue is that the users are expressing some requirement which doesnt align with the tool arguments?

> Error: unable to generate appropriate recipe search parameters

> Error: Unable to generate recipe search parameters

> Error: Unable to generate recipe arguments from the provided request.

> Error: unable to generate recipe arguments due to missing or incompatible input data.

> Error: unable to generate appropriate recipe arguments based on the provided dietary and serving information.

In [None]:
examples = get_examples_for_top_pairings(df, state_pairs[0], top_n=1, examples_per_pairing=5)
print_example_conversations(examples)


# "GetCustomerProfile" -> "GetRecipes"	9

* Several of these are the same query (` I need a gluten-free dinner idea for four`) - and it looks like there arent any in the db.
    * Not sure this is a _failure_ - its correctly saying it doesnt have any recipes for that requirement
* `7e827a7a-6114-460e-844d-9afcad1410b8` looks like perhaps the retriever was broken at the point of query

In [None]:
examples = get_examples_for_top_pairings(df, state_pairs[1], top_n=1, examples_per_pairing=5)
print_example_conversations(examples)

# "GenCustomerArgs"	-> "GetRecipes"	8

So - these look like the same failure modes on the GetRecipes, its just that for whatever reason the planner didnt put in the GetCustomerProfile tool call first for these.

In [None]:
examples = get_examples_for_top_pairings(df, state_pairs[2], top_n=1, examples_per_pairing=5)
print_example_conversations(examples)