In [1]:
%matplotlib notebook
%load_ext autoreload
%pwd

'/ocean/projects/asc170022p/mtragoza/lung-project/notebooks'

In [2]:
import sys
import pandas as pd
sys.path.append('../../param_search')
import param_search as ps

## Setup experiment

[[Setup](#Setup-experiment)] [[Submit](#Submit-jobs)] [[Monitor](#Monitor-jobs)] [[Analyze](#Analyze-results)]

In [11]:
# define a job template and name format
template = '''\
#!/bin/bash
#SBATCH --job-name={job_name}
#SBATCH --account=asc170022p
#SBATCH --partition=GPU-shared
#SBATCH --gres=gpu:v100-32:1
#SBATCH --mem=63000M
#SBATCH --time=48:00:00
#SBATCH -o %J.stdout
#SBATCH -e %J.stderr
#SBATCH --mail-type=all

hostname
pwd
module load anaconda3
conda activate /ocean/projects/asc170022p/mtragoza/mambaforge/envs/lung-project
nvidia-smi

python ../../../train.py \\
    --data_name emory \\
    --data_root /ocean/projects/asc170022p/shared/Data/4DLungCT/Emory/ \\
    --mask_roi {mask_roi} \\
    --mesh_version {mesh_version} \\
    --test_case {test_case} \\
    --test_phase {test_phase} \\
    --model_arch {model_arch} \\
    --input_anat {input_anat} \\
    --input_coords {input_coords} \\
    --conv_channels {conv_channels} \\
    --output_func {output_func} \\
    --trainer_task {trainer_task} \\
    --rho_value {rho_value} \\
    --interp_size {interp_size} \\
    --learning_rate {lr:.0e} \\
    --num_epochs {num_epochs} \\
    --test_every {test_every} \\
    --save_every {save_every} \\
    --save_prefix {job_name}

echo Done
'''

In [8]:
param_space = ps.ParamSpace(
    mask_roi='lung_regions2',
    mesh_version=11,
    test_case=[None], #['Case1Pack', 'Case2Pack', 'Case3Pack', 'Case4Pack', 'Case5Pack', 'Case6Pack', 'Case7Pack', 'Case8Deploy', 'Case9Pack'],
    test_phase=[0], # 10, 20, 30, 40, 50, 60, 70, 80, 90],
    rho_value=[0, 1000, 'anat'],
    conv_channels=[32],
    interp_size=[5],
    output_func=['relu', 'exp', 'softplus'],
)
name_format = 'train__emory__{test_case}__{test_phase}__{rho_value}__{output_func}'

## fitting

param_space = ps.ParamSpace(
    mask_roi='lung_regions2',
    mesh_version=11,
    test_case=['Case10Pack'], #['Case1Pack', 'Case2Pack', 'Case3Pack', 'Case4Pack', 'Case5Pack', 'Case6Pack', 'Case7Pack', 'Case8Deploy', 'Case9Pack'],
    test_phase=[0], # 10, 20, 30, 40, 50, 60, 70, 80, 90],
    input_anat=[True, False],
    input_coords=[True],
    model_arch=['unet3d'],
    conv_channels=[32],
    output_func=['softplus'],
    trainer_task=['fit'],
    rho_value=[0, 1000, 'anat'],
    interp_size=[5],
    lr=[1e-5],
    num_epochs=1000,
    test_every=100,
    save_every=100,
)
name_format = 'fit__emory__{input_anat:d}{input_coords:d}__{output_func}__{rho_value}__{test_case}__{test_phase}'


for p in param_space:
    print(name_format.format(**p))

print(len(param_space))

fit__emory__11__softplus__0__Case10Pack__0
fit__emory__11__softplus__1000__Case10Pack__0
fit__emory__11__softplus__anat__Case10Pack__0
fit__emory__01__softplus__0__Case10Pack__0
fit__emory__01__softplus__1000__Case10Pack__0
fit__emory__01__softplus__anat__Case10Pack__0
6


## Submit jobs

[[Setup](#Setup-experiment)] [[Submit](#Submit-jobs)] [[Monitor](#Monitor-jobs)] [[Analyze](#Analyze-results)]

In [9]:
expt_name = [
    '2024-11-22__phantom__250',
    '2024-11-22__phantom__250__resub',
    '2024-11-30__emory__phase',
    '2024-12-02__emory__interp_size',
    '2024-12-03__emory__gpu_shared',
    '2024-12-07__emory__interface',
    '2024-12-08__emory__lung_regions2',
    '2024-12-08__emory__clamp',
    '2024-12-12__emory__fit'
][-1]
expt_name

'2024-12-12__emory__fit'

In [12]:
do_submit = True
if do_submit:
    jobs = ps.submit(template, name_format, param_space, work_dir=expt_name)
    jobs.to_csv(f'{expt_name}.jobs')

100%|██████████| 6/6 [00:02<00:00,  2.71it/s]


  .replace('', float('nan')).map(pd.to_numeric)


## Monitor jobs

[[Setup](#Setup-experiment)] [[Submit](#Submit-jobs)] [[Monitor](#Monitor-jobs)] [[Analyze](#Analyze-results)]

In [None]:
jobs = pd.read_csv(f'{expt_name}.jobs', index_col=0)
status = ps.status(jobs, parse_stderr=True)
status

In [None]:
status['job_state'] = status['job_state'].fillna('DONE')
status['stderr'] = status['stderr'].fillna('N/A')
status.groupby(['job_state', 'stderr'])[['job_name']].count()

In [None]:
status.iloc[0]

## Analyze results

[[Setup](#Setup-experiment)] [[Submit](#Submit-jobs)] [[Monitor](#Monitor-jobs)] [[Analyze](#Analyze-results)]

In [None]:
m = ps.metrics(jobs, sep=',')
m

In [None]:
m.groupby(['job_name'])[['epoch']].max()

In [None]:
unfinished_jobs = d[d.epoch < 100]
unfinished_jobs

In [None]:
m.groupby(['job_name'])[['batch']].max()

In [None]:
m.columns

In [None]:
%autoreload

fig = ps.plot(
    m[(m.phase == 'train') & (m.epoch > 100)],
    x=['rho_value', 'output_func'],
    y=['u_error', 'e_pred_norm', 'e_anat_corr'],
    hue=None,
    legend=True,
    legend_kws=dict(bbox_to_anchor=(0, -0.2)),
    tight=True,
    height=2.25, width=2.75
)

In [None]:
fig = ps.plot(
    m[(m.phase == 'test') & (m.epoch > 100) & (m.rep == 'dofs')],
    x=['rho_value', 'output_func'],
    y=['u_error', 'e_pred_norm', 'e_anat_corr'],
    hue=None,
    legend=True,
    legend_kws=dict(bbox_to_anchor=(0, -0.2)),
    tight=True,
    height=2.25, width=2.75
)

In [None]:
fig = ps.plot(
    m[(m.phase == 'test') & (m.epoch > 100) & (m.rep == 'image')],
    x=['rho_value', 'output_func'],
    y=['u_error', 'e_pred_norm', 'e_anat_corr'],
    hue=None,
    legend=True,
    legend_kws=dict(bbox_to_anchor=(0, -0.2)),
    tight=True,
    height=2.25, width=2.75
)

In [None]:
fig = ps.plot(
    m[(m.phase == 'train') & (m.epoch > 100)],
    x=['rho_value', 'output_func'],
    y=['e_950_corr', 'e_dis0_corr', 'e_dis1_corr', 'e_dis2_corr'],
    hue=None,
    legend=True,
    legend_kws=dict(bbox_to_anchor=(0, -0.2)),
    tight=True,
    height=2.25, width=2.75
)
for ax in fig.axes:
    ax.set_ylim(-0.25, 0.25)

In [None]:
fig = ps.plot(
    m[(m.phase == 'test') & (m.epoch > 100) & (m.rep == 'dofs')],
    x=['rho_value', 'output_func'],
    y=['e_950_corr', 'e_dis0_corr', 'e_dis1_corr', 'e_dis2_corr'],
    hue=None,
    legend=True,
    legend_kws=dict(bbox_to_anchor=(0, -0.2)),
    tight=True,
    height=2.25, width=2.75
)
for ax in fig.axes:
    ax.set_ylim(-0.25, 0.25)

## Image grids

In [3]:
import sys, os
os.environ['PKG_CONFIG_PATH'] = '/ocean/projects/asc170022p/mtragoza/mambaforge/envs/lung-project/lib/pkgconfig'

sys.path.append('..')
import project

In [13]:
param_space = ps.ParamSpace(
    mask_roi='lung_regions2',
    mesh_verson=[11],
    model_arch=['unet3d'],
    conv_channels=[32],
    rho_value=[1000],
    lr=[1e-5],
) * (
    ps.ParamSpace(
        trainer_task=['train'],
        input_anat=[True],
        input_coords=[False],
        output_func=['softplus'],
        num_epochs=200,
    ) + ps.ParamSpace(
        trainer_task=['fit'],
        input_anat=[True, False],
        input_coords=[True],
        output_func=['softplus'],
        num_epochs=1000
    )
)
name_formats = [
    'train__emory__{test_case}__{test_phase}__{rho_value}__{output_func}',
    'fit__emory__{input_anat:d}{input_coords:d}__{output_func}__{rho_value}__{test_case}__{test_phase}',
    'fit__emory__{input_anat:d}{input_coords:d}__{output_func}__{rho_value}__{test_case}__{test_phase}',
]
expt_names = [
    '2024-12-08__emory__clamp',
    '2024-12-12__emory__fit',
    '2024-12-12__emory__fit',
]
len(param_space)

3

In [14]:
%autoreload
import numpy as np
import torch

n_models = 3
n_cases = 9

masks = []
d_masks = []
a_images = []
e_images = []

masks = [[] for i in range(n_models)]
d_masks = [[] for i in range(n_models)]
a_images = [[] for i in range(n_models)]
e_preds = [[] for i in range(n_models)]

for i in range(n_models):
    for j in range(n_cases):
        case_name = ('Case8Deploy' if j == 7 else f'Case{j+1}Pack')
        job_params = param_space[i]
        if job_params['trainer_task'] == 'train':
            job_params['test_case'] = None
        else:
            job_params['test_case'] = case_name
        job_params['test_phase'] = 0
        job_name = name_formats[i].format(**job_params)
        expt_name = expt_names[i]
        print(job_name)

        data_root = f'/ocean/projects/asc170022p/shared/Data/4DLungCT/Emory/'
        fixed_phase = 0
        case = project.imaging.Emory4DCTCase(data_root, case_name, project.imaging.ALL_PHASES)
    
        anat_file = case.nifti_file(fixed_phase)
        mask_file = case.totalseg_mask_file(fixed_phase, roi='lung_regions2')
        d_mask_file = case.medpseg_mask_file(fixed_phase, roi='findings')
        
        mask_nifti = project.data.load_nii_file(mask_file)
        mask = mask_nifti.get_fdata()
        resolution = mask_nifti.header.get_zooms()
        a_image = project.data.load_nii_file(anat_file).get_fdata()
        d_mask = project.data.load_nii_file(d_mask_file).get_fdata()
        print(mask.shape, a_image.shape, d_mask.shape)
    
        shape = a_image.shape
        if job_params['model_arch'] == 'unet3d':
            model = project.model.UNet3D(
                in_channels=1*job_params['input_anat'] + 3*job_params['input_coords'],
                out_channels=1,
                num_levels=3,
                num_conv_layers=2,
                conv_channels=job_params['conv_channels'],
                conv_kernel_size=3,
                output_func=job_params['output_func']
            ).cuda()
            
        elif job_params['model_arch'] == 'param_map':
            model = project.model.ParameterMap(
                shape=(1, shape[0]//2, shape[1]//2, shape[2]//2),
                upsample_mode='nearest',
                conv_kernel_size=3,
                output_func=job_params['output_func']
            ).cuda()
        
        epoch = job_params['num_epochs']
        model_path = f'{expt_name}/{job_name}/state/model_{epoch}.pt'
        model.load_state_dict(torch.load(model_path))
        
        binary_mask = torch.as_tensor(mask > 0, device='cuda')
        c_t = project.training.get_input_coords(binary_mask, resolution)
        a_t = torch.as_tensor(a_image, dtype=torch.float32, device='cuda').unsqueeze(0)
        if job_params['input_anat'] and job_params['input_coords']:
            input_t = torch.cat([a_t, c_t], dim=0)
        elif job_params['input_coords']:
            input_t = c_t
        elif job_params['input_anat']:
            input_t = a_t

        e_t = (model.forward(input_t[None,...])[0,0]*1000).clamp(min=1, max=1e12)
        e_pred = e_t.detach().cpu().numpy() * (mask > 0)
        
        masks[i].append(mask)
        a_images[i].append(a_image)
        e_preds[i].append(e_pred)
        d_masks[i].append(d_mask)
        
        del model


train__emory__None__0__1000__softplus
Loading /ocean/projects/asc170022p/shared/Data/4DLungCT/Emory/Case1Pack/TotalSegment/case1_T00/lung_regions2.nii.gz... (256, 256, 94)
Loading /ocean/projects/asc170022p/shared/Data/4DLungCT/Emory/Case1Pack/NIFTI/case1_T00.nii.gz... (256, 256, 94)
Loading /ocean/projects/asc170022p/shared/Data/4DLungCT/Emory/Case1Pack/medpseg/case1_T00_findings.nii.gz... (256, 256, 94)
(256, 256, 94) (256, 256, 94) (256, 256, 94)


  model.load_state_dict(torch.load(model_path))


train__emory__None__0__1000__softplus
Loading /ocean/projects/asc170022p/shared/Data/4DLungCT/Emory/Case2Pack/TotalSegment/case2_T00/lung_regions2.nii.gz... (256, 256, 94)
Loading /ocean/projects/asc170022p/shared/Data/4DLungCT/Emory/Case2Pack/NIFTI/case2_T00.nii.gz... (256, 256, 94)
Loading /ocean/projects/asc170022p/shared/Data/4DLungCT/Emory/Case2Pack/medpseg/case2_T00_findings.nii.gz... (256, 256, 94)
(256, 256, 94) (256, 256, 94) (256, 256, 94)
train__emory__None__0__1000__softplus
Loading /ocean/projects/asc170022p/shared/Data/4DLungCT/Emory/Case3Pack/TotalSegment/case3_T00/lung_regions2.nii.gz... (256, 256, 94)
Loading /ocean/projects/asc170022p/shared/Data/4DLungCT/Emory/Case3Pack/NIFTI/case3_T00.nii.gz... (256, 256, 94)
Loading /ocean/projects/asc170022p/shared/Data/4DLungCT/Emory/Case3Pack/medpseg/case3_T00_findings.nii.gz... (256, 256, 94)
(256, 256, 94) (256, 256, 94) (256, 256, 94)
train__emory__None__0__1000__softplus
Loading /ocean/projects/asc170022p/shared/Data/4DLungC

fit__emory__01__softplus__1000__Case2Pack__0
Loading /ocean/projects/asc170022p/shared/Data/4DLungCT/Emory/Case2Pack/TotalSegment/case2_T00/lung_regions2.nii.gz... (256, 256, 94)
Loading /ocean/projects/asc170022p/shared/Data/4DLungCT/Emory/Case2Pack/NIFTI/case2_T00.nii.gz... (256, 256, 94)
Loading /ocean/projects/asc170022p/shared/Data/4DLungCT/Emory/Case2Pack/medpseg/case2_T00_findings.nii.gz... (256, 256, 94)
(256, 256, 94) (256, 256, 94) (256, 256, 94)
fit__emory__01__softplus__1000__Case3Pack__0
Loading /ocean/projects/asc170022p/shared/Data/4DLungCT/Emory/Case3Pack/TotalSegment/case3_T00/lung_regions2.nii.gz... (256, 256, 94)
Loading /ocean/projects/asc170022p/shared/Data/4DLungCT/Emory/Case3Pack/NIFTI/case3_T00.nii.gz... (256, 256, 94)
Loading /ocean/projects/asc170022p/shared/Data/4DLungCT/Emory/Case3Pack/medpseg/case3_T00_findings.nii.gz... (256, 256, 94)
(256, 256, 94) (256, 256, 94) (256, 256, 94)
fit__emory__01__softplus__1000__Case4Pack__0
Loading /ocean/projects/asc170022

In [16]:
x_shape, y_shape, z_shape = (256, 256, 94)
x_res, y_res, z_res = (0.97, 0.97, 2.5)

x_extent = (x_shape - 1) * x_res
y_extent = (y_shape - 1) * y_res
z_extent = (z_shape - 1) * z_res

y_slice = y_shape//2

ax_height = 1.5
ax_width = ax_height * x_extent/z_extent
print((ax_height, ax_width))

e_cmap = project.visual.mre_color_map()

(1.5, 1.5958064516129031)


In [22]:
%autoreload

def get_y_slice(mask):
    density = mask.sum(axis=(0,2)) / mask.sum()
    return np.abs(density.cumsum() - 0.5).argmin()

def get_vmax(e_pred, mask):
    e_pred = np.where(mask > 0, e_pred, np.nan)
    return np.nanpercentile(e_pred, 50) * 6
    

model_names = [
    'Our method',
    'Baseline 1',
    'Baseline'
]
model_inds = [0, 2]

columns = ['anat', 'mask', 'd_mask'] + model_inds
cases = range(5)

n_rows = len(cases)
n_cols = len(columns)

fig, axes, cbar_ax = project.visual.subplot_grid(n_rows, n_cols, ax_height, ax_width, space=0, pad=1)

for row_idx, case_idx in enumerate(cases):
    
    mask = masks[0][case_idx]
    d_mask = d_masks[0][case_idx]
    a_image = a_images[0][case_idx]
    d_mask = (a_image < -910) * (mask != 6)

    y_slice = get_y_slice(mask**2)
    
    for col_idx, column in enumerate(columns):
        ax = axes[row_idx,col_idx]

        if col_idx == 0:
            ax.set_ylabel(f'Case {case_idx+1}')
            
        if column == 'anat':
            if row_idx == 0:
                ax.set_title('CT image')
            project.visual.imshow(
                ax, (a_image)[:,y_slice,:],
                aspect=z_res/x_res,
                interpolation_stage='rgba',
                **project.visual.get_color_map('CT')
            )
        
        elif column == 'mask':
            if row_idx == 0:
                ax.set_title('Vessels/airways')
            project.visual.imshow(
                ax, mask[:,y_slice,:],
                aspect=z_res/x_res,
                interpolation_stage='rgba',
                **project.visual.get_color_map('regions')
            )
            
        elif column == 'd_mask':
            if row_idx == 0:
                ax.set_title('Emphysema')
            project.visual.imshow(
                ax, ((1 + d_mask)*(mask > 0))[:,y_slice,:],
                aspect=z_res/x_res,
                interpolation_stage='rgba',
                **project.visual.get_color_map('regions')
            )
    
        else: # elasticity map
            model_idx = column
            model_name = model_names[model_idx]
            if row_idx == 0:
                ax.set_title(model_name)
            e_pred = e_preds[model_idx][row_idx]
            vmax = get_vmax(e_pred, mask)
            project.visual.imshow(
                ax, (e_pred)[:,y_slice,:],
                aspect=z_res/x_res,
                interpolation_stage='rgba',
                cmap=e_cmap, vmax=vmax, vmin=-vmax
            )

        ax.set_xticks([])
        ax.set_yticks([])

for ext in ['png', 'pdf']:
    fig.savefig(f'emory_lung_images.{ext}', bbox_inches='tight', dpi=400)

<IPython.core.display.Javascript object>