In [1]:
import ipywidgets as widgets
import json
from IPython.display import display, clear_output

preferences = [
    {"source_name": "reddit", "label_weights": {}},
    {"source_name": "x", "label_weights": {}}
]

def validate_label(source, label):
    if source == "Reddit" and not label.startswith("r/"):
        return False
    elif source == "X" and not label.startswith("#"):
        return False
    return True

def validate_weight(weight, new_source, new_label):
    try:
        weight = float(weight)
        if 0.1 <= weight <= 1 and round(weight * 10) == weight * 10:
            # Calculate total weight across both sources.
            total_weight = sum(sum(source["label_weights"].values()) for source in preferences)
            
            # If we're updating an existing label, subtract its current weight.
            source_dict = next(s for s in preferences if s["source_name"] == new_source)
            if new_label in source_dict["label_weights"]:
                total_weight -= source_dict["label_weights"][new_label]
            
            # Check if adding the new weight keeps the total <= 1.
            if total_weight + weight <= 1:
                return True
    except ValueError:
        pass
    return False

def update_output(change):
    source = source_dropdown.value
    if source == "Reddit":
        label_input.placeholder = "r/..."
    else:
        label_input.placeholder = "#..."

def on_submit(b):
    source = source_dropdown.value.lower()
    label = label_input.value
    weight = weight_input.value
    
    if not validate_label(source_dropdown.value, label):
        output.clear_output()
        with output:
            print(f"Invalid label format for {source}. Please try again.")
        return
    
    if not validate_weight(weight, source, label):
        output.clear_output()
        with output:
            print(f"Invalid weight. Please enter a number between 0.1 and 1, divisible by 0.1, and ensure total weights across all sources don't exceed 1.")
        return
    
    weight = float(weight)
    source_dict = next(s for s in preferences if s["source_name"] == source)
    source_dict["label_weights"][label] = weight
    
    output.clear_output()
    with output:
        print(f"Added: {label} with weight {weight} for {source}")
        print("\nCurrent preferences:")
        print(json.dumps([s for s in preferences if s["label_weights"]], indent=2))
        
        # Display total weight across all sources.
        total_weight = sum(sum(source["label_weights"].values()) for source in preferences)
        print(f"\nTotal weight across all sources: {total_weight:.1f}")

def on_finalize(b):
    # Filter out sources with empty label_weights.
    final_preferences = [s for s in preferences if s["label_weights"]]
    
    with open('my_preferences.json', 'w') as f:
        json.dump(final_preferences, f, indent=4)
    
    output.clear_output()
    with output:
        print("Preferences saved to my_preferences.json:")
        print(json.dumps(final_preferences, indent=2))

source_dropdown = widgets.Dropdown(options=['Reddit', 'X'], description='Source:')
label_input = widgets.Text(placeholder='r/...', description='Label:')
weight_input = widgets.FloatText(value=0.1, step=0.1, description='Weight:')
submit_button = widgets.Button(description='Submit')
finalize_button = widgets.Button(description='Finalize')
output = widgets.Output()

source_dropdown.observe(update_output, names='value')
submit_button.on_click(on_submit)
finalize_button.on_click(on_finalize)

display(source_dropdown, label_input, weight_input, submit_button, finalize_button, output)

Dropdown(description='Source:', options=('Reddit', 'X'), value='Reddit')

Text(value='', description='Label:', placeholder='r/...')

FloatText(value=0.1, description='Weight:', step=0.1)

Button(description='Submit', style=ButtonStyle())

Button(description='Finalize', style=ButtonStyle())

Output()