In [1]:
%matplotlib notebook
%load_ext autoreload

In [4]:
import sys, os
os.environ['PKG_CONFIG_PATH'] = '/ocean/projects/asc170022p/mtragoza/mambaforge/envs/lung-project/lib/pkgconfig'

import numpy as np
import xarray as xr

sys.path.append('..')
import project

In [5]:
case = project.imaging.Emory4DCT('../data/Emory-4DCT')[0]
case.load_niftis()
case.load_masks(roi='lung_regions')
case.mask.name = 'regions'
project.visual.XArrayViewer(case.mask)

Loading ../data/Emory-4DCT/Case1Pack/NIFTI/case1_T00.nii.gz
Loading ../data/Emory-4DCT/Case1Pack/NIFTI/case1_T10.nii.gz
Loading ../data/Emory-4DCT/Case1Pack/NIFTI/case1_T20.nii.gz
Loading ../data/Emory-4DCT/Case1Pack/NIFTI/case1_T30.nii.gz
Loading ../data/Emory-4DCT/Case1Pack/NIFTI/case1_T40.nii.gz
Loading ../data/Emory-4DCT/Case1Pack/NIFTI/case1_T50.nii.gz
Loading ../data/Emory-4DCT/Case1Pack/NIFTI/case1_T60.nii.gz
Loading ../data/Emory-4DCT/Case1Pack/NIFTI/case1_T70.nii.gz
Loading ../data/Emory-4DCT/Case1Pack/NIFTI/case1_T80.nii.gz
Loading ../data/Emory-4DCT/Case1Pack/NIFTI/case1_T90.nii.gz
Loading ../data/Emory-4DCT/Case1Pack/TotalSegment/case1_T00/lung_regions.nii.gz
Loading ../data/Emory-4DCT/Case1Pack/TotalSegment/case1_T10/lung_regions.nii.gz
Loading ../data/Emory-4DCT/Case1Pack/TotalSegment/case1_T20/lung_regions.nii.gz
Loading ../data/Emory-4DCT/Case1Pack/TotalSegment/case1_T30/lung_regions.nii.gz
Loading ../data/Emory-4DCT/Case1Pack/TotalSegment/case1_T40/lung_regions.nii.gz


