# 04. End-to-End MRS Analysis: Fitting and Absolute Quantification

This notebook demonstrates an end-to-end workflow for Magnetic Resonance Spectroscopy (MRS) data analysis. We will first use the `AdvancedLinearCombinationModel` to fit simulated MRS data and estimate metabolite amplitudes along with other spectral parameters like frequency shifts and linewidth broadenings. Subsequently, we will use the `AbsoluteQuantifier` class to calculate the absolute concentrations of these metabolites, taking into account relaxation effects, tissue corrections, and water referencing.

## Setup

Import necessary libraries and modules. This includes PyTorch for numerical operations, Matplotlib for plotting, and our custom `AdvancedLinearCombinationModel` and `AbsoluteQuantifier` classes from the `lcm_library`.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import warnings # To manage warnings if needed

# Setup sys.path to find the lcm_library
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    print(f"Added {module_path} to sys.path")
else:
    print(f"{module_path} already in sys.path")

from lcm_library.advanced_model import AdvancedLinearCombinationModel
from lcm_library.quantification import AbsoluteQuantifier

%matplotlib inline

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

## Phase 1: Data Simulation and Preparation

We'll start by simulating MRS data. This includes defining ideal basis spectra for a few metabolites and then creating an observed spectrum by combining these basis spectra with known amplitudes, frequency shifts, linewidth changes, a baseline, and noise. This simulated data will serve as the input for our fitting model.

### 1.1. Define Simulation Parameters

In [None]:
NUM_POINTS = 1024       # Number of points in the spectrum
SW_HZ = 2000.0          # Spectral width in Hz
F0_MHZ = 123.25         # Spectrometer frequency in MHz (for 3T Proton)
DT_S = 1.0 / SW_HZ      # Dwell time in seconds

METABOLITE_NAMES_SIM = ['NAA', 'Cr', 'Cho'] # Example metabolites
NUM_METABOLITES_SIM = len(METABOLITE_NAMES_SIM)

# True parameters for simulation (ground truth for fitting)
TRUE_AMPLITUDES_SIM = np.array([12.0, 9.0, 5.0])
TRUE_SHIFTS_HZ_SIM = np.array([-1.0, 0.5, 1.0])  # Hz shift from basis spectrum position
TRUE_LW_HZ_ADDITIONAL_SIM = np.array([1.0, 1.5, 0.7]) # Additional Lorentzian broadening in Hz
TRUE_BASELINE_COEFFS_SIM = np.array([10.0, -4.0, 2.0]) # Coefficients for a 2nd degree polynomial

print(f"Dwell time (dt): {DT_S*1000:.2f} ms")

### 1.2. Simulate Ideal Basis Spectra

These are the 'pure' metabolite spectra that the `AdvancedLinearCombinationModel` will use. We simulate them as simple Lorentzian peaks. They are assumed to be perfectly known (no inherent shift or broadening beyond their defined lineshape here) and are fftshifted.

In [None]:
def create_lorentzian_peak_freq_domain(num_points, dt_s, peak_center_hz_offset, amplitude, lw_hz):
    """Creates a frequency-domain Lorentzian peak, fftshifted and complex."""
    time_axis = np.arange(0, num_points) * dt_s
    fid = amplitude * np.exp(1j * 2 * np.pi * peak_center_hz_offset * time_axis) * np.exp(-time_axis * np.pi * lw_hz)
    spectrum_shifted = np.fft.fftshift(np.fft.fft(fid))
    return spectrum_shifted.astype(np.complex64)

basis_spectra_list_sim = []
basis_peak_hz_offsets_sim = [-300.0, -100.0, 50.0] # Distinct frequencies for NAA, Cr, Cho bases
basis_inherent_lw_hz_sim = 2.0 # Base linewidth for all basis spectra

for i in range(NUM_METABOLITES_SIM):
    peak = create_lorentzian_peak_freq_domain(NUM_POINTS, DT_S, 
                                              basis_peak_hz_offsets_sim[i], 
                                              1.0, # Basis spectra are normalized to amplitude 1
                                              basis_inherent_lw_hz_sim)
    basis_spectra_list_sim.append(peak)

