In [None]:
import os
from dotenv import load_dotenv
import google.generativeai as genai

load_dotenv()

GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
genai.configure(api_key=GEMINI_API_KEY)

EMBEDDING_MODEL = "models/text-embedding-004" # Or latest embedding model
GENERATION_MODEL = "models/gemini-2.0-flash" # Or another suitable Gemini model

YNAB_API_KEY = os.getenv("YNAB_API_KEY")
YNAB_BUDGET_ID = os.getenv("YNAB_BUDGET_ID")
YNAB_BASE_URL = "https://api.ynab.com/v1"


In [66]:
import pandas as pd
import requests
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity # Or use numpy
import os
import json
import time

In [67]:
def get_ynab_headers():
    return {"Authorization": f"Bearer {YNAB_API_KEY}"}

def get_ynab_categories(budget_id):
    """Fetches category groups and categories from YNAB."""
    url = f"{YNAB_BASE_URL}/budgets/{budget_id}/categories"
    response = requests.get(url, headers=get_ynab_headers())
    response.raise_for_status() # Raise exception for bad status codes
    data = response.json()['data']['category_groups']

    categories = []
    category_map = {} # Map category name to ID if needed later
    for group in data:
        if not group.get('hidden', False): # Ignore hidden groups
             for category in group.get('categories', []):
                 if not category.get('hidden', False) and not category.get('deleted', False) : # Ignore hidden/deleted categories
                     full_name = f"{group['name']}: {category['name']}"
                     categories.append(full_name)
                     category_map[full_name] = category['id']
    return categories, category_map

In [None]:
# Get Valid YNAB Categories
print(f"Fetching categories for budget ID: {YNAB_BUDGET_ID}...")
try:
    valid_categories, category_map = get_ynab_categories(YNAB_BUDGET_ID)
    print(f"Found {len(valid_categories)} valid categories.")
    if not valid_categories:
            print("Error: No valid categories found. Check YNAB setup and budget ID.")
except Exception as e:
    print(f"Error fetching YNAB categories: {e}")
    raise

In [None]:
valid_categories.remove('Internal Master Category: Uncategorized')
valid_categories.remove('Internal Master Category: Deferred Income SubCategory')
print('\n'.join(valid_categories))

In [None]:
API_ENDPOINT = f"{YNAB_BASE_URL}/budgets/{YNAB_BUDGET_ID}/transactions"
HEADERS = {
    "Authorization": f"Bearer {YNAB_API_KEY}"
}

# --- Function to Fetch Transactions ---
def get_transactions(endpoint, headers, params):
    """Fetches uncategorized transactions from the YNAB API."""
    print(f"Fetching transactions from budget: {YNAB_BUDGET_ID}...")
    try:
        response = requests.get(endpoint, headers=headers, params=params)
        response.raise_for_status() # Raises HTTPError for bad responses (4XX, 5XX)

        print("Successfully fetched data.")
        data = response.json()

        if 'data' in data and 'transactions' in data['data']:
            return data['data']['transactions']
        else:
            print("Error: Unexpected response format from YNAB API.")
            print("Response:", data)
            return []

    except requests.exceptions.RequestException as e:
        print(f"Error during API request: {e}")
        if hasattr(e, 'response') and e.response is not None:
            print(f"Status Code: {e.response.status_code}")
            try:
                print(f"Response Body: {e.response.json()}")
            except ValueError: # If response body isn't valid JSON
                print(f"Response Body: {e.response.text}")
            if e.response.status_code == 401:
                print("\nHint: Check if your YNAB_API_TOKEN is correct and hasn't expired.")
            elif e.response.status_code == 404:
                 print(f"\nHint: Check if your BUDGET_ID ('{BUDGET_ID}') is correct.")
        return None
    except ValueError: # Includes JSONDecodeError
        print("Error: Could not decode JSON response from YNAB API.")
        return None

# --- Main Execution ---
transactions_list = get_transactions(API_ENDPOINT, HEADERS, PARAMS)

