In [None]:
import numpy as np
import matplotlib.pyplot as plt
import jax

from chromatix.ops import init_plane_resample
import chromatix.functional as cx

import phantom   # our new file -- phantom.py
from simulate import apply_psf
from phantom import Material
from inputs.xscatter import get_wavelen, get_wavenum
from matdecomp import spbi_material_basis

adip = Material('adipose', 'H(11.2)C(61.9)N(1.7)O(25.1)', 0.93)
gland = Material('gland', 'H(10.2)C(18.4)N(3.2)O(67.6)', 1.04)
tissue = Material('tissue', 'H(10.2)C(14.3)N(3.4)O(70.8)Na(0.2)P(0.3)S(0.3)Cl(0.2)K(0.3)', 1.06)   
bone = Material('bone', 'H(3.4)C(15.5)N(4.2)O(43.5)Na(0.1)Mg(0.2)P(10.3)S(0.3)Ca(22.5)', 1.92)


In [None]:
plt.rcParams.update({
    'figure.dpi': 300,
    'font.size':10,
    'axes.titlesize':10,
    'axes.labelsize':8,
    'axes.linewidth': .5,
    'xtick.labelsize':8,
    'ytick.labelsize':8,
    'legend.fontsize': 8,
    'image.cmap':'gray'
})

In [None]:
%%time 

############################################################
### PARAMETERS

energy = 20  # energy [keV] for evaluating the delta, beta values

# Phantom dimensions and resolution
N = 256     # num voxels in x,y direction
Nz = 50     # num voxels in z direction (make it smaller to save memory!)
dx = 2e-6   # voxel size in x,y direction [m]
dz = 10e-6  # voxel size in z direction (make it larger to achieve desired thickness)

thickness = Nz*dz  # Total axial (z) thickness [m] - note for mammo, this should be a few cm

# Material dictionary mapping labels from binary volume to Material objects 
mat_dict = {0:gland, 1:adip}
mat_frac = 0.5  # fraction of one material in the binary mixture (used for thresholding)

# An ellipsoidal structural mask for realistic projection shape
### GLJ note - this is causing artifacts in the simulations when R > 0. Need to fix
### For now use struct = 1 or flat ellipsoid instead
struct = phantom.tmap_ellipsoid(N, 0.35*N, 0.42*N, angle=20)   
struct2 = np.zeros(struct.shape)
struct2[struct>1e-3] = 1  # "flattens" the ellipsoid
struct = struct2
# struct = 1   


############################################################
### MAKE PHANTOM

np.random.seed(500)  # For reproducibility of the generated phantom texture

# By default, vol is a cube, so we truncate in the z-direction with [:Nz]
vol = phantom.make_phantom(N, dx, alpha=4)[:Nz]  

# Binarize the volume into two materials based on material fraction
vol_mask = phantom.thresh_texture(vol, mat_frac)

# Convert material labels into delta and beta volumes at the given energy
# Output shape: (Ne, Nz, N, N); here Ne = 1 since energy is scalar
vol_delta, vol_beta = phantom.make_db_vol(vol_mask, mat_dict, energy)

# Generate parallel-beam projections along z-axis.
# Note the volumes are 4D now (first axis is energy, but we only have one energy now.)
# So we index the first axis with 0.
proj_delta_flat = np.sum(vol_delta[0], axis=0)   # be careful with thickness element! 
proj_beta_flat = np.sum(vol_beta[0], axis=0)

# Add the structure envelope to give our phantom more realistic geometry.
proj_delta = struct * proj_delta_flat
proj_beta = struct * proj_beta_flat


############################################################
### DISPLAY RESULTS

fig, ax = plt.subplots(1, 2, figsize=[8,3], layout='constrained')
fig.suptitle(f'Energy = {energy} keV', fontweight='bold')

ax[0].set_title(r'$\delta$')
m = ax[0].imshow(proj_delta)
fig.colorbar(m, ax=ax[0])

ax[1].set_title(r'$\beta$')
m = ax[1].imshow(proj_beta)
fig.colorbar(m, ax=ax[1])

plt.show()

# Simulation functions

In [None]:
%%time