<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='phase', options=((0, 0), (10, 1), (20, 2), (30, 3), (40, 4)…

<project.visual.XArrayViewer at 0x14f261b47f40>

In [6]:
shape = case.shape
resolution = case.resolution

def spatial_coordinates(shape, resolution):
    x = np.arange(shape[0]) * resolution[0]
    y = np.arange(shape[1]) * resolution[1]
    z = np.arange(shape[2]) * resolution[2]
    return np.stack(np.meshgrid(x, y, z, indexing='ij'), axis=-1)

coords = spatial_coordinates(shape, resolution)

project.visual.XArrayViewer(
    project.utils.as_xarray(coords, dims=['x', 'y', 'z', 'component'], name='coords')
)

coords


<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='z', options=((0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)…

<project.visual.XArrayViewer at 0x14f2513fe4d0>

In [7]:
def frequency_coordinates(shape):
    x = np.fft.fftfreq(shape[0])
    y = np.fft.fftfreq(shape[1])
    z = np.fft.fftfreq(shape[2])
    return np.stack(np.meshgrid(x, y, z, indexing='ij'), axis=-1)
    
freqs = frequency_coordinates(shape)

project.visual.XArrayViewer(
    project.utils.as_xarray(freqs, dims=['x', 'y', 'z', 'component'], name='freqs')
)

freqs


<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='z', options=((0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)…

<project.visual.XArrayViewer at 0x14f25117caf0>

In [8]:
def butterworth_filter(freqs, cutoff, power):
    radius = np.linalg.norm(freqs, axis=-1)
    return 1 / (1 + (radius/cutoff)**power)


def bandpass_filter(freqs, cutoff1, cutoff2, power):
    filter1 = butterworth_filter(freqs, cutoff1, power)
    filter2 = butterworth_filter(freqs, cutoff2, power)
    return (1 - filter1) * filter2


filters = np.stack([
    bandpass_filter(freqs, c/2, c, power=2) for c in np.logspace(-3, 0, 5)
])

project.visual.XArrayViewer(
    project.utils.as_xarray(
        filters,
        dims=['f', 'x', 'y', 'z'],
        name='filter'
    ),
    cmap='Greys_r', vmin=0, vmax=1,
)

filter


<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='f', options=((0, 0), (1, 1), (2, 2), (3, 3), (4, 4)), value…

<project.visual.XArrayViewer at 0x14f25117c520>

In [9]:
def random_texture(filter_):
    complex_noise = (
        np.random.normal(0, 1, filter_.shape) + 
        np.random.normal(0, 1, filter_.shape) * 1j
    )
    texture = np.fft.ifftn(complex_noise * filter_).real
    texture -= texture.min()
    texture /= texture.max()
    return texture * 2 - 1

texture = np.stack([
    random_texture(f) for f in filters
])

project.visual.XArrayViewer(
    project.utils.as_xarray(
        texture,
        dims=['f', 'x', 'y', 'z'],
        name='texture'
    ),
    cmap='Greys_r', vmin=-1, vmax=1,
)

texture


<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='f', options=((0, 0), (1, 1), (2, 2), (3, 3), (4, 4)), value…

<project.visual.XArrayViewer at 0x14f250f27af0>

In [10]:
import scipy.stats

def draw_target(coords, center, radius, rotation):
    displacement = ((coords - center) @ rotation) / radius
    distance = np.linalg.norm(displacement, ord=2, axis=-1)
    return (distance <= 1).astype(int)

def draw_random_target(coords, radius_min, radius_max, pad=0.2, iso=False):
    coords_min = coords.min(axis=(0,1,2))
    coords_max = coords.max(axis=(0,1,2))
    coords_range = coords_max - coords_min
    center_min = coords_min + coords_range * pad
    center_max = coords_max - coords_range * pad
    center = np.random.uniform(center_min, center_max)
    radius = np.random.uniform(radius_min, radius_max, (1 if iso else 3))
    rotation = scipy.stats.ortho_group.rvs(3)
    return draw_target(coords, center, radius, rotation)

target = draw_random_target(coords, radius_min=5, radius_max=25)

project.visual.XArrayViewer(
    project.utils.as_xarray(
        target,
        dims=['x', 'y', 'z'],
        name='mask',
    )
)

mask


<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='z', options=((0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)…

<project.visual.XArrayViewer at 0x14f2512a7ca0>

In [11]:
def add_non_overlapping_targets(regions, resolution, max_targets, radius_min, radius_max):
    coords = spatial_coordinates(regions.shape, resolution)
    regions = np.array(regions)
    target_count = 0
    attempt_count = 0
    while (target_count < max_targets) and (attempt_count < max_targets * 10):
        target = draw_random_target(coords, radius_min, radius_max)
        collisions = np.logical_and(target, (regions != 1))
        if not np.any(collisions):
            target_count += 1
            regions += target * target_count
        attempt_count += 1
    print(target_count, attempt_count)
    return regions

regions = (case.mask.sel(phase=0, roi='lung_regions').values > 0).astype(int)
regions = add_non_overlapping_targets(regions, resolution, max_targets=5, radius_min=5, radius_max=30)

project.visual.XArrayViewer(
    project.utils.as_xarray(
        regions,
        dims=['x', 'y', 'z'],
        name='regions',
    )
)

4 50
regions


<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='z', options=((0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)…

<project.visual.XArrayViewer at 0x14f251267610>

In [12]:
n_regions = int(regions.max()) + 1
region_id = np.arange(n_regions)
region_indicator = (regions[None,...] == region_id[:,None,None,None])

# sample latent variables

latent = np.random.rand(n_regions)
print(latent)

# assign stiffness to regions

min_log_kpa = -1
max_log_kpa =  1
log_kpa = latent * (max_log_kpa - min_log_kpa) + min_log_kpa

mu = 10**log_kpa * 1000
mu[0] = 0
print(mu)

mu = (region_indicator * mu[:,None,None,None]).sum(axis=0)

project.visual.XArrayViewer(
    project.utils.as_xarray(mu, dims=['x', 'y', 'z'], name='elast'), vmin=-1e4, vmax=1e4
)

[0.27393626 0.61440195 0.26534991 0.27250627 0.94626688 0.65019659]
[   0.         1693.57294702  339.3906097   350.76199555 7807.8911471
 1997.06951467]
elast


<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='z', options=((0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)…

<project.visual.XArrayViewer at 0x14f250f74f70>

In [13]:
# map latent variables to texture bias and filter parameters

anat_range = 2000
bias_range = 1000
bias_midpt = 0
bias_max = bias_midpt + bias_range/2
bias_min = bias_midpt - bias_range/2

bias = latent * bias_range + bias_min
bias[0] = bias_midpt - anat_range/2
bias

array([-1000.        ,   114.40195384,  -234.65008898,  -227.4937339 ,
         446.26687501,   150.1965911 ])

In [14]:
log_cutoff_max =  0
log_cutoff_min = -3
log_cutoff_range = log_cutoff_max - log_cutoff_min

log_cutoff = latent * log_cutoff_range + log_cutoff_min
cutoff = 10**log_cutoff
cutoff

array([0.00663451, 0.06969568, 0.00625244, 0.0065693 , 0.689923  ,
       0.08924621])

In [15]:
texture = np.zeros_like(region_indicator, dtype=float)
for i in range(n_regions):
    filter_ = bandpass_filter(freqs, cutoff[i]/2, cutoff[i], power=2)
    texture[i] = random_texture(filter_)

texture[0] = 0

project.visual.XArrayViewer(
    project.utils.as_xarray(texture, dims=['f', 'x', 'y', 'z'], name='texture'), cmap='Greys_r'
)

texture


<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='f', options=((0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)…

<project.visual.XArrayViewer at 0x14f2662a99f0>

In [16]:
texture_range = anat_range - bias_range
anat = texture * texture_range/2 + bias[:,None,None,None]
anat = (region_indicator * anat).sum(axis=0)
print(anat.min(), anat.max())

project.visual.XArrayViewer(
    project.utils.as_xarray(anat, dims=['x', 'y', 'z'], name='anat')
)

-1000.0 778.5972596960692
anat


<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='z', options=((0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)…

<project.visual.XArrayViewer at 0x14f250f049a0>

In [17]:
phase_sigma = 2
phase = np.random.normal(0, phase_sigma, (3, 3))
extent = np.max(shape) * np.array(resolution)

disp_bc = np.sin(2 * np.pi * (coords/extent) @ phase)

project.visual.XArrayViewer(
    project.utils.as_xarray(disp_bc, dims=['x', 'y', 'z', 'c'], name='displacement'), vmin=-2, vmax=2
)

displacement


<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='z', options=((0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)…

<project.visual.XArrayViewer at 0x14f2506b1150>

In [None]:
import pygalmesh

mesh = pygalmesh.generate_from_array(
    regions.astype(np.uint16),
    voxel_size=resolution,
    max_cell_circumradius=10.0,
    min_facet_angle=15,
    max_facet_distance=1.0,
    odt=True,
    lloyd=True
)
mesh = project.meshing.remove_unused_points(mesh)

mesh_file = 'lung.xdmf'
project.meshing.save_mesh_meshio(mesh_file, mesh, cell_blocks=[1])
mesh = project.meshing.load_mesh_fenics(mesh_file)
mesh

In [None]:
print('\n'*4)

In [None]:
import torch

# convert to FEM basis coefficients
mesh_radius = 10
device = 'cpu'
dtype = torch.float32
pde_ = project.pde.LinearElasticPDE(mesh)
print('Interpolating FEM coefficients...')
u_dofs = project.interpolation.image_to_dofs(
    torch.as_tensor(disp_bc, dtype=dtype, device=device).permute(3, 0, 1, 2),
    resolution, pde_.V,
    radius=int(mesh_radius),
    sigma=mesh_radius/2
)
mu_dofs = project.interpolation.image_to_dofs(
    torch.as_tensor(mu, dtype=dtype, device=device),
    resolution, pde_.S,
    radius=int(mesh_radius),
    sigma=mesh_radius/2
)
anat_dofs = project.interpolation.image_to_dofs(
    torch.as_tensor(anat, dtype=dtype, device=device),
    resolution, pde_.S,
    radius=int(mesh_radius),
    sigma=mesh_radius/2
)
rho_dofs = (1 + anat_dofs/1000) * 1000

In [None]:
# solve FEM for simulated displacement dofs
print('Solving FEM model...')
u_sim_dofs = pde_.forward(
    u_dofs.unsqueeze(0),
    mu_dofs.unsqueeze(0),
    rho_dofs.unsqueeze(0),
)[0]

In [None]:
# convert to displacement image
print('Converting displacement to image...')
disp_sim = project.interpolation.dofs_to_image(
    u_sim_dofs, pde_.V, disp_bc.shape[:3], resolution
).permute(1,2,3,0)

project.visual.XArrayViewer(
    project.utils.as_xarray(
        disp_sim, dims=['x', 'y', 'z', 'c'], name='displacement'
    )
)

In [None]:
%autoreload
import torch
import nibabel as nib

from project import meshing, interpolation, pde

def generate_phantom(
    mask,
    resolution,
    max_targets=5,
    radius_min=5,
    radius_max=25,
    log_kpa_min=-1,
    log_kpa_max=1,
    bias_midpoint=0,
    bias_range=500,
    anat_range=2000,
    log_cutoff_min=-3,
    log_cutoff_max=0,
    phase_sigma=1.0,
    mesh_radius=10,
    mesh_file='phantom.xdmf',
    random_seed=None
):
    print(f'Setting random seed to {random_seed}')
    np.random.seed(random_seed)
    
    # define coordinates over spatial domain
    print('Defining spatial domain...')
    coords = spatial_coordinates(mask.shape, resolution)

    # define spatial regions for target and background
    regions = add_non_overlapping_targets(mask, resolution, max_targets, radius_min, radius_max)
    n_regions = 1 + int(regions.max())
    region_id = np.arange(n_regions)
    region_indicator = (regions[None,...] == region_id[:,None,None,None])

    # sample latent variables for each region
    print('Sampling latent variables...')
    latent = np.random.rand(n_regions)
    print(latent)

    # map latent variables to stiffness values
    print('Generating stiffness map..')
    log_kpa_range = (log_kpa_max - log_kpa_min)
    log_kpa = latent * log_kpa_range + log_kpa_min
    mu = np.power(10, log_kpa) * 1000 # log kPa -> Pa
    mu[0] = 0 # background stiffness
    print(mu)

    # assign stiffness to each spatial region
    mu = (region_indicator * mu[:,None,None,None]).sum(axis=0)
    mu[0] = 0 # background stiffness

    # map latent variables to anatomical texture parameters
    print('Generating anatomical image...')
    bias_min = bias_midpoint - bias_range/2
    bias = latent * bias_range + bias_min

    log_cutoff_range = (log_cutoff_max - log_cutoff_min)
    log_cutoff = latent * log_cutoff_range + log_cutoff_min
    cutoff = np.power(10, log_cutoff)

    freq_coords = frequency_coordinates(shape)
    texture = np.zeros_like(region_indicator, dtype=float)
    for i in range(n_regions):
        filter_ = bandpass_filter(freq_coords, cutoff[i]/2, cutoff[i], power=2)
        texture[i] = random_texture(filter_)

    texture_range = (anat_range - bias_range)
    anat = texture * texture_range/2 + bias[:,None,None,None]
    anat_min = bias_midpoint - anat_range/2
    anat[0] = anat_min

    # assign anatomical textures to each region
    anat = (region_indicator * anat).sum(axis=0)
    
    # generate random displacement boundary condition
    print('Generating displacement BC...')
    phase = np.random.normal(0, phase_sigma, (3, 3))
    extent = np.max(shape) * np.array(resolution)
    disp_bc = np.sin(2 * np.pi * (coords/extent) @ phase)

    # generate mesh using pygalmesh
    print('Generating mesh...')
    mask = regions.astype(np.uint16)
    mesh = pygalmesh.generate_from_array(
        mask, resolution,
        max_cell_circumradius=float(mesh_radius),
        min_facet_angle=15,
        max_facet_distance=1.0,
        odt=True, lloyd=True
    )
    mesh = meshing.remove_unused_points(mesh)
    
    # save mesh using meshio, then read with fenics
    project.meshing.save_mesh_meshio(mesh_file, mesh, cell_blocks=[1])
    mesh = project.meshing.load_mesh_fenics(mesh_file)
    
    # convert to FEM basis coefficients
    print('Interpolating FEM coefficients...')
    pde_ = pde.LinearElasticPDE(mesh)
    device = 'cpu'
    dtype = torch.float32
    u_dofs = interpolation.image_to_dofs(
        torch.as_tensor(disp_bc, dtype=dtype, device=device).permute(3, 0, 1, 2),
        resolution, pde_.V,
        radius=int(mesh_radius),
        sigma=mesh_radius/2
    )
    mu_dofs = interpolation.image_to_dofs(
        torch.as_tensor(mu, dtype=dtype, device=device),
        resolution, pde_.S,
        radius=int(mesh_radius),
        sigma=mesh_radius/2
    )
    anat_dofs = interpolation.image_to_dofs(
        torch.as_tensor(anat, dtype=dtype, device=device),
        resolution, pde_.S,
        radius=int(mesh_radius),
        sigma=mesh_radius/2
    )
    rho_dofs = (1 + anat_dofs/1000) * 1000
    
    # solve FEM for simulated displacement dofs
    print('Solving FEM model...')
    u_sim_dofs = pde_.forward(
        u_dofs.unsqueeze(0),
        mu_dofs.unsqueeze(0),
        rho_dofs.unsqueeze(0),
    )[0]
    
    # convert to displacement image
    print('Converting displacement to image...')
    disp_sim = interpolation.dofs_to_image(
        u_sim_dofs, pde_.V, disp_bc.shape[:3], resolution
    ).permute(1,2,3,0)
    
    return mu, anat, disp_bc, disp_sim, mask, mesh

mask = (case.mask.sel(phase=0, roi='lung_regions').values > 0).astype(int)
mu, anat, disp_bc, disp_sim, mask, mesh = generate_phantom(mask, resolution)

project.visual.XArrayViewer(project.utils.as_xarray(mu, dims=['x', 'y', 'z'], name='elast'))
project.visual.XArrayViewer(project.utils.as_xarray(anat, dims=['x', 'y', 'z'], name='anat'))
project.visual.XArrayViewer(project.utils.as_xarray(disp_bc, dims=['x', 'y', 'z', 'c'], name='disp'))
project.visual.XArrayViewer(project.utils.as_xarray(disp_sim, dims=['x', 'y', 'z', 'c'], name='disp'))

In [None]:
import pathlib

class PhantomSet(object):
    
    def __init__(self, data_root, num_phantoms):
        self.data_root = pathlib.Path(data_root)
        self.phantoms = []
        for i in range(num_phantoms):
            phantom = Phantom(data_root, phantom_id=i)
            self.phantoms.append(phantom)
            
    def generate(self, *args, **kwargs):
        self.data_root.mkdir(exist_ok=True)
        for i, phantom in enumerate(self.phantoms):
            phantom.generate(*args, **kwargs)
            
    def get_examples(self, mesh_radius):
        examples = []
        for i, phantom in enumerate(self.phantoms):
            examples.append({
                'name': phantom.phantom_name,
                'anat_file': phantom.anat_file(),
                'disp_file': phantom.disp_file(),
                'mask_file': phantom.mask_file(),
                'mesh_file': phantom.mesh_file(mesh_radius),
                'mesh_radius': mesh_radius
            })
        return examples


class Phantom(object):
    
    def __init__(self, data_root, phantom_id):
        self.data_root = pathlib.Path(data_root)
        self.phantom_id = int(phantom_id)
        
    @property
    def phantom_name(self):
        return f'phantom{self.phantom_id}'
        
    @property
    def phantom_dir(self):
        return self.data_root / self.phantom_name
        
    def anat_file(self):
        return self.phantom_dir / f'{self.phantom_name}_anat.nii.gz'

    def disp_file(self):
        return self.phantom_dir / f'{self.phantom_name}_disp.nii.gz'
    
    def elast_file(self):
        return self.phantom_dir / f'{self.phantom_name}_elast.nii.gz'
    
    def mask_file(self):
        return self.phantom_dir / f'{self.phantom_name}_mask.nii.gz'
    
    def mesh_file(self, mesh_radius):
        return self.phantom_dir / f'{self.phantom_name}_mesh{mesh_radius}.xdmf'
    
    def generate(self, mask_file, mesh_radius, **kwargs):
        self.phantom_dir.mkdir(exist_ok=True)
        
        nifti = nib.load(mask_file)
        mask = (nifti.get_fdata() > 0).astype(int)
        resolution = nifti.header.get_zooms()
        
        elast, anat, disp_bc, disp, mask, mesh = generate_phantom(
            mask=mask,
            resolution=resolution,
            mesh_radius=mesh_radius,
            mesh_file=self.mesh_file(mesh_radius),
            random_seed=self.phantom_id,
            **kwargs 
        )
        self.elast = project.utils.as_xarray(
            elast, dims=['x', 'y', 'z'], name='elasticity'
        )

        x = np.arange(shape[0]) * resolution[0]
        y = np.arange(shape[1]) * resolution[1]
        z = np.arange(shape[2]) * resolution[2]

        self.elast = xr.DataArray(
            data=elast,
            dims=['x', 'y', 'z'],
            coords=dict(x=x, y=y, z=z),
            name='elasticity'
        )
        self.anat = xr.DataArray(
            data=anat,
            dims=['x', 'y', 'z'],
            coords=dict(x=x, y=y, z=z),
            name='anatomy'
        )
        self.disp = xr.DataArray(
            data=disp,
            dims=['x', 'y', 'z', 'component'],
            coords=dict(x=x, y=y, z=z, component=['x', 'y', 'z']),
            name='displacement'
        )
        self.mask = xr.DataArray(
            data=mask,
            dims=['x', 'y', 'z'],
            coords=dict(x=x, y=y, z=z),
            name='mask'
        )
        affine = np.diag(list(resolution) + [1])
        
        print(f'Saving {self.elast_file()}')
        nib.save(nib.Nifti1Image(self.elast.data, affine), self.elast_file())

        print(f'Saving {self.anat_file()}')
        nib.save(nib.Nifti1Image(self.anat.data, affine), self.anat_file())
              
        print(f'Saving {self.disp_file()}')
        nib.save(nib.Nifti1Image(self.disp.data, affine), self.disp_file())
              
        print(f'Saving {self.mask_file()}')
        nib.save(nib.Nifti1Image(self.mask.data, affine), self.mask_file())  

phantom_set = PhantomSet(data_root='../data/lung_phantom/', num_phantoms=100)
phantom_set.generate(mask_file='../data/Emory-4DCT/Case1Pack/TotalSegment/case1_T00/lung_regions.nii.gz', mesh_radius=10.0)

In [None]:
examples = phantom_set.get_examples(mesh_radius=10.0)
examples

In [None]:
%autoreload
dataset = project.data.Dataset(examples)
dataset

In [None]:
dataset.load_example(0)