In [1]:
from fit.datamodules.tomo_rec import MNISTTomoFourierTargetDataModule
from fit.utils import convert2FC, fft_interpolate, psfft, convert_to_dft, PSNR
from fit.utils.tomo_utils import get_proj_coords_pol, get_img_coords_pol
from matplotlib import pyplot as plt

import torch

import numpy as np

from skimage.transform import iradon

from fit.utils.utils import denormalize, PSNR

from fit.modules import TRecTransformerModule

from matplotlib import gridspec

from tqdm import tqdm_notebook as tqdm

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint

In [2]:
batch_size = 8
num_angles = 7
img_shape = 27
inner_circle = True

In [3]:
dm = MNISTTomoFourierTargetDataModule(root_dir='/data/mnist/', batch_size=batch_size, 
                                      num_angles=num_angles, inner_circle=inner_circle)
dm.setup()

test_dl = dm.test_dataloader()

In [4]:
det_len = dm.gt_ds.get_ray_trafo().geometry.detector.shape[0]

In [5]:
proj_r, proj_phi, src_flatten = get_proj_coords_pol(angles=dm.gt_ds.get_ray_trafo().geometry.angles, 
                                                          det_len=det_len)
target_r, target_phi, dst_flatten, order = get_img_coords_pol(img_shape=img_shape, det_len=det_len)

In [6]:
trainer = Trainer(max_epochs=20, 
                  gpus=1,
                  checkpoint_callback=ModelCheckpoint(
                                            filepath=None,
                                            save_top_k=1,
                                            verbose=False,
                                            save_last=True,
                                            monitor='Train/avg_val_mse',
                                            mode='min',
                                            prefix='best_val_loss_'
                                        ), 
                  deterministic=True)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


# FIT - TRec (Ours)

In [7]:
best_path = '/home/tbuchhol/HaarVAE/FIT_experiments/iccv_experiments/TRec/trec_mnist_v0.1.24_fbp_prod_dconv8_nh8_dq32_icTrue_ibf2_bfc5_nl4_me300_ta7_pc/lightning_logs/version_39398335/checkpoints/best_val_loss_-epoch=298-step=513980.ckpt'

In [8]:
print(best_path)

/home/tbuchhol/HaarVAE/FIT_experiments/iccv_experiments/TRec/trec_mnist_v0.1.24_fbp_prod_dconv8_nh8_dq32_icTrue_ibf2_bfc5_nl4_me300_ta7_pc/lightning_logs/version_39398335/checkpoints/best_val_loss_-epoch=298-step=513980.ckpt


In [9]:
model = TRecTransformerModule.load_from_checkpoint(best_path,
                                                   y_coords_proj=proj_r,
                                                   x_coords_proj=proj_phi,
                                                   y_coords_img=target_r,
                                                   x_coords_img=target_phi,
                                                   angles=dm.gt_ds.get_ray_trafo().geometry.angles,
                                                   src_flatten_coords=src_flatten,
                                                   dst_flatten_coords=dst_flatten,
                                                   dst_order=order,encoder_only=False, use_fbp=True)

In [10]:
test_res = trainer.test(model, datamodule=dm)[0]



HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…



bin_factor set to 1.

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'Mean PSNR': array(27.84994888305664, dtype=float32),
 'SEM PSNR': array(0.025329159572720528, dtype=float32)}
--------------------------------------------------------------------------------


In [11]:
np.round(test_res['Mean PSNR'].item(), 2)

27.850000000000001

# Fourier Query Points

In [7]:
best_path = '/home/tbuchhol/HaarVAE/FIT_experiments/iccv_experiments/TRec/trec_mnist_v0.1.24_zero_prod_dconv8_nh8_dq32_icTrue_ibf2_bfc5_nl4_me300_ta7_pc/lightning_logs/version_39413814/checkpoints/best_val_loss_-epoch=261-step=450377.ckpt'

In [8]:
print(best_path)

/home/tbuchhol/HaarVAE/FIT_experiments/iccv_experiments/TRec/trec_mnist_v0.1.24_zero_prod_dconv8_nh8_dq32_icTrue_ibf2_bfc5_nl4_me300_ta7_pc/lightning_logs/version_39413814/checkpoints/best_val_loss_-epoch=261-step=450377.ckpt


In [9]:
model = TRecTransformerModule.load_from_checkpoint(best_path,
                                                   y_coords_proj=proj_r,
                                                   x_coords_proj=proj_phi,
                                                   y_coords_img=target_r,
                                                   x_coords_img=target_phi,
                                                   angles=dm.gt_ds.get_ray_trafo().geometry.angles,
                                                   src_flatten_coords=src_flatten,
                                                   dst_flatten_coords=dst_flatten,
                                                   dst_order=order,encoder_only=False, use_fbp=False)