if transactions_list is not None:
    if transactions_list:
        print(f"Found {len(transactions_list)} transactions.")
        # Create Pandas DataFrame
        df = pd.DataFrame(transactions_list)

        # Optional: Convert amount from milliunits to standard currency units
        # YNAB API returns amounts in milliunits (amount * 1000)
        if 'amount' in df.columns:
            df['amount'] = df['amount'] / 1000.0

        # Optional: Convert date string to datetime objects
        if 'date' in df.columns:
             df['date'] = pd.to_datetime(df['date'])

        # Optional: Select and reorder columns for better readability
        columns_to_show = [
            'date', 'payee_name', 'memo', 'amount', 'cleared',
            'account_name', 'category_name', 'id', 'account_id',
            'transfer_account_id', 'import_id', 'deleted', 'approved'
            # Add or remove columns as needed
        ]
        # Filter to only include columns that actually exist in the dataframe
        df_display = df[[col for col in columns_to_show if col in df.columns]]

        # You can now work with the DataFrame 'df' or 'df_display'
        # e.g., df.to_csv('uncategorized_transactions.csv', index=False)
    else:
        print("No transactions found.")
else:
    print("Failed to retrieve transactions.")

In [None]:
# df_display.to_pickle("transactions.pkl")

In [None]:
# df = pd.read_pickle("transactions.pkl")

In [22]:
def get_embeddings(examples, task_type="RETRIEVAL_DOCUMENT"):
    """Generates embedding for the given text examples."""
    try:
        adjusted_examples = {}
        results = {}

        for text in examples:
            # Limit text size if necessary (check model limits)
            # Simple truncation, smarter chunking might be better if needed
            max_length = 1800 # Example limit, adjust as needed
            truncated_text = text[:max_length] if len(text) > max_length else text
            truncated_text = truncated_text.strip()
            adjusted_examples[text] = truncated_text

        output = genai.embed_content(
            model=EMBEDDING_MODEL,
            content=list(adjusted_examples.values()),
            task_type=task_type # Use RETRIEVAL_DOCUMENT for items to be indexed
            # Use RETRIEVAL_QUERY for the item you want to find neighbors for
        )

        for text, embed in zip(adjusted_examples.keys(), output['embedding'], strict=True):
            results[text] = embed

        return results
    except Exception as e:
        print(f"Error getting embedding for embeddings: {examples}")
        raise


def get_batch_embeddings(texts, task_type="RETRIEVAL_DOCUMENT", batch_size=250):
    """Generates embeddings for a list of texts in batches."""
    all_embeddings = {} # Store as dict: {text: embedding}
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i+batch_size]
        valid_texts = [t for t in batch_texts if t.strip()] # Skip empty
        if not valid_texts:
            continue

        try:
            # Note: Batch embedding API might differ slightly, check docs
            # This example uses single embeddings in a loop for simplicity
            # For true batching, use appropriate API if available or adapt.
            print(f"Processing batch {i//batch_size + 1}...")
            results = get_embeddings(valid_texts, task_type)
            all_embeddings.update(results)
            time.sleep(0.2) # Avoid hitting rate limits aggressively

        except Exception as e:
            print(f"Error processing batch starting at index {i}: {e}")
            raise
            # Handle partial failures if needed

    # Map back to original texts, handling potential None values from errors/empty strings
    results = [all_embeddings.get(text) for text in texts]
    print(f"Generated {len([e for e in results if e])} embeddings out of {len(texts)} texts.")
    return results

In [23]:
# Generate Embeddings (Batching is recommended for large datasets)

# print("Generating embeddings for all transactions...")
# # Prepare texts for embedding - use the combined field
# all_texts = df['payee_name'].tolist()
# # Use batch embedding function
# all_embeddings = get_batch_embeddings(all_texts, task_type="RETRIEVAL_DOCUMENT")
# df['embedding'] = all_embeddings

Generating embeddings for all transactions...
Processing batch 1...
Processing batch 2...
Processing batch 3...
Processing batch 4...
Processing batch 5...
Processing batch 6...
Processing batch 7...
Processing batch 8...
Processing batch 9...
Processing batch 10...
Processing batch 11...
Processing batch 12...
Processing batch 13...
Processing batch 14...
Processing batch 15...
Processing batch 16...
Processing batch 17...
Processing batch 18...
Processing batch 19...
Processing batch 20...
Processing batch 21...
Processing batch 22...
Processing batch 23...
Processing batch 24...
Processing batch 25...
Processing batch 26...
Processing batch 27...
Processing batch 28...
Processing batch 29...
Processing batch 30...
Processing batch 31...
Processing batch 32...
Processing batch 33...
Processing batch 34...
Processing batch 35...
Processing batch 36...
Processing batch 37...
Processing batch 38...
Processing batch 39...
Processing batch 40...
Processing batch 41...
Processing batch 42.

