# Advanced MRS Fitting with Constraints using PyTorch

This notebook demonstrates the use of the `AdvancedLinearCombinationModel` from the `lcm_library`. This model leverages PyTorch for fitting Magnetic Resonance Spectroscopy (MRS) data, allowing for the optimization of metabolite amplitudes, frequency shifts, and linewidth broadenings, all while applying specified constraints. It also supports simultaneous fitting of a polynomial baseline.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import warnings # To manage warnings if needed

# Setup sys.path to find the lcm_library
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    print(f"Added {module_path} to sys.path")
else:
    print(f"{module_path} already in sys.path")

from lcm_library.advanced_model import AdvancedLinearCombinationModel

%matplotlib inline

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

## 1. Data Simulation

We will simulate MRS data including a few metabolites, a baseline, and noise. This allows us to have known ground truth values for amplitudes, shifts, and linewidths to compare against the model's fitting results.

### 1.1. Define Simulation Parameters

In [None]:
NUM_POINTS = 1024       # Number of points in the spectrum
SW_HZ = 2000.0          # Spectral width in Hz
F0_MHZ = 123.25         # Spectrometer frequency in MHz (for 3T Proton)
DT_S = 1.0 / SW_HZ      # Dwell time in seconds

METABOLITE_NAMES = ['MetA', 'MetB', 'MetC']
NUM_METABOLITES = len(METABOLITE_NAMES)

# True parameters for simulation (these are the 'ground truth' values)
TRUE_AMPLITUDES = np.array([10.0, 8.0, 7.0])
TRUE_SHIFTS_HZ = np.array([-1.5, 0.3, 1.2])  # Hz shift from basis spectrum position
TRUE_LW_HZ_ADDITIONAL = np.array([1.2, 1.8, 0.8]) # Additional Lorentzian broadening in Hz
TRUE_BASELINE_COEFFS = np.array([8.0, -3.0, 1.5]) # Coefficients for a 2nd degree polynomial

print(f"Dwell time (dt): {DT_S*1000:.2f} ms")

### 1.2. Simulate Ideal Basis Spectra

These are the 'pure' metabolite spectra that the `AdvancedLinearCombinationModel` will use as input. We'll simulate them as simple Lorentzian peaks at distinct frequencies. These basis spectra are assumed to be perfectly known (no inherent shift or broadening beyond their defined lineshape here). The basis spectra should be complex and fftshifted for compatibility with the model.

In [None]:
def create_lorentzian_peak_freq_domain(num_points, dt_s, peak_center_hz_offset, amplitude, lw_hz):
    """Creates a frequency-domain Lorentzian peak, fftshifted and complex."""
    time_axis = np.arange(0, num_points) * dt_s
    fid = amplitude * np.exp(1j * 2 * np.pi * peak_center_hz_offset * time_axis) * np.exp(-time_axis * np.pi * lw_hz)
    spectrum_shifted = np.fft.fftshift(np.fft.fft(fid))
    return spectrum_shifted.astype(np.complex64)

basis_spectra_list_true = []
basis_peak_hz_offsets = [-250.0, -50.0, 100.0] 
basis_inherent_lw_hz = 1.5 

for i in range(NUM_METABOLITES):
    peak = create_lorentzian_peak_freq_domain(NUM_POINTS, DT_S, 
                                              basis_peak_hz_offsets[i], 
                                              1.0, 
                                              basis_inherent_lw_hz)
    basis_spectra_list_true.append(peak)

basis_spectra_tensor_true = torch.tensor(np.array(basis_spectra_list_true).T, dtype=torch.complex64).to(device)
print(f"Shape of true basis_spectra_tensor: {basis_spectra_tensor_true.shape}")

hz_axis_full_range = np.linspace(-SW_HZ / 2, SW_HZ / 2 - SW_HZ/NUM_POINTS, NUM_POINTS)
ppm_axis_plot = (hz_axis_full_range / F0_MHZ)[::-1] 

plt.figure(figsize=(10,4))
for i in range(NUM_METABOLITES):
    plt.plot(ppm_axis_plot, basis_spectra_tensor_true[:, i].real.cpu(), label=METABOLITE_NAMES[i])
plt.title(f"Ideal Basis Spectra (Real Part, Shifted for Display)")
plt.xlabel(f"Chemical Shift (ppm, relative to {F0_MHZ} MHz as 0 ppm)")
plt.ylabel("Intensity")
plt.xlim(max(ppm_axis_plot), min(ppm_axis_plot)) 
plt.legend()
plt.grid(True, alpha=0.5)
plt.show()

### 1.3. Simulate Observed MRS Spectrum

Now, we use the true parameters (amplitudes, shifts, additional linewidths) to generate the observed spectrum from the ideal basis spectra. This involves applying the transformations in the time domain. The final observed spectrum will be the real part, as is typically fitted.

