# Cell Parameter Refinement with nanoBragg PyTorch

This tutorial demonstrates how to use the differentiable cell parameters in nanoBragg PyTorch to refine unit cell parameters from diffraction data.

## Overview

With the new general triclinic cell parameter support, you can now:
- Define arbitrary unit cells (not just cubic)
- Make cell parameters differentiable
- Use gradient-based optimization to refine parameters
- Handle any crystal system from triclinic to cubic

## Setup

First, let's import the necessary modules and set up our environment:

In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

# Set environment variable for MKL compatibility
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

# Import nanoBragg PyTorch components
from nanobrag_torch.config import CrystalConfig
from nanobrag_torch.models.crystal import Crystal
from nanobrag_torch.models.detector import Detector
from nanobrag_torch.simulator import Simulator

# Set device and precision
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float64

print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

## 1. Loading Triclinic Crystal Data

Let's start by creating a target triclinic crystal that we'll try to recover through optimization:

In [None]:
# Define our "true" crystal parameters (triclinic)
true_params = {
    'cell_a': 281.0,      # Angstroms
    'cell_b': 281.0,      # Angstroms
    'cell_c': 165.2,      # Angstroms
    'cell_alpha': 90.0,   # degrees
    'cell_beta': 90.0,    # degrees
    'cell_gamma': 120.0   # degrees (hexagonal-like)
}

# Create the true crystal
true_config = CrystalConfig(
    space_group_name='P1',
    **true_params,
    mosaic_spread_deg=0.0,
    mosaic_domains=1,
    N_cells=(5, 5, 5)
)

true_crystal = Crystal(config=true_config, device=device, dtype=dtype)
detector = Detector(device=device, dtype=dtype)

print("True crystal parameters:")
for param, value in true_params.items():
    print(f"  {param}: {value}")

# Generate "observed" diffraction pattern
simulator = Simulator(true_crystal, detector, crystal_config=true_config, device=device, dtype=dtype)
observed_image = simulator.run().detach()

# Display the diffraction pattern
plt.figure(figsize=(8, 8))
plt.imshow(observed_image.cpu().numpy(), cmap='viridis', origin='lower')
plt.colorbar(label='Intensity')
plt.title('Observed Diffraction Pattern')
plt.xlabel('Fast axis (pixels)')
plt.ylabel('Slow axis (pixels)')
plt.show()

## 2. Setting Up Differentiable Parameters

Now let's create initial guess parameters that are differentiable:

In [None]:
# Start with a perturbed initial guess (10% error)
initial_params = torch.tensor([
    true_params['cell_a'] * 0.9,     # 10% too small
    true_params['cell_b'] * 1.1,     # 10% too large
    true_params['cell_c'] * 0.95,    # 5% too small
    true_params['cell_alpha'] * 1.05, # 5% too large
    true_params['cell_beta'] * 0.95,  # 5% too small
    true_params['cell_gamma'] * 1.02  # 2% too large
], device=device, dtype=dtype, requires_grad=True)

print("Initial guess parameters:")
param_names = ['cell_a', 'cell_b', 'cell_c', 'cell_alpha', 'cell_beta', 'cell_gamma']
for i, name in enumerate(param_names):
    error = (initial_params[i].item() - list(true_params.values())[i]) / list(true_params.values())[i] * 100
    print(f"  {name}: {initial_params[i].item():.1f} (error: {error:.1f}%)")

## 3. Defining a Loss Function

We'll use mean squared error between the observed and simulated diffraction patterns:

In [None]:
def compute_loss(params, observed_image, detector, device, dtype):
    """Compute MSE loss between simulated and observed diffraction patterns."""
    # Unpack parameters
    cell_a, cell_b, cell_c, cell_alpha, cell_beta, cell_gamma = params
    
    # Create crystal with current parameters
    config = CrystalConfig(
        space_group_name='P1',
        cell_a=cell_a,
        cell_b=cell_b,
        cell_c=cell_c,
        cell_alpha=cell_alpha,
        cell_beta=cell_beta,
        cell_gamma=cell_gamma,
        mosaic_spread_deg=0.0,
        mosaic_domains=1,
        N_cells=(5, 5, 5)
    )
    
    crystal = Crystal(config=config, device=device, dtype=dtype)
    
    # Simulate diffraction pattern
    simulator = Simulator(crystal, detector, crystal_config=config, device=device, dtype=dtype)
    simulated_image = simulator.run()
    
    # Compute MSE loss
    loss = torch.nn.functional.mse_loss(simulated_image, observed_image)
    
    return loss, simulated_image