basis_spectra_tensor_true = torch.tensor(np.array(basis_spectra_list_sim).T, dtype=torch.complex64).to(device)
print(f"Shape of true basis_spectra_tensor: {basis_spectra_tensor_true.shape}")

hz_axis_full_range_sim = np.linspace(-SW_HZ / 2, SW_HZ / 2 - SW_HZ/NUM_POINTS, NUM_POINTS)
ppm_axis_plot_sim = (hz_axis_full_range_sim / F0_MHZ)[::-1] # Reversed for typical MRS display

plt.figure(figsize=(10,4))
for i in range(NUM_METABOLITES_SIM):
    plt.plot(ppm_axis_plot_sim, basis_spectra_tensor_true[:, i].real.cpu(), label=METABOLITE_NAMES_SIM[i])
plt.title(f"Ideal Basis Spectra (Real Part)")
plt.xlabel(f"Chemical Shift (ppm, relative to {F0_MHZ} MHz as 0 ppm)")
plt.ylabel("Intensity")
plt.xlim(max(ppm_axis_plot_sim), min(ppm_axis_plot_sim)) 
plt.legend()
plt.grid(True, alpha=0.5)
plt.show()

### 1.3. Simulate Observed MRS Spectrum

The observed spectrum is generated by applying the true amplitudes, shifts, and additional linewidths to the ideal basis spectra. A polynomial baseline and random noise are also added.

In [None]:
time_axis_torch_sim = torch.arange(0, NUM_POINTS * DT_S, DT_S, device=device, dtype=torch.float32)
final_metabolite_sum_complex_sim = torch.zeros(NUM_POINTS, dtype=torch.complex64, device=device)

# Convert basis to time domain for modifications
basis_spectra_time_domain_sim = torch.fft.ifft(torch.fft.ifftshift(basis_spectra_tensor_true, dim=0), dim=0)

for i in range(NUM_METABOLITES_SIM):
    basis_time = basis_spectra_time_domain_sim[:, i]
    decay = torch.exp(-time_axis_torch_sim * np.pi * TRUE_LW_HZ_ADDITIONAL_SIM[i])
    broadened_time = basis_time * decay
    phase_ramp = torch.exp(1j * 2 * np.pi * TRUE_SHIFTS_HZ_SIM[i] * time_axis_torch_sim)
    shifted_broadened_time = broadened_time * phase_ramp
    # Convert back to frequency domain and fftshift
    modified_freq_shifted = torch.fft.fftshift(torch.fft.fft(shifted_broadened_time), dim=0)
    final_metabolite_sum_complex_sim += TRUE_AMPLITUDES_SIM[i] * modified_freq_shifted

# Simulate baseline
norm_freq_axis_sim = torch.linspace(-1, 1, NUM_POINTS, device=device, dtype=torch.float32)
baseline_signal_sim = torch.zeros_like(norm_freq_axis_sim)
for d_idx, coeff in enumerate(TRUE_BASELINE_COEFFS_SIM):
    baseline_signal_sim += coeff * (norm_freq_axis_sim ** d_idx)

observed_spectrum_no_noise_complex_sim = final_metabolite_sum_complex_sim + baseline_signal_sim
observed_spectrum_no_noise_real_sim = observed_spectrum_no_noise_complex_sim.real

# Add noise
noise_std_sim = 1.0 
noise_sim = torch.normal(0, noise_std_sim, size=(NUM_POINTS,), device=device, dtype=torch.float32)
observed_spectrum_tensor = observed_spectrum_no_noise_real_sim + noise_sim

print(f"Shape of observed_spectrum_tensor: {observed_spectrum_tensor.shape}")

