In [2]:
#!/usr/bin/env python
# coding: utf-8

# --- Imports ---
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, Subset
from PIL import Image
import numpy as np
import os
from tqdm.auto import tqdm
import time
import datetime
import csv
import random
import pandas as pd
import ipywidgets as widgets
from IPython.display import display, clear_output
import threading
import io
# --- REMOVED: ipyevents import ---

# --- Assuming these are accessible ---
import config
from importnb import Notebook
with Notebook():
    from jetbot_dataset import JetbotDataset, load_train_test_split

# --- Configuration ---
AGGREGATE_CSV_PATH = config.CSV_PATH
AGGREGATE_DATA_DIR = config.DATA_DIR
DISPLAY_IMAGE_SIZE = config.IMAGE_SIZE
NUM_PREV_FRAMES = config.NUM_PREV_FRAMES
OUTPUT_REWARD_CSV = os.path.join(config.OUTPUT_DIR, "slider_reward_labels.csv") # Changed output filename

# --- Decide which dataset to use ---
USE_SUBSET = "train" # Options: None, "train", "test"
DATASET_SPLIT_FILENAME = config.SPLIT_DATASET_FILENAME

# --- Load Full Dataset (Needed for accessing original dataframe) ---
full_dataset_for_metadata = JetbotDataset(
    csv_path=AGGREGATE_CSV_PATH, data_dir=AGGREGATE_DATA_DIR,
    image_size=DISPLAY_IMAGE_SIZE, num_prev_frames=NUM_PREV_FRAMES,
    transform=None
)

# --- Load the Dataset for Labeling (Full or Subset) ---
if USE_SUBSET:
    train_dataset_subset, test_dataset_subset = load_train_test_split(full_dataset_for_metadata, DATASET_SPLIT_FILENAME)
    if train_dataset_subset is None or test_dataset_subset is None:
        raise FileNotFoundError(f"Dataset split file '{DATASET_SPLIT_FILENAME}' not found or invalid.")

    if USE_SUBSET == "train": labeling_dataset = train_dataset_subset
    elif USE_SUBSET == "test": labeling_dataset = test_dataset_subset
    else: raise ValueError("Invalid USE_SUBSET value.")
    print(f"Using {USE_SUBSET} Subset ({len(labeling_dataset)} sequences)")
else:
    labeling_dataset = full_dataset_for_metadata
    print(f"Using Full Dataset ({len(labeling_dataset)} sequences)")

# --- Data Structures for Labeling ---
reward_labels = {} # {original_dataframe_index: reward}
current_labeling_index = 0 # Index within the len(labeling_dataset)
is_auto_advancing = False
auto_advance_timer = None
# --- REMOVED: current_reward_value variable ---

# --- UI Widgets ---
index_slider = widgets.IntSlider(
    value=current_labeling_index, min=0, max=len(labeling_dataset) - 1 if len(labeling_dataset) > 0 else 0, step=1,
    description='Sequence Index:', continuous_update=False, layout=widgets.Layout(width='80%')
)
# --- RE-ADDED: reward_slider ---
reward_slider = widgets.FloatSlider(
    value=0.0,
    min=0.0, # Range: 0 to 1
    max=1.0,
    step=0.01, # Step: 0.01
    description='Reward:',
    continuous_update=True, # IMPORTANT: For continuous feedback while dragging
    orientation='horizontal',
    readout=True,
    readout_format='.2f', # Format
    layout=widgets.Layout(width='80%')
)

save_button = widgets.Button(description="Save All Rewards", button_style='success')

# Auto-Advance Widgets
speed_slider = widgets.FloatSlider(
    value=0.1, min=0.01, max=0.2, step=0.01, description='Delay (s):',
    continuous_update=False, orientation='horizontal', readout=True, readout_format='.2f',
    layout=widgets.Layout(width='50%')
)
start_stop_button = widgets.Button(description="Start Auto", button_style='info', icon='play')

# Image Widget
image_widget = widgets.Image(format='jpeg', width=DISPLAY_IMAGE_SIZE + 50, height=DISPLAY_IMAGE_SIZE + 50)

# Output areas
info_output = widgets.Output()
status_output = widgets.Output()
# --- REMOVED: keyboard_info widget ---


# --- Helper Function (remains the same) ---
def get_original_dataframe_index(current_dataset, current_index_in_dataset):
    temp_dataset = current_dataset
    actual_index = current_index_in_dataset
    while isinstance(temp_dataset, Subset):
        if actual_index >= len(temp_dataset.indices): return None
        actual_index = temp_dataset.indices[actual_index]
        temp_dataset = temp_dataset.dataset
    if not hasattr(temp_dataset, 'valid_indices') or not hasattr(temp_dataset, 'dataframe'): return None
    if actual_index >= len(temp_dataset.valid_indices): return None
    return temp_dataset.valid_indices[actual_index]