In [None]:
# df.to_pickle("transactions_embedded.pkl")

In [71]:
df = pd.read_pickle("transactions_embedded.pkl")

In [72]:
df['is_categorized'] = df['category_name'] != 'Uncategorized'
df['chosen_category'] = None

In [73]:
def classify_with_gemini(target_transaction_text, examples, category_list):
    """Uses Gemini with few-shot examples to classify a transaction."""
    model = genai.GenerativeModel(GENERATION_MODEL)

    valid_categories = "\n".join([f'- "{cat}"' for cat in category_list]) # Format for prompt

    example_prompt_part = "\n".join(
        [f"- Text: '{ex['text']}', Amount: ${ex['amount']:.2f}, Assigned Category: '{ex['category']}'" for ex in examples]
    )

    prompt = f"""You are an expert financial transaction categorizer. Your goal is to assign the most likely category to the 'Target Transaction' based on its text (combined Payee and Memo) and amount. Use the provided examples of previously categorized transactions for guidance.

Available Categories:
{valid_categories}

Examples of Categorized Transactions:
{example_prompt_part}

Target Transaction:
- Text: '{target_transaction_text['text']}', Amount: ${target_transaction_text['amount']:.2f}

Based ONLY on the information provided and the list of Available Categories, predict the top 3 most likely categories for the Target Transaction.
Also, provide your confidence level (choose one: High, Medium, Low) for each category. If you can't confidently suggest likely categories, then
return a list with fewer items or no items at all.

Respond in JSON format like this:
[
    {{"category": "Category A", "confidence": "High/Medium/Low"}},
    {{"category": "Category B", "confidence": "High/Medium/Low"}},
    {{"category": "Category C", "confidence": "High/Medium/Low"}}
]

Do not respond with any other text but the JSON payload.
"""
    # print("--- PROMPT ---")
    # print(prompt) # For debugging
    # print("--------------")

    try:
        # Add safety settings if desired
        response = model.generate_content(prompt)

        # Basic parsing, might need more robust error handling/JSON cleaning
        # print("--- RESPONSE ---")
        # print(response.text) # For debugging
        # print("----------------")

        # Clean potential markdown/fencing
        cleaned_response = response.text.strip().lstrip('```json').rstrip('```').strip()

        result = json.loads(cleaned_response)

        return result

    except Exception as e:
        print(f"Error during Gemini generation or parsing: {e}")
        print(f"Failed prompt text (first 100 chars): {prompt[:100]}...")
        print(f"Raw response text: {response.text if 'response' in locals() else 'N/A'}")
        raise

In [74]:
def find_nearest_neighbors(target_embedding, all_data, k):
    """Finds k nearest neighbors based on cosine similarity."""
    if target_embedding is None or len(all_data) == 0:
        return []

    # Ensure embeddings are numpy arrays
    target_embedding_np = np.array(target_embedding).reshape(1, -1)

    categorized_df = all_data[all_data['is_categorized']]
    categorized_embeddings = categorized_df['embedding'].tolist()
    categorized_embeddings_np = np.array(categorized_embeddings)

    similarities = cosine_similarity(target_embedding_np, categorized_embeddings_np)[0]

    # Get indices of top k similarities
    # Use argpartition for efficiency if k is small compared to N
    # Using argsort for simplicity here
    nearest_indices = np.argsort(similarities)[::-1][:k]

    neighbors = []
    for idx in nearest_indices:
        neighbor_data = all_data.iloc[idx]
        neighbors.append({
            "text": neighbor_data['payee_name'],
            "amount": neighbor_data['amount'],
            "category": neighbor_data['category_name'],
        })
    return neighbors


In [96]:
NUM_NEIGHBORS = 25 # Number of similar examples to provide in the prompt

def determine_categories(row, all_data):
    target_embedding = row['embedding']
    if target_embedding is None:
        print("    Skipping due to missing embedding.")
        return False

    neighbors = find_nearest_neighbors(target_embedding, all_data, k=NUM_NEIGHBORS)

    if not neighbors:
            print("    Could not find similar categorized transactions.")
            return False

    target_info = {'text': row['payee_name'], 'amount': row['amount']}
    classification = classify_with_gemini(target_info, neighbors, valid_categories)
    valid_options = [item for item in classification if item['category'] in valid_categories]

    all_data.loc[row.name, 'top3'][:] = valid_options
    return True


