In [1]:
from fit.datamodules.tomo_rec.TRecDataModule import KanjiFourierTargetDataModule
from fit.utils import convert2FC, fft_interpolate, psfft, convert_to_dft, PSNR
from fit.utils.tomo_utils import get_proj_coords_pol, get_img_coords_pol
from matplotlib import pyplot as plt

import torch

import numpy as np

from skimage.transform import iradon

from fit.utils.utils import denormalize, PSNR

from fit.modules import TRecTransformerModule

from matplotlib import gridspec

from tqdm import tqdm_notebook as tqdm

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint

In [2]:
batch_size = 1
num_angles = 33
img_shape = 63
inner_circle = True

In [3]:
dm = KanjiFourierTargetDataModule(root_dir='/data/kkanji2/', batch_size=batch_size, num_angles=num_angles)
dm.setup()

test_dl = dm.test_dataloader()

In [4]:
det_len = dm.gt_ds.get_ray_trafo().geometry.detector.shape[0]

In [5]:
proj_r, proj_phi, src_flatten = get_proj_coords_pol(angles=dm.gt_ds.get_ray_trafo().geometry.angles, 
                                                          det_len=det_len)
target_r, target_phi, dst_flatten, order = get_img_coords_pol(img_shape=img_shape, det_len=det_len)

In [6]:
trainer = Trainer(max_epochs=20, 
                  gpus=1,
                  checkpoint_callback=ModelCheckpoint(
                                            filepath=None,
                                            save_top_k=1,
                                            verbose=False,
                                            save_last=True,
                                            monitor='Train/avg_val_mse',
                                            mode='min',
                                            prefix='best_val_loss_'
                                        ), 
                  deterministic=True)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


# FIT - TRec (Ours)

In [7]:
best_path = '/home/tbuchhol/HaarVAE/FIT_experiments/iccv_experiments/TRec/trec_Kanji_v0.1.24_fbp_prod_dconv8_nh8_dq32_icTrue_ibf2_bfc5_nl4_me120_ta33_pc/lightning_logs/version_39398130/checkpoints/best_val_loss_-epoch=65-step=412499.ckpt'

In [8]:
print(best_path)

/home/tbuchhol/HaarVAE/FIT_experiments/iccv_experiments/TRec/trec_Kanji_v0.1.24_fbp_prod_dconv8_nh8_dq32_icTrue_ibf2_bfc5_nl4_me120_ta33_pc/lightning_logs/version_39398130/checkpoints/best_val_loss_-epoch=65-step=412499.ckpt


In [None]:
model = TRecTransformerModule.load_from_checkpoint(best_path,
                                                   y_coords_proj=proj_r,
                                                   x_coords_proj=proj_phi,
                                                   y_coords_img=target_r,
                                                   x_coords_img=target_phi,
                                                   angles=dm.gt_ds.get_ray_trafo().geometry.angles,
                                                   src_flatten_coords=src_flatten,
                                                   dst_flatten_coords=dst_flatten,
                                                   dst_order=order,encoder_only=False, use_fbp=True)

In [None]:
test_res = trainer.test(model, datamodule=dm)[0]

In [None]:
np.round(test_res['Mean PSNR'].item(), 2)

# Fourier Query Points

In [7]:
best_path = '/home/tbuchhol/HaarVAE/FIT_experiments/iccv_experiments/TRec/trec_Kanji_v0.1.24_zero_prod_dconv8_nh8_dq32_icTrue_ibf2_bfc5_nl4_me120_ta33_pc/lightning_logs/version_39413813/checkpoints/best_val_loss_-epoch=119-step=749999.ckpt'

In [8]:
print(best_path)

/home/tbuchhol/HaarVAE/FIT_experiments/iccv_experiments/TRec/trec_Kanji_v0.1.24_zero_prod_dconv8_nh8_dq32_icTrue_ibf2_bfc5_nl4_me120_ta33_pc/lightning_logs/version_39413813/checkpoints/best_val_loss_-epoch=119-step=749999.ckpt


In [9]:
model = TRecTransformerModule.load_from_checkpoint(best_path,
                                                   y_coords_proj=proj_r,
                                                   x_coords_proj=proj_phi,
                                                   y_coords_img=target_r,
                                                   x_coords_img=target_phi,
                                                   angles=dm.gt_ds.get_ray_trafo().geometry.angles,
                                                   src_flatten_coords=src_flatten,
                                                   dst_flatten_coords=dst_flatten,
                                                   dst_order=order,encoder_only=False, use_fbp=False)

