In [1]:
from fit.datamodules.tomo_rec.TRecDataModule import LoDoPaBFourierTargetDataModule
from fit.utils import convert2FC, fft_interpolate, psfft, convert_to_dft, PSNR
from fit.utils.tomo_utils import get_proj_coords_pol, get_img_coords_pol
from matplotlib import pyplot as plt

import torch

import numpy as np

from skimage.transform import iradon

from fit.utils.utils import denormalize, PSNR

from fit.modules import TRecTransformerModule

from matplotlib import gridspec

from tqdm import tqdm_notebook as tqdm

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint

In [2]:
batch_size = 1
num_angles = 33
img_shape = 111
inner_circle = True

In [3]:
dm = LoDoPaBFourierTargetDataModule(batch_size=batch_size, gt_shape=img_shape, num_angles=num_angles)
dm.setup()

test_dl = dm.test_dataloader()

In [4]:
det_len = dm.gt_ds.get_ray_trafo().geometry.detector.shape[0]

In [5]:
proj_r, proj_phi, src_flatten = get_proj_coords_pol(angles=dm.gt_ds.get_ray_trafo().geometry.angles, 
                                                          det_len=det_len)
target_r, target_phi, dst_flatten, order = get_img_coords_pol(img_shape=img_shape, det_len=det_len)

In [6]:
trainer = Trainer(max_epochs=20, 
                  gpus=1,
                  checkpoint_callback=ModelCheckpoint(
                                            filepath=None,
                                            save_top_k=1,
                                            verbose=False,
                                            save_last=True,
                                            monitor='Train/avg_val_mse',
                                            mode='min',
                                            prefix='best_val_loss_'
                                        ), 
                  deterministic=True)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


# FIT - TRec (Ours)

In [23]:
best_path = '/home/tbuchhol/HaarVAE/FIT_experiments/iccv_experiments/TRec/trec_LoDoPaB111_v0.1.24_fbp_prod_dconv8_nh8_dq32_icTrue_ibf8_bfc5_nl4_me350_ta33_pc/lightning_logs/version_39398340/checkpoints/best_val_loss_-epoch=75-step=75999.ckpt'

In [24]:
print(best_path)

/home/tbuchhol/HaarVAE/FIT_experiments/iccv_experiments/TRec/trec_LoDoPaB111_v0.1.24_fbp_prod_dconv8_nh8_dq32_icTrue_ibf8_bfc5_nl4_me350_ta33_pc/lightning_logs/version_39398340/checkpoints/best_val_loss_-epoch=75-step=75999.ckpt


In [25]:
model = TRecTransformerModule.load_from_checkpoint(best_path,
                                                   y_coords_proj=proj_r,
                                                   x_coords_proj=proj_phi,
                                                   y_coords_img=target_r,
                                                   x_coords_img=target_phi,
                                                   angles=dm.gt_ds.get_ray_trafo().geometry.angles,
                                                   src_flatten_coords=src_flatten,
                                                   dst_flatten_coords=dst_flatten,
                                                   dst_order=order,encoder_only=False, use_fbp=True)

In [26]:
test_res = trainer.test(model, datamodule=dm)[0]



HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…



bin_factor set to 1.

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'Mean PSNR': array(30.976335525512695, dtype=float32),
 'SEM PSNR': array(0.03192515671253204, dtype=float32)}
--------------------------------------------------------------------------------


In [27]:
np.round(test_res['Mean PSNR'].item(), 2)

30.98

# Fourier Query Points

In [7]:
best_path = '/home/tbuchhol/HaarVAE/FIT_experiments/iccv_experiments/TRec/trec_LoDoPaB111_v0.1.24_zero_prod_dconv8_nh8_dq32_icTrue_ibf8_bfc5_nl4_me350_ta33_pc/lightning_logs/version_39413815/checkpoints/best_val_loss_-epoch=339-step=339999.ckpt'

In [8]:
print(best_path)

/home/tbuchhol/HaarVAE/FIT_experiments/iccv_experiments/TRec/trec_LoDoPaB111_v0.1.24_zero_prod_dconv8_nh8_dq32_icTrue_ibf8_bfc5_nl4_me350_ta33_pc/lightning_logs/version_39413815/checkpoints/best_val_loss_-epoch=339-step=339999.ckpt


In [9]:
model = TRecTransformerModule.load_from_checkpoint(best_path,
                                                   y_coords_proj=proj_r,
                                                   x_coords_proj=proj_phi,
                                                   y_coords_img=target_r,
                                                   x_coords_img=target_phi,
                                                   angles=dm.gt_ds.get_ray_trafo().geometry.angles,
                                                   src_flatten_coords=src_flatten,
                                                   dst_flatten_coords=dst_flatten,
                                                   dst_order=order,encoder_only=False, use_fbp=False)

