In [8]:
import numpy as np
import plotly.graph_objects as go
from ipywidgets import HBox, VBox, Button, FloatSlider, Layout
from IPython.display import display

# ============================================================
#  Initial parameters
# ============================================================
a_init = 10.0
b_init = 10.0
gamma_init = 90.0     # degrees
qmax_init = 5.0       # Å^-1
max_hk = 8
sigma = 0.5           # Gaussian width (Å)

# Fractional coordinates (fixed)
atoms_frac_init = np.array([
    [0.0, 0.0],
    [0.5, 0.5]
])
num_atoms = len(atoms_frac_init)

# Repeats for display
nx = ny = 3   # fixed grid size

# ============================================================
#  Global state
# ============================================================
a = a_init
b = b_init
gamma = gamma_init
qmax = qmax_init
atoms_frac = atoms_frac_init.copy()
atom_mask = np.ones(num_atoms, dtype=bool)

# ============================================================
# Helper functions
# ============================================================
def get_cell_vectors(a_val, b_val, gamma_deg):
    g = np.deg2rad(gamma_deg)
    a_vec = np.array([a_val, 0.0])
    b_vec = np.array([b_val * np.cos(g), b_val * np.sin(g)])
    return a_vec, b_vec

def frac_to_cart_atoms(frac_coords, a_vec, b_vec):
    u = frac_coords[:,0][:,None]
    v = frac_coords[:,1][:,None]
    return u*a_vec + v*b_vec

def tile_unit_cell_3x3(atoms_cart_cell, a_vec, b_vec):
    """Always tile a 3×3 block centered on the origin."""
    pts = []
    for i in [-1,0,1]:
        for j in [-1,0,1]:
            pts.append(atoms_cart_cell + i*a_vec + j*b_vec)
    return np.vstack(pts)

def compute_3x3_bounds(a_vec, b_vec, pad=1.0):
    """Returns bounding box for entire 3×3 block."""
    # Corners of the 3×3 parallelogram
    corners = []
    for i in [-1,2]:
        for j in [-1,2]:
            corners.append(i*a_vec + j*b_vec)
    corners = np.array(corners)

    x_min = corners[:,0].min() - pad
    x_max = corners[:,0].max() + pad
    y_min = corners[:,1].min() - pad
    y_max = corners[:,1].max() + pad
    return x_min, x_max, y_min, y_max

def build_density(points, a_vec, b_vec, grid_size=256):
    # Now density grid is defined by the full 3×3 parallelogram boundaries
    x_min, x_max, y_min, y_max = compute_3x3_bounds(a_vec, b_vec)

    x = np.linspace(x_min, x_max, grid_size)
    y = np.linspace(y_min, y_max, grid_size)

    X, Y = np.meshgrid(x, y)
    rho = np.zeros_like(X)

    for px, py in points:
        rho += np.exp(-((X - px)**2 + (Y - py)**2)/(2*sigma**2))

    return rho, x, y

# ============================================================
#  Reciprocal space
# ============================================================
def get_reciprocal_vectors(a_vec, b_vec):
    A = a_vec[0]*b_vec[1] - a_vec[1]*b_vec[0]   # 2D "volume"
    a_star = 2*np.pi * np.array([ b_vec[1], -b_vec[0] ]) / A
    b_star = 2*np.pi * np.array([-a_vec[1],  a_vec[0] ]) / A
    return a_star, b_star

def gaussian_form_factor(qmag, beta=0.02):
    return np.exp(-beta * qmag*qmag)

def compute_structure_factors(a_vec, b_vec, atoms_frac_local, atom_mask_local, qmax_val, max_hk_val):
    a_star, b_star = get_reciprocal_vectors(a_vec, b_vec)
    Hs, Ks, Qxs, Qys, Is, phases = [], [], [], [], [], []

    for h in range(-max_hk_val, max_hk_val+1):
        for k in range(-max_hk_val, max_hk_val+1):
            if h==0 and k==0:
                continue
            Q = h*a_star + k*b_star
            qmag = np.linalg.norm(Q)
            if qmag > qmax_val:
                continue

            F = 0
            for j in range(len(atoms_frac_local)):
                if not atom_mask_local[j]:
                    continue
                u, v = atoms_frac_local[j]
                phase = 2*np.pi*(h*u + k*v)
                F += gaussian_form_factor(qmag) * np.exp(1j*phase)

            I = np.abs(F)**2
            if I < 1e-8:
                continue

            Hs.append(h)
            Ks.append(k)
            Qxs.append(Q[0])
            Qys.append(Q[1])
            Is.append(I)
            phases.append(np.degrees(np.angle(F)))

    return (np.array(Hs), np.array(Ks),
            np.array(Qxs), np.array(Qys),
            np.array(Is), np.array(phases))

