### Libs

In [1]:
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
from PIL import Image as PILImage
import os
import pandas as pd
import shutil
from PIL import Image
import cv2
import numpy as np
import io


### Images data & labels & output dir (Global Stuff)

In [2]:
# Images
_input_dir = r"Z:\Object_Detection\SAROS_LABEL"
image_files = os.listdir(_input_dir)
image_files = [os.path.join(_input_dir, _file) for _file in image_files if _file.endswith('.png')]

# Label names
labels = ["skull", "shoulder", "humerus", "vertebrae_C", "thorax", "vertebrae_L", 
                       "forearm", "pelvis", "femur", "hand", "patella", "shin", "tarsal", "foot"]

output_dir = "SAROS"
# Create results dir
shutil.rmtree(output_dir, ignore_errors=True)
os.makedirs(output_dir)

### Functions

In [3]:
def load_image(index):
    """Load an image from the list and display it in the widget."""
    try:
        _img = cv2.imread(image_files[index])
        # Get the original dimensions
        _original_height, _original_width = _img.shape[:2]
    
        # Set the new height
        _new_height = 500
        
        # Calculate the new width to preserve aspect ratio
        _new_width = int((_new_height / _original_height) * _original_width)
        
        # Resize the image
        _resized_image = cv2.resize(_img, (_new_width, _new_height), interpolation=cv2.INTER_LANCZOS4)
        
        # Define the fixed size (zero-padded) output dimensions
        _fixed_height = 500
        _fixed_width = max(500, _new_width)  # Example: Fixed width is at least 800 or the new width, whichever is greater
        
        # Create a new image with the fixed size and zero padding
        _output_image = np.zeros((_fixed_height, _fixed_width, 3), dtype=np.uint8)
        
        # Calculate the position to paste the resized image
        _x_offset = (_fixed_width - _new_width) // 2
        _y_offset = (_fixed_height - _new_height) // 2
        
        # Paste the resized image onto the center of the fixed size image
        _output_image[_y_offset:_y_offset+_new_height, _x_offset:_x_offset+_new_width] = _resized_image

        # Convert the output image to a format that can be displayed in Jupyter
        _output_image = cv2.cvtColor(_output_image, cv2.COLOR_BGR2RGB)
        _pil_image = Image.fromarray(_output_image)
        _buffer = io.BytesIO()
        _pil_image.save(_buffer, format='JPEG')
        _buffer.seek(0)

        _image_widget.value = _buffer.read()
    except Exception as _e:
        print(f"Error loading image: {_e}")
    
    # Report
    with _text_output:
        clear_output(wait=True)
        print(f"Image {index+1:3} of {len(image_files)}")
    
    # Load labels
    _label_file = f"{image_files[index].split('_reducted_image.png')[0]}_results.xlsx"
    _df = pd.read_excel(_label_file, index_col = 0)

    # Load row
    try:
        _row = _df.loc["True"]
        #Set checkboxes   
        for _checkbox in _checkboxes[:-1]: #Remember to escape placeholder
            if _row[_checkbox.description] > 0:
                _checkbox.value = True
            else:
                _checkbox.value = False
    # Inform
        with _text_output:
            print("Loaded labels from xlsx")
    except:
        pass

def on_save_button_clicked(b):
    global current_index
    save_labels(current_index)
    
def save_labels(index):
    """Save the selected labels for the current image."""
    # Obtain labels
    _selected_labels = [_checkbox.description for _checkbox in _checkboxes if _checkbox.value]
    
    # Report that saving is done
    with _text_output:
        print(f"SavedImage: {image_files[index]}, Labels: {_selected_labels}")

    # Create export dir
    _name = image_files[index].split(os.sep)[-1].split("_reducted")[0]

    # Make export dir
    shutil.rmtree(os.path.join(output_dir, _name), ignore_errors=True)
    os.makedirs(os.path.join(output_dir, _name))

    # Export image 
    shutil.copy(image_files[index], os.path.join(output_dir, _name, "reducted_image.png"))

    # Create pandas dataframe
    _df = pd.DataFrame()
    _df["True"] = [1 if _label in _selected_labels else 0 for _label in labels]
    _df = _df.T
    _df.columns = labels
    _df.to_excel(os.path.join(output_dir, _name, f"labels.xlsx"))
    
def on_next_button_clicked(b):
    global current_index
    current_index += 1
    if current_index >= len(image_files):
        current_index = 0
    load_image(current_index)
    display_widgets()

def on_reset_button_clicked(b):
    # Reset checkboxes
    for _checkbox in _checkboxes:
        _checkbox.value = False

def on_prev_button_clicked(b):
    global current_index
    current_index -= 1
    if current_index < 0:
        current_index = len(image_files) - 1
    load_image(current_index)
    display_widgets()
        

### Create widgets

In [4]:
# Apply custom CSS for larger checkboxes
_custom_css = """
<style>
    .custom-checkbox input[type="checkbox"] {
        transform: scale(2.0);
        margin-left: 20px;

    }
    .custom-checkbox label {
        font-size: 18px;
    }
</style>
"""
display(HTML(_custom_css))

# Widget to display the current image
_image_widget = widgets.Image(format='jpeg', width=500, height=500)

# Create checkboxes for labels
_checkboxes = []
for _label in labels:
    _checkbox = widgets.Checkbox(value=False, description = _label)
    _checkbox.add_class("custom-checkbox")
    _checkboxes.append(_checkbox)

# Fill the remaining cells in the last row with empty placeholders
_num_columns = 3
_num_labels = len(labels)
_num_empty = (_num_columns - (_num_labels % _num_columns)) % _num_columns
for _ in range(_num_empty):
    _checkboxes.append(widgets.Label(''))

# Arrange checkboxes in a grid with 5 columns and 3 rows
_rows = [widgets.HBox(_checkboxes[_i:_i+_num_columns]) for _i in range(0, len(_checkboxes), _num_columns)]
_checkbox_layout = widgets.VBox(_rows)

# Button to go to the previous image
_next_button = widgets.Button(description="Next")

# Button to save labels and go to the previous image
_prev_button = widgets.Button(description="Previous")

# Button to save labels and go to the previous image
_reset_button = widgets.Button(description="Reset")

# Button to save images
_save_button = widgets.Button(description="Save")

# Link buttons to their functions
_next_button.on_click(on_next_button_clicked)
_prev_button.on_click(on_prev_button_clicked)
_save_button.on_click(on_save_button_clicked)
_reset_button.on_click(on_reset_button_clicked)

# Display widgets
_image_output = widgets.Output()
_text_output = widgets.Output()

### Main

In [5]:
# Current image index
current_index = 0

# Load the first image
load_image(current_index)

def display_widgets():
    with _image_output:
        clear_output(wait=True)
        display(
            widgets.VBox([
                widgets.HBox([_image_widget], layout=widgets.Layout(justify_content='center')),
                _checkbox_layout,
                widgets.HBox([_prev_button, _reset_button, _save_button, _next_button], layout=widgets.Layout(justify_content='center'))
            ], layout=widgets.Layout(align_items='center'))
        )

        
display_widgets()

display(_image_output)
display(_text_output)

Output()

Output()