# --- Callback Functions ---
def get_data_for_labeling_index(dataset_idx):
    """ Safely gets data using the labeling_dataset's __getitem__ """
    if dataset_idx >= len(labeling_dataset):
        print(f"Error: Index {dataset_idx} out of bounds.")
        return None, None
    try:
        current_img_pil, action_tensor, _ = labeling_dataset[dataset_idx]
        if isinstance(current_img_pil, torch.Tensor):
             current_img_pil = transforms.ToPILImage()(current_img_pil.cpu())
        return current_img_pil.convert("RGB"), action_tensor.item()
    except Exception as e:
        print(f"Error getting data for labeling index {dataset_idx}: {e}")
        return None, None

def pil_to_widget_bytes(pil_image):
    """ Converts PIL Image to bytes suitable for ipywidgets.Image """
    if pil_image is None: return None
    with io.BytesIO() as output_bytes:
        pil_image.save(output_bytes, format="JPEG")
        return output_bytes.getvalue()

def save_current_reward():
    """Saves the reward from the slider for the currently displayed index."""
    # --- MODIFIED: Reads from slider, not variable ---
    current_idx_in_labeling_dataset = index_slider.value
    original_df_idx = get_original_dataframe_index(labeling_dataset, current_idx_in_labeling_dataset)

    if original_df_idx is not None:
        reward_value = reward_slider.value # Read from slider
        # Ensure reward is clamped between 0 and 1 before saving
        reward_to_save = np.clip(reward_value, 0.0, 1.0)
        reward_labels[original_df_idx] = reward_to_save
        # print(f"Debug: Stored reward {reward_to_save:.2f} for Original DF Index: {original_df_idx}") # Optional Debug
    # else:
         # print(f"Warning: Could not save reward for labeling index {current_idx_in_labeling_dataset}.") # Optional Debug

def update_display(labeling_idx):
    """Loads and displays frame, updates reward slider from saved state if different."""
    global reward_slider # Needed to modify slider value

    if labeling_idx < 0 or labeling_idx >= len(labeling_dataset):
        with info_output: clear_output(wait=True); print(f"Invalid index: {labeling_idx}")
        image_widget.value = b''; return

    current_img_pil, action = get_data_for_labeling_index(labeling_idx)
    original_df_idx = get_original_dataframe_index(labeling_dataset, labeling_idx)

    with status_output: clear_output(wait=True)

    if current_img_pil is None or original_df_idx is None:
        with info_output: clear_output(wait=True); print(f"Could not load data/index for labeling index: {labeling_idx}.")
        image_widget.value = b''; return

    # --- Update Reward Slider based ONLY on existing saved value ---
    reward_info_text = "No reward assigned yet."
    slider_value_to_set = None

    if original_df_idx in reward_labels:
        saved_val = reward_labels[original_df_idx]
        reward_info_text = f"Assigned reward: {saved_val:.2f}"
        # Only plan to set slider if saved value differs significantly
        if abs(reward_slider.value - saved_val) > (reward_slider.step / 2.0):
            slider_value_to_set = saved_val

    # Update slider only if necessary, avoiding observer feedback loops
    if slider_value_to_set is not None:
        try:
            reward_slider.unobserve(on_reward_change, names='value')
            reward_slider.value = slider_value_to_set
        finally:
            reward_slider.observe(on_reward_change, names='value')
    # If no reward saved or saved value matches current, DO NOTHING to slider

    # --- End Reward Slider Update ---

    with info_output:
        clear_output(wait=True)
        print(f"Labeling Index: {labeling_idx}/{len(labeling_dataset)-1} (Original DF Index: {original_df_idx})")
        print(f"Action leading to this frame: {action:.4f}")
        print(reward_info_text)

    # --- Display Image ---
    try:
        if hasattr(current_img_pil, 'resize'):
             display_image = current_img_pil.resize((DISPLAY_IMAGE_SIZE, DISPLAY_IMAGE_SIZE), Image.Resampling.LANCZOS)
             image_widget.value = pil_to_widget_bytes(display_image)
        else: image_widget.value = b''; print("Error: Invalid image object.")
    except Exception as e:
        with status_output: clear_output(wait=True); print(f"Error display: {e}")
        image_widget.value = b''

def stop_auto_advance(change=None):
    """Stops the auto-advance timer and resets UI."""
    global is_auto_advancing, auto_advance_timer
    if auto_advance_timer is not None:
        auto_advance_timer.cancel(); auto_advance_timer = None
    was_advancing = is_auto_advancing
    is_auto_advancing = False
    start_stop_button.description = "Start Auto"; start_stop_button.button_style = 'info'; start_stop_button.icon = 'play'
    index_slider.disabled = False
    if was_advancing:
        with status_output: clear_output(wait=True); print("Auto-advance stopped.")