plt.figure(figsize=(10,4))
plt.plot(ppm_axis_plot_sim, observed_spectrum_tensor.cpu().numpy(), label="Simulated Observed Spectrum (with noise)")
plt.plot(ppm_axis_plot_sim, observed_spectrum_no_noise_real_sim.cpu().numpy(), label="Ground Truth (No Noise)", linestyle='--')
plt.title("Simulated Observed MRS Spectrum")
plt.xlabel(f"Chemical Shift (ppm, relative to {F0_MHZ} MHz as 0 ppm)")
plt.ylabel("Intensity")
plt.xlim(max(ppm_axis_plot_sim), min(ppm_axis_plot_sim))
plt.legend()
plt.grid(True, alpha=0.5)
plt.show()

### 1.4. Define Fitting Mask

We define a mask to specify the spectral region for fitting (e.g., 0.2 to 4.2 ppm).

In [None]:
ppm_min_fit_sim = 0.2 
ppm_max_fit_sim = 4.2 

fitting_mask_numpy_sim = (ppm_axis_plot_sim <= ppm_max_fit_sim) & (ppm_axis_plot_sim >= ppm_min_fit_sim)
fitting_mask_tensor = torch.tensor(fitting_mask_numpy_sim, dtype=torch.bool, device=device)

print(f"Fitting mask covers {fitting_mask_tensor.sum().item()} points between {ppm_min_fit_sim:.2f} and {ppm_max_fit_sim:.2f} ppm.")

plt.figure(figsize=(10,2))
masked_spectrum_visualization_sim = torch.zeros_like(observed_spectrum_tensor)
masked_spectrum_visualization_sim[fitting_mask_tensor] = observed_spectrum_tensor[fitting_mask_tensor]
plt.plot(ppm_axis_plot_sim, masked_spectrum_visualization_sim.cpu().numpy(), label="Fitting Region Active")
plt.plot(ppm_axis_plot_sim, observed_spectrum_tensor.cpu().numpy(), label="Observed Spectrum", alpha=0.3)
plt.title("Selected Fitting Region")
plt.xlabel(f"Chemical Shift (ppm, relative to {F0_MHZ} MHz as 0 ppm)")
plt.yticks([])
plt.xlim(max(ppm_axis_plot_sim), min(ppm_axis_plot_sim))
plt.legend()
plt.show()

## Phase 2: Fitting with `AdvancedLinearCombinationModel`

Now we instantiate and use the `AdvancedLinearCombinationModel` to fit the simulated spectrum. We provide initial guesses for the parameters and define constraints for the fitting process.

In [None]:
initial_params_guess_fit = {
    METABOLITE_NAMES_SIM[0]: {'amp': 10.0, 'shift_hz': 0.0, 'lw_hz': 0.8},
    METABOLITE_NAMES_SIM[1]: {'amp': 7.0, 'shift_hz': 0.0, 'lw_hz': 1.2},
    METABOLITE_NAMES_SIM[2]: {'amp': 4.0, 'shift_hz': 0.0, 'lw_hz': 0.5},
    'baseline': {'coeff0': 0.0, 'coeff1': 0.0, 'coeff2': 0.0} 
}

constraints_fit = {
    'max_shift_hz': 3.0,      # Max allowable absolute shift in Hz 
    'min_lw_hz': 0.2,         # Min allowable additional linewidth in Hz
    'max_lw_hz': 6.0          # Max allowable additional linewidth in Hz
}

BASELINE_DEGREE_FIT = len(TRUE_BASELINE_COEFFS_SIM) - 1

advanced_model = AdvancedLinearCombinationModel(
    basis_spectra_tensor=basis_spectra_tensor_true, 
    metabolite_names=METABOLITE_NAMES_SIM,
    observed_spectrum_tensor=observed_spectrum_tensor, 
    dt=DT_S,
    fitting_mask=fitting_mask_tensor,
    initial_params=initial_params_guess_fit,
    constraints=constraints_fit,
    baseline_degree=BASELINE_DEGREE_FIT,
    device=device
)

print("AdvancedLinearCombinationModel instantiated.")

### 2.1. Run the Fitting Process

In [None]:
advanced_model.fit(num_iterations=3500, lr=0.03, optim_type='adam', print_loss_every=500, weight_decay=1e-4)

### 2.2. Compare True vs. Fitted Parameters

