In [10]:
import numpy as np
import plotly.graph_objects as go
from ipywidgets import HBox, VBox, Button, FloatSlider, Layout, Output, Label
from IPython.display import display, clear_output

# ============================================================
# GLOBALS / PLACEHOLDERS
# ============================================================

initialized = False     # Prevent double initialization
ui_container = VBox()   # Shown immediately so Voilá never stalls

display(ui_container)   # Voilá displays this instantly


# ============================================================
# HELPER FUNCTIONS (pure, no widget state)
# ============================================================

def get_cell_vectors(a_val, b_val, gamma_deg):
    g = np.deg2rad(gamma_deg)
    return np.array([a_val, 0.0]), np.array([b_val*np.cos(g), b_val*np.sin(g)])

def frac_to_cart_atoms(frac_coords, a_vec, b_vec):
    u = frac_coords[:,0][:,None]
    v = frac_coords[:,1][:,None]
    return u*a_vec + v*b_vec

def tile_unit_cell(atoms_cart, nx_val, ny_val, a_vec, b_vec):
    pts = []
    for i in range(nx_val):
        for j in range(ny_val):
            pts.append(atoms_cart + i*a_vec + j*b_vec)
    return np.vstack(pts)

def build_density(points, a_vec, b_vec, nx_val, ny_val, grid_size=256, pad=0.0, sigma=0.5):
    c00 = np.array([0.0, 0.0])
    c10 = nx_val * a_vec
    c01 = ny_val * b_vec
    c11 = nx_val * a_vec + ny_val * b_vec

    xs = [c00[0], c10[0], c01[0], c11[0]]
    ys = [c00[1], c10[1], c01[1], c11[1]]

    x_min = min(xs) - pad
    x_max = max(xs) + pad
    y_min = min(ys) - pad
    y_max = max(ys) + pad

    x = np.linspace(x_min, x_max, grid_size)
    y = np.linspace(y_min, y_max, grid_size)
    X, Y = np.meshgrid(x, y)

    rho = np.zeros_like(X)
    for px, py in points:
        rho += np.exp(-((X-px)**2 + (Y-py)**2)/(2*sigma**2))

    return rho, x, y

def get_reciprocal_vectors(a_vec, b_vec):
    A = a_vec[0]*b_vec[1] - a_vec[1]*b_vec[0]
    a_star = 2*np.pi*np.array([b_vec[1], -b_vec[0]])/A
    b_star = 2*np.pi*np.array([-a_vec[1], a_vec[0]])/A
    return a_star, b_star

def gaussian_form_factor(qmag, beta=0.02):
    return np.exp(-beta*qmag*qmag)

def compute_structure_factors(a_vec, b_vec, atoms_frac_local, atom_mask_local,
                              qmax_val, max_hk_val=8):
    a_star, b_star = get_reciprocal_vectors(a_vec, b_vec)

    H,K,Qx,Qy,Ivals,phases = [],[],[],[],[],[]

    for h in range(-max_hk_val,max_hk_val+1):
        for k in range(-max_hk_val,max_hk_val+1):
            Q = h*a_star + k*b_star
            qmag = np.linalg.norm(Q)
            if qmag > qmax_val:
                continue

            F = 0j
            for j in range(atoms_frac_local.shape[0]):
                if not atom_mask_local[j]: continue
                u,v = atoms_frac_local[j]
                phase = 2*np.pi*(h*u + k*v)
                F += gaussian_form_factor(qmag)*np.exp(1j*phase)

            I = abs(F)**2
            if I < 1e-10: continue

            H.append(h); K.append(k)
            Qx.append(Q[0]); Qy.append(Q[1])
            Ivals.append(I)
            phases.append(np.degrees(np.angle(F)))

    return (np.array(H),np.array(K),np.array(Qx),np.array(Qy),
            np.array(Ivals),np.array(phases))

def intensities_to_sizes(I):
    if I.size==0: return np.array([])
    Ilog = np.log10(I+1e-6)
    Ilog -= Ilog.min()
    if Ilog.max()>0: Ilog /= Ilog.max()
    return 2 + 8*Ilog

# ============================================================
# ⭐ UNIT CELL OVERLAY HELPERS (RESTORED)
# ============================================================

