In [15]:
import os, sys
from os import listdir
from os.path import isfile, join
import numpy as np
import torch
import matplotlib.pyplot as plt

sys.path.append('/home/vanveen/ConvDecoder/')
from utils.data_io import load_h5
from utils.evaluate import calc_metrics
from utils.transform import fft_2d, ifft_2d, root_sum_squares, \
                            reshape_complex_vals_to_adj_channels, \
                            reshape_adj_channels_to_complex_vals, \
                            crop_center

In [17]:
def load_gt(file_id):
    _, ksp_orig = load_h5(file_id)
    ksp_orig = torch.from_numpy(ksp_orig)
    return crop_center(root_sum_squares(ifft_2d(ksp_orig)), DIM, DIM)

def norm_im(im):
    ''' normalize img to [0,1] range '''
    return (im - im.min()) * (1. / im.max())

def plot_list(arr_list, clim=None):
    
    NUM_COLS = len(arr_list)
    title_list = ['ground-truth', 'dc', 'fm', 'diff']
    
    fig = plt.figure(figsize=(20,20))
    
    for idx in range(NUM_COLS):
        ax = fig.add_subplot(1,NUM_COLS,idx+1)
        ax.imshow(arr_list[idx], cmap='gray', clim=clim)
        ax.set_title(title_list[idx], fontsize=20)
        ax.axis('off')

### compare old dd+ (pytorch==1.5) vs new dd+ (pytorch==1.7)

from the baseline directory `/bmrNAS/people/dvv/out_fastmri/`...
- (1) `old_pytorch1.5/orig_alpha_search/` (n=11): using dd+ pt==1.5 before code re-work 
- (2) `old_pytorch1.5/sf0.1/` (n=5): using dd+ pt==1.5 after code re-work, then converted back
    - not sure if this is actually in pt==1.5 or if i just took pt==1.7 and used the deprecated fft/ifft functions in a hacky way. either way, i might have this stashed in a branch somewhere?
- (3) `new_pytorch1.7/sf0.1/` (n=5): using dd+ pt==1.7 sf0.1 w `fit.py`
- (4) `expmt_fm_loss/trials_best/` (n=11): using dd+ pt==1.7 sf=0.1 w `fit_feat_map_loss.py` instead of `fit.py`
    - i.e. looking at trial_id = `0000_10k`

##### results
- 1 > 4
- 2 > 1, but we only have five samples for 2
- 2, 3 are comparable
- ssim, psnr
    - (1)
        - n=5: 0.7884340675904544 31.730754973365197
        - n=11: 0.7569765714478369 30.87466793360071
    - (2) 0.796467652242989 32.034361010080524
    - (3) 0.793504546953909 31.93507930075753
    - (4)
        - n=5: 0.7647340218936469 30.56831878262691
        - n=11: 0.7399233683613828 29.98914482917978

# TODO
- try to re-create (3) by following the current codebase, i.e. same process for (4) but using `fit.py`
    - i.e. compare to files in `path_3` across n=5
- if that fails, look at old ipynb's and branches. try to re-create (3) first, but fall back to (2) if necessary
- once i re-create decent results, want to run this over a larger validation set. send to arjun

In [49]:
path_base = '/bmrNAS/people/dvv/out_fastmri/'
path_1 = path_base + 'old_pytorch1.5/orig_alpha_search/'
path_2 = path_base + 'old_pytorch1.5/sf0.1/'
path_3 = path_base + 'new_pytorch1.7/sf0.1/'
path_4 = path_base + '/expmt_fm_loss/trials_best/'

# file_id_list = ['1000273', '1000325', '1000464', \
#                 '1000537', '1000818', '1001140', '1001219', \
#                 '1001338', '1001598', '1001533', '1001798']
file_id_list = ['1000273', '1000325', '1000464', '1000537', '1000818']

In [50]:
DIM = 320
nf = len(file_id_list)

s_o_list, p_o_list, s_n_list, p_n_list = np.empty(nf), np.empty(nf), np.empty(nf), np.empty(nf)

for idx, file_id in enumerate(file_id_list):
    
    img_gt = load_gt(file_id)

#     img_old = np.load('{}{}_iter10000_alpha0.npy'.format(path_1, file_id))
    img_old = np.load('{}{}_dc.npy'.format(path_3, file_id))
#     img_new = np.load('{}0000_10k_{}_dc.npy'.format(path_4, file_id))

    img_new = XXXXXXXXXXXXXXXXXXXXXXXXXXX

    _, _, s_o_list[idx], p_o_list[idx] = calc_metrics(img_old, np.array(img_gt))
    _, _, s_n_list[idx], p_n_list[idx] = calc_metrics(img_new, np.array(img_gt))

    img_diff = np.abs(norm_im(img_old) - norm_im(img_new))

#     plot_list([img_gt, img_old, img_new, img_diff])

print(s_o_list.mean(), p_o_list.mean()) 
print(s_n_list.mean(), p_n_list.mean())

0.7569765714478369 30.87466793360071
0.7399233683613828 29.98914482917978
