In [1]:
%store -r notebook_data
globals().update(notebook_data)

In [2]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.spatial.transform import Rotation
from scipy.fft import fftn, ifftshift, fftshift

from pynx.scattering.fhkl import Fhkl_thread

import cdiutils
cdiutils.update_plot_params()

#### We need to add all the functions related to simulations in `cdiutils.simulations` submodule

In [3]:
def read_lammps_positions(
        file_path: str,
        separate_atoms: bool = False,
        centre: bool =True
) -> np.ndarray | dict:
    """
    Reads atom positions from a LAMMPS .xyz or .lmp (dump) file.

    Args:
        file_path (str): path to the input file.
        separate_atoms (bool, optional): if True, separates positions by
            atom type. If False, returns all positions in one array.
            Defaults to False.

    Returns:
        dict or np.ndarray:
            - If separate_atoms=True, returns a dictionary {atom_type:
                positions_array}.
            - If separate_atoms=False, returns a single array of all
            positions.

    Raises:
        ValueError: if the file type cannot be detected or the file
            cannot be parsed.
    """
    # initialise storage for positions and atom types
    positions = []
    atom_types = []

    with open(file_path, 'r') as file:
        if file_path.endswith("xyz"):
            # Parse .xyz file
            file.readline()  # nb of atoms
            file.readline()  # comment line
            
            # parse atom positions
            for line in file:
                split_line = line.strip().split()
                atom_type = split_line[0]  # Atom type is the first column
                position = list(map(float, split_line[1:4]))  # x, y, z pos
                atom_types.append(atom_type)
                positions.append(position)

        elif file_path.endswith("lmp"):
            raise ValueError(
                "Reading LAMMPS dump files is not implemented yet. "
                "Please convert to .xyz format."
            )

    # convert to arrays
    positions = np.array(positions)

    # Atom should centred with respect to the centroid of ALL atoms
    if centre:
        positions -= np.mean(positions, axis=0, keepdims=True)

    if separate_atoms:
        # group positions by atom type
        separated_positions = {atom_type: [] for atom_type in atom_types}
        for atom_type, position in zip(atom_types, positions):
            separated_positions[atom_type].append(position)

        # convert lists to arrays
        for atom_type in separated_positions:
            separated_positions[atom_type] = np.array(
                separated_positions[atom_type]
            ).T
            # print(np.mean(separated_positions[atom_type], axis=1))
            # separated_positions[atom_type] -= np.mean(
            #     separated_positions[atom_type], axis=1, keepdims=True
            # )

        return dict(separated_positions)
    else:
        # return all positions in a single array
        return positions.T


In [4]:
path = (
    "/data/id01/inhouse/polack/shared/for_clement/exp_minimized.xyz"
)

path = (
    "/data/id01/inhouse/corentin/lammps/Pt/nanosculpt/"
    "div10_oriented_v2.5/NP_Pt_nanosculpt_div10_oriented_v2.5_minimized.xyz"
)

#path = (
#    "/data/id01/inhouse/corentin/lammps/Pt/nanosculpt/div10_clean_forces_v2/"
#    "NP_Pt_nanosculpt_div10_clean_forces_v2_minimized.xyz"
#)

path = (
    "/data/id01/inhouse/polack/scratch/ewoks_03_simulation/exp_minimized.xyz"
)

positions = read_lammps_positions(path, centre=True)

print(f"Shape: {positions.shape}")

Shape: (3, 1938657)


A simple sanity check

In [None]:
fig, ax = plt.subplots(1, 1, subplot_kw={"projection": "3d"}, layout="tight")

# scatter plot of the positions
scatter_params = {
    "s": 1.5, 
    "antialiased": False, "depthshade": True, "alpha": 0.2,
    "edgecolor": "k", "linewidth": 0.01, "c": "teal"
}
ax.scatter(*positions, **scatter_params)
ax.set_xlabel("$x$"); ax.set_ylabel("$y$"); ax.set_zlabel("$z$");
ax.set_xlim([-200, 200])
ax.set_ylim([-200, 200])
ax.set_zlim([-200, 200])


(-200.0, 200.0)

#### Some parameters
- Here we are going to simulate an X-ray scattering experiment. For that we need to define where we are going to look at in the reciprocal space. Therefore, we need:
    * a Bragg reflection (`hkl`)
    * a lattice parameter

- We also need to specify how fine we want our reciprocal space grid to be. This is basically defined by:
    * the number of steps (`step_nb`), the higher the better the resolution, oversampling
    * the reciprocal space resolution (`dq`)
    * the lattice parameter

In [None]:
"""
#hkl = [1, 1, 1]
hkl = [0, 0, 2]
#hkl = [0, 2, 0]
"""
original_vector = np.array(hkl) / np.linalg.norm(hkl)
target_vector = np.array([0, 1, 0], dtype=float)
target_vector /= np.linalg.norm(target_vector)