In [76]:
def find_identical_uncategorized(target_text, current_df_index, full_df):
    """Finds other uncategorized rows with the exact same CombinedText."""
    identical = full_df[
        (full_df['payee_name'] == target_text) &
        (full_df['is_categorized'] == False) &
        (full_df.index != current_df_index) # Exclude the row being currently labeled
    ]
    return identical

In [77]:
import math

def find_partial_match_payee(target_text, current_df_index, df, *, match_frac=0.5):
    if not target_text:
        print("Target text cannot be empty.")
        return pd.DataFrame(columns=df.columns) # Return empty DataFrame

    # Calculate the minimum length for a 50% match (rounding up)
    target_len = len(target_text)
    match_len = math.ceil(target_len * match_frac)

    # Handle cases where match_len might be 0 (e.g., target_text length 1)
    # In this case, 50% rounded up is 1, meaning the whole string.
    if match_len == 0 and target_len > 0:
         match_len = 1
    elif match_len == 0: # target_text was empty, already handled, but defensive check
         return pd.DataFrame(columns=df.columns)

    # Extract the target prefix and suffix
    target_prefix = target_text[:match_len]
    target_suffix = target_text[-match_len:]

    # --- Condition 1: Prefix Match ---
    # Check if the 'Payee' starts with the target prefix
    # Only consider payees that are at least as long as the required match length
    prefix_match = (
        df['payee_name'].str.startswith(target_prefix)
        & (df['payee_name'].str.len() >= match_len))

    # --- Condition 2: Suffix Match ---
    # Check if the 'Payee' ends with the target suffix
    # Only consider payees that are at least as long as the required match length
    suffix_match = (
        df['payee_name'].str.endswith(target_suffix)
        & (df['payee_name'].str.len() >= match_len))

    # Combine the conditions: rows match if either prefix OR suffix matches
    matching_condition = prefix_match | suffix_match

    # Filter the DataFrame based on the combined condition
    matched_df = df[matching_condition &
                    (df.index != current_df_index) &
                    (df['is_categorized'] == False)]

    return matched_df

In [78]:
# Initialize columns needed for the UI process if they don't exist
df['top3'] = pd.Series([[]] * len(df), index=df.index).astype('object')

In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output
import pandas as pd
import functools

# Filter for uncategorized transactions
uncategorized_indices = df[~df['is_categorized'] & ~df['approved']].index.tolist()
total_to_process = len(uncategorized_indices)
current_index_pos = 0 # Tracks the position within the uncategorized_indices list

# --- UI Widgets ---

# Progress Bar
progress_label = widgets.Label("Progress")
progress_bar = widgets.IntProgress(
    value=0, min=0, max=max(1, total_to_process), # Avoid max=0 error
    bar_style='info', orientation='horizontal',
    layout={'width': '100%'},
)
progress_box = widgets.VBox([progress_label, progress_bar])

# Transaction Details Output Area
transaction_info = widgets.HTML(value="", layout={'border': '1px solid black', 'padding': '5px'})

# Loading indicator
loading_label = widgets.Label("🧠 Asking Gemini for suggestions...")
loading_indicator = widgets.VBox([loading_label], layout={'display': 'none'})

# --- Stage 1 Widgets ---
stage1_label = widgets.HTML("<b>Stage 1: Choose Category for Current Transaction</b>")
suggestion_buttons_box = widgets.VBox([]) # Suggestion buttons added dynamically
dropdown_options = [("Select...", None)] + sorted([(cat, cat) for cat in valid_categories])
category_dropdown = widgets.Dropdown(options=dropdown_options, description="Or Choose:")
category_dropdown.disabled = True
stage1_box = widgets.VBox([
    stage1_label,
    suggestion_buttons_box,
    category_dropdown
], layout={'margin': '10px 0 0 0', 'display': 'none'}) # Hidden initially

