In [1]:
import numpy as np
import plotly.graph_objects as go
from ipywidgets import HBox, VBox, Button, FloatSlider, Layout, Output
from IPython.display import display, clear_output

# ============================================================
#  INITIAL PARAMETERS
# ============================================================
a_init = 10.0
b_init = 10.0
gamma_init = 90.0     # degrees

nx, ny = 3, 3         # Only show a 3×3 block
sigma = 0.5           # Gaussian width for real-space heatmap

# Diffraction parameters
qmax_init = 5.0       # Å⁻¹
max_hk = 8            # Maximum |h|,|k| for discrete reflections

# Atoms (fractional coordinates)
atoms_frac_init = np.array([
    [0.0, 0.0],
    [0.5, 0.5],
])
num_atoms = len(atoms_frac_init)

# ============================================================
# GLOBAL STATE
# ============================================================
a = a_init
b = b_init
gamma = gamma_init
qmax = qmax_init

atoms_frac = atoms_frac_init.copy()
atom_mask = np.ones(num_atoms, dtype=bool)

# ============================================================
# GEOMETRY HELPERS
# ============================================================
def get_cell_vectors(a_val, b_val, gamma_deg):
    """Real-space lattice vectors from a, b, gamma."""
    g = np.deg2rad(gamma_deg)
    return np.array([a_val, 0.0]), np.array([b_val * np.cos(g), b_val * np.sin(g)])


def frac_to_cart_atoms(frac_coords, a_vec, b_vec):
    """Convert fractional to Cartesian coordinates (2D)."""
    u = frac_coords[:, 0][:, None]
    v = frac_coords[:, 1][:, None]
    return u * a_vec + v * b_vec


def tile_unit_cell(atoms_cart, nx_val, ny_val, a_vec, b_vec):
    """Tile atoms into an nx×ny block in real space."""
    pts = []
    for i in range(nx_val):
        for j in range(ny_val):
            pts.append(atoms_cart + i * a_vec + j * b_vec)
    return np.vstack(pts)

# ============================================================
# REAL SPACE DENSITY — exact 3×3 bounded region
# ============================================================
def build_density(points, a_vec, b_vec, nx_val, ny_val, grid_size=256, pad=0.0):
    """
    Compute a 2D Gaussian density map on a rectangular grid
    that fully covers the nx×ny skewed parallelogram.
    """
    # Parallelogram corners of the nx×ny lattice
    c00 = np.array([0.0, 0.0])
    c10 = nx_val * a_vec
    c01 = ny_val * b_vec
    c11 = nx_val * a_vec + ny_val * b_vec

    xs = [c00[0], c10[0], c01[0], c11[0]]
    ys = [c00[1], c10[1], c01[1], c11[1]]

    x_min = min(xs) - pad
    x_max = max(xs) + pad
    y_min = min(ys) - pad
    y_max = max(ys) + pad

    # Rectangular grid covering entire 3×3 area
    x = np.linspace(x_min, x_max, grid_size)
    y = np.linspace(y_min, y_max, grid_size)
    X, Y = np.meshgrid(x, y)

    rho = np.zeros_like(X)
    for px, py in points:
        rho += np.exp(-((X - px) ** 2 + (Y - py) ** 2) / (2 * sigma ** 2))

    return rho, x, y

# ============================================================
# RECIPROCAL SPACE & STRUCTURE FACTORS
# ============================================================
def get_reciprocal_vectors(a_vec, b_vec):
    """Return a*, b* from real-space vectors."""
    A = a_vec[0] * b_vec[1] - a_vec[1] * b_vec[0]
    a_star = 2 * np.pi * np.array([b_vec[1], -b_vec[0]]) / A
    b_star = 2 * np.pi * np.array([-a_vec[1], a_vec[0]]) / A
    return a_star, b_star


def gaussian_form_factor(qmag, beta=0.02):
    """Simple isotropic Gaussian atomic form factor."""
    return np.exp(-beta * qmag * qmag)


def compute_structure_factors(a_vec, b_vec, atoms_frac_local, atom_mask_local,
                              qmax_val, max_hk_val=8):
    """Compute discrete structure factors F(h,k)."""
    a_star, b_star = get_reciprocal_vectors(a_vec, b_vec)

    H, K, Qx, Qy, Ivals, phases = [], [], [], [], [], []

    for h in range(-max_hk_val, max_hk_val + 1):
        for k in range(-max_hk_val, max_hk_val + 1):
            Q = h * a_star + k * b_star
            qmag = np.linalg.norm(Q)
            if qmag > qmax_val:
                continue

            F = 0j
            for j in range(num_atoms):
                if not atom_mask_local[j]:
                    continue
                u, v = atoms_frac_local[j]
                phase = 2 * np.pi * (h * u + k * v)
                F += gaussian_form_factor(qmag) * np.exp(1j * phase)

            I = np.abs(F) ** 2
            if I < 1e-10:
                continue

            H.append(h)
            K.append(k)
            Qx.append(Q[0])
            Qy.append(Q[1])
            Ivals.append(I)
            phases.append(np.degrees(np.angle(F)))

    return (np.array(H), np.array(K),
            np.array(Qx), np.array(Qy),
            np.array(Ivals), np.array(phases))


