In [1]:
# --- Imports ---
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, Subset
from PIL import Image
import numpy as np
import os
from tqdm.auto import tqdm
import time
import datetime
import csv
import random
import pandas as pd
import ipywidgets as widgets
from IPython.display import display, clear_output
import threading
import io

# --- Assuming these are accessible ---
# Make sure config.py is in the path or values are defined
import config

from importnb import Notebook

with Notebook():
    # Make sure JetbotDataset class definition is available
    from jetbot_dataset import JetbotDataset

In [2]:
# --- Configuration ---
# Use paths from config for the combined dataset
AGGREGATE_CSV_PATH = config.CSV_PATH
AGGREGATE_DATA_DIR = config.DATA_DIR
DISPLAY_IMAGE_SIZE = config.IMAGE_SIZE # For display consistency
NUM_PREV_FRAMES = config.NUM_PREV_FRAMES
OUTPUT_REWARD_CSV = os.path.join(config.OUTPUT_DIR, "interactive_reward_labels.csv")

# --- Load Dataset (only need metadata and access to __getitem__) ---
# Create a transform just for loading/displaying PIL images easily
transform_for_display = transforms.Compose([
    transforms.Resize((DISPLAY_IMAGE_SIZE, DISPLAY_IMAGE_SIZE)), # Resize for display consistency
])

# Load the full dataset to access all valid indices and items
# We won't apply the normalization transform here, just load PIL images
full_dataset = JetbotDataset(
    csv_path=AGGREGATE_CSV_PATH,
    data_dir=AGGREGATE_DATA_DIR,
    image_size=DISPLAY_IMAGE_SIZE,
    num_prev_frames=NUM_PREV_FRAMES,
    transform=None # Load PIL images directly for display
)



Loaded combined CSV with columns: ['session_id', 'image_path', 'timestamp', 'action']
Total rows in CSV: 23081, Valid sequence start indices: 23037


In [3]:
# --- Data Structures for Labeling ---
# Use a dictionary to store rewards {actual_dataframe_index: reward}
# This handles non-contiguous indices if dataset is filtered later
reward_labels = {}
current_dataset_index = 0 # Index within the len(full_dataset)
is_auto_advancing = False # State for auto-advance
auto_advance_timer = None # To hold the Timer object

In [4]:
index_slider = widgets.IntSlider(
    value=current_dataset_index, min=0, max=len(full_dataset) - 1, step=1,
    description='Sequence Index:', continuous_update=False, layout=widgets.Layout(width='80%')
)
reward_slider = widgets.FloatSlider(
    value=0.0, min=-1.0, max=1.0, step=0.1, description='Reward:',
    continuous_update=True, # Also try False to update only on release for less lag
    orientation='horizontal', readout=True, readout_format='.1f', layout=widgets.Layout(width='80%')
)
prev_button = widgets.Button(description="<< Previous")
next_button = widgets.Button(description="Save & Next >>") # Changed description slightly
save_button = widgets.Button(description="Save All Rewards", button_style='success')

# Auto-Advance Widgets
auto_advance_checkbox = widgets.Checkbox(value=False, description='Enable Auto-Advance', indent=False)
speed_slider = widgets.FloatSlider(
    value=0.5, min=0.05, max=3.0, step=0.05, description='Delay (s):', # Faster range
    continuous_update=False, orientation='horizontal', readout=True, readout_format='.2f',
    layout=widgets.Layout(width='50%')
)
start_stop_button = widgets.Button(description="Start Auto", button_style='info', icon='play')

# --- NEW: Single Image Widget for Display ---
image_widget = widgets.Image(
    format='jpeg', # or 'png'
    width=DISPLAY_IMAGE_SIZE + 50, # Adjust width/height as needed
    height=DISPLAY_IMAGE_SIZE + 50,
    # layout=widgets.Layout(border='1px solid black') # Optional border
)

# Output areas
info_output = widgets.Output()
status_output = widgets.Output()

In [5]:
# --- Callback Functions ---
def get_data_for_index(dataset_idx):
    """ Safely gets data using the dataset's __getitem__ """
    try:
        # __getitem__ returns current_image_PIL, action_tensor, prev_frames_PIL_list
        current_img_pil, action_tensor, _ = full_dataset[dataset_idx] # Don't need prev_frames for display now
        return current_img_pil, action_tensor.item()
    except IndexError:
        print(f"Error: Index {dataset_idx} out of bounds for dataset.")
        return None, None
    except Exception as e:
        print(f"Error getting data for index {dataset_idx}: {e}")
        return None, None

def pil_to_widget_bytes(pil_image):
    """ Converts PIL Image to bytes suitable for ipywidgets.Image """
    if pil_image is None:
        return None
    with io.BytesIO() as output_bytes:
        pil_image.save(output_bytes, format="JPEG") # Or PNG
        return output_bytes.getvalue()