# Test the loss function
initial_loss, initial_image = compute_loss(initial_params, observed_image, detector, device, dtype)
print(f"Initial loss: {initial_loss.item():.6f}")

## 4. Running Optimization

Now let's use gradient-based optimization to refine the cell parameters:

In [None]:
# Set up optimizer
optimizer = torch.optim.Adam([initial_params], lr=0.01)

# Track optimization history
loss_history = []
param_history = []

# Optimization loop
n_iterations = 50
for iteration in range(n_iterations):
    optimizer.zero_grad()
    
    # Compute loss
    loss, simulated_image = compute_loss(initial_params, observed_image, detector, device, dtype)
    
    # Backward pass
    loss.backward()
    
    # Store history
    loss_history.append(loss.item())
    param_history.append(initial_params.detach().clone().cpu().numpy())
    
    # Optimization step
    optimizer.step()
    
    # Print progress
    if iteration % 10 == 0:
        print(f"Iteration {iteration:3d}: Loss = {loss.item():.6f}")

print(f"\nFinal loss: {loss_history[-1]:.6f}")
print(f"Loss reduction: {(1 - loss_history[-1]/loss_history[0])*100:.1f}%")

## 5. Visualizing Convergence

Let's visualize how the optimization progressed:

In [None]:
# Plot loss curve
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.semilogy(loss_history)
plt.xlabel('Iteration')
plt.ylabel('Loss (MSE)')
plt.title('Optimization Loss Curve')
plt.grid(True)

# Plot parameter convergence
plt.subplot(1, 2, 2)
param_history = np.array(param_history)
true_values = list(true_params.values())

for i, name in enumerate(param_names):
    relative_error = (param_history[:, i] - true_values[i]) / true_values[i] * 100
    plt.plot(relative_error, label=name)