In [None]:
fitted_metabolite_amplitudes = advanced_model.get_fitted_amplitudes()
fitted_shifts = advanced_model.get_fitted_shifts_hz()
fitted_lws = advanced_model.get_fitted_linewidths_hz()
fitted_bl_coeffs = advanced_model.get_fitted_baseline_coeffs()

print("--- Parameter Comparison ---")
print("\nMetabolite Amplitudes:")
for i, name in enumerate(METABOLITE_NAMES_SIM):
    print(f"  {name}: True = {TRUE_AMPLITUDES_SIM[i]:.2f}, Fitted = {fitted_metabolite_amplitudes.get(name, 0.0):.2f}")

print("\nFrequency Shifts (Hz):")
for i, name in enumerate(METABOLITE_NAMES_SIM):
    print(f"  {name}: True = {TRUE_SHIFTS_HZ_SIM[i]:.2f}, Fitted = {fitted_shifts.get(name, 0.0):.2f}")

print("\nAdditional Linewidths (Hz):")
for i, name in enumerate(METABOLITE_NAMES_SIM):
    print(f"  {name}: True Additional = {TRUE_LW_HZ_ADDITIONAL_SIM[i]:.2f}, Fitted Additional = {fitted_lws.get(name, 0.0):.2f}")

if fitted_bl_coeffs is not None:
    print("\nBaseline Coefficients:")
    for i in range(len(fitted_bl_coeffs)):
        true_coeff_val = TRUE_BASELINE_COEFFS_SIM[i] if i < len(TRUE_BASELINE_COEFFS_SIM) else 'N/A (degree mismatch)'
        print(f"  Coeff {i}: True = {true_coeff_val if isinstance(true_coeff_val, str) else f'{true_coeff_val:.2f}'}, Fitted = {fitted_bl_coeffs[i]:.2f}")

### 2.3. Visualization of Fit

In [None]:
with torch.no_grad():
    full_model_spectrum_tensor_fit = advanced_model.get_full_model_spectrum(real_part=True)
    
    fitted_baseline_signal_fit = torch.zeros(advanced_model.num_points, device=advanced_model.device)
    if advanced_model.baseline_coeffs_raw is not None and hasattr(advanced_model, 'baseline_poly_terms'):
        fitted_baseline_signal_fit = advanced_model.baseline_poly_terms @ advanced_model.baseline_coeffs_raw.detach()
        
    fitted_params_transformed = advanced_model.get_transformed_parameters()
    amps_transformed = fitted_params_transformed['amplitudes']
    shifts_transformed = fitted_params_transformed['shifts_hz']
    lws_transformed = fitted_params_transformed['linewidths_hz']
    
    basis_time_transformed = torch.fft.ifft(torch.fft.ifftshift(advanced_model.basis_spectra_freq_shifted, dim=0), dim=0)
    time_axis_expanded = advanced_model.time_axis.unsqueeze(1)
    individual_metab_components_freq_list_fit = []
    for i in range(advanced_model.num_metabolites):
        metab_time = basis_time_transformed[:, i]
        decay_transformed = torch.exp(-time_axis_expanded[:,0] * np.pi * lws_transformed[i])
        phase_ramp_transformed = torch.exp(1j * 2 * np.pi * shifts_transformed[i] * time_axis_expanded[:,0])
        modified_metab_time = metab_time * decay_transformed * phase_ramp_transformed
        modified_metab_freq_shifted = torch.fft.fftshift(torch.fft.fft(modified_metab_time), dim=0)
        individual_metab_components_freq_list_fit.append(amps_transformed[i] * modified_metab_freq_shifted.real) 

plt.figure(figsize=(12, 9))