def update_display(dataset_idx):
    """Loads and displays the *current* frame for the given index."""
    global reward_slider
    # --- Assume get_data_for_index now consistently returns the *tensor* [0, 1] and scalar action ---
    # Modify get_data_for_index if it doesn't return the tensor consistently.
    current_tensor_01, action = get_data_for_index_returning_tensor(dataset_idx) # Renamed for clarity

    # Reset status
    with status_output:
        clear_output()

    if current_tensor_01 is None:
        with info_output:
            clear_output(wait=True)
            print("Could not load data for this index.")
        image_widget.value = b'' # Clear image widget
        return

    actual_df_idx = full_dataset.valid_indices[dataset_idx]

    # Update info text
    with info_output:
        clear_output(wait=True)
        print(f"Sequence Index: {dataset_idx}/{len(full_dataset)-1} (DataFrame Index: {actual_df_idx})")
        print(f"Action leading to this frame: {action:.4f}")
        if actual_df_idx in reward_labels:
            # Set slider value WITHOUT triggering its observe callback temporarily
            reward_slider.unobserve(on_reward_change, names='value')
            reward_slider.value = reward_labels[actual_df_idx]
            reward_slider.observe(on_reward_change, names='value')
            print(f"Assigned reward: {reward_labels[actual_df_idx]:.1f}")
        else:
             print("No reward assigned yet.")

    # --- Resize the TENSOR using torchvision FIRST ---
    try:
        resize_transform = transforms.Resize(
            (DISPLAY_IMAGE_SIZE, DISPLAY_IMAGE_SIZE),
            interpolation=transforms.InterpolationMode.BILINEAR, # Or NEAREST, etc.
            antialias=True # Generally recommended for downsampling
        )
        # Ensure tensor is on CPU if Resize transform is implicitly using PIL backend sometimes
        resized_tensor_01 = resize_transform(current_tensor_01.cpu())

        # --- Convert the RESIZED tensor [0, 1] to PIL Image ---
        resized_img_pil = transforms.ToPILImage()(resized_tensor_01).convert("RGB")

        # --- Update the Image Widget ---
        image_widget.value = pil_to_widget_bytes(resized_img_pil)

    except Exception as e:
        with status_output:
             clear_output(wait=True)
             print(f"Error during image processing/display: {e}")
        image_widget.value = b'' # Clear image on error

# --- Helper function get_data_for_index_returning_tensor (Example) ---
# You need to ensure this function *reliably* returns the tensor in [0, 1] range
def get_data_for_index_returning_tensor(dataset_idx):
    """ Gets data and ensures the image is returned as a Tensor [0, 1]. """
    try:
        # Assuming full_dataset always applies ToTensor internally now
        current_tensor_maybe_norm, action_tensor, _ = full_dataset[dataset_idx]

        action_value = action_tensor.item()

        # --- Check if tensor is normalized [-1, 1] and unnormalize to [0, 1] ---
        # This depends on whether the 'transform=None' instance of your dataset
        # still applies normalization internally or not. If it only applies ToTensor,
        # the data is already [0, 1]. If it applies Normalize, unnormalize here.
        # Assuming it might be normalized [-1, 1] if full_dataset had a transform:
        if current_tensor_maybe_norm.min() < -0.1: # Heuristic check for normalization
             current_tensor_01 = (current_tensor_maybe_norm.clamp(-1, 1) + 1) / 2
        else:
             current_tensor_01 = current_tensor_maybe_norm # Assume it's already [0, 1]

        return current_tensor_01, action_value

    except Exception as e:
        print(f"Error in get_data_for_index_returning_tensor for index {dataset_idx}: {e}")
        return None, None


def save_current_reward(change=None): # Can be triggered by reward slider change too
    """Saves the reward for the currently displayed index."""
    current_idx = index_slider.value
    actual_df_idx = full_dataset.valid_indices[current_idx]
    # Get value from slider directly
    reward_value = reward_slider.value
    reward_labels[actual_df_idx] = reward_value
    # Don't clear status output here, let next/prev handle it
    # print(f"Debug: Stored reward {reward_value:.1f} for Df Index: {actual_df_idx}")

def stop_auto_advance(change=None):
    """Stops the auto-advance timer and resets UI."""
    global is_auto_advancing, auto_advance_timer
    # Print for debugging (optional)
    # print(f"--- stop_auto_advance called, change: {change} ---")
    if auto_advance_timer is not None:
        auto_advance_timer.cancel() # Stop the scheduled timer
        auto_advance_timer = None
    was_advancing = is_auto_advancing # Check if it *was* advancing
    is_auto_advancing = False
    start_stop_button.description = "Start Auto"
    start_stop_button.button_style = 'info'
    start_stop_button.icon = 'play'
    # Re-enable manual navigation
    prev_button.disabled = False
    next_button.disabled = False
    index_slider.disabled = False
    # --- REMOVED THIS LINE ---
    # auto_advance_checkbox.value=False # <<< REMOVE THIS LINE >>>
    # -------------------------
    if was_advancing: # Only print stop message if it was actually running
        with status_output:
            clear_output(wait=True)
            print("Auto-advance stopped.")