rotation_axis = np.cross(original_vector, target_vector)
rotation_axis /= np.linalg.norm(rotation_axis)
angle = np.arccos(np.dot(original_vector, target_vector)) 

rotation = Rotation.from_rotvec(angle * rotation_axis)
rotated_hkl = rotation.apply(np.column_stack(hkl))

# extract rotated x, y, z arrays
rotated_hkl = rotated_hkl.T
# rotated_hkl = np.array([0, 0, 2])
# rotated_hkl = np.array([0, 2, 0])

rotated_hkl

In [None]:
lattice_parameter = 3.92
d_spacing = lattice_parameter / np.sqrt(hkl[0] ** 2 + hkl[1] ** 2 + hkl[2] ** 2)

step_nb = 800
step_nb = 400
dq = np.array([1, 1, 1]) / step_nb

q_size = (200, 200, 200)
# q_size = (128, 128, 128)

# Let us build the reciprocal space grid
q_ranges = []
for i in range(3):
    q_ranges.append(
        # (hkl[i] + (np.arange(q_size[i]) - (q_size[i] / 2)) * dq[i])
        (rotated_hkl[i] + (np.arange(q_size[i]) - (q_size[i] / 2)) * dq[i])
        / lattice_parameter
    )
q_grid = np.meshgrid(*q_ranges, indexing="ij")

# Here is where the magic happens
scattered_amp, dt = Fhkl_thread(
    *q_grid,
    *positions,
    occ=None,
    gpu_name="NVIDIA A40",
    language=""
)

diffraction_pattern = np.abs(scattered_amp)**2

fig, axes = cdiutils.plot.plot_volume_slices(
    diffraction_pattern,
    norm="log",
    voxel_size=(dq[0], dq[1], dq[2]),
    data_centre=[np.mean(q) for q in q_grid],
    convention="xu",
    show=False,
    # integrate=True
)

cdiutils.plot.add_labels(axes, space="rcp", convention="xu")
display(fig)

print(
    f"Elapsed time for the scattering computation: {dt} s\n"
    "Spatial resolution in the direct space: ",
    1 / np.ptp(q_grid[0]) * .1, 1 / np.ptp(q_grid[1]) * .1, 1 / np.ptp(q_grid[2]) * .1, "nm"
)


In [None]:
def find_bragg_peak_centre(
    diffraction_pattern: np.ndarray,
    q_grid: tuple,
    search_radius: float = 0.1
) -> tuple:
    """
    Find the center of mass of the Bragg peak in reciprocal space.

    Args:
        diffraction_pattern: 3D diffraction pattern
        q_grid: tuple of 3D q-coordinate grids (qx, qy, qz)
        search_radius: fraction of the pattern to consider for
        centre of mass

    Returns:
        tuple: (qx_center, qy_center, qz_center) of the peak
    """
    total_intensity = np.sum(diffraction_pattern)
    q_com = []
    for q in q_grid:
        q_com.append(np.sum((diffraction_pattern * q)) / total_intensity)
    return tuple(q_com)

def refine_q_grid(
    positions: tuple,
    hkl: np.ndarray,
    lattice_parameter_guess: float,
    q_size: tuple = (200, 200, 200),
    step_nb: int = 400,
    max_iterations: int = 5,
    convert_to_angstrom: bool = True
):
    """
    Iteratively refine the q-grid to center the Bragg peak.

    Args:
        positions: atomic positions
        hkl: Miller indices
        lattice_parameter_guess: initial lattice parameter guess in
            meters. Will be converted to Angstrom if convert_to_angstrom is True.
        convert_to_angstrom: if True, converts lattice parameter to Angstrom
        q_size: size of q-grid
        step_nb: number of steps for dq calculation
        max_iterations: maximum refinement iterations

    Returns:
        tuple: (
            refined_q_grid, scattered_amp, refined_lattice_parameter
        )
    """
    unit = "m"
    rcp_space_unit = "1/m"
    if convert_to_angstrom:
        lattice_parameter_guess *= 1e10
        unit = "angstrom"
        rcp_space_unit = "1/angstrom"

    current_lattice = lattice_parameter_guess
    current_center = hkl / current_lattice

    for iteration in range(max_iterations):
        print(f"Refinement iteration {iteration + 1}")
        print(f"Current lattice parameter: {current_lattice:.4f}")
        print(f"Current center: {current_center}")

        # build q-grid around current center
        dq = np.array([1, 1, 1]) / step_nb  # step size in reciprocal space
        q_ranges = []
        for i in range(3):
            q_ranges.append(
                current_center[i] + (np.arange(q_size[i]) - (q_size[i] / 2)) * dq[i]
            )
        q_grid = np.meshgrid(*q_ranges, indexing="ij")

        # calculate the kinematic scattering
        scattered_amp, dt = Fhkl_thread(
            *q_grid, *positions, occ=None, gpu_name="NVIDIA A40", language=""
        )

        diffraction_pattern = np.abs(scattered_amp) ** 2

        # find the actual peak center
        peak_center = find_bragg_peak_centre(diffraction_pattern, q_grid)
        print(f"Found peak center at: {peak_center}")

        # calculate shift from expected center
        expected_center = [np.mean(q) for q in q_grid]
        shift = [pc - ec for pc, ec in zip(peak_center, expected_center)]

        print(f"Shift from grid center: {shift} {rcp_space_unit}")

        # update center for next iteration
        current_center = peak_center

        # optionally refine lattice parameter based on peak position
        # the magnitude of the q-vector should match 2Ï€/d_spacing
        q_magnitude = np.linalg.norm(peak_center)
        theoretical_q = np.linalg.norm(hkl) / current_lattice

        # refine lattice parameter
        lattice_correction = theoretical_q / q_magnitude
        current_lattice *= lattice_correction
        current_center = [c / lattice_correction for c in current_center]

        print(f"Refined lattice parameter: {current_lattice:.4f} {unit}")
        print(f"Shift magnitude: {np.linalg.norm(shift):.6f} {rcp_space_unit}")
        print("-" * 50)

        # check convergence
        if (
            np.linalg.norm(shift) < dq[0] * 0.1
        ):  # Converged if shift < 10% of pixel size
            print("Converged!")
            break

    return q_grid, scattered_amp, current_lattice