plt.subplot(3,1,1)
plt.plot(ppm_axis_plot_sim, observed_spectrum_tensor.cpu().numpy(), label="Observed Spectrum", alpha=0.7)
plt.plot(ppm_axis_plot_sim, full_model_spectrum_tensor_fit.cpu().numpy(), label="Full Model Fit", color='red')
plt.plot(ppm_axis_plot_sim, fitted_baseline_signal_fit.cpu().numpy(), label="Fitted Baseline", color='green', linestyle=':')
plt.title("Observed Spectrum vs. Full Model Fit and Baseline")
plt.xlabel(f"Chemical Shift (ppm, relative to {F0_MHZ} MHz as 0 ppm)")
plt.ylabel("Intensity")
plt.xlim(max(ppm_axis_plot_sim), min(ppm_axis_plot_sim))
plt.legend()
plt.grid(True, alpha=0.4)

residuals_fit = observed_spectrum_tensor.cpu().numpy() - full_model_spectrum_tensor_fit.cpu().numpy()
plt.subplot(3,1,2)
plt.plot(ppm_axis_plot_sim, residuals_fit, label="Residual (Observed - Model)", color='blue')
plt.title("Residuals")
plt.xlabel(f"Chemical Shift (ppm, relative to {F0_MHZ} MHz as 0 ppm)")
plt.ylabel("Intensity")
plt.xlim(max(ppm_axis_plot_sim), min(ppm_axis_plot_sim))
plt.legend()
plt.grid(True, alpha=0.4)

plt.subplot(3,1,3)
data_minus_baseline_fit = observed_spectrum_tensor.cpu().numpy() - fitted_baseline_signal_fit.cpu().numpy()
plt.plot(ppm_axis_plot_sim, data_minus_baseline_fit, label="Data - Est. Baseline", color='lightgray', alpha=0.9)
sum_fitted_metabs_fit = torch.zeros_like(individual_metab_components_freq_list_fit[0])
for i, name in enumerate(METABOLITE_NAMES_SIM):
    if i < len(individual_metab_components_freq_list_fit):
        component_to_plot = individual_metab_components_freq_list_fit[i].cpu().numpy()
        sum_fitted_metabs_fit += individual_metab_components_freq_list_fit[i]
        plt.plot(ppm_axis_plot_sim, component_to_plot, label=f"Fitted {name}", linestyle='--')
plt.plot(ppm_axis_plot_sim, sum_fitted_metabs_fit.cpu().numpy(), label="Sum of Fitted Metabolites", color='purple', linestyle='-')
plt.title("Fitted Metabolite Components vs. Data (Baseline Subtracted)")
plt.xlabel(f"Chemical Shift (ppm, relative to {F0_MHZ} MHz as 0 ppm)")
plt.ylabel("Intensity")
plt.xlim(max(ppm_axis_plot_sim), min(ppm_axis_plot_sim))
plt.legend(fontsize='small')
plt.grid(True, alpha=0.4)

plt.tight_layout()
plt.show()

## Phase 3: Absolute Quantification with `AbsoluteQuantifier`

With the metabolite amplitudes estimated from the fitting phase, we now proceed to calculate their absolute concentrations. This requires additional information such as water signal amplitude, proton counts for each metabolite, sequence parameters (TE, TR), relaxation times, and optionally, tissue fractions for correction.

### 3.1. Gather Inputs for `AbsoluteQuantifier`

In [None]:
# Use fitted metabolite amplitudes from Phase 2
metabolite_amplitudes_for_quant = fitted_metabolite_amplitudes
print(f"Metabolite amplitudes for quantification: {metabolite_amplitudes_for_quant}")

# Simulate or define a placeholder water amplitude
# In a real scenario, this would be measured from an unsuppressed water reference scan
simulated_water_amplitude = 8000.0 
print(f"Simulated water amplitude: {simulated_water_amplitude}")

# Define proton counts for each metabolite (known biochemical constants)
proton_counts_metabolites = {
    'NAA': 3,  # N-acetyl aspartate (CH3)
    'Cr': 3,   # Creatine (CH3)
    'Cho': 9   # Choline (N(CH3)3)
}
print(f"Proton counts: {proton_counts_metabolites}")

# Define sequence parameters (example values)
te_ms_quant = 30.0  # Echo Time in ms
tr_ms_quant = 2000.0 # Repetition Time in ms
print(f"TE: {te_ms_quant} ms, TR: {tr_ms_quant} ms")