def intensities_to_sizes(I):
    if I.size == 0:
        return np.array([])
    Ilog = np.log10(I+1e-6)
    Ilog -= Ilog.min()
    if Ilog.max() > 0:
        Ilog /= Ilog.max()
    return 3 + 10 * Ilog

# ============================================================
# Initial computation
# ============================================================
a_vec, b_vec = get_cell_vectors(a, b, gamma)
atoms_cart = frac_to_cart_atoms(atoms_frac, a_vec, b_vec)
active_atoms = atoms_cart[atom_mask]
pts = tile_unit_cell_3x3(active_atoms, a_vec, b_vec)
rho, xg, yg = build_density(pts, a_vec, b_vec)

# ============================================================
# Real-space figure
# ============================================================
real_fig = go.FigureWidget()
real_fig.add_trace(go.Heatmap(z=rho, x=xg, y=yg,
                              colorscale="Viridis", showscale=False))

def draw_unit_cells_3x3(fig, a_vec_local, b_vec_local):
    fig.layout.shapes = ()
    for i in [-1,0,1]:
        for j in [-1,0,1]:
            c0 = i*a_vec_local + j*b_vec_local
            c1 = c0 + a_vec_local
            c2 = c0 + a_vec_local + b_vec_local
            c3 = c0 + b_vec_local
            fig.add_shape(
                type="path",
                path=f"M {c0[0]},{c0[1]} L {c1[0]},{c1[1]} L {c2[0]},{c2[1]} L {c3[0]},{c3[1]} Z",
                line=dict(color="white", width=1),
                fillcolor="rgba(0,0,0,0)",
                opacity=0.7
            )

draw_unit_cells_3x3(real_fig, a_vec, b_vec)

real_fig.update_layout(
    width=500, height=500,
    title="Real Space (3×3 Cells)",
    xaxis=dict(title="x (Å)", scaleanchor="y"),
    yaxis=dict(title="y (Å)")
)

# ============================================================
# Diffraction plots
# ============================================================
H,K,Qx,Qy,Ivals,phase_deg = compute_structure_factors(
    a_vec,b_vec,atoms_frac,atom_mask,qmax,max_hk
)

sizes = intensities_to_sizes(Ivals)

fig_amp = go.FigureWidget(
    data=[go.Scatter(
        x=Qx, y=Qy, mode="markers",
        marker=dict(size=sizes, color=Ivals, colorscale="Viridis",
                    showscale=True, colorbar=dict(title="|F|²"))
    )],
    layout=go.Layout(width=500, height=500,
                     title="Diffraction |F|²",
                     xaxis=dict(title="Qx (Å⁻¹)", scaleanchor="y"),
                     yaxis=dict(title="Qy (Å⁻¹)"))
)

fig_phase = go.FigureWidget(
    data=[go.Scatter(
        x=Qx, y=Qy, mode="markers",
        marker=dict(size=sizes, color=phase_deg, colorscale="Twilight",
                    showscale=True, cmin=-180, cmax=180,
                    colorbar=dict(title="Phase (degrees)"))
    )],
    layout=go.Layout(width=500, height=500,
                     title="Diffraction Phase",
                     xaxis=dict(title="Qx (Å⁻¹)", scaleanchor="y"),
                     yaxis=dict(title="Qy (Å⁻¹)"))
)