In [None]:
initial_lattice_guess = 3.92e-10  # Your initial guess
step_nb = 1600
q_size = (200, 200, 200)

# Refine the q-grid
refined_q_grid, scattered_amp, refined_lattice = refine_q_grid(
    positions, rotated_hkl, initial_lattice_guess, 
    q_size=q_size, 
    step_nb=step_nb, 
    max_iterations=10
)

print(f"Final refined lattice parameter: {refined_lattice:.4f}")

# Plot the refined result
dq = np.array([1, 1, 1]) / step_nb
fig, axes = cdiutils.plot.plot_volume_slices(
    np.abs(scattered_amp) ** 2,
    norm="log",
    voxel_size=(dq[0], dq[1], dq[2]),
    data_centre=[np.mean(q) for q in refined_q_grid],
    convention="xu",
    show=False
)
cdiutils.plot.add_labels(axes, space="rcp", convention="xu")
display(fig)

In [None]:
print(
    f"Elapsed time for the scattering computation: {dt} s\n"
    "Spatial resolution in the direct space: ",
    1 / np.ptp(refined_q_grid) * .1, 1 / np.ptp(refined_q_grid) * .1, 1 / np.ptp(refined_q_grid ) * .1, "nm"
)
    

In [None]:
q_com = find_bragg_peak_centre(
    np.abs(scattered_amp)**2, refined_q_grid
)
d_spacing = 1 / np.linalg.norm(q_com)
print(
    f"d-spacing and associated lattice parameter are: {d_spacing} and "
    f"{d_spacing * np.sqrt(hkl[0] ** 2 + hkl[1] ** 2 + hkl[2] ** 2)} A."
)

In [None]:
voxel_size = tuple(1e-1 / np.ptp(q) for q in refined_q_grid)  # in nm
# obj = fftshift(fftn(ifftshift(np.conj(scattered_amp))))
obj = fftshift(fftn(ifftshift(scattered_amp)))

support = cdiutils.utils.make_support(np.abs(obj), isosurface=0.75)
plot_params = {
    "support": support, "convention": "xu", "voxel_size": voxel_size,
    "data_centre": (0, 0, 0), "show": True,
}
# cdiutils.plot.plot_volume_slices(np.abs(obj))
cdiutils.plot.plot_volume_slices(np.abs(obj), **plot_params)
fig, axes = cdiutils.plot.plot_volume_slices(
    np.angle(obj), cmap="cet_CET_C9s_r", vmin=-np.pi, vmax=np.pi, **plot_params
)
cdiutils.plot.add_labels(axes, convention="xu")
display(fig)

In [None]:
obj = cdiutils.process.PostProcessor.apodize(obj)
struct_properties = cdiutils.process.PostProcessor.get_structural_properties(
    obj,
    isosurface=0.65,
    g_vector=tuple(q * (2*np.pi) for q in q_com),
    # hkl=[1, 1, 1],
    #hkl=[0, 0, 2],
    hkl=hkl,
    voxel_size=voxel_size,
    phase_factor=1
)
for k in struct_properties:
    print(k)

In [None]:
for k in ("amplitude", "phase", "het_strain_from_dspacing", "lattice_parameter"):
    #plot_params.update(cdiutils.plot.get_plot_configs(k))
    plot_params["vmin"], plot_params["vmax"] = None, None
    cdiutils.plot.plot_volume_slices(struct_properties[k], **plot_params)

print(np.nanmean(struct_properties["lattice_parameter"]))