In [2]:
from Data_Gen import *
from utils import *
from PFITRE_Net import PFITRE_net

import numpy as np
import pandas as pd
import tifffile as tf

%matplotlib inline

# generate data with artifact for training - purely missing angle artifacts

In [None]:
image_infile = glob.glob('./demo/img_training/*.tiff')
out_dir = "./training_dataset/demo/"

for i in tqdm(range(len(image_infile))):
    image = tf.imread(image_infile[i])
    fn = os.path.basename(image_infile[i])[:-5]
    
    ## Data augmentation: resize image to target size, rotate to specific angle
    image2d = Data_augment(image, resize=True, target_size=320, rotate=False, rot_angle=0)

    ## Generte random projection angle list in radian
    theta = angle_list_gen(miss_angle='Random', step=1, rand_int=5, rot_angle=0)  #50,10

    ## Conduct forward projection and receive sinogram
    sinogram = Gen_sino(image2d, theta, padding=True)

    ## reconstruct by linear solver
    recon = recon_by_solver(sinogram, theta, padding=True, cor_shift=0)

    angle_list = theta/np.pi*180
    table_angle = pd.DataFrame({'Theta': angle_list})
    
    out_gt_fn = out_dir + "/gt/"+fn +'.tiff'
    out_sino_fn = out_dir + "/sino/"+fn +'.tiff'
    out_recon_fn = out_dir + "/recon/"+fn +'.tiff'
    out_angle_fn = out_dir + "/angle/"+fn +'.csv'
    
    tf.imsave(out_gt_fn, img_2d_to_3d(image2d))
    tf.imsave(out_sino_fn, sinogram)
    tf.imsave(out_recon_fn, img_2d_to_3d(recon))
    table_angle.to_csv(out_angle_fn)

# generate pairwise dataset for training - transfer learning with more artifacts

In [None]:
image_infile = glob.glob('./demo/img_training/*.tiff')
out_dir = "./training_dataset/demo_2ndstep/"

Model = PFITRE_net()

## The model weight can be downloaded here: https://drive.google.com/file/d/1rqop4dAZ5QSjZluPkQnnMj5Qkmn5gtKo/view?usp=drive_link
## The path where model weight is saved
model_weights_path = './mdl_weight/ckpt_PFITRE.pth'

## load weight to the model
Model = NN_load(Model, model_weights_path)

for i in tqdm(range(len(image_infile))):
    image = tf.imread(image_infile[i])
    fn = os.path.basename(image_infile[i])[:-5]
    
    ## Data augmentation: resize image to target size, rotate to specific angle
    image2d = Data_augment(image, resize=True, target_size=320, rotate=False, rot_angle=0)

    ## Generte random projection angle list in radian
    theta = angle_list_gen(miss_angle='Random', step=1, rand_int=5, rot_angle=0)  #50,10

    ## Conduct forward projection and receive sinogram
    sinogram = Gen_sino(image2d, theta, padding=True)

    ## introduce other artifact onto sinogram
    ## Poission, Gaussian or both noise
    sinogram = noisy_sino(sinogram, noise_type='Mixed',  noise_level_Gaussian='Random', noise_level_Poisson='Random')
    # ## out of field of view artifact
    sinogram = out_FOV(sinogram)
    ## Alignment artifact
    sinogram = sino_shiftX(sinogram)

    ## reconstruct by linear solver
    recon, x_list = recon_ADMM_NN(sinogram, theta, Model, iter_num=8, ADMM_rho_const=15, cor_shift=0, padding=True, initial=None)


    angle_list = theta/np.pi*180
    table_angle = pd.DataFrame({'Theta': angle_list})
    
    out_gt_fn = out_dir + "/gt/"+fn +'.tiff'
    out_sino_fn = out_dir + "/sino/"+fn +'.tiff'
    out_recon_fn = out_dir + "/recon/"+fn +'.tiff'
    out_angle_fn = out_dir + "/angle/"+fn +'.csv'
    
    tf.imsave(out_gt_fn, img_2d_to_3d(image2d))
    tf.imsave(out_sino_fn, sinogram)
    tf.imsave(out_recon_fn, img_2d_to_3d(recon))
    table_angle.to_csv(out_angle_fn)