# --- Stage 2 Widgets ---
stage2_label = widgets.HTML("<b>Stage 2: Apply to Similar Transactions?</b>")
similar_transactions_header = widgets.HTML("")
similar_transactions_checkboxes = widgets.VBox([]) # Checkboxes added dynamically
apply_checked_button = widgets.Button(description="Apply to Checked Transactions", button_style="success", icon='check')
skip_to_next_button = widgets.Button(description="Apply to Current Only (Skip Others)", icon='forward')
stage2_buttons = widgets.HBox([apply_checked_button, skip_to_next_button])
stage2_box = widgets.VBox([
    stage2_label,
    similar_transactions_header,
    similar_transactions_checkboxes,
    stage2_buttons
], layout={'margin': '10px 0 0 0', 'display': 'none'}) # Hidden initially

# --- Common Widgets ---
cancel_button = widgets.Button(description="Cancel All", button_style="danger", icon='stop')
skip_button = widgets.Button(description="Skip", icon='forward')
actions_box = widgets.HBox([cancel_button, skip_button])
status_output = widgets.Output()

# --- Main UI Layout ---
main_ui_box = widgets.VBox([
    progress_box,
    transaction_info,
    loading_indicator,
    stage1_box, # Will be shown in stage 1
    stage2_box, # Will be shown in stage 2
    actions_box, # Contains Cancel button
    status_output,
])

# --- Core Logic Functions ---

@status_output.capture(clear_output=True)
def move_to_next_transaction():
    """Advances the UI to the next uncategorized transaction."""
    global current_index_pos
    update_ui_for_transaction(current_index_pos + 1)


@status_output.capture(clear_output=True)
def apply_category_and_advance(category, base_index, indices_to_update):
    """Applies the category to the specified indices and moves to the next transaction."""
    global df
    indices_to_update = list(set(indices_to_update)) # Ensure unique indices

    # Apply category
    for index in indices_to_update:
        if index in df.index and not df.loc[index, 'is_categorized']:
            df.loc[index, 'chosen_category'] = category
            df.loc[index, 'category_name'] = category # Also update the main category
            df.loc[index, 'is_categorized'] = True

    num_updated = len(indices_to_update)
    if base_index in indices_to_update:
        print(f"Categorized current transaction (Index {base_index}) as '{category}'.")
        if num_updated > 1:
             print(f"Also applied to {num_updated - 1} similar transaction{'s' if num_updated > 2 else ''}.")
    else:
         # This case shouldn't happen with current logic, but as safety:
         print(f"Categorized {num_updated} similar transaction{'s' if num_updated > 1 else ''} as '{category}'.")

    # Move to the next item
    move_to_next_transaction()


@status_output.capture(clear_output=True)
def transition_to_stage2(selected_category, base_index):
    """Handles the logic after a category is chosen in Stage 1."""
    global df
    transaction = df.loc[base_index]

    # Disable Stage 1 controls
    category_dropdown.disabled = True

    for btn in suggestion_buttons_box.children:
        if isinstance(btn, widgets.Button):
            btn.disabled = True

    # Find similar transactions
    identical_df = find_identical_uncategorized(transaction['payee_name'], base_index, df)
    similar_df = find_partial_match_payee(transaction['payee_name'], base_index, df)

    if not (identical_df.empty and similar_df.empty):
        # --- Setup Stage 2 ---
        stage1_box.layout.display = 'none' # Hide Stage 1

        # Populate checkboxes
        checkbox_widgets = []
        checkbox_indices = [] # Store corresponding indices
        for idx, identical_row in identical_df.iterrows():
            cb = widgets.Checkbox(
                value=True, # Default to selected
                description=f"Row {idx}: {identical_row['date']} / {identical_row['account_name']} / {identical_row['payee_name']} / ${identical_row['amount']:.2f} / {identical_row['memo']}",
                indent=False,
                layout={'width': '95%'} # Allow wrapping
            )
            checkbox_widgets.append(cb)
            checkbox_indices.append(idx) # Store the index

        for idx, identical_row in similar_df.iterrows():
            if idx in checkbox_indices:
                continue

            cb = widgets.Checkbox(
                value=False, # Default to not selected because partial match
                description=f"Row {idx}: {identical_row['date']} / {identical_row['account_name']} / {identical_row['payee_name']} / ${identical_row['amount']:.2f} / {identical_row['memo']}",
                indent=False,
                layout={'width': '95%'} # Allow wrapping
            )
            checkbox_widgets.append(cb)
            checkbox_indices.append(idx) # Store the index

        similar_transactions_header.value = f"Found {len(checkbox_indices)} similar transaction(s) that should be '{selected_category}':"
        similar_transactions_checkboxes.children = checkbox_widgets

        # Store data needed by Stage 2 handlers (using attributes on the container)
        stage2_box.selected_category = selected_category
        stage2_box.base_index = base_index
        stage2_box.checkbox_widgets = checkbox_widgets # Reference to checkboxes
        stage2_box.checkbox_indices = checkbox_indices # Reference to indices

        # Setup button actions for Stage 2 using functools.partial (cleaner than globals)

        # Define handlers inline or use partial
        def handle_apply_checked(b):
            # Get stored data
            cat = stage2_box.selected_category
            base_idx = stage2_box.base_index
            chk_widgets = stage2_box.checkbox_widgets
            chk_indices = stage2_box.checkbox_indices

            # Find checked indices
            checked_indices = [chk_indices[i] for i, cb in enumerate(chk_widgets) if cb.value]
            indices_to_update = [base_idx] + checked_indices
            apply_category_and_advance(cat, base_idx, indices_to_update)

        def handle_skip_others(b):
             # Get stored data
            cat = stage2_box.selected_category
            base_idx = stage2_box.base_index
            # Only apply to the base index
            apply_category_and_advance(cat, base_idx, [base_idx])

        # If you don't do this then we'll get repeated fires for these widgets
        apply_checked_button._click_handlers.callbacks = []
        skip_to_next_button._click_handlers.callbacks = []

        apply_checked_button.on_click(handle_apply_checked)
        skip_to_next_button.on_click(handle_skip_others)

        # Display Stage 2
        stage2_box.layout.display = 'flex'
    else:
        # No similar transactions found - apply to current and move on immediately
        print(f"No similar uncategorized transactions found for Index {base_index}.")
        apply_category_and_advance(selected_category, base_index, [base_index])