def draw_unit_cells(fig, a_vec, b_vec, nx=3, ny=3):
    """Draw white polygon outlines of the tiled unit cells."""
    shapes=[]
    for i in range(nx):
        for j in range(ny):
            c0 = i*a_vec + j*b_vec
            c1 = c0 + a_vec
            c2 = c0 + a_vec + b_vec
            c3 = c0 + b_vec
            path = (
                f"M {c0[0]},{c0[1]} "
                f"L {c1[0]},{c1[1]} "
                f"L {c2[0]},{c2[1]} "
                f"L {c3[0]},{c3[1]} Z"
            )
            shapes.append(dict(
                type="path",
                path=path,
                line=dict(color="white",width=1),
                fillcolor="rgba(0,0,0,0)",
                opacity=0.5
            ))
    fig.update_layout(shapes=shapes)


# ============================================================
# MAIN INITIALIZATION FUNCTION (called by button)
# ============================================================

def initialize_app(button=None):
    global initialized
    global atoms_frac, atoms_frac_init, atom_mask
    global a,b,gamma,qmax
    global out_real,out_amp,out_phase
    global real_fig,fig_amp,fig_phase
    global a_slider,b_slider,g_slider,q_slider
    global slider_boxes
    global ui_container

    if initialized:
        return
    initialized=True

    # ---------- initial parameters ----------
    a_init=10.0
    b_init=10.0
    gamma_init=90.0
    qmax_init=5.0

    atoms_frac_init=np.array([[0.0,0.0],[0.5,0.5]])
    atoms_frac=atoms_frac_init.copy()
    atom_mask=np.ones(len(atoms_frac_init),dtype=bool)

    a,b,gamma,qmax=a_init,b_init,gamma_init,qmax_init

    # ---------- output widgets ----------
    out_real=Output()
    out_amp=Output()
    out_phase=Output()

    # ========================================================
    # UPDATE FUNCTION
    # ========================================================
    def update_all(*args):
        global a,b,gamma,qmax,atoms_frac,atom_mask

        a_vec,b_vec=get_cell_vectors(a,b,gamma)
        atoms_cart=frac_to_cart_atoms(atoms_frac,a_vec,b_vec)
        active_atoms=atoms_cart[atom_mask]
        pts=tile_unit_cell(active_atoms,3,3,a_vec,b_vec)

        rho,xg,yg=build_density(pts,a_vec,b_vec,3,3)

        real_fig.data[0].z=rho
        real_fig.data[0].x=xg
        real_fig.data[0].y=yg

        # ⭐ REDRAW UNIT CELL OVERLAY
        draw_unit_cells(real_fig,a_vec,b_vec)

        H,K,Qx,Qy,Ivals,phases=compute_structure_factors(
            a_vec,b_vec,atoms_frac,atom_mask,qmax,8
        )
        sizes=intensities_to_sizes(Ivals)

        fig_amp.data[0].x=Qx
        fig_amp.data[0].y=Qy
        fig_amp.data[0].marker.size=sizes
        fig_amp.data[0].marker.color=Ivals

        fig_phase.data[0].x=Qx
        fig_phase.data[0].y=Qy
        fig_phase.data[0].marker.size=sizes
        fig_phase.data[0].marker.color=phases

        with out_real:
            clear_output(wait=True)
            display(real_fig)
        with out_amp:
            clear_output(wait=True)
            display(fig_amp)
        with out_phase:
            clear_output(wait=True)
            display(fig_phase)

    # ========================================================
    # INITIAL FIGURES
    # ========================================================
    a_vec,b_vec=get_cell_vectors(a,b,gamma)
    atoms_cart=frac_to_cart_atoms(atoms_frac,a_vec,b_vec)
    pts=tile_unit_cell(atoms_cart,3,3,a_vec,b_vec)
    rho,xg,yg=build_density(pts,a_vec,b_vec,3,3)

    real_fig=go.Figure(go.Heatmap(z=rho,x=xg,y=yg,colorscale="Viridis"))
    real_fig.update_layout(
        width=480,height=480,
        title="Real Space",
        xaxis=dict(scaleanchor="y",title="x (Å)"),
        yaxis=dict(title="y (Å)")
    )

    # ⭐ ADD UNIT CELL OVERLAY ON INITIAL DRAW
    draw_unit_cells(real_fig,a_vec,b_vec)

    H,K,Qx,Qy,Ivals,phases=compute_structure_factors(a_vec,b_vec,atoms_frac,atom_mask,qmax)
    sizes=intensities_to_sizes(Ivals)

    fig_amp=go.Figure(go.Scatter(
        x=Qx,y=Qy,mode="markers",
        marker=dict(size=sizes,color=Ivals,colorscale="Viridis",showscale=True)
    ))
    fig_amp.update_layout(
        width=480,height=480,
        title="Diffraction |F|²",
        xaxis=dict(scaleanchor="y",title="Qx (Å⁻¹)"),
        yaxis=dict(title="Qy (Å⁻¹)")
    )

    fig_phase=go.Figure(go.Scatter(
        x=Qx,y=Qy,mode="markers",
        marker=dict(size=sizes,color=phases,
                    colorscale="Twilight",cmin=-180,cmax=180,showscale=True)
    ))
    fig_phase.update_layout(
        width=480,height=480,
        title="Phase Map",
        xaxis=dict(scaleanchor="y",title="Qx (Å⁻¹)"),
        yaxis=dict(title="Qy (Å⁻¹)")
    )

    # ========================================================
    # SLIDERS (atom fractional positions)
    # ========================================================
    slider_boxes=[]
    num_atoms=len(atoms_frac)

    for k in range(num_atoms):
        u_s=FloatSlider(description=f"Atom {k} x",
                        min=0,max=1,step=0.01,value=atoms_frac[k,0],
                        layout=Layout(width='260px'))
        v_s=FloatSlider(description=f"Atom {k} y",
                        min=0,max=1,step=0.01,value=atoms_frac[k,1],
                        layout=Layout(width='260px'))

        def make_u_cb(i):
            def cb(change):
                global atoms_frac
                atoms_frac[i,0]=change["new"]
                update_all()
            return cb

        def make_v_cb(i):
            def cb(change):
                global atoms_frac
                atoms_frac[i,1]=change["new"]
                update_all()
            return cb

        u_s.observe(make_u_cb(k),"value")
        v_s.observe(make_v_cb(k),"value")

        slider_boxes.append(HBox([u_s,v_s]))

    # ========================================================
    # LATTICE PARAMETER SLIDERS
    # ========================================================
    a_slider=FloatSlider(description="a (Å)",min=3,max=20,step=0.1,value=a)
    b_slider=FloatSlider(description="b (Å)",min=3,max=20,step=0.1,value=b)
    g_slider=FloatSlider(description="γ (°)",min=30,max=150,step=1,value=gamma)
    q_slider=FloatSlider(description="Qmax",min=1,max=15,step=0.1,value=qmax)

    def on_cell_change(change):
        global a,b,gamma
        a=a_slider.value
        b=b_slider.value
        gamma=g_slider.value
        update_all()

    def on_q_change(change):
        global qmax
        qmax=q_slider.value
        update_all()

    a_slider.observe(on_cell_change,"value")
    b_slider.observe(on_cell_change,"value")
    g_slider.observe(on_cell_change,"value")
    q_slider.observe(on_q_change,"value")

    # ========================================================
    # BUTTONS
    # ========================================================
    toggle_btn=Button(description="Toggle Atom 1")
    reset_btn=Button(description="Reset")

    def on_toggle(btn):
        global atom_mask
        if len(atom_mask)>1:
            atom_mask[1]=not atom_mask[1]
        update_all()

    def on_reset(btn):
        global a,b,gamma,qmax,atoms_frac,atom_mask
        a,b,gamma,qmax=a_init,b_init,gamma_init,qmax_init
        atoms_frac[:]=atoms_frac_init
        atom_mask[:]=True

        a_slider.value=a
        b_slider.value=b
        g_slider.value=gamma
        q_slider.value=qmax

        for k in range(num_atoms):
            bx=slider_boxes[k].children
            bx[0].value=atoms_frac[k,0]
            bx[1].value=atoms_frac[k,1]

        update_all()

    toggle_btn.on_click(on_toggle)
    reset_btn.on_click(on_reset)

    # ========================================================
    # FINAL UI
    # ========================================================
    fig_row=HBox([out_real,out_amp,out_phase])
    control_row=HBox([toggle_btn,reset_btn])
    cell_row=HBox([a_slider,b_slider,g_slider,q_slider])

    ui_container.children=[
        Label("2D Lattice and Diffraction Explorer"),
        fig_row,
        control_row,
        cell_row,
    ] + slider_boxes

    update_all()


# ============================================================
# INITIALIZE BUTTON (safe in Voilá)
# ============================================================

init_button=Button(
    description="Initialize App",
    button_style="success",
    layout=Layout(width="200px")
)
init_button.on_click(initialize_app)

ui_container.children=[
    Label("Click to start the interactive FFT app"),
    init_button
]


VBox()