In [1]:
from dataclasses import dataclass, field
from typing import List, Dict, Optional
import math
import json
from datetime import datetime

import ipywidgets as widgets
from IPython.display import display, clear_output


In [15]:
@dataclass #one receipt line
class Item:
    item_name: str
    item_price: float
    payer: str
    owed_by: List[str]
    quantity: int=1
    item_id: Optional[str] = None
    
@dataclass #the whole receipt
class Bill:
    members: List[str]
    items: List[Item]
    tax:float = 0.0
    tip:float = 0.0
    fees:float = 0.0
    extras_split_mode: str = 'equal'
    payer_paid_everything: bool = True
    payer: Optional[str] = None
    bill_name: Optional[str] = None
    
    

In [16]:
bill = Bill(members=[], items=[], bill_name=None, payer=None)

In [17]:
def add_member(bill, name):
    clean = name.strip().title()
    if not clean:
        return False
    if clean in bill.members:
        return False
    bill.members.append(clean)
    return True

def remove_member(bill, name):
    if name not in bill.members:
        return False, []
    bill.members.remove(name)
    
    invalid = []
    for it in bill.items:
        it.owed_by = [m for m in it.owed_by if m!= name]
        
        if len(it.owed_by) == 0:
            invalid.append(it) #case when nobody owes this item after removing name
            
    if bill.payer == name:
        bill.payer= None
    return True, invalid

In [18]:
def add_item(bill, item_name, item_price, owed_by):
    if not bill.members:
        return False, None
    
    clean_item_name = item_name.strip()
    if not clean_item_name:
        return False, None
    
    try:
        price = float(item_price)
    except (TypeError, ValueError):
        return False, None
    if price <= 0:
        return False, None
    
    owed_list = list(owed_by)

    if not owed_list:
        owed_list = bill.members.copy()
    else:    
        for name in owed_list:
            if name not in bill.members:
                return False, None
    
    payer_for_item = bill.payer if (bill.payer_paid_everything and bill.payer is not None) else ""

    new_item = Item(
        item_name=clean_item_name,
        item_price=price,
        payer=payer_for_item,
        owed_by=owed_list
    )

    bill.items.append(new_item)
    
    return True, new_item



In [19]:
def remove_item(bill, index):
    if index < 0 or index >= len(bill.items):
        return False
    bill.items.pop(index)
    return True


def set_payer(bill, payer_name):
    if payer_name not in bill.members:
        return False
    bill.payer = payer_name
    return True

In [21]:

out = widgets.Output()
members_list_box = widgets.VBox([])
items_list_box = widgets.VBox([])

# --- PAYER UI ---
payer_paid_toggle = widgets.ToggleButtons(
    options=[("Yes", True), ("No", False)],
    value=bill.payer_paid_everything,
    description="Payer paid all?"
)

payer_dropdown = widgets.Dropdown(
    options=["(select)"],
    value="(select)",
    description="Payer",
    layout=widgets.Layout(width="320px")
)

# --- ITEMS UI ---
item_name_input = widgets.Text(
    description="Item",
    placeholder="e.g., Milk",
    layout=widgets.Layout(width="320px")
)

item_price_input = widgets.FloatText(
    description="Unit $",
    value=0.0,
    layout=widgets.Layout(width="320px")
)

item_qty_input = widgets.BoundedIntText(
    description="Qty",
    value=1,
    min=1,
    max=999,
    layout=widgets.Layout(width="320px")
)

owed_by_select = widgets.SelectMultiple(
    options=[],
    description="Owed by",
    layout=widgets.Layout(width="320px", height="110px")
)

add_item_btn = widgets.Button(description="Add item", button_style="info")

# --- MEMBERS UI ---
member_input = widgets.Text(
    description="Name",
    placeholder="Add member name",
    layout=widgets.Layout(width="320px")
)

add_member_btn = widgets.Button(description="Add member", button_style="success")


def _remove_member(name):
    if "remove_member" in globals():
        success, invalid_items = remove_member(bill, name)
        if invalid_items:
            with out:
                print("\nNOTE: These items now have owed_by=[] and need fixing:")
                for it in invalid_items:
                    print("-", it.item_name)
    else:
        if name in bill.members:
            bill.members.remove(name)
            for it in bill.items:
                it.owed_by = [m for m in it.owed_by if m != name]
            if bill.payer == name:
                bill.payer = None
    refresh()


def _remove_item(index):
    if "remove_item" in globals():
        remove_item(bill, index)
    else:
        if 0 <= index < len(bill.items):
            bill.items.pop(index)
    refresh()


def compute_owed_subtotals(bill):
    members = bill.members
    owed = {m: 0.0 for m in members}

    for it in bill.items:
        group = it.owed_by if it.owed_by else members
        if not group:
            continue
        share = it.item_price / len(group)
        for m in group:
            if m in owed:
                owed[m] += share

    extras = float(bill.tax) + float(bill.tip) + float(bill.fees)
    if extras > 0 and members:
        if bill.extras_split_mode == "equal":
            per = extras / len(members)
            for m in members:
                owed[m] += per
        elif bill.extras_split_mode == "proportional":
            base = sum(owed.values())
            if base <= 1e-9:
                per = extras / len(members)
                for m in members:
                    owed[m] += per
            else:
                for m in members:
                    owed[m] += extras * (owed[m] / base)

    for m in owed:
        owed[m] = round(owed[m] + 1e-9, 2)

    return owed


def compute_payments_to_payer(bill):
    owed = compute_owed_subtotals(bill)
    payer = bill.payer
    if not payer:
        return owed, None, {}

    payments = {m: (0.0 if m == payer else owed[m]) for m in owed}
    return owed, payer, payments