def intensities_to_sizes(I):
    """Map intensities to marker sizes."""
    if I.size == 0:
        return np.array([])
    Ilog = np.log10(I + 1e-6)
    Ilog -= Ilog.min()
    if Ilog.max() > 0:
        Ilog /= Ilog.max()
    return 2 + 8 * Ilog

# ============================================================
# INITIAL FIGURES (plain go.Figure)
# ============================================================
a_vec, b_vec = get_cell_vectors(a, b, gamma)
atoms_cart = frac_to_cart_atoms(atoms_frac, a_vec, b_vec)
pts = tile_unit_cell(atoms_cart, nx, ny, a_vec, b_vec)
rho, xg, yg = build_density(pts, a_vec, b_vec, nx, ny)

real_fig = go.Figure()
real_fig.add_trace(go.Heatmap(
    z=rho, x=xg, y=yg,
    colorscale="Viridis", showscale=False
))

def draw_unit_cells(fig, a_vec_local, b_vec_local):
    shapes = []
    for i in range(nx):
        for j in range(ny):
            c0 = i * a_vec_local + j * b_vec_local
            c1 = c0 + a_vec_local
            c2 = c0 + a_vec_local + b_vec_local
            c3 = c0 + b_vec_local
            path = f"M {c0[0]},{c0[1]} L {c1[0]},{c1[1]} L {c2[0]},{c2[1]} L {c3[0]},{c3[1]} Z"
            shapes.append(dict(
                type="path", path=path,
                line=dict(color="white", width=1),
                fillcolor="rgba(0,0,0,0)",
                opacity=0.5
            ))
    fig.update_layout(shapes=shapes)

draw_unit_cells(real_fig, a_vec, b_vec)
real_fig.update_layout(
    width=500, height=500,
    title="Real Space",
    xaxis=dict(title="x (Å)"),
    yaxis=dict(title="y (Å)", scaleanchor="x")
)

# Diffraction figures
H, K, Qx, Qy, Ivals, phases = compute_structure_factors(
    a_vec, b_vec, atoms_frac, atom_mask, qmax, max_hk
)
sizes = intensities_to_sizes(Ivals)

fig_amp = go.Figure()
fig_amp.add_trace(go.Scatter(
    x=Qx, y=Qy, mode="markers",
    marker=dict(
        size=sizes,
        color=Ivals,
        colorscale="Viridis",
        showscale=True,
        colorbar=dict(title="|F|²")
    )
))
fig_amp.update_layout(
    width=500, height=500,
    title="Diffraction |F|²",
    xaxis=dict(title="Qx (Å⁻¹)", scaleanchor="y"),
    yaxis=dict(title="Qy (Å⁻¹)")
)

fig_phase = go.Figure()
fig_phase.add_trace(go.Scatter(
    x=Qx, y=Qy, mode="markers",
    marker=dict(
        size=sizes,
        color=phases,
        colorscale="Twilight",
        showscale=True,
        cmin=-180, cmax=180,
        colorbar=dict(title="Phase (°)")
    )
))
fig_phase.update_layout(
    width=500, height=500,
    title="Phase Map",
    xaxis=dict(title="Qx (Å⁻¹)", scaleanchor="y"),
    yaxis=dict(title="Qy (Å⁻¹)")
)

# ============================================================
# OUTPUT WIDGETS FOR VOILA
# ============================================================
out_real = Output()
out_amp = Output()
out_phase = Output()

def redraw_all_outputs():
    """Redraw all three figures into their Output widgets."""
    with out_real:
        clear_output(wait=True)
        display(real_fig)
    with out_amp:
        clear_output(wait=True)
        display(fig_amp)
    with out_phase:
        clear_output(wait=True)
        display(fig_phase)