In [None]:
time_axis_torch = torch.arange(0, NUM_POINTS * DT_S, DT_S, device=device, dtype=torch.float32)
final_metabolite_sum_complex = torch.zeros(NUM_POINTS, dtype=torch.complex64, device=device)

basis_spectra_time_domain = torch.fft.ifft(torch.fft.ifftshift(basis_spectra_tensor_true, dim=0), dim=0)

for i in range(NUM_METABOLITES):
    basis_time = basis_spectra_time_domain[:, i]
    decay = torch.exp(-time_axis_torch * np.pi * TRUE_LW_HZ_ADDITIONAL[i])
    broadened_time = basis_time * decay
    phase_ramp = torch.exp(1j * 2 * np.pi * TRUE_SHIFTS_HZ[i] * time_axis_torch)
    shifted_broadened_time = broadened_time * phase_ramp
    modified_freq_shifted = torch.fft.fftshift(torch.fft.fft(shifted_broadened_time), dim=0)
    final_metabolite_sum_complex += TRUE_AMPLITUDES[i] * modified_freq_shifted

norm_freq_axis_sim = torch.linspace(-1, 1, NUM_POINTS, device=device, dtype=torch.float32)
baseline_signal = torch.zeros_like(norm_freq_axis_sim)
for d_idx, coeff in enumerate(TRUE_BASELINE_COEFFS):
    baseline_signal += coeff * (norm_freq_axis_sim ** d_idx)

observed_spectrum_no_noise_complex = final_metabolite_sum_complex + baseline_signal 
observed_spectrum_no_noise_real = observed_spectrum_no_noise_complex.real

noise_std = 1.5 
noise = torch.normal(0, noise_std, size=(NUM_POINTS,), device=device, dtype=torch.float32)
observed_spectrum_tensor = observed_spectrum_no_noise_real + noise

print(f"Shape of observed_spectrum_tensor: {observed_spectrum_tensor.shape}")

plt.figure(figsize=(10,4))
plt.plot(ppm_axis_plot, observed_spectrum_tensor.cpu().numpy(), label="Simulated Observed Spectrum (with noise)")
plt.plot(ppm_axis_plot, observed_spectrum_no_noise_real.cpu().numpy(), label="Ground Truth (No Noise)", linestyle='--')
plt.title("Simulated Observed MRS Spectrum")
plt.xlabel(f"Chemical Shift (ppm, relative to {F0_MHZ} MHz as 0 ppm)")
plt.ylabel("Intensity")
plt.xlim(max(ppm_axis_plot), min(ppm_axis_plot))
plt.legend()
plt.grid(True, alpha=0.5)
plt.show()

### 1.4. Define Fitting Mask

We'll define a mask to fit only a specific region of the spectrum, e.g., between 0.2 and 4.2 ppm for this simulation.

In [None]:
ppm_min_fit = 0.2 
ppm_max_fit = 4.2 

fitting_mask_numpy = (ppm_axis_plot <= ppm_max_fit) & (ppm_axis_plot >= ppm_min_fit)
fitting_mask_tensor = torch.tensor(fitting_mask_numpy, dtype=torch.bool, device=device)

print(f"Fitting mask covers {fitting_mask_tensor.sum().item()} points between {ppm_min_fit:.2f} and {ppm_max_fit:.2f} ppm.")

plt.figure(figsize=(10,2))
masked_spectrum_visualization = torch.zeros_like(observed_spectrum_tensor)
masked_spectrum_visualization[fitting_mask_tensor] = observed_spectrum_tensor[fitting_mask_tensor]
plt.plot(ppm_axis_plot, masked_spectrum_visualization.cpu().numpy(), label="Fitting Region Active")
plt.plot(ppm_axis_plot, observed_spectrum_tensor.cpu().numpy(), label="Observed Spectrum", alpha=0.3)
plt.title("Selected Fitting Region")
plt.xlabel(f"Chemical Shift (ppm, relative to {F0_MHZ} MHz as 0 ppm)")
plt.yticks([])
plt.xlim(max(ppm_axis_plot), min(ppm_axis_plot))
plt.legend()
plt.show()

## 2. Instantiate `AdvancedLinearCombinationModel`

In [None]:
# Initial parameters for the model (these are the starting guesses for the fit)
initial_params_guess = {
    'MetA': {'amp': 7.0, 'shift_hz': 0.1, 'lw_hz': 1.0},
    'MetB': {'amp': 6.0, 'shift_hz': -0.1, 'lw_hz': 1.5},
    'MetC': {'amp': 5.0, 'shift_hz': 0.0, 'lw_hz': 0.8},
    'baseline': {'coeff0': 0.0, 'coeff1': 0.0, 'coeff2': 0.0} 
}