def auto_advance_step():
    """Performs one step of auto-advance, SAVING reward from slider first."""
    global is_auto_advancing, auto_advance_timer

    if not is_auto_advancing: return

    try:
        # --- SAVE reward for the frame just displayed ---
        save_current_reward() # Reads reward_slider.value
        # --- END SAVE ---

        current_idx = index_slider.value
        next_idx = current_idx + 1

        if next_idx < len(labeling_dataset):
            index_slider.value = next_idx # Triggers on_index_change -> update_display
            if is_auto_advancing: # Check again in case stop was called
                delay = speed_slider.value
                auto_advance_timer = threading.Timer(delay, auto_advance_step)
                auto_advance_timer.start()
        else:
            with status_output: clear_output(wait=True); print("End of dataset.")
            stop_auto_advance()
    except Exception as e:
        print(f"DEBUG: ERROR in auto_advance_step: {e}")
        stop_auto_advance()


def toggle_auto_advance(b):
    """Starts or stops the auto-advance feature."""
    global is_auto_advancing, auto_advance_timer
    if is_auto_advancing: stop_auto_advance()
    else:
        is_auto_advancing = True
        start_stop_button.description = "Stop Auto"; start_stop_button.button_style = 'warning'; start_stop_button.icon = 'stop'
        index_slider.disabled = True
        with status_output: clear_output(wait=True); print(f"Auto-advance starting...")
        try:
             delay = speed_slider.value
             auto_advance_timer = threading.Timer(delay, auto_advance_step)
             auto_advance_timer.start()
        except Exception as e: print(f"Error starting timer: {e}"); stop_auto_advance()


def on_index_change(change):
    """Called when the index slider value changes."""
    if change['type'] == 'change' and change['name'] == 'value':
        # Stop auto advance if user manually changes slider
        # if is_auto_advancing: stop_auto_advance() # Keep this to allow manual override
        update_display(change['new'])

# --- RE-ADDED: on_reward_change callback ---
def on_reward_change(change):
     """Called when the reward slider value changes."""
     if change['type'] == 'change' and change['name'] == 'value':
         # Save immediately ONLY if not auto-advancing
         # During auto-advance, saving is handled by auto_advance_step
         if not is_auto_advancing:
              save_current_reward()

def on_save_button_clicked(b):
    """Saves the collected reward labels to a CSV file."""
    stop_auto_advance()
    # ...(implementation remains the same)...
    with status_output:
        clear_output(wait=True)
        if not reward_labels: print("No rewards assigned yet."); return
        print(f"Preparing to save {len(reward_labels)} labels...")
        items_to_save = []
        original_dataframe = full_dataset_for_metadata.dataframe
        for df_idx in sorted(reward_labels.keys()):
             reward = reward_labels[df_idx]
             try:
                  if df_idx >= len(original_dataframe): continue
                  original_row = original_dataframe.iloc[df_idx]
                  items_to_save.append({
                       'dataframe_index': df_idx, 'session_id': original_row['session_id'],
                       'image_path': original_row['image_path'], 'action': original_row['action'],
                       'assigned_reward': reward })
             except Exception as e: print(f"Warn: Error processing DF index {df_idx}: {e}.")
        if not items_to_save: print("No valid entries to save."); return
        save_df = pd.DataFrame(items_to_save)
        try:
            save_df.to_csv(OUTPUT_REWARD_CSV, index=False, float_format='%.8f')
            print(f"Saved {len(items_to_save)} labels to {OUTPUT_REWARD_CSV}")
        except Exception as e: print(f"Error saving CSV: {e}")

# --- REMOVED: handle_keydown ---

# --- Link Widgets ---
index_slider.observe(on_index_change, names='value')
# --- RE-ADDED: Observe reward_slider ---
reward_slider.observe(on_reward_change, names='value')
save_button.on_click(on_save_button_clicked)
start_stop_button.on_click(toggle_auto_advance)

# --- Arrange Layout ---
auto_advance_controls = widgets.HBox([speed_slider, start_stop_button])
# --- UPDATED: Layout uses reward_slider ---
ui = widgets.VBox([
    info_output,
    image_widget,
    index_slider,
    reward_slider, # Use slider instead of reward_display/keyboard_info
    auto_advance_controls,
    save_button,
    status_output
])

# --- REMOVED: Keyboard handler setup ---

# --- Initial Display ---
if len(labeling_dataset) > 0:
    update_display(current_labeling_index)
    display(ui)
else:
     print("Cannot display UI because the labeling dataset is empty.")

Loaded combined CSV with columns: ['session_id', 'image_path', 'timestamp', 'action']
Total rows in CSV: 23081, Valid sequence start indices: 23037
Using train Subset (16900 sequences)


VBox(children=(Output(), Image(value=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xf…