def auto_advance_step():
    """Performs one step of auto-advance."""
    global is_auto_advancing, auto_advance_timer # Removed current_dataset_index

    if not is_auto_advancing:
        return # Stop if flag is turned off

    # --- Save reward for the frame that WAS just displayed ---
    # (This uses the index_slider's value *before* we increment it)
    save_current_reward() # Save before advancing

    # --- Move to next index ---
    current_idx = index_slider.value
    next_idx = current_idx + 1

    if next_idx < len(full_dataset):
        # Update the slider - this will trigger on_index_change -> update_display
        index_slider.value = next_idx

        # --- Schedule the next step AFTER processing the current one ---
        delay = speed_slider.value
        auto_advance_timer = threading.Timer(delay, auto_advance_step)
        auto_advance_timer.start()
    else:
        # Reached end
        with status_output:
            clear_output(wait=True)
            print("Reached end of dataset. Auto-advance stopped.")
        stop_auto_advance() # Stop automatically at the end

def toggle_auto_advance(b):
    """Starts or stops the auto-advance feature."""
    global is_auto_advancing, auto_advance_timer
    if is_auto_advancing:
        stop_auto_advance()
    else:
        if not auto_advance_checkbox.value:
             with status_output:
                 clear_output(wait=True)
                 print("Please enable 'Enable Auto-Advance' checkbox first.")
             return

        is_auto_advancing = True
        start_stop_button.description = "Stop Auto"
        start_stop_button.button_style = 'warning'
        start_stop_button.icon = 'stop'
        prev_button.disabled = True
        next_button.disabled = True
        index_slider.disabled = True
        with status_output:
            clear_output(wait=True)
            print(f"Auto-advance started with {speed_slider.value:.2f}s delay...")

        # --- IMPORTANT: Save reward for the *very first* frame before starting timer ---
        save_current_reward()
        # --- Start the timer loop ---
        delay = speed_slider.value
        auto_advance_timer = threading.Timer(delay, auto_advance_step)
        auto_advance_timer.start()


def on_prev_button_clicked(b):
    stop_auto_advance()
    new_index = index_slider.value - 1
    if new_index >= 0:
        index_slider.value = new_index

def on_next_button_clicked(b):
    stop_auto_advance()
    save_current_reward() # Save reward for CURRENT index first
    new_index = index_slider.value + 1
    if new_index < len(full_dataset):
        index_slider.value = new_index
    else:
        with status_output:
             clear_output(wait=True)
             print("Already at the last sequence.")


def on_index_change(change):
    """Called when the index slider value changes."""
    if change['type'] == 'change' and change['name'] == 'value':
        update_display(change['new'])

def on_reward_change(change):
     """Called when the reward slider value changes."""
     if change['type'] == 'change' and change['name'] == 'value':
         # If not auto-advancing, save immediately when slider changes
         if not is_auto_advancing:
              save_current_reward()


def on_save_button_clicked(b):
    """Saves the collected reward labels to a CSV file."""
    stop_auto_advance()
    with status_output:
        clear_output(wait=True)
        if not reward_labels:
            print("No rewards have been assigned yet to save.")
            return

        items_to_save = []
        # Iterate through the *original dataframe indices* stored as keys
        for df_idx in sorted(reward_labels.keys()):
             reward = reward_labels[df_idx]
             try: # Add error handling for potentially missing rows if df changed
                  original_row = full_dataset.dataframe.iloc[df_idx]
                  items_to_save.append({
                       'dataframe_index': df_idx, # Original index in aggregate CSV
                       'session_id': original_row['session_id'],
                       'image_path': original_row['image_path'],
                       'action': original_row['action'],
                       'assigned_reward': reward
                  })
             except IndexError:
                  print(f"Warning: Could not find original data for DataFrame index {df_idx}. Skipping.")

        if not items_to_save:
             print("No valid reward entries to save.")
             return

        save_df = pd.DataFrame(items_to_save)
        try:
            save_df.to_csv(OUTPUT_REWARD_CSV, index=False)
            print(f"Successfully saved {len(reward_labels)} reward labels to {OUTPUT_REWARD_CSV}")
        except Exception as e:
            print(f"Error saving rewards: {e}")

In [6]:
# --- Link Widgets ---
index_slider.observe(on_index_change, names='value')
reward_slider.observe(on_reward_change, names='value') # Observe reward changes
prev_button.on_click(on_prev_button_clicked)
next_button.on_click(on_next_button_clicked)
save_button.on_click(on_save_button_clicked)
start_stop_button.on_click(toggle_auto_advance)
# Link checkbox change to potentially stop auto-advance if unchecked while running
auto_advance_checkbox.observe(stop_auto_advance, names='value')

In [7]:
# --- Initial Display ---
update_display(current_dataset_index)

# --- Arrange Layout ---
auto_advance_controls = widgets.HBox([auto_advance_checkbox, speed_slider, start_stop_button])
manual_controls = widgets.HBox([prev_button, next_button])

# Display image widget prominently
ui = widgets.VBox([
    info_output,
    image_widget, # Use the single image widget
    index_slider,
    reward_slider,
    manual_controls,
    auto_advance_controls,
    save_button,
    status_output
])

# --- Display UI ---
display(ui)

VBox(children=(Output(), Image(value=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xf…