# ============================================================
# Update function
# ============================================================
def update_all():
    global a_vec, b_vec
    a_vec, b_vec = get_cell_vectors(a, b, gamma)

    atoms_cart = frac_to_cart_atoms(atoms_frac, a_vec, b_vec)
    active_atoms = atoms_cart[atom_mask]

    # 3×3 block density
    pts = tile_unit_cell_3x3(active_atoms, a_vec, b_vec)
    rho_new, x_new, y_new = build_density(pts, a_vec, b_vec)

    real_fig.data[0].z = rho_new
    real_fig.data[0].x = x_new
    real_fig.data[0].y = y_new

    draw_unit_cells_3x3(real_fig, a_vec, b_vec)

    # Reciprocal space update
    H,K,Qx,Qy,Ivals,phase_vals = compute_structure_factors(
        a_vec,b_vec,atoms_frac,atom_mask,qmax,max_hk
    )
    sizes = intensities_to_sizes(Ivals)

    fig_amp.data[0].x = Qx
    fig_amp.data[0].y = Qy
    fig_amp.data[0].marker.color = Ivals
    fig_amp.data[0].marker.size = sizes

    fig_phase.data[0].x = Qx
    fig_phase.data[0].y = Qy
    fig_phase.data[0].marker.color = phase_vals
    fig_phase.data[0].marker.size = sizes

# ============================================================
# Sliders: fractional coordinates
# ============================================================
slider_boxes = []

u_sliders = []
v_sliders = []

for k in range(num_atoms):
    u_s = FloatSlider(description=f"Atom {k} x",
                      min=0,max=1,step=0.01,value=atoms_frac[k,0],
                      layout=Layout(width='320px'))
    v_s = FloatSlider(description=f"Atom {k} y",
                      min=0,max=1,step=0.01,value=atoms_frac[k,1],
                      layout=Layout(width='320px'))

    u_sliders.append(u_s)
    v_sliders.append(v_s)

    def make_cb_u(idx):
        def cb(change):
            atoms_frac[idx,0] = change['new']
            update_all()
        return cb

    def make_cb_v(idx):
        def cb(change):
            atoms_frac[idx,1] = change['new']
            update_all()
        return cb

    u_s.observe(make_cb_u(k),'value')
    v_s.observe(make_cb_v(k),'value')

    slider_boxes.append(HBox([u_s, v_s], layout=Layout(margin='2px 0')))

# ============================================================
# Sliders: unit cell + Qmax
# ============================================================
a_slider = FloatSlider(description="a (Å)",
                       min=2,max=20,step=0.1,value=a,
                       layout=Layout(width='240px'))
b_slider = FloatSlider(description="b (Å)",
                       min=2,max=20,step=0.1,value=b,
                       layout=Layout(width='240px'))
g_slider = FloatSlider(description="γ (°)",
                       min=30,max=150,step=1,value=gamma,
                       layout=Layout(width='240px'))
q_slider = FloatSlider(description="Qmax (Å⁻¹)",
                       min=1,max=15,step=0.1,value=qmax,
                       layout=Layout(width='240px'))

def on_cell_change(change):
    global a,b,gamma
    a = a_slider.value
    b = b_slider.value
    gamma = g_slider.value
    update_all()

def on_q_change(change):
    global qmax
    qmax = q_slider.value
    update_all()

a_slider.observe(on_cell_change,'value')
b_slider.observe(on_cell_change,'value')
g_slider.observe(on_cell_change,'value')
q_slider.observe(on_q_change,'value')

unitcell_box = HBox([a_slider,b_slider,g_slider,q_slider],
                    layout=Layout(margin='4px 0'))

# ============================================================
# Buttons
# ============================================================
reset_btn = Button(description="Reset all")
toggle_btn = Button(description="Toggle Atom 1")

def on_reset(btn):
    global a,b,gamma,qmax,atoms_frac,atom_mask
    a,b,gamma,qmax = a_init,b_init,gamma_init,qmax_init
    atoms_frac[:] = atoms_frac_init
    atom_mask[:] = True

    a_slider.value = a
    b_slider.value = b
    g_slider.value = gamma
    q_slider.value = qmax

    for k in range(num_atoms):
        u_sliders[k].value = atoms_frac[k,0]
        v_sliders[k].value = atoms_frac[k,1]

    update_all()

def on_toggle(btn):
    atom_mask[1] = not atom_mask[1]
    update_all()

reset_btn.on_click(on_reset)
toggle_btn.on_click(on_toggle)

button_box = HBox([reset_btn, toggle_btn], layout=Layout(margin='4px 0'))

# ============================================================
# Layout
# ============================================================
fig_row = HBox([real_fig, fig_amp, fig_phase],
               layout=Layout(justify_content='space-between'))

ui = VBox([fig_row, button_box, unitcell_box] + slider_boxes)
display(ui)

update_all()


VBox(children=(HBox(children=(FigureWidget({
    'data': [{'colorscale': [[0.0, '#440154'], [0.111111111111111…