In [None]:
import os
import numpy as np
import vgi
import time
import torch 
from vgi.ct import createCircleMask
from gmmcbct import Composer, CTtransform, show, intensity_adjust, evaluate, strev

data_dir = 'testdata/'
out_dir = 'output/fdk'
data_name = '025_128_v2'
#data_name = '025_256'
s_angles = '30'
gt_path = data_dir + data_name + '_gt.npy'
pre_recon_path = data_dir + data_name + '_fdk_30.npy'
proj_path = data_dir + data_name + '_proj_30.npy'

gt = np.load(gt_path)
pre_vol = np.load(pre_recon_path)
if len(pre_vol.shape) == 4:
    pre_vol = pre_vol[0]
proj = np.load(proj_path)
proj = np.swapaxes(proj, 0, 1)
print('gt', gt.shape, vgi.metric(gt))

mask = createCircleMask([128, 128], 63)
#pre_vol = vgi.normalize(pre_vol)
#gt = gt * mask
#pre_vol = pre_vol * mask
show(pre_vol)
print('pre_vol', pre_vol.shape, vgi.metric(pre_vol))
print('proj', proj.shape, vgi.metric(proj))
out_shape = gt.shape
proj_shape = proj.shape
n_angles, n_det_rows, n_det_cols = proj_shape

out_path = data_dir + data_name + '_ge.npy'

#ct_trans = None
ct_trans = CTtransform.create128f(n_angles = n_angles)
max_batch_size = 1
volume_batch_size = 4
slice_batch_size = 128
min_size = 1.5


#loss = 'SSIML1'
#loss = 'SSIM'
#loss = 'L1'
loss = 'MSE'

init_opt = True
n_randoms = 5
clip_min = 0.0
clip_max = None
sigma = .8
#bin_threshold = 0.6
bin_thres_set = [0.6, 0.55, 0.50, 0.45, 0.4, 0.35, 0.3, 0.25]
#bin_thres_set = [0.6, 0.55,]
bin_threshold_ratio = 1.0
min_vx = 4


verbose = 1
n_log = 50
epoches = 1
#fo_rounds = 1
fo_rounds = 100
#bo_rounds = 0
bo_rounds = 1

min_decline = 0.0000001
#min_decline = 100000.
min_init_size = 1.5
rep = 2
rep_vx = 27
rep_s = 0.5

composer = Composer(shape = out_shape, target = proj, 
                    data_trans = ct_trans,
                    volume_batch_size = volume_batch_size, slice_batch_size = slice_batch_size,
                    max_batch_size = max_batch_size,
                    loss = loss)
print('voxel location range', composer.min_p, composer.max_p)
print('data boundary', composer.data_min_p, composer.data_max_p)
print('projection value range', composer.val_range)
ev = evaluate(pre_vol, gt, adjust = True)  
print('pre_vol ev:', strev(ev))

ev_all = []
for bin_threshold in bin_thres_set:
    print(' ----------------------------- ')
    print('bin_threshold:', bin_threshold)
    
    t_s = time.time()
    composer.setMinSize(min_size)
    vol, _G = composer.reconstructErrMap(pre_vol, gt = gt, init_opt = init_opt, n_randoms = n_randoms,
                                        epoches = epoches, rounds = bo_rounds,
                                         min_init_size = min_init_size,
                                         clip_min = clip_min, clip_max = clip_max, sigma = sigma, min_vx = min_vx,
                                         rep = rep, rep_vx = rep_vx, rep_s = rep_s,
                                         bin_threshold = bin_threshold, bin_threshold_ratio = bin_threshold_ratio,
                                         min_decline = min_decline, opt_rounds = fo_rounds, clamp = True, 
                                         verbose = verbose, n_log = n_log)
    #show(vol)
    G = vgi.toNumpy(_G)
    n_g = len(_G)
    t = time.time() - t_s
    print(data_name, ', total time:%0.2f'%t)  
    vol = vol * mask
    ev = evaluate(vol, gt, adjust = True)  
    ev_all.append([bin_threshold, n_g, t, ev])
    print('Finished n:', n_g, ', ev:', strev(ev))    
    
    nbinthres = int(bin_threshold * 100)
    out_name = data_name + '_bt%d'%nbinthres
    out_rec_path = os.path.join(out_dir, out_name + '_rec.npy')
    out_g_path = os.path.join(out_dir, out_name + '_G.npy')
    np.save(out_rec_path, vol)       
    np.save(out_g_path, G)    
    
print('=======================================')
print('thres', 'n', 'time', 'MAE', 'MSE', 'SSIM', 'PSNR', sep = '\t')
for ev_i in ev_all:
    print('%0.2f'%ev_i[0], ev_i[1], '%0.1f'%ev_i[2], strev(ev_i[3]), sep = '\t')