In [10]:
test_res = trainer.test(model, datamodule=dm)[0]



HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…



bin_factor set to 1.

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'Mean PSNR': array(21.899662017822266, dtype=float32),
 'SEM PSNR': array(0.04256994277238846, dtype=float32)}
--------------------------------------------------------------------------------


In [11]:
np.round(test_res['Mean PSNR'].item(), 2)

21.899999999999999

# Encoder Only

In [12]:
best_path = '/home/tbuchhol/HaarVAE/FIT_experiments/iccv_experiments/TRec/trec_LoDoPaB111_v0.1.24_eo_fbp_prod_dconv8_nh8_dq32_icTrue_ibf8_bfc5_nl4_me350_ta33_pc/lightning_logs/version_39398338/checkpoints/best_val_loss_-epoch=80-step=80999.ckpt'

In [13]:
print(best_path)

/home/tbuchhol/HaarVAE/FIT_experiments/iccv_experiments/TRec/trec_LoDoPaB111_v0.1.24_eo_fbp_prod_dconv8_nh8_dq32_icTrue_ibf8_bfc5_nl4_me350_ta33_pc/lightning_logs/version_39398338/checkpoints/best_val_loss_-epoch=80-step=80999.ckpt


In [14]:
model = TRecTransformerModule.load_from_checkpoint(best_path,
                                                   y_coords_proj=proj_r,
                                                   x_coords_proj=proj_phi,
                                                   y_coords_img=target_r,
                                                   x_coords_img=target_phi,
                                                   angles=dm.gt_ds.get_ray_trafo().geometry.angles,
                                                   src_flatten_coords=src_flatten,
                                                   dst_flatten_coords=dst_flatten,
                                                   dst_order=order,encoder_only=True, use_fbp=True)

In [15]:
test_res = trainer.test(model, datamodule=dm)[0]



HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…



bin_factor set to 1.

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'Mean PSNR': array(30.743947982788086, dtype=float32),
 'SEM PSNR': array(0.0313936248421669, dtype=float32)}
--------------------------------------------------------------------------------


In [16]:
np.round(test_res['Mean PSNR'].item(), 2)

30.739999999999998

# ConvBlock Only

In [17]:
best_path = '/home/tbuchhol/HaarVAE/FIT_experiments/iccv_experiments/TRec/trec_LoDoPaB111_v0.1.24_baseline_dconv8_me350_ta33/lightning_logs/version_39398342/checkpoints/best_val_loss_-epoch=221-step=221999.ckpt'

In [18]:
print(best_path)

/home/tbuchhol/HaarVAE/FIT_experiments/iccv_experiments/TRec/trec_LoDoPaB111_v0.1.24_baseline_dconv8_me350_ta33/lightning_logs/version_39398342/checkpoints/best_val_loss_-epoch=221-step=221999.ckpt


In [19]:
model = TRecTransformerModule.load_from_checkpoint(best_path,
                                                   y_coords_proj=proj_r,
                                                   x_coords_proj=proj_phi,
                                                   y_coords_img=target_r,
                                                   x_coords_img=target_phi,
                                                   angles=dm.gt_ds.get_ray_trafo().geometry.angles,
                                                   src_flatten_coords=src_flatten,
                                                   dst_flatten_coords=dst_flatten,
                                                   dst_order=order,encoder_only=False, use_fbp=True, 
                                                   convblock_only=True)

In [20]:
test_res = trainer.test(model, datamodule=dm)[0]



HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'Mean PSNR': array(30.701017379760742, dtype=float32),
 'SEM PSNR': array(0.03536752611398697, dtype=float32)}
--------------------------------------------------------------------------------


In [21]:
np.round(test_res['Mean PSNR'].item(), 2)

30.699999999999999

# FBP 

In [22]:
test_dl = dm.test_dataloader()

fbp_psnrs = []
for i in tqdm(range(len(test_dl.dataset.ds))):
    sino, img = test_dl.dataset.ds[i]
    rec = torch.from_numpy(np.array(iradon(sino.numpy().T, theta=-np.rad2deg(dm.gt_ds.get_ray_trafo().geometry.angles)).T))
    img = denormalize(img, dm.mean, dm.std)
    rec = denormalize(rec, dm.mean, dm.std)
    rec *= model.circle.cpu()
    img *= model.circle.cpu()
    fbp_psnrs.append(PSNR(img, rec, img.max() - img.min()))
    
np.round(torch.mean(torch.stack(fbp_psnrs)).item(), 2)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  after removing the cwd from sys.path.


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3553.0), HTML(value='')))




26.890000000000001