# Demo

**This is a demo for testing and training PsiDONet. Please refer to README for more information about installation and files organization.**

**@author: Mathilde Galinier (megalinier@gmail.com)**

In [None]:
import os
import torch
from fundamental_functions.auxiliary_functions import create_path_save_name
from fundamental_functions.Train_Test_PsiDONet import PsiDONet_class
from fundamental_functions.tools import compute_quality_results

## 1. Train a model

### 1.1 Train conditions

In [None]:
missing_angle      = 30
step_angle         = 1
size_image         = 128
mu                 = 0.000002
L                  = 5
train_conditions   = [missing_angle, step_angle, size_image, mu, L]
dataset            = 'Ellipses'

### 1.2 Choice of the hyperparameters

In [None]:
model_unrolling    = 'PSIDONetO'        # 'PSIDONetO' or 'PSIDONetOplus'
learning_rate      = 0.005
nb_epochs          = 3
minibatch_size     = 25
loss_type          ='MSE'               # 'MSE' or 'SSIM' 
loss_domain        ='WAV'               # 'WAV' or 'IM'   
nb_unrolledBlocks  = 40
nb_repetBlock      = 3
filter_size        = size_image//4            
wavelet_type       ='haar'              # 'haar' or 'db2'
level_decomp       = 3
precision_float    = 32
size_val_limit     = 4*minibatch_size        

### 1.3 Definition of the paths

In [None]:
optionalText       = ''
path_main          = os.path.join('..')
path_save          = os.path.join(path_main,'PyTorch','Results',\
                        create_path_save_name(train_conditions, optionalText, model_unrolling,
                                learning_rate, nb_epochs, minibatch_size, loss_type, loss_domain,
                                nb_unrolledBlocks, nb_repetBlock, filter_size, 
                                wavelet_type, level_decomp, precision_float, dataset))
path_datasets      = os.path.join(path_main,dataset + '_Datasets','Size_'+str(size_image))
paths              = [path_main, path_datasets, path_save]

### 1.4 Train the model

In [None]:
network           = PsiDONet_class(\
                    train_conditions=train_conditions,\
                    folders=paths,\
                    mode='train',\
                    model_unrolling=model_unrolling,\
                    learning_rate=learning_rate,\
                    nb_epochs=nb_epochs,\
                    minibatch_size=minibatch_size,\
                    loss_type=loss_type,\
                    loss_domain=loss_domain,\
                    nb_unrolledBlocks=nb_unrolledBlocks,\
                    nb_repetBlock=nb_repetBlock,\
                    filter_size=filter_size,\
                    wavelet_type=wavelet_type,\
                    level_decomp=level_decomp,\
                    precision_float=precision_float,\
                    size_val_limit=size_val_limit,\
                    dataset=dataset) 
network.train()


## 2. Test a model

### 2.1 Path to model to restore

In [None]:
path_to_restore = os.path.join(path_save,'parameters','MinOnVal')

### 2.2 Test trained model

In [None]:
test(train_conditions=train_conditions,\
        folders=paths,\
        model_unrolling=model_unrolling,\
        minibatch_size=minibatch_size,\
        nb_unrolledBlocks=nb_unrolledBlocks,\
        nb_repetBlock=nb_repetBlock,\
        filter_size=filter_size,\
        wavelet_type=wavelet_type,\
        level_decomp=level_decomp,\
        precision_float=precision_float,\
        dataset=dataset, \
        path_to_restore = path_to_restore)  

## 3. Show results

### 3.1 Compute quality assessment on validation set

In [None]:
print('--------------------------------------------------------------------------------------------------------------------------------')
print('Evaluating the results on test set...')
relative_err_mean, MSE_mean, SSIM_mean, PSNR_mean, HaarPSI_mean \
= compute_quality_results(os.path.join(path_datasets, 'val','Images'),os.path.join(path_save,'valset_restoredImages'),precision_float)
print('--------------------------------------------------------------------------------------------------------------------------------')

### 3.2 Compute quality assessment on test set

In [None]:
print('--------------------------------------------------------------------------------------------------------------------------------')
print('Evaluating the results on test set...')
relative_err_mean, MSE_mean, SSIM_mean, PSNR_mean, HaarPSI_mean \
= compute_quality_results(os.path.join(path_datasets, 'test','Images'),os.path.join(path_save,'testset_restoredImages'),precision_float)
print('--------------------------------------------------------------------------------------------------------------------------------')

### 3.3 Visualisation 

In [None]:
from skimage.transform import iradon
from tools import compute_angles
import scipy.io as sio
import numpy as np
import matplotlib.pyplot as plt

# choose an image 
num = 10556

# Paths 
angles              = compute_angles(missing_angle, step_angle)
path_im_groundtruth = os.path.join(path_datasets,'test','Images','im_reduced_'+str(size_image)+'x'+str(size_image)+'_'+str(num)+'.mat')
path_sino           = os.path.join(path_datasets, 'test','Sinograms','sino_angles_0_1_179_'+str(num)+'.mat')
path_im_restored    = os.path.join(path_save,'testset_restoredImages','im_reduced_'+str(size_image)+'x'+str(size_image)+'_'+str(num)+'.mat')

# Load images
im_groundtruth = sio.loadmat(path_im_groundtruth)['im_reduced']
sino           = sio.loadmat(path_sino)['mnc'][:,angles]
im_fbp         = iradon(sino, theta=angles, circle=False)[1:-1,1:-1]
im_restored    = sio.loadmat(path_im_restored)['image']

# Compute relative errors
err_fbp        = np.linalg.norm(im_fbp-im_groundtruth)/np.linalg.norm(im_groundtruth)
err_restored   = np.linalg.norm(im_restored-im_groundtruth)/np.linalg.norm(im_groundtruth)

# Plot
plt.figure(1,figsize=(10, 10))
#
plt.subplot(131)
plt.imshow(im_groundtruth)
plt.axis('off')
plt.title('Groundtruth')
#
plt.subplot(132)
plt.imshow(np.clip(im_fbp,0,1))
plt.axis('off')
plt.title('FBP, RE: %.3f'%(err_fbp))
#
plt.subplot(133)
plt.imshow(im_restored)
plt.axis('off')
plt.title('Restored: RE: %.3f'%(err_restored))
#
plt.show()