# Define relaxation times (T1 and T2 in ms) for water and metabolites
# These are literature values and can vary based on field strength, tissue type, etc.
relaxation_times_quant = {
    'water': {'T1_ms': 1200.0, 'T2_ms': 80.0}, # Example for brain tissue water at 3T
    'NAA': {'T1_ms': 1400.0, 'T2_ms': 200.0},
    'Cr': {'T1_ms': 1000.0, 'T2_ms': 150.0},
    'Cho': {'T1_ms': 1100.0, 'T2_ms': 180.0}
    # Add other metabolites if they are in your basis set and you have their T1/T2
}
print(f"Relaxation times (ms): {relaxation_times_quant}")

# Define tissue fractions for voxel composition (optional, but recommended for accuracy)
tissue_fractions_quant = {'gm': 0.7, 'wm': 0.2, 'csf': 0.1} # Grey Matter, White Matter, CSF
print(f"Tissue fractions: {tissue_fractions_quant}")

# Define tissue-specific water content (fractions, 0-1)
water_conc_tissue_specific_fractions_quant = {'gm': 0.82, 'wm': 0.70, 'csf': 0.99}
print(f"Water content by tissue type: {water_conc_tissue_specific_fractions_quant}")

# Note: default_water_conc_tissue_mM (default: 35880.0 mM) and default_protons_water (default: 2)
# from AbsoluteQuantifier will be used unless specified otherwise during instantiation.

### 3.2. Instantiate `AbsoluteQuantifier` and Calculate Concentrations

In [None]:
quantifier = AbsoluteQuantifier() # Using default water concentration and proton count

absolute_concentrations, warnings_list = quantifier.calculate_concentrations(
    metabolite_amplitudes=metabolite_amplitudes_for_quant,
    water_amplitude=simulated_water_amplitude,
    proton_counts_metabolites=proton_counts_metabolites,
    te_ms=te_ms_quant,
    tr_ms=tr_ms_quant,
    relaxation_times=relaxation_times_quant,
    tissue_fractions=tissue_fractions_quant,
    water_conc_tissue_specific_fractions=water_conc_tissue_specific_fractions_quant
)

### 3.3. Display Quantification Results

In [None]:
print("--- Absolute Quantification Results ---")
if absolute_concentrations:
    for metab, conc in absolute_concentrations.items():
        print(f"  {metab}: {conc:.2f} mM")
else:
    print("  No concentrations calculated or an error occurred.")

print("\n--- Warnings from Quantification ---")
if warnings_list:
    for warning in warnings_list:
        print(f"  - {warning}")
else:
    print("  No warnings generated.")

print("\nDiscussion of Results:")
print("The concentrations above are estimates of the absolute amount of each metabolite in millimolar (mM) units. ")
print("These values are corrected for T1 and T2 relaxation effects, the number of protons contributing to each signal,")
print("the water reference signal, and the partial volume effects based on tissue fractions and their water content.")
print("The accuracy of these values depends on the quality of the MRS data, the fitting process, the accuracy of the water reference, ")
print("and the correctness of the input parameters like relaxation times and tissue compositions.")

## Conclusion

This notebook has demonstrated a comprehensive pipeline for MRS data analysis:
1. **Data Simulation:** We created realistic MRS data with known characteristics.
2. **Spectral Fitting:** We used `AdvancedLinearCombinationModel` to fit the data, estimating metabolite amplitudes, frequency shifts, and linewidths, along with a baseline.
3. **Absolute Quantification:** We then took the fitted amplitudes and, using `AbsoluteQuantifier`, converted them into absolute concentrations (mM). This step involved providing crucial parameters such as water reference amplitude, proton counts, sequence timings (TE, TR), relaxation times, and tissue composition details.

By combining these two main components, `AdvancedLinearCombinationModel` for robust spectral fitting and `AbsoluteQuantifier` for detailed concentration calculations, users can perform an end-to-end analysis of MRS data. This workflow allows for the extraction of meaningful biochemical information from complex MRS signals, while accounting for various physical and physiological factors that influence the observed spectra.