### pyFAI Levenberg-Marquardt geometry optimization seeded by MAXIMA-ViT/MAXIMA-Swin

This notebook is the proof of concept for a workflow that loads a 2D diffraction pattern .tiff, infers an 6D detector geometry, and refines it using pyFAIs `GeometryRefinement` class.

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import fabio
import yaml

from pyFAI.geometry import Geometry
from pyFAI.geometryRefinement import GeometryRefinement
from pyFAI.azimuthalIntegrator import AzimuthalIntegrator
from skimage.feature import peak_local_max
from scipy.optimize import minimize

from src.utils import load_model, image_to_tensor, get_calibrant, get_detector

In [None]:
CONFIG_PATH = "path/to/cfg.yaml"
MODEL_PATH = "path/to/model.pth"
IMAGE_PATH = "path/to/pattern.tiff"

CALIBRANT_ALIAS = 'alpha_Al2O3' 
WAVELENGTH = 0.512126e-10  # m (should be pulled from HDF5 for experiment)
DETECTOR_ALIAS = 'Eiger2Cdte_1M'

In [None]:
def visualize_results(image, initial_geo, refined_geo, calibrant):
    """
    Plots the initial guess vs the refined result overlay using 2-theta contours.
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))

    log_img = np.log1p(image)

    def make_ai(geo):
        return AzimuthalIntegrator(
            dist=geo.dist,
            poni1=geo.poni1,
            poni2=geo.poni2,
            rot1=geo.rot1,
            rot2=geo.rot2,
            rot3=geo.rot3,
            wavelength=geo.wavelength,
            detector=geo.detector
        )

    def draw_rings(ax, ai, cal):
        tth_array = ai.center_array(unit='2th_rad')
        theoretical_rings = [r for r in cal.get_2th() if r < tth_array.max()]
        ax.contour(tth_array, levels=theoretical_rings, colors='blue', linewidths=0.7, alpha=0.9)

    ai_initial = make_ai(initial_geo)
    ai_refined = make_ai(refined_geo)

    ax1.imshow(log_img, cmap='inferno', origin='lower')
    ax1.set_title(f"Initial ViT Prediction")
    draw_rings(ax1, ai_initial, calibrant)

    ax2.imshow(log_img, cmap='inferno', origin='lower')
    ax2.set_title(f"Refined NLLS Result")
    draw_rings(ax2, ai_refined, calibrant)

    plt.tight_layout()
    plt.show()

In [None]:
class PeakOptimizer:
    def __init__(self, image, initial_geometry, calibrant, exclude_border=300):
        self.image = image
        self.geo = initial_geometry
        self.calibrant = calibrant
        self.exclude_border = exclude_border
        
        self.best_params = None
        self.best_error = float('inf')
        self.best_peaks = None
        self.best_refiner = None
        self.best_geometry = None

    def _objective(self, params):    
        # min_distance must be >= 1 and an integer
        min_dist = int(max(1, round(params[0])))
        
        # threshold must be between 0 and 1
        thresh = np.clip(params[1], 0.001, 1.0)
        
        # tolerance must be positive (degrees)
        tol_deg = max(0.1, params[2])

        # peak detection
        try:
            peaks = peak_local_max(
                self.image, 
                min_distance=min_dist, 
                threshold_rel=thresh,
                exclude_border=self.exclude_border
            )
        except Exception:
            return 1e6 # penalty for crash

        if len(peaks) < 5:
            return 1e5  # penalty for too few peaks

        tth_measured = self.geo.tth(peaks[:, 0], peaks[:, 1])
        tth_expected = np.array(self.calibrant.get_2th())
        
        diff_matrix = np.abs(tth_measured[:, None] - tth_expected[None, :])
        
        min_diffs = diff_matrix.min(axis=1)
        ring_indices = diff_matrix.argmin(axis=1)
        
        mask = min_diffs < np.deg2rad(tol_deg)
        
        if np.sum(mask) < 6: 
            return 1e4 # penalty for insufficient labeled peaks

        # labeled data: [y, x, ring_index]
        valid_peaks = peaks[mask]
        valid_indices = ring_indices[mask]
        
        data = np.column_stack((valid_peaks[:, 0], valid_peaks[:, 1], valid_indices))

        # refinement
        try:
            refiner = GeometryRefinement(
                data=data,  
                dist=self.geo.dist,
                poni1=self.geo.poni1,
                poni2=self.geo.poni2,
                rot1=self.geo.rot1,
                rot2=self.geo.rot2,
                rot3=self.geo.rot3,
                pixel1=self.geo.detector.pixel1,
                pixel2=self.geo.detector.pixel2,
                detector=self.geo.detector,
                wavelength=self.calibrant.wavelength,
                calibrant=self.calibrant
            )
            
            error = refiner.refine2()
            
            if error < self.best_error:
                self.best_error = error
                self.best_params = (min_dist, thresh, tol_deg)
                self.best_peaks = data
                self.best_refiner = refiner
                self.best_geometry = Geometry(
                    dist=refiner.dist,
                    poni1=refiner.poni1,
                    poni2=refiner.poni2,
                    rot1=refiner.rot1,
                    rot2=refiner.rot2,
                    rot3=refiner.rot3,
                    wavelength=self.calibrant.wavelength,
                    detector=self.geo.detector
                )
                
            return error

        except Exception as e:
            return 1e6 # penalty for refinement failure

    def optimize(self, initial_guess=[5, 0.1, 1.0]):
        """
        Runs the optimizer.
        Initial Guess: [min_distance, threshold_rel, tolerance_degrees]
        """
        
        result = minimize( # uses nelder-mead for non-smooth objective 
            self._objective, 
            x0=initial_guess, 
            method='Nelder-Mead', 
            tol=1e-4,
            options={'maxiter': 50, 'disp': True}
        )
        
        return self.best_refiner
    
    def get_best_geometry(self):
        return self.best_geometry
    
    def get_best_refiner(self):
        return self.best_refiner

In [None]:
with open(CONFIG_PATH, "r") as f:
        config = yaml.safe_load(f)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# load and preprocess pattern
image = fabio.open(IMAGE_PATH).data.astype(np.float32)

image = np.clip(image, a_min=30.0, a_max=300.0) # clip zingers

image_size = config['model'].get('image_size', 1056)

tensor = image_to_tensor(image, image_size)
tensor = tensor.unsqueeze(0).to(device)

In [None]:
# load model architecture/weights of choice

model = load_model(MODEL_PATH, config)
model.to(device)
model.eval()

In [None]:
# infer geometry

with torch.no_grad():
        prediction = model(tensor).cpu().numpy().flatten()

calibrant = get_calibrant(CALIBRANT_ALIAS, WAVELENGTH)
detector = get_detector(DETECTOR_ALIAS)

initial_geometry = Geometry(
        dist=prediction[0],
        poni1=prediction[1],
        poni2=prediction[2],
        rot1=prediction[3],
        rot2=prediction[4],\
        rot3=prediction[5],
        wavelength=calibrant.wavelength,
        detector=detector
    )

param_names = ["dist", "poni1", "poni2", "rot1", "rot2", "rot3"]
print("\n--- Initial ViT Prediction ---")
for name, val in zip(param_names, prediction):
    print(f"{name:<10}: {val:.6f}")

In [None]:
# initialize and run optimizer for refinement params

optimizer = PeakOptimizer(
    image=image,
    initial_geometry=initial_geometry,
    calibrant=calibrant
)

refiner = optimizer.optimize()

In [None]:
refined_dist = refiner.dist
refined_poni1 = refiner.poni1
refined_poni2 = refiner.poni2
refined_rot1 = refiner.rot1
refined_rot2 = refiner.rot2
refined_rot3 = refiner.rot3

final_geometry = Geometry(
    dist=refined_dist,
    poni1=refined_poni1,
    poni2=refined_poni2,
    rot1=refined_rot1,
    rot2=refined_rot2,
    rot3=refined_rot3,
    wavelength=WAVELENGTH,
    detector=detector
)

In [None]:
refined_params = [refined_dist, refined_poni1, refined_poni2, refined_rot1, refined_rot2, refined_rot3]

print(f"{'Parameter':<10} | {'Initial (ViT)':<15} | {'Refined (NLLS)':<15} | {'Delta':<15}")
print("-" * 65)
for name, init, final in zip(param_names, prediction, refined_params):
    delta = abs(init - final)
    print(f"{name:<10} | {init:.6f}        | {final:.6f}        | {delta:.6f}")

In [None]:
visualize_results(image, initial_geometry, final_geometry, calibrant)