@status_output.capture(clear_output=True)
def update_ui_for_transaction(index_pos):
    """Updates the UI to show Stage 1 for the transaction at the given position."""
    global current_index_pos, total_to_process, df

    current_index_pos = index_pos # Update the global pointer

    # Move past other transactions that might have been categorized as a result
    # of bulk application of the other ones.
    while current_index_pos < total_to_process:
        df_index = uncategorized_indices[current_index_pos]
        if not df.loc[df_index, 'is_categorized']:
            break
        else:
            current_index_pos += 1

    # --- Check for Completion ---
    if current_index_pos >= total_to_process:
        main_ui_box.children = [widgets.HTML(f"<h2>Finished! Processed {total_to_process} transactions.</h2>")]
        return # Stop processing

    # --- Get Current Transaction ---
    df_index = uncategorized_indices[current_index_pos]
    transaction = df.loc[df_index]

    # --- Update Progress Bar ---
    work_remaining = total_to_process - len(df[(df['is_categorized'] == False)].index)
    progress_bar.value = work_remaining
    progress_label.value = f'Processing: {1.0 * work_remaining / total_to_process * 100.0:.2f}% complete ({work_remaining}/{total_to_process})'

    # --- Display Transaction Details ---
    transaction_info.value = f"""\
<b>Index:</b> {df_index}<br>
<b>Date:</b> {transaction['date']}<br>
<b>Date:</b> {transaction['account_name']}<br>
<b>Payee:</b> {transaction['payee_name']}<br>
<b>Memo:</b> {transaction['memo']}<br>
<b>Amount:</b> ${transaction['amount']:.2f}
"""

    # --- Reset UI Elements for Stage 1 ---
    status_output.clear_output()
    stage1_box.layout.display = 'flex' # Show Stage 1
    stage2_box.layout.display = 'none' # Hide Stage 2
    suggestion_buttons_box.children = []

    category_dropdown.value = None
    category_dropdown.disabled = False

    # Show a temporary loading message if needed
    loading_indicator.layout.display = 'flex' # Optional
    determine_categories(transaction, df)
    loading_indicator.layout.display = 'none' # Optional

    suggestions = df.loc[df_index, 'top3']

    # --- Populate Suggestion Buttons ---
    new_buttons = []
    for suggestion in suggestions:
        cat_name = suggestion.get('category', None)
        if cat_name is None:
            print("Skipping None category")
            continue

        score = suggestion.get('confidence', None)
        label = f"{cat_name}" + (f" ({score})" if score is not None else "")
        button = widgets.Button(description=label, button_style='info', layout={'width': 'auto'}, icon='tag')

        # Use partial to pass category name and index directly
        button.on_click(functools.partial(handle_category_selection, category_name=cat_name, index=df_index))
        new_buttons.append(button)

    if not new_buttons:
        no_suggestion_label = widgets.Label("No automatic suggestions.")
        new_buttons.append(no_suggestion_label)

    suggestion_buttons_box.children = new_buttons