def simulate_projection(proj_beta, proj_delta, dx, det_N, det_dx, energy, R, 
                        I0=None, det_psf=None, det_fwhm=5e-6, n_medium=1, N_pad=100, key=jax.random.PRNGKey(3)):
    """
    Simulates a single-energy X-ray phase-contrast imaging (XPCI) projection using propagation-based phase contrast.

    Parameters
    ----------
    proj_beta : ndarray
        2D array representing the line integral of the imaginary part of the refractive index (∫ beta dz) 
        at a given X-ray energy.
    proj_delta : ndarray
        2D array representing the line integral of the real part of the refractive index decrement (∫ delta dz)
        at the same X-ray energy as specified by the `energy` argument.
    dx : float
        Pixel size of the input projections (phantom resolution) in meters.
    det_N : int
        Number of detector pixels along one dimension (assumes a square detector).
    det_dx : float
        Detector pixel size in meters.
    energy : float
        X-ray energy in keV used for the simulation. **Must match the energy used to generate `proj_beta` and `proj_delta`.**
    R : float
        Propagation distance (object-to-detector) in meters.
    I0 : float, optional
        Mean incident photon fluence per pixel (used to apply Poisson noise).
    det_psf : callable, optional
        Point spread function (PSF) model for the detector, applied as a blur to the image.
    det_fwhm : float, optional
        Full-width at half maximum (FWHM) of the detector PSF in meters. Default is 1e-6.
    n_medium : float, optional
        Refractive index of the propagation medium (e.g., air = 1.0). Default is 1.
    N_pad : int, optional
        Padding used in the transfer propagation step. Default is 100.
    key : jax.random.PRNGKey, optional
        PRNG key for generating Poisson noise if `I0` is specified.

    Returns
    -------
    img : ndarray
        Simulated detector intensity image, normalized such that the center pixel equals 1.0 before noise and PSF.

    Notes
    -----
    - The input projections `proj_beta` and `proj_delta` **must be computed at the same energy** as the `energy` parameter.
    - The detector field of view (`det_N * det_dx`) **must be less than or equal to** the phantom field of view (`proj_beta.shape[0] * dx`).
    - If `I0` is provided, Poisson noise is applied to simulate quantum noise.
    - If `det_psf` is provided, a PSF blur is applied to simulate detector resolution.
    - This function assumes square images and detectors.
    """

    assert (proj_beta.shape == proj_delta.shape)
    assert det_N * det_dx <= proj_beta.shape[0] * dx, 'Detector FOV must be <= phantom FOV'

    phantom_fov = proj_beta.shape[0] * dx
    det_shape = (det_N, det_N)
    
    field = cx.plane_wave(
        shape = proj_beta.shape, 
        dx = dx,
        spectrum = get_wavelen(energy),
        spectral_density = 1.0,
    )
    field = field / field.intensity.max()**0.5  # normalize
    cval = field.intensity.max()

    exit_field = cx.thin_sample(field, proj_beta[None, ..., None, None], proj_delta[None, ..., None, None], 1.0)
    det_field = cx.transfer_propagate(exit_field, R, n_medium, N_pad, cval=cval, mode='same')

    det_img = det_field.intensity.squeeze()
    if det_psf is not None:
        det_img = apply_psf(det_img, dx, psf=det_psf, fwhm=det_fwhm, kernel_width=0.1)

    det_resample_func = init_plane_resample(det_shape, (det_dx, det_dx), resampling_method='linear')
    img = det_resample_func(det_img[...,None,None], field.dx.ravel()[:1])[...,0,0]
    img /= img.ravel()[0] 

    if I0 is not None:
        img = jax.random.poisson(key, I0*img, img.shape) / I0
        
    return img