In [10]:
test_res = trainer.test(model, datamodule=dm)[0]



HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…



bin_factor set to 1.

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'Mean PSNR': array(25.990751266479492, dtype=float32),
 'SEM PSNR': array(0.028939001262187958, dtype=float32)}
--------------------------------------------------------------------------------


In [11]:
np.round(test_res['Mean PSNR'].item(), 2)

25.989999999999998

# Encoder Only

In [8]:
best_path = '/home/tbuchhol/HaarVAE/FIT_experiments/iccv_experiments/TRec/trec_Kanji_v0.1.24_eo_fbp_prod_dconv8_nh8_dq32_icTrue_ibf2_bfc5_nl4_me120_ta33_pc/lightning_logs/version_39398332/checkpoints/best_val_loss_-epoch=97-step=612499.ckpt'

In [9]:
print(best_path)

/home/tbuchhol/HaarVAE/FIT_experiments/iccv_experiments/TRec/trec_Kanji_v0.1.24_eo_fbp_prod_dconv8_nh8_dq32_icTrue_ibf2_bfc5_nl4_me120_ta33_pc/lightning_logs/version_39398332/checkpoints/best_val_loss_-epoch=97-step=612499.ckpt


In [10]:
model = TRecTransformerModule.load_from_checkpoint(best_path,
                                                   y_coords_proj=proj_r,
                                                   x_coords_proj=proj_phi,
                                                   y_coords_img=target_r,
                                                   x_coords_img=target_phi,
                                                   angles=dm.gt_ds.get_ray_trafo().geometry.angles,
                                                   src_flatten_coords=src_flatten,
                                                   dst_flatten_coords=dst_flatten,
                                                   dst_order=order,encoder_only=True, use_fbp=True)

In [11]:
test_res = trainer.test(model, datamodule=dm)[0]



HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…



bin_factor set to 1.

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'Mean PSNR': array(30.493745803833008, dtype=float32),
 'SEM PSNR': array(0.03007417730987072, dtype=float32)}
--------------------------------------------------------------------------------


In [12]:
np.round(test_res['Mean PSNR'].item(), 2)

30.489999999999998

# ConvBlock Only

In [13]:
best_path = '/home/tbuchhol/HaarVAE/FIT_experiments/iccv_experiments/TRec/trec_Kanji_v0.1.24_baseline_dconv8_me120_ta33/lightning_logs/version_39398108/checkpoints/best_val_loss_-epoch=65-step=412499.ckpt'

In [14]:
print(best_path)

/home/tbuchhol/HaarVAE/FIT_experiments/iccv_experiments/TRec/trec_Kanji_v0.1.24_baseline_dconv8_me120_ta33/lightning_logs/version_39398108/checkpoints/best_val_loss_-epoch=65-step=412499.ckpt


In [15]:
model = TRecTransformerModule.load_from_checkpoint(best_path,
                                                   y_coords_proj=proj_r,
                                                   x_coords_proj=proj_phi,
                                                   y_coords_img=target_r,
                                                   x_coords_img=target_phi,
                                                   angles=dm.gt_ds.get_ray_trafo().geometry.angles,
                                                   src_flatten_coords=src_flatten,
                                                   dst_flatten_coords=dst_flatten,
                                                   dst_order=order,encoder_only=False, use_fbp=True, 
                                                   convblock_only=True)

In [16]:
test_res = trainer.test(model, datamodule=dm)[0]



HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'Mean PSNR': array(26.915298461914062, dtype=float32),
 'SEM PSNR': array(0.02448676899075508, dtype=float32)}
--------------------------------------------------------------------------------


In [17]:
np.round(test_res['Mean PSNR'].item(), 2)

26.920000000000002

# FBP 

In [18]:
test_dl = dm.test_dataloader()

fbp_psnrs = []
for i in tqdm(range(len(test_dl.dataset.ds))):
    sino, img = test_dl.dataset.ds[i]
    rec = torch.from_numpy(np.array(iradon(sino.numpy().T, theta=-np.rad2deg(dm.gt_ds.get_ray_trafo().geometry.angles)).T))
    img = denormalize(img, dm.mean, dm.std)
    rec = denormalize(rec, dm.mean, dm.std)
    rec *= model.circle.cpu()
    img *= model.circle.cpu()
    fbp_psnrs.append(PSNR(img, rec, img.max() - img.min()))
    
np.round(torch.mean(torch.stack(fbp_psnrs)).item(), 2)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  after removing the cwd from sys.path.


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5000.0), HTML(value='')))




22.059999999999999