constraints_fit = {
    'max_shift_hz': 2.0,  # Max allowable shift in Hz 
    'min_lw_hz': 0.3,     # Min allowable additional linewidth in Hz
    'max_lw_hz': 5.0      # Max allowable additional linewidth in Hz
}

BASELINE_DEGREE_FIT = len(TRUE_BASELINE_COEFFS) - 1

advanced_model = AdvancedLinearCombinationModel(
    basis_spectra_tensor=basis_spectra_tensor_true, 
    metabolite_names=METABOLITE_NAMES,
    observed_spectrum_tensor=observed_spectrum_tensor, 
    dt=DT_S,
    fitting_mask=fitting_mask_tensor,
    initial_params=initial_params_guess,
    constraints=constraints_fit,
    baseline_degree=BASELINE_DEGREE_FIT,
    device=device
)

print("AdvancedLinearCombinationModel instantiated.")
print(f"Model parameters will be optimized on device: {next(advanced_model.parameters()).device}")

## 3. Run the Fitting Process

In [None]:
advanced_model.fit(num_iterations=3000, lr=0.025, optim_type='adam', print_loss_every=250, weight_decay=1e-5)

## 4. Display Results

### 4.1. Compare True vs. Fitted Parameters

In [None]:
fitted_amps = advanced_model.get_fitted_amplitudes()
fitted_shifts = advanced_model.get_fitted_shifts_hz()
fitted_lws = advanced_model.get_fitted_linewidths_hz()
fitted_bl_coeffs = advanced_model.get_fitted_baseline_coeffs()

print("--- Parameter Comparison ---")
print("\nMetabolite Amplitudes:")
for i, name in enumerate(METABOLITE_NAMES):
    print(f"  {name}: True = {TRUE_AMPLITUDES[i]:.2f}, Fitted = {fitted_amps.get(name, 0.0):.2f}")

print("\nFrequency Shifts (Hz):")
for i, name in enumerate(METABOLITE_NAMES):
    print(f"  {name}: True = {TRUE_SHIFTS_HZ[i]:.2f}, Fitted = {fitted_shifts.get(name, 0.0):.2f}")

print("\nAdditional Linewidths (Hz):")
for i, name in enumerate(METABOLITE_NAMES):
    print(f"  {name}: True Additional = {TRUE_LW_HZ_ADDITIONAL[i]:.2f}, Fitted Additional = {fitted_lws.get(name, 0.0):.2f}")

if fitted_bl_coeffs is not None:
    print("\nBaseline Coefficients:")
    for i in range(len(fitted_bl_coeffs)):
        true_coeff_val = TRUE_BASELINE_COEFFS[i] if i < len(TRUE_BASELINE_COEFFS) else 'N/A (degree mismatch)'
        print(f"  Coeff {i}: True = {true_coeff_val if isinstance(true_coeff_val, str) else f'{true_coeff_val:.2f}'}, Fitted = {fitted_bl_coeffs[i]:.2f}")

### 4.2. Visualization of Fit

In [None]:
with torch.no_grad():
    full_model_spectrum_tensor = advanced_model.get_full_model_spectrum(real_part=True)
    
    fitted_baseline_signal = torch.zeros(advanced_model.num_points, device=advanced_model.device)
    if advanced_model.baseline_coeffs_raw is not None and hasattr(advanced_model, 'baseline_poly_terms'):
        fitted_baseline_signal = advanced_model.baseline_poly_terms @ advanced_model.baseline_coeffs_raw.detach()
        
    fitted_params_t = advanced_model.get_transformed_parameters()
    amps_t = fitted_params_t['amplitudes']
    shifts_t = fitted_params_t['shifts_hz']
    lws_t = fitted_params_t['linewidths_hz']
    
    basis_time_t = torch.fft.ifft(torch.fft.ifftshift(advanced_model.basis_spectra_freq_shifted, dim=0), dim=0)
    time_axis_exp_t = advanced_model.time_axis.unsqueeze(1)
    individual_metab_components_freq_list = []
    for i in range(advanced_model.num_metabolites):
        metab_t = basis_time_t[:, i]
        decay_t = torch.exp(-time_axis_exp_t[:,0] * np.pi * lws_t[i])
        phase_ramp_t = torch.exp(1j * 2 * np.pi * shifts_t[i] * time_axis_exp_t[:,0])
        mod_metab_t = metab_t * decay_t * phase_ramp_t
        mod_metab_f_shifted = torch.fft.fftshift(torch.fft.fft(mod_metab_t), dim=0)
        individual_metab_components_freq_list.append(amps_t[i] * mod_metab_f_shifted.real) 

plt.figure(figsize=(12, 9))