# --- Event Handlers ---


@status_output.capture(clear_output=True)
def handle_category_selection(button_or_change, category_name=None, index=None):
    """Handles category selection from EITHER suggestion button OR dropdown."""
    selected_category = None
    base_index = None

    if isinstance(button_or_change, widgets.Button): # Clicked a suggestion button
        if button_or_change.disabled:
            print("Event fired again even though the button is disabled!")
            return

        selected_category = category_name
        base_index = index
        # print(f"Button click: Category='{selected_category}', Index={base_index}") # Debug print
    elif isinstance(button_or_change, dict): # Changed dropdown value
        if button_or_change['type'] == 'change' and button_or_change['name'] == 'value':
            selected_category = button_or_change['new']
            base_index = uncategorized_indices[current_index_pos] # Get current index
            # print(f"Dropdown change: Category='{selected_category}', Index={base_index}") # Debug print
    else:
        raise RuntimeError("Expected event source", button_or_change)

    # Disable controls immediately after selection information has been extracted
    category_dropdown.disabled = True
    for btn in suggestion_buttons_box.children:
        if isinstance(btn, widgets.Button):
            btn.disabled = True

    if (selected_category is not None
        and selected_category in valid_categories
        and base_index is not None):
        # Proceed to find similar transactions and decide stage 2 or next
        transition_to_stage2(selected_category, base_index)
    else:
        pass
        # print("Selection invalid or 'Select...' chosen.") # Debug print


@status_output.capture(clear_output=True)
def on_cancel_click(button):
    """Handles clicks on the Cancel All button."""
    main_ui_box.children = [widgets.HTML("<h2>Operation cancelled by user.</h2>")]


@status_output.capture(clear_output=True)
def on_skip_click(button):
    global current_index_pos
    update_ui_for_transaction(current_index_pos + 1)

# --- Connect Event Handlers ---
# We connect the *same* handler to both dropdown and button clicks (inside update_ui)
category_dropdown.observe(handle_category_selection, names='value')
cancel_button.on_click(on_cancel_click)
skip_button.on_click(on_skip_click)

# --- Initial Display ---
if total_to_process > 0:
    display(main_ui_box)
    update_ui_for_transaction(0) # Start with the first transaction
else:
    display(widgets.HTML("<h2>No uncategorized transactions found to process.</h2>"))

VBox(children=(VBox(children=(Label(value='Progress'), IntProgress(value=0, bar_style='info', layout=Layout(wi…

In [98]:
transactions_to_update_df = df[
    (df['chosen_category'].notna()) &
    (df['chosen_category'] != '') &
    (df['approved'] == False) &
    (df['id'].notna())
].copy() # Use .copy() to avoid SettingWithCopyWarning

print(f"Found {len(transactions_to_update_df)} transactions with chosen categories to update in YNAB.")

transactions_payload = []

for index, row in transactions_to_update_df.iterrows():
    category_name = row['chosen_category']
    transaction_id = row['id']

    assert category_name in category_map, category_name
    category_id = category_map[category_name]
    transactions_payload.append({
        "id": transaction_id,
        "category_id": category_id,
        "approved": True,
    })

update_url = f"{YNAB_BASE_URL}/budgets/{YNAB_BUDGET_ID}/transactions"
headers = {
    'Authorization': f'Bearer {YNAB_API_KEY}',
    'Content-Type': 'application/json'
}
payload = json.dumps({"transactions": transactions_payload})

print("Sending update request to YNAB API...")
response = requests.patch(update_url, headers=headers, data=payload)
response.raise_for_status() # Raise HTTPError for bad responses

print("YNAB API Update Response:")
print(f"Status Code: {response.status_code}")

# Check response details (optional but recommended)
response_data = response.json().get('data', {})
updated_ids = response_data.get('transaction_ids', [])

print(f"Successfully updated {len(updated_ids)} transactions in YNAB.")


Found 485 transactions with chosen categories to update in YNAB.
Sending update request to YNAB API...
YNAB API Update Response:
Status Code: 200
Successfully updated 487 transactions in YNAB.