In [10]:
test_res = trainer.test(model, datamodule=dm)[0]



HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…



bin_factor set to 1.

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'Mean PSNR': array(27.901988983154297, dtype=float32),
 'SEM PSNR': array(0.026088178157806396, dtype=float32)}
--------------------------------------------------------------------------------


In [11]:
np.round(test_res['Mean PSNR'].item(), 2)

27.899999999999999

# Encoder Only

In [17]:
best_path = '/home/tbuchhol/HaarVAE/FIT_experiments/iccv_experiments/TRec/trec_mnist_v0.1.24_eo_fbp_prod_dconv8_nh8_dq32_icTrue_ibf2_bfc5_nl4_me300_ta7_pc/lightning_logs/version_39398334/checkpoints/best_val_loss_-epoch=295-step=508823.ckpt'

In [18]:
print(best_path)

/home/tbuchhol/HaarVAE/FIT_experiments/iccv_experiments/TRec/trec_mnist_v0.1.24_eo_fbp_prod_dconv8_nh8_dq32_icTrue_ibf2_bfc5_nl4_me300_ta7_pc/lightning_logs/version_39398334/checkpoints/best_val_loss_-epoch=295-step=508823.ckpt


In [19]:
model = TRecTransformerModule.load_from_checkpoint(best_path,
                                                   y_coords_proj=proj_r,
                                                   x_coords_proj=proj_phi,
                                                   y_coords_img=target_r,
                                                   x_coords_img=target_phi,
                                                   angles=dm.gt_ds.get_ray_trafo().geometry.angles,
                                                   src_flatten_coords=src_flatten,
                                                   dst_flatten_coords=dst_flatten,
                                                   dst_order=order,encoder_only=True, use_fbp=True)

In [20]:
test_res = trainer.test(model, datamodule=dm)[0]

HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

bin_factor set to 1.

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'Mean PSNR': array(26.892974853515625, dtype=float32),
 'SEM PSNR': array(0.024121969938278198, dtype=float32)}
--------------------------------------------------------------------------------


In [21]:
np.round(test_res['Mean PSNR'].item(), 2)

26.890000000000001

# ConvBlock Only

In [22]:
best_path = '/home/tbuchhol/HaarVAE/FIT_experiments/iccv_experiments/TRec/trec_mnist_v0.1.24_baseline_dconv8_me300_ta7/lightning_logs/version_39398337/checkpoints/best_val_loss_-epoch=130-step=225188.ckpt'

In [23]:
print(best_path)

/home/tbuchhol/HaarVAE/FIT_experiments/iccv_experiments/TRec/trec_mnist_v0.1.24_baseline_dconv8_me300_ta7/lightning_logs/version_39398337/checkpoints/best_val_loss_-epoch=130-step=225188.ckpt


In [24]:
model = TRecTransformerModule.load_from_checkpoint(best_path,
                                                   y_coords_proj=proj_r,
                                                   x_coords_proj=proj_phi,
                                                   y_coords_img=target_r,
                                                   x_coords_img=target_phi,
                                                   angles=dm.gt_ds.get_ray_trafo().geometry.angles,
                                                   src_flatten_coords=src_flatten,
                                                   dst_flatten_coords=dst_flatten,
                                                   dst_order=order,encoder_only=False, use_fbp=True, 
                                                   convblock_only=True)

In [25]:
test_res = trainer.test(model, datamodule=dm)[0]

HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'Mean PSNR': array(22.525243759155273, dtype=float32),
 'SEM PSNR': array(0.017983626574277878, dtype=float32)}
--------------------------------------------------------------------------------


In [26]:
np.round(test_res['Mean PSNR'].item(), 2)

22.530000000000001

# FBP 

In [27]:
test_dl = dm.test_dataloader()

fbp_psnrs = []
for i in tqdm(range(len(test_dl.dataset.ds))):
    sino, img = test_dl.dataset.ds[i]
    rec = torch.from_numpy(np.array(iradon(sino.numpy().T, theta=-np.rad2deg(dm.gt_ds.get_ray_trafo().geometry.angles)).T))
    img = denormalize(img, dm.mean, dm.std)
    rec = denormalize(rec, dm.mean, dm.std)
    rec *= model.circle.cpu()
    img *= model.circle.cpu()
    fbp_psnrs.append(PSNR(img, rec, img.max() - img.min()))
    
np.round(torch.mean(torch.stack(fbp_psnrs)).item(), 2)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  after removing the cwd from sys.path.


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=10000.0), HTML(value='')))




17.870000000000001