plt.xlabel('Iteration')
plt.ylabel('Relative Error (%)')
plt.title('Parameter Convergence')
plt.legend()
plt.grid(True)
plt.axhline(y=0, color='k', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.show()

# Print final parameters
print("\nFinal refined parameters:")
final_params = initial_params.detach().cpu().numpy()
for i, name in enumerate(param_names):
    true_val = true_values[i]
    refined_val = final_params[i]
    error = (refined_val - true_val) / true_val * 100
    print(f"  {name}: {refined_val:.3f} (true: {true_val:.3f}, error: {error:+.2f}%)")

## 6. Comparing Diffraction Patterns

Let's visualize the difference between initial, refined, and true diffraction patterns:

In [None]:
# Generate diffraction pattern with refined parameters
with torch.no_grad():
    _, refined_image = compute_loss(initial_params, observed_image, detector, device, dtype)
    refined_image = refined_image.cpu().numpy()

# Also get the initial guess image
initial_params_copy = torch.tensor(param_history[0], device=device, dtype=dtype)
with torch.no_grad():
    _, initial_guess_image = compute_loss(initial_params_copy, observed_image, detector, device, dtype)
    initial_guess_image = initial_guess_image.cpu().numpy()

observed_np = observed_image.cpu().numpy()

# Plot comparison
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Row 1: Images
im1 = axes[0, 0].imshow(initial_guess_image, cmap='viridis', origin='lower')
axes[0, 0].set_title('Initial Guess')
plt.colorbar(im1, ax=axes[0, 0])

im2 = axes[0, 1].imshow(refined_image, cmap='viridis', origin='lower')
axes[0, 1].set_title('Refined')
plt.colorbar(im2, ax=axes[0, 1])

im3 = axes[0, 2].imshow(observed_np, cmap='viridis', origin='lower')
axes[0, 2].set_title('True (Observed)')
plt.colorbar(im3, ax=axes[0, 2])

# Row 2: Differences
diff1 = axes[1, 0].imshow(initial_guess_image - observed_np, cmap='RdBu_r', origin='lower')
axes[1, 0].set_title('Initial - True')
plt.colorbar(diff1, ax=axes[1, 0])

diff2 = axes[1, 1].imshow(refined_image - observed_np, cmap='RdBu_r', origin='lower')
axes[1, 1].set_title('Refined - True')
plt.colorbar(diff2, ax=axes[1, 1])

# Hide the last subplot
axes[1, 2].axis('off')

for ax in axes.flat:
    if ax.get_visible():
        ax.set_xlabel('Fast axis (pixels)')
        ax.set_ylabel('Slow axis (pixels)')

plt.tight_layout()
plt.show()

# Print RMS errors
initial_rms = np.sqrt(np.mean((initial_guess_image - observed_np)**2))
refined_rms = np.sqrt(np.mean((refined_image - observed_np)**2))
print(f"\nRMS errors:")
print(f"  Initial guess: {initial_rms:.6f}")
print(f"  Refined:       {refined_rms:.6f}")
print(f"  Improvement:   {(1 - refined_rms/initial_rms)*100:.1f}%")

## Advanced Topics

### Constrained Optimization

In practice, you might want to add constraints to keep parameters in physically reasonable ranges:

In [None]:
def constrained_optimization(initial_params, observed_image, detector, device, dtype, n_iterations=50):
    """Optimization with parameter constraints."""
    params = initial_params.clone()
    optimizer = torch.optim.Adam([params], lr=0.01)
    
    for iteration in range(n_iterations):
        optimizer.zero_grad()
        
        # Apply constraints (e.g., positive lengths, angles between 20-160 degrees)
        constrained_params = params.clone()
        constrained_params[:3] = torch.nn.functional.relu(params[:3]) + 1.0  # Lengths > 1 Å
        constrained_params[3:] = torch.clamp(params[3:], min=20.0, max=160.0)  # Angles in range
        
        loss, _ = compute_loss(constrained_params, observed_image, detector, device, dtype)
        loss.backward()
        optimizer.step()
        
        if iteration % 10 == 0:
            print(f"Iteration {iteration}: Loss = {loss.item():.6f}")
    
    return constrained_params

# Example usage (not run to save computation)
# constrained_params = constrained_optimization(initial_params.clone(), observed_image, detector, device, dtype)

### Using Different Optimizers

You can experiment with different PyTorch optimizers:

In [None]:
# Examples of different optimizers
optimizers = {
    'Adam': lambda p: torch.optim.Adam([p], lr=0.01),
    'SGD': lambda p: torch.optim.SGD([p], lr=0.1, momentum=0.9),
    'LBFGS': lambda p: torch.optim.LBFGS([p], lr=1, max_iter=20)
}

# LBFGS requires a closure
def lbfgs_optimization(params, observed_image, detector, device, dtype, n_iterations=10):
    optimizer = torch.optim.LBFGS([params], lr=1, max_iter=20)
    
    def closure():
        optimizer.zero_grad()
        loss, _ = compute_loss(params, observed_image, detector, device, dtype)
        loss.backward()
        return loss
    
    for i in range(n_iterations):
        loss = optimizer.step(closure)
        print(f"LBFGS step {i}: Loss = {loss.item():.6f}")
    
    return params

# Example usage (not run to save computation)
# lbfgs_params = lbfgs_optimization(initial_params.clone(), observed_image, detector, device, dtype)

## Summary

In this tutorial, we demonstrated:

1. **Loading triclinic crystal data** with arbitrary unit cell parameters
2. **Setting up differentiable parameters** using PyTorch tensors with `requires_grad=True`
3. **Defining a loss function** that compares simulated and observed diffraction patterns
4. **Running gradient-based optimization** to refine cell parameters
5. **Visualizing convergence** and comparing results

The key advantages of this approach are:
- **Automatic differentiation**: No need to derive gradients manually
- **General crystal systems**: Works for any unit cell from triclinic to cubic
- **GPU acceleration**: Can leverage CUDA for faster computation
- **Integration with ML**: Can be combined with neural networks for advanced applications

This differentiable simulation capability opens up many possibilities for crystallographic refinement, including joint refinement of multiple parameters, uncertainty quantification, and integration with machine learning models.