# ============================================================
# UPDATE FUNCTION
# ============================================================
def update_all():
    global a_vec, b_vec

    # Update lattice vectors
    a_vec, b_vec = get_cell_vectors(a, b, gamma)

    # Atom positions
    atoms_cart = frac_to_cart_atoms(atoms_frac, a_vec, b_vec)
    active_atoms = atoms_cart[atom_mask]      # <-- apply mask
    pts = tile_unit_cell(active_atoms, nx, ny, a_vec, b_vec)

    # Real-space density
    rho_new, x_new, y_new = build_density(pts, a_vec, b_vec, nx, ny)
    real_fig.data[0].z = rho_new
    real_fig.data[0].x = x_new
    real_fig.data[0].y = y_new
    draw_unit_cells(real_fig, a_vec, b_vec)

    # Structure factors
    H, K, Qx, Qy, Ivals, phases = compute_structure_factors(
        a_vec, b_vec, atoms_frac, atom_mask, qmax, max_hk
    )
    sizes = intensities_to_sizes(Ivals)

    # If no points, keep empty
    if len(fig_amp.data) == 0:
        fig_amp.add_trace(go.Scatter(mode="markers"))
    if len(fig_phase.data) == 0:
        fig_phase.add_trace(go.Scatter(mode="markers"))

    fig_amp.data[0].x = Qx
    fig_amp.data[0].y = Qy
    fig_amp.data[0].marker.size = sizes
    fig_amp.data[0].marker.color = Ivals

    fig_phase.data[0].x = Qx
    fig_phase.data[0].y = Qy
    fig_phase.data[0].marker.size = sizes
    fig_phase.data[0].marker.color = phases

    redraw_all_outputs()

# ============================================================
# SLIDERS FOR FRACTIONAL ATOM COORDS
# ============================================================
slider_boxes = []
for k in range(num_atoms):
    u_s = FloatSlider(
        description=f"Atom {k} x",
        min=0, max=1, step=0.01,
        value=atoms_frac[k, 0],
        layout=Layout(width='300px')
    )
    v_s = FloatSlider(
        description=f"Atom {k} y",
        min=0, max=1, step=0.01,
        value=atoms_frac[k, 1],
        layout=Layout(width='300px')
    )

    def make_u_cb(i):
        def cb(change):
            atoms_frac[i, 0] = change['new']
            update_all()
        return cb

    def make_v_cb(i):
        def cb(change):
            atoms_frac[i, 1] = change['new']
            update_all()
        return cb

    u_s.observe(make_u_cb(k), "value")
    v_s.observe(make_v_cb(k), "value")

    slider_boxes.append(HBox([u_s, v_s], layout=Layout(margin='2px 0')))

# ============================================================
# UNIT CELL & QMAX SLIDERS
# ============================================================
a_slider = FloatSlider(description="a (Å)", min=3, max=20, step=0.1, value=a_init)
b_slider = FloatSlider(description="b (Å)", min=3, max=20, step=0.1, value=b_init)
g_slider = FloatSlider(description="γ (°)", min=30, max=150, step=1, value=gamma_init)
q_slider = FloatSlider(description="Qmax", min=1, max=15, step=0.1, value=qmax_init)

def on_cell_change(change):
    global a, b, gamma
    a = a_slider.value
    b = b_slider.value
    gamma = g_slider.value
    update_all()

def on_q_change(change):
    global qmax
    qmax = q_slider.value
    update_all()

a_slider.observe(on_cell_change, 'value')
b_slider.observe(on_cell_change, 'value')
g_slider.observe(on_cell_change, 'value')
q_slider.observe(on_q_change, 'value')

unitcell_box = HBox([a_slider, b_slider, g_slider, q_slider],
                    layout=Layout(margin='4px 0'))

# ============================================================
# BUTTONS
# ============================================================
reset_btn = Button(description="Reset All")
toggle_btn = Button(description="Toggle Atom 1")

def on_reset(btn):
    global a, b, gamma, qmax, atoms_frac, atom_mask

    a = a_init
    b = b_init
    gamma = gamma_init
    qmax = qmax_init
    atoms_frac[:] = atoms_frac_init
    atom_mask[:] = True

    a_slider.value = a
    b_slider.value = b
    g_slider.value = gamma
    q_slider.value = qmax

    # This will trigger callbacks as well, but that's fine
    for k in range(num_atoms):
        slider_boxes[k].children[0].value = atoms_frac[k, 0]
        slider_boxes[k].children[1].value = atoms_frac[k, 1]

    update_all()

def on_toggle(btn):
    atom_mask[1] = not atom_mask[1]
    update_all()

reset_btn.on_click(on_reset)
toggle_btn.on_click(on_toggle)

button_row = HBox([reset_btn, toggle_btn], layout=Layout(margin='4px 0'))

# ============================================================
# FINAL UI LAYOUT
# ============================================================
fig_row = HBox(
    [out_real, out_amp, out_phase],
    layout=Layout(justify_content='space-around')
)

ui = VBox(
    [fig_row, button_row, unitcell_box] + slider_boxes
)

display(ui)

# Initial draw
update_all()


VBox(children=(HBox(children=(Output(), Output(), Output()), layout=Layout(justify_content='space-around')), H…