plt.subplot(3,1,1)
plt.plot(ppm_axis_plot, observed_spectrum_tensor.cpu().numpy(), label="Observed Spectrum", alpha=0.7)
plt.plot(ppm_axis_plot, full_model_spectrum_tensor.cpu().numpy(), label="Full Model Fit", color='red')
plt.plot(ppm_axis_plot, fitted_baseline_signal.cpu().numpy(), label="Fitted Baseline", color='green', linestyle=':')
plt.title("Observed Spectrum vs. Full Model Fit and Baseline")
plt.xlabel(f"Chemical Shift (ppm, relative to {F0_MHZ} MHz as 0 ppm)")
plt.ylabel("Intensity")
plt.xlim(max(ppm_axis_plot), min(ppm_axis_plot))
plt.legend()
plt.grid(True, alpha=0.4)

residuals = observed_spectrum_tensor.cpu().numpy() - full_model_spectrum_tensor.cpu().numpy()
plt.subplot(3,1,2)
plt.plot(ppm_axis_plot, residuals, label="Residual (Observed - Model)", color='blue')
plt.title("Residuals")
plt.xlabel(f"Chemical Shift (ppm, relative to {F0_MHZ} MHz as 0 ppm)")
plt.ylabel("Intensity")
plt.xlim(max(ppm_axis_plot), min(ppm_axis_plot))
plt.legend()
plt.grid(True, alpha=0.4)

plt.subplot(3,1,3)
data_minus_baseline = observed_spectrum_tensor.cpu().numpy() - fitted_baseline_signal.cpu().numpy()
plt.plot(ppm_axis_plot, data_minus_baseline, label="Data - Est. Baseline", color='lightgray', alpha=0.9)
sum_fitted_metabs = torch.zeros_like(individual_metab_components_freq_list[0])
for i, name in enumerate(METABOLITE_NAMES):
    if i < len(individual_metab_components_freq_list):
        component_to_plot = individual_metab_components_freq_list[i].cpu().numpy()
        sum_fitted_metabs += individual_metab_components_freq_list[i]
        plt.plot(ppm_axis_plot, component_to_plot, label=f"Fitted {name}", linestyle='--')
plt.plot(ppm_axis_plot, sum_fitted_metabs.cpu().numpy(), label="Sum of Fitted Metabolites", color='purple', linestyle='-')
plt.title("Fitted Metabolite Components vs. Data (Baseline Subtracted)")
plt.xlabel(f"Chemical Shift (ppm, relative to {F0_MHZ} MHz as 0 ppm)")
plt.ylabel("Intensity")
plt.xlim(max(ppm_axis_plot), min(ppm_axis_plot))
plt.legend(fontsize='small')
plt.grid(True, alpha=0.4)

plt.tight_layout()
plt.show()

## 5. Experiment with Constraints (Brief Discussion)

The `constraints` dictionary in `AdvancedLinearCombinationModel` allows users to define soft limits for frequency shifts and linewidths. 

- **Tighter Constraints**: If `max_shift_hz` was set to a very small value (e.g., 0.2 Hz) and a metabolite truly shifted by 1 Hz, the model would try to fit the shift but would be limited by the `tanh` transformation scaling. The parameter would effectively hit the boundary defined by the constraint. This can be useful to prevent overfitting to noise or spectral artifacts if prior knowledge suggests shifts should be small.
- **Linewidth Constraints**: `min_lw_hz` prevents linewidths from becoming unrealistically small (e.g., zero or negative), and `max_lw_hz` prevents them from becoming excessively broad (which might erroneously absorb noise or other signals). These ensure that the fitted linewidths remain within a physically plausible range.

If the initial parameters are far from the true values, or if the data is very noisy, or if basis spectra are highly correlated, the optimization might result in parameters that are at the boundaries of these constraints. This can be an indication that the fit is struggling, the model is misspecified, or that the constraints are too restrictive for the given data. Examining the raw parameters (`model.amplitudes_raw`, `model.shifts_hz_raw`, `model.linewidths_hz_raw`) can show how close the internal unconstrained parameters are to values that would saturate the sigmoid/tanh functions. If parameters are consistently at their boundaries, reconsidering the constraints or the overall model complexity might be necessary.

## 6. Conclusion

This notebook demonstrated the setup and use of the `AdvancedLinearCombinationModel` for fitting simulated MRS data. Key features highlighted include:
- Initialization with basis spectra, observed data, and various fitting parameters (initial guesses, constraints).
- The `fit` method for optimizing model parameters using PyTorch's gradient-based optimizers.
- Getter methods to retrieve constrained, interpretable fitted parameters for amplitudes, frequency shifts, linewidths, and baseline coefficients.
- Visualization of the fitting results, including the overall fit, residuals, fitted baseline, and individual metabolite components.

This advanced model provides a flexible framework for MRS quantification where non-linear parameters like frequency shifts and linewidth changes are important, and where constraints are necessary to guide the optimization towards physiologically plausible solutions. The use of PyTorch allows for leveraging automatic differentiation and GPU acceleration for more complex modeling tasks in the future.