energy = 20   
R = 10e-2
dx = 2e-6
det_dx = 5e-6
det_N = int(dx*N // det_dx)
print(f'phantom FOV = {dx*N*1e6} um,  det FOV = {det_dx*det_N*1e6} um')
c = 1e-2 / Nz  # thickness scaling!
img = simulate_projection(proj_beta*c, proj_delta*c, dx, det_N, det_dx, energy, R, det_psf='gaussian', det_fwhm=5e-6)

plt.imshow(img)
plt.colorbar()
plt.show()

# Multiple-energy simulation (for material decomposition)

## First — simulate images with different $E$

In [None]:
%%time 

############################################################
### Inputs

energies = np.array([20, 25])  # [keV]
R = 1e-2             # propagation distance [m]

# mat_dict = {0:bone, 1:tissue}
mat_dict = {0:gland, 1:adip}
mat_frac = 0.5  

# x-y plane
N = 256     
dx = 2e-6   

# z axis (propagation direction)
Nz = 100    
thickness = 1e-2  # breast thickness [m]
dz = thickness / Nz  # note this is a dependent variable -- only need to specify two out of Nz, thickness, dz to calc the third.

# detector features -- note det FOV must be <= phantom FOV (x-y plane)
det_N = 100
det_dx = 5e-6

# phantom structure 
struct = np.zeros([N, N])
struct[phantom.tmap_ellipsoid(N, 0.35*N, 0.42*N, angle=20) > 1e-3] = 1  

############################################################
### Simulate -- loop over the two energies

# make binary phantom material mask
np.random.seed(33)
vol = phantom.make_phantom(N, dx, alpha=4)[:Nz]  
vol_mask = phantom.thresh_texture(vol, mat_frac)

imgs = []
for energy in energies:
    vol_delta, vol_beta = phantom.make_db_vol(vol_mask, mat_dict, energy)

    proj_delta_flat = np.sum(vol_delta[0], axis=0)   # be careful with thickness element! 
    proj_beta_flat = np.sum(vol_beta[0], axis=0)

    # account for structure + true thickness dimension ("dz" element of summation")
    proj_delta = dz * struct * proj_delta_flat
    proj_beta = dz * struct * proj_beta_flat

    img = simulate_projection(proj_beta, proj_delta, dx, det_N, det_dx, energy, R, det_psf='gaussian', det_fwhm=5e-6, I0=1e5)
    imgs.append(img)

imgs = np.array(imgs)

###########################################################
## DISPLAY RESULTS

kw = {'vmin':imgs.min(), 'vmax':imgs.max(), 'cmap':'gray'}   # imshow arguments - match intensity scale (vmin, vmax)

fig, ax = plt.subplots(1, len(energies), figsize=[3.1*len(energies),3], layout='constrained')
fig.suptitle('Raw output images', fontweight='bold')
for i in range(len(energies)):
    m = ax[i].imshow(imgs[i], **kw)
    ax[i].set_title(f'{energies[i]} keV')
fig.colorbar(m, ax=ax)
plt.show()


## Second — call material decomposition function `spbi_material_basis()`

In [None]:
%%time

# Define the material basis
mat1 = mat_dict[0]
mat2 = mat_dict[1]

# Convert material deltas and betas to format for Schaff code
ds1, bs1 = mat1.db(energies)
ds2, bs2 = mat2.db(energies)
ds = -np.array([ds1, ds2]) 
mus = 2 * np.array([bs1, bs2]) * get_wavenum(energies)

# Decompose
bmis = spbi_material_basis(imgs, R, det_dx, ds.T, mus.T)  # don't forget to transpose ds and mus!

# Impose physical constraint
bmis = bmis.clip(0, None)   # what happens if you skip this line? why include it? 

# Show -- how do the intensity scales compare to `thickness`? 
print(f'true thickness = {thickness} m')

bmi_kw = {'vmin':0, 'vmax':2*thickness, 'cmap':'bwr'}

fig, ax = plt.subplots(1, 2, figsize=[7,3], layout='constrained')
fig.suptitle('Basis material images (projected thickness)', fontweight='bold')
for i, mat in enumerate([mat1, mat2]):
    m = ax[i].imshow(bmis[i], **bmi_kw)
    ax[i].set_title(mat.name)
    ax[i].axis('off')
    fig.colorbar(m, ax=ax[i])
plt.show()


## Visually compare the material decomposition to the ground truth
Note this is at the phantom resolution, so it will not only be perfectly accurate, it will be much sharper of an image! 

In [None]:
gt_bmi1 = dz * struct * np.sum(vol_mask == 0, axis=0)
gt_bmi2 = dz * struct * np.sum(vol_mask == 1, axis=0)

fig, ax = plt.subplots(1, 2, figsize=[7,3], layout='constrained')
fig.suptitle('True material projections', fontweight='bold')
for i, mat in enumerate([mat1, mat2]):
    m = ax[i].imshow([gt_bmi1, gt_bmi2][i], **bmi_kw)   # kw defined above
    ax[i].set_title(mat.name)
    ax[i].axis('off')
    fig.colorbar(m, ax=ax[i])
plt.show()

# TO-DO:
- Clean up code (maybe write single function that goes directly from `energies` and `R` -> `bmis` (basis material images). This will make your life easier later. Or maybe just move the simulate projection function to `simulate.py`.
- Come up with a single metric ``figure of merit" (FOM) to assess how good the results are (calculate ground truth fraction $F_{gt}$ and estimated fraction from material decomposition $F_{md}$, then maybe your FOM can be something like % difference)
- Note that if you use `I0`=0, then `R=0` should give a nearly perfect material decomposition. It's when you add noise (Poisson counting statistics from your "I0" incident photons) that phase contrast (i.e. `R>0`) becomes useful. You can leave I0=1e5 for now. To see the effect, try re-running the material decomp code above with `R=0`. 

Measure the FOM for some different **imaging system** input parameters:
1. Energies `energies` should be in range 15–35 keV. They should always be two different energies.
2. Distances `R` should be in range 0–20 cm. (Smaller might be better for now).
3. Try "small" (2 micron), "medium" (5 micron), and "large" (10 micron) `det_dx`.

After this, try varying different **material** parameters:
1. Test $F_{gt}$ from 0.2 to 0.8 (or look up realistic ranges, maybe following the BIRADs classification into four breast density types)
2. Test `thickness` from 1 to 5 cm (or look up realistic ranges, 1 cm might be small but I am curious)
3. Test different $\alpha$ (textures, formerly the $\beta$ parameter)

Explore methods for neatly presenting the key findings (show some example images and BMIs, but mainly focus on showing trends in your FOM or maybe how $F_{md}$ varies as a function of $F_{gt}$.

To get error bars, you can measure the FOM for many different random seeds. Just make sure to re-measure $F_{gt}$ *after* you generate the phantom for a given random seed.

You can also try exploring different `I0`, which varies the noise level. Larger values (more counts) should reduce the noise level, while smaller values (fewer counts) will make your images noisier. I expect smaller `I0` will make larger `R` perform relatively better.