def refresh():
    # Members list
    member_rows = []
    for m in bill.members:
        rm_btn = widgets.Button(description="Remove", layout=widgets.Layout(width="90px"))
        rm_btn.on_click(lambda _b, name=m: _remove_member(name))
        member_rows.append(widgets.HBox([widgets.Label(m, layout=widgets.Layout(width="180px")), rm_btn]))
    members_list_box.children = member_rows

    # Payer toggle/dropdown
    bill.payer_paid_everything = bool(payer_paid_toggle.value)

    payer_dropdown.options = ["(select)"] + bill.members
    if bill.payer in bill.members:
        payer_dropdown.value = bill.payer
    else:
        payer_dropdown.value = "(select)"
        bill.payer = None

    payer_dropdown.disabled = (bill.payer_paid_everything is False)

    # Owed-by options
    owed_by_select.options = bill.members

    # Items list
    item_rows = []
    for i, it in enumerate(bill.items):
        qty = getattr(it, "quantity", 1) or 1
        unit = it.item_price / qty
        label = f"[{i}] {it.item_name} x{qty} | unit ${unit:.2f} | line ${it.item_price:.2f} | owed_by={it.owed_by}"
        rm = widgets.Button(description="Remove", icon="trash", layout=widgets.Layout(width="90px"))
        rm.on_click(lambda _b, idx=i: _remove_item(idx))
        item_rows.append(widgets.HBox([widgets.Label(label, layout=widgets.Layout(width="520px")), rm]))
    items_list_box.children = item_rows

    # Debug
    with out:
        clear_output()
        print("Current Bill State")
        print("members:", bill.members)
        print("payer_paid_everything:", bill.payer_paid_everything)
        print("payer:", bill.payer)
        print("items:", [(it.item_name, it.item_price, getattr(it, "quantity", 1), it.owed_by) for it in bill.items])
        print("tax/tip/fees:", bill.tax, bill.tip, bill.fees, "| extras_split_mode:", bill.extras_split_mode)


def on_add_member(_):
    ok = add_member(bill, member_input.value)
    if ok:
        member_input.value = ""
    refresh()


def on_toggle_payer_paid(change):
    bill.payer_paid_everything = bool(change["new"])
    refresh()


def on_payer_change(change):
    val = change["new"]
    if val == "(select)":
        bill.payer = None
        refresh()
        return
    set_payer(bill, val)
    refresh()


def on_add_item(_):
    qty = int(item_qty_input.value or 1)
    unit_price = float(item_price_input.value or 0.0)
    line_total = unit_price * qty

    ok, new_item = add_item(
        bill,
        item_name_input.value,
        line_total,
        list(owed_by_select.value)
    )

    if ok:
        new_item.quantity = qty
        item_name_input.value = ""
        item_price_input.value = 0.0
        item_qty_input.value = 1

    refresh()


# --- EXTRAS + SUMMARY UI ---
tax_input = widgets.FloatText(description="Tax", value=bill.tax, layout=widgets.Layout(width="320px"))
tip_input = widgets.FloatText(description="Tip", value=bill.tip, layout=widgets.Layout(width="320px"))
fees_input = widgets.FloatText(description="Fees", value=bill.fees, layout=widgets.Layout(width="320px"))

extras_mode = widgets.ToggleButtons(
    options=[("Equal", "equal"), ("Proportional", "proportional")],
    value=bill.extras_split_mode,
    description="Extras split"
)

calc_btn = widgets.Button(description="Calculate totals", button_style="primary")
summary_out = widgets.Output()


def on_calculate(_):
    bill.tax = float(tax_input.value or 0.0)
    bill.tip = float(tip_input.value or 0.0)
    bill.fees = float(fees_input.value or 0.0)
    bill.extras_split_mode = extras_mode.value

    with summary_out:
        clear_output()

        if not bill.members:
            print("Add members first.")
            return
        if not bill.items:
            print("Add at least one item.")
            return
        if bill.payer_paid_everything and not bill.payer:
            print("Select a payer first.")
            return
        if not bill.payer_paid_everything:
            print("Multi payer mode not implemented yet. Switch 'Payer paid all?' to Yes.")
            return

        owed, payer, payments = compute_payments_to_payer(bill)

        print("Owed totals (each person’s share):")
        for m in bill.members:
            print(f"  {m}: ${owed[m]:.2f}")

        if payer:
            print("\nPay the payer:")
            for m in bill.members:
                if m != payer:
                    print(f"  {m} -> {payer}: ${payments[m]:.2f}")


# Hook up events
add_member_btn.on_click(on_add_member)
payer_paid_toggle.observe(on_toggle_payer_paid, names="value")
payer_dropdown.observe(on_payer_change, names="value")
add_item_btn.on_click(on_add_item)
calc_btn.on_click(on_calculate)

# Layout (NO duplicates)
ui = widgets.VBox([
    widgets.HTML("<h3>1) Members</h3>"),
    widgets.HBox([member_input, add_member_btn]),
    members_list_box,

    widgets.HTML("<h3>2) Who paid?</h3>"),
    payer_paid_toggle,
    payer_dropdown,

    widgets.HTML("<h3>3) Items</h3>"),
    item_name_input,
    item_price_input,
    item_qty_input,
    owed_by_select,
    add_item_btn,
    items_list_box,

    widgets.HTML("<h3>4) Extras + Summary</h3>"),
    tax_input,
    tip_input,
    fees_input,
    extras_mode,
    calc_btn,
    summary_out,

    widgets.HTML("<hr>"),
    out
])

display(ui)
refresh()


VBox(children=(HTML(value='<h3>1) Members</h3>'), HBox(children=(Text(value='', description='Name', layout=Lay…