# Compare Model to Algorithmic Approaches

## Imports

In [1]:
from datetime import datetime
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (12, 8)
plt.rcParams["figure.dpi"] = 300
import torch
from torch.nn.parallel import DataParallel as DP, DistributedDataParallel as DDP

from tomopy.prep.stripe import remove_stripe_fw, remove_stripe_based_sorting
#from larix.methods.misc import INPAINT_EUCL_WEIGHTED, INPAINT_NDF, INPAINT_NM
from tomopy.misc.corr import inpainter_morph

from network.patch_visualizer import PatchVisualizer
from network.models import MaskedGAN
from network.models.generators import PatchUNet
from network.models.discriminators import PatchDiscriminator
from utils.data_io import loadTiff
from utils.tomography import reconstruct

# Number of threads for OpenMP. If too high, may cause error
%env OMP_NUM_THREADS=16

env: OMP_NUM_THREADS=16


[home:74834] mca_base_component_repository_open: unable to open mca_btl_openib: librdmacm.so.1: cannot open shared object file: No such file or directory (ignored)


## Functions
TomoPy edits sinograms in-place, so we must copy the input to avoid changing it.<br>
TomoPy also expects 3D volumes to have shape `(angles, det Y, det X)` so we must swap axes 0 and 1.

In [2]:
def fourier_wavelet(sino_volume, level=None, wname='db5', sigma=2, pad=True, ncore=None):
    inpt = sino_volume.copy().swapaxes(0, 1)
    output = remove_stripe_fw(inpt, level, wname, sigma, pad, ncore=ncore)
    return output.swapaxes(0, 1)


def remove_stripes_based_sorting(sino_volume, size=21, dim=1, ncore=None):
    inpt = sino_volume.copy().swapaxes(0, 1)
    output = remove_stripe_based_sorting(inpt, size, dim, ncore=ncore)
    return output.swapaxes(0, 1)


def remove_stripes_larix(sino_volume, mask, mode='NDF'):
    if mode == 'NDF':
        func = INPAINT_NDF
        kwargs = dict(regularisation_parameter=5000, edge_parameter=0, iterationsNumb=5000,
                      time_marching_parameter=0.000075, penalty_type=1)
    elif mode == 'EUCL':
        func = INPAINT_EUCL_WEIGHTED
        kwargs = dict(iterationsNumb=3, windowsize_half=2, method_type='random')
    elif mode == 'NM':
        func = INPAINT_NM
        kwargs = dict(SW_increment=2, iterationsNumb=150)
    else:
        raise ValueError(f"Mode {mode} not recognized.")
    inpainted = np.empty_like(sino_volume)
    for s in range(sino_volume.shape[0]):
        inpainted[s] = func(sino_volume[s], mask[s], **kwargs)
    return inpainted

def inpaint_stripes_tomopy(sino_volume, mask, inpainting_type='random'):
    func = inpainter_morph
    kwargs = dict(size=5,
                  iterations=5,
                  inpainting_type=inpainting_type,
                  method_type='2D')
    inpainted = np.empty_like(sino_volume)
    for s in range(sino_volume.shape[0]):
        inpainted[s] = func(sino_volume[s], mask[s].astype(bool), **kwargs)
    return inpainted

In [3]:
def root_mean_squared_error(data1, data2, axis=None):
    return np.sqrt(np.mean(np.square(data1 - data2), axis=axis))

In [4]:
def load_model(path, device=None):
    # Load model state dict from disk
    checkpoint = torch.load(path, map_location=device)
    # Initialize Generator and Discriminator
    gen = DP(PatchUNet())
    gen.load_state_dict(checkpoint['gen_state_dict'])
    disc = DP(PatchDiscriminator())
    disc.load_state_dict(checkpoint['disc_state_dict'])
    # Initialize Model
    model = MaskedGAN(gen, disc, mode='test', device=device)
    return model

## Setup

### Parameters

In [7]:
#i12 = Path('/dls/i12/data/2022/nt33730-1/processing/NoStripesNet')
i12 = Path('/media/algol/HD-LXU3/No_stripes_net_data/')
data_dir = i12/'data'/'wider_stripes'
model_file = i12/'pretrained_models'/'five_sample'/'4x4'/'val'/'five_sample_4x4_100.tar'
mask_file = i12/'stripe_masks.npz'
if torch.cuda.is_available():
    d = torch.device('cuda')
else:
    d = torch.device('cpu')

rng = np.random.default_rng()
sample_no = 0
print(f"Sample No.: {sample_no}")

Sample No.: 0


### Load Model & Visualizer

In [8]:
model = load_model(model_file, device=d)
v = PatchVisualizer(data_dir, model, sample_no=sample_no, mask_file=mask_file)

### Load Clean & Stripe Volumes

In [None]:
cleans = []
stripes = []
start_time = datetime.now()
for idx in v.clean_idxs:
    cleans.append(v.get_sinogram(idx, 'clean').astype(np.float32))
    stripes.append(v.get_sinogram(idx, 'stripe').astype(np.float32))
cleans = np.asarray(cleans)
stripes = np.asarray(stripes)
mask = np.abs(cleans - stripes).astype(bool, copy=False)
print(f"Loading finished in {datetime.now() - start_time}s")
print(f"{cleans.shape=}, {cleans.dtype=}")
print(f"{stripes.shape=}, {stripes.dtype=}")
print(f"{mask.shape=}, {mask.dtype=}")

## Calculate RMSEs

### RMSE from Clean to Stripe (i.e. control)

In [8]:
# calculating RMSE for the whole image
rmse_control = root_mean_squared_error(cleans, stripes)
# calculating RMSE for the stripes region ONLY
rmse_stripes = root_mean_squared_error(cleans[mask == True], stripes[mask == True])
print(f"{rmse_stripes = }")

rmse_stripes = 4886.932


In [9]:
# calculating RMSE for the whole sinogram but not the stripes region
rmse_nostripes = root_mean_squared_error(cleans[mask == False], stripes[mask == False])
print(f"{rmse_nostripes = }")

rmse_nostripes = 0.0


### RMSE of Fourier Wavelet

In [10]:
start_time = datetime.now()
fw = fourier_wavelet(stripes, level=None, wname='db5', sigma=0.6, pad=True, ncore=16)
print(f"Time: {datetime.now() - start_time}")
rmse_fw = root_mean_squared_error(cleans, fw)
print(f"{rmse_fw = }")
rmse_fw_stripes = root_mean_squared_error(cleans[mask == True], fw[mask == True])
print(f"{rmse_fw_stripes = }")
rmse_fw_nostripes = root_mean_squared_error(cleans[mask == False], fw[mask == False])
print(f"{rmse_fw_nostripes = }")

Time: 0:02:03.385654
rmse_fw = 2066.9531
rmse_fw_stripes = 3013.0122
rmse_fw_nostripes = 2017.4863


### RMSE of Sorting algorithm

In [11]:
start_time = datetime.now()
vo_sorting = remove_stripes_based_sorting(stripes, size=31, dim=1, ncore=16)
print(f"Time: {datetime.now() - start_time}")
rmse_vo = root_mean_squared_error(cleans, vo_sorting)
print(f"{rmse_vo = }")
rmse_vo_stripes = root_mean_squared_error(cleans[mask == True], vo_sorting[mask == True])
print(f"{rmse_vo_stripes = }")
rmse_vo_nostripes = root_mean_squared_error(cleans[mask == False], vo_sorting[mask == False])
print(f"{rmse_vo_nostripes = }")

Time: 0:02:26.708789
rmse_vo = 677.88464
rmse_vo_stripes = 3220.4485
rmse_vo_nostripes = 206.97923


### RMSE of Algorthmic Inpainting

In [12]:
start_time = datetime.now()
larix_euclidian = remove_stripes_larix(stripes, mask, mode='EUCL')
print(f"Time: {datetime.now() - start_time}")
rmse_euc = root_mean_squared_error(cleans, larix_euclidian)
print(f"{rmse_euc=}")
rmse_euc_stripes = root_mean_squared_error(cleans[mask == True], larix_euclidian[mask == True])
print(f"{rmse_euc_stripes=}")
rmse_euc_nostripes = 0 # for some reason doing RMSE for mask==False causes the kernel to die
print(f"{rmse_euc_nostripes=}")

Time: 0:18:59.963909
rmse_euc=231.20004
rmse_euc_stripes=1151.0647
rmse_euc_nostripes=0


### RMSE of cGAN Inpainting

In [13]:
nsn = np.load(i12/'processed'/'model_output.npz')['synth']
assert nsn.shape == cleans.shape
rmse_nsn = root_mean_squared_error(cleans, nsn)
print(f"{rmse_nsn = }")
rmse_nsn_stripe = root_mean_squared_error(cleans[mask == True], nsn[mask == True])
print(f"{rmse_nsn_stripe = }")
rmse_nsn_nostripes = 0 # for some reason doing RMSE for mask==False causes the kernel to die
print(f"{rmse_nsn_nostripes = }")

rmse_nsn = 204.40088148354633
rmse_nsn_stripe = 1017.6448093905587
rmse_nsn_nostripes = 0
