In [1]:
import h5py
import os

import numpy as np
from matplotlib import pyplot as plt
from matplotlib import patches as patches
plt.rcParams.update(plt.rcParamsDefault)
import seaborn as sns
import pandas as pd

from interlacer import utils
from scripts import training_config
from visualizations import visualization_lib, plotting_utils

## Undersampling: Ablations + Comparison with Other Methods

In [2]:
def load_ref_imgs(h5_path, inputs, x=50, y=50, n=64):
    f = h5py.File(h5_path, 'r')
    
    if(inputs):
        key = 'inputs'
    else:
        key = 'outputs'
        
    k = f[key][()]
    img = utils.convert_to_image_domain(k)
    patch = img[:,x:x+n,y:y+n,:]
    
    return img, k, patch

def load_imgs(h5_path, x=50, y=50, n=64):
    f = h5py.File(h5_path, 'r')

    img = f['img_outputs'][()]
    k = f['freq_outputs'][()]
    patch = img[:,x:x+n,y:y+n,:]
    
    return img, k, patch

def load_ssims(h5_path):
    f = h5py.File(h5_path, 'r')
    ssim = f['ssim'][()]
    return ssim
    
    return img, k, patch

def load_psnrs(h5_path):
    f = h5py.File(h5_path, 'r')
    psnr = f['psnr'][()]
    return psnr
    
    return img, k, patch

def plot_imgs(axes, col, img, k, patch, gt_patch, idx, title=None, rect=None):
    
    axes[0][col].set_title(title, fontsize=12)
    plotting_utils.plot_img(img, idx, ax=axes[0][col])
    plotting_utils.plot_img(patch, idx, ax=axes[1][col],)
    plotting_utils.plot_img_diff(patch, gt_patch, idx, ax=axes[2][col],)
    plotting_utils.plot_k(k, idx, ax=axes[3][col])
    
    if(rect is not None):
        x, y, n = rect
        rect_plot = patches.Rectangle((y,x), n, n, linewidth=1, edgecolor='r', facecolor='none')
        axes[0][col].add_patch(rect_plot)
    

In [3]:
def remap_title(title):
    title_dict = {'CONV_RESIDUAL FREQ': 'Frequency',
                 'CONV_RESIDUAL IMAGE': 'Image',
                 'ALTERNATING_RESIDUAL FREQ': 'Alternating',
                 'INTERLACER_RESIDUAL FREQ': 'Interleaved',}
    return title_dict[title]

fig, axes = plt.subplots(4, 7, figsize=(14,8), gridspec_kw = {'wspace':0.05, 'hspace':0.05})
fig.subplots_adjust(hspace=1, wspace=1)
exp_dir = '/data/vision/polina/users/nmsingh/dev/fouriernetworks/training/melba/undersample_ablations'
baseline_dirs = os.listdir(exp_dir)
baseline_dirs.sort()
baseline_dirs = [baseline_dirs[i] for i in [0,3]]

ref_h5_path = '/data/vision/polina/users/nmsingh/dev/fouriernetworks/testsets/brain_undersample_test.h5'
ind = 3
x = 130
y = 80
n = 64


gt_img, gt_k, gt_patch = load_ref_imgs(ref_h5_path, inputs=False, x=x, y=y, n=n)
plot_imgs(axes, 0, gt_img, gt_k, gt_patch, gt_patch, ind, title='Ground Truth')

img, k, patch = load_ref_imgs(ref_h5_path, inputs=True, x=x, y=y, n=n)
plot_imgs(axes, 1, img, k, patch, gt_patch, ind, title='Input')
input_k = k

for exp_i in range(2):
    exp = baseline_dirs[exp_i-2]
    model_dir = os.path.join(exp_dir,exp)

    config_path = visualization_lib.get_config_path(model_dir)
    config = training_config.TrainingConfig(os.path.join(model_dir, config_path))
    config.read_config()
    title = remap_title(config.architecture+' '+config.input_domain)
    h5_path = os.path.join(model_dir,'brain_undersample_test_output.h5')
    
    img, k, patch = load_imgs(h5_path, x=x, y=y, n=n)
    i = exp_i+5
    plot_imgs(axes, i, img, k, patch, gt_patch, ind, title=title, rect=(x,y,n))

plt.subplots_adjust(wspace=None, hspace=None)

unet_h5_path = '/data/vision/polina/users/nmsingh/dev/fouriernetworks/training/melba/baselines_matched/unet_brain_undersample_test_output.h5'
img, k, patch = load_imgs(unet_h5_path, x=x, y=y, n=n)
plot_imgs(axes, 2, img, k, patch, gt_patch, ind, 'U-Net', rect=(x,y,n))    

cascade_h5_path = '/data/vision/polina/users/nmsingh/dev/fouriernetworks/training/melba/baselines_matched/cascade_brain_undersample_test_output.h5'
img, k, patch = load_imgs(cascade_h5_path, x=x, y=y, n=n)
plot_imgs(axes, 3, img, k, patch, gt_patch, ind, 'CascadeNet', rect=(x,y,n))    
cascade_k = k

kiki_h5_path = '/data/vision/polina/users/nmsingh/dev/fouriernetworks/training/melba/baselines_matched/kiki_brain_undersample_test_output.h5'
img, k, patch = load_imgs(kiki_h5_path, x=x, y=y, n=n)
plot_imgs(axes, 4, img, k, patch, gt_patch, ind, 'KIKI', rect=(x,y,n))    

plt.savefig('undersample_sota.pdf',bbox_inches='tight')

In [4]:
# Uniform undersampling

def remap_title(title):
    title_dict = {'CONV_RESIDUAL FREQ': 'Frequency',
                 'CONV_RESIDUAL IMAGE': 'Image',
                 'ALTERNATING_RESIDUAL FREQ': 'Alternating',
                 'INTERLACER_RESIDUAL FREQ': 'Interleaved',}
    return title_dict[title]

fig, axes = plt.subplots(4, 7, figsize=(14,8), gridspec_kw = {'wspace':0.05, 'hspace':0.05})
fig.subplots_adjust(hspace=1, wspace=1)
exp_dir = '/data/vision/polina/users/nmsingh/dev/fouriernetworks/training/melba/uniform_undersample_ablations_8x'
baseline_dirs = os.listdir(exp_dir)
baseline_dirs.sort()
baseline_dirs = [baseline_dirs[i] for i in [0,3]]

ref_h5_path = '/data/vision/polina/users/nmsingh/dev/fouriernetworks/testsets/brain_uniform_undersample_8x_test.h5'
ind = 1

gt_img, gt_k, gt_patch = load_ref_imgs(ref_h5_path, inputs=False)
plot_imgs(axes, 0, gt_img, gt_k, gt_patch, gt_patch, ind, title='Ground Truth')

img, k, patch = load_ref_imgs(ref_h5_path, inputs=True)
plot_imgs(axes, 1, img, k, patch, gt_patch, ind, title='Input')
input_k = k

for exp_i in range(2):
    exp = baseline_dirs[exp_i-2]
    model_dir = os.path.join(exp_dir,exp)
    print(model_dir)
    config_path = visualization_lib.get_config_path(model_dir)
    config = training_config.TrainingConfig(os.path.join(model_dir, config_path))
    config.read_config()
    title = remap_title(config.architecture+' '+config.input_domain)
    h5_path = os.path.join(model_dir,'uniform_undersample_ablations_8x_test_output.h5')
    
    img, k, patch = load_imgs(h5_path)
    i = exp_i+5
    plot_imgs(axes, i, img, k, patch, gt_patch, ind, title=title)

plt.subplots_adjust(wspace=None, hspace=None)

unet_h5_path = '/data/vision/polina/users/nmsingh/dev/fouriernetworks/training/melba/baselines_matched_uniform_8x/unet_brain_undersample_test_output.h5'
img, k, patch = load_imgs(unet_h5_path)
plot_imgs(axes, 2, img, k, patch, gt_patch, ind, 'U-Net')    

cascade_h5_path = '/data/vision/polina/users/nmsingh/dev/fouriernetworks/training/melba/baselines_matched_uniform_8x/cascade_brain_undersample_test_output.h5'
img, k, patch = load_imgs(cascade_h5_path)
plot_imgs(axes, 3, img, k, patch, gt_patch, ind, 'CascadeNet')    
cascade_k = k

kiki_h5_path = '/data/vision/polina/users/nmsingh/dev/fouriernetworks/training/melba/baselines_matched_uniform_8x/kiki_brain_undersample_test_output.h5'
img, k, patch = load_imgs(kiki_h5_path)
plot_imgs(axes, 4, img, k, patch, gt_patch, ind, 'KIKI')    

plt.savefig('uniform_undersample_8x_sota.pdf',bbox_inches='tight')

/data/vision/polina/users/nmsingh/dev/fouriernetworks/training/melba/uniform_undersample_ablations_8x/MRI-uniform_undersample-0.875-None-None-None-None-None-ALTERNATING_RESIDUAL-3-64-1-10-compimage-L1-0.1-FREQ-FREQ-3-piece-True-5000-4
/data/vision/polina/users/nmsingh/dev/fouriernetworks/training/melba/uniform_undersample_ablations_8x/MRI-uniform_undersample-0.875-None-None-None-None-None-INTERLACER_RESIDUAL-3-64-1-10-compimage-L1-0.1-FREQ-FREQ-3-piece-True-5000-4


In [5]:
def append_list(df,ssims,architecture,statistic):
    for i in range(len(ssims)):
        df = df.append({'Value':ssims[i],'Architecture':architecture,'Statistic':statistic},ignore_index=True)
        
    return df


def load_allnetwork_stats(exp_dir,abl_str,sota_str):
    
    baseline_dirs = os.listdir(exp_dir)
    baseline_dirs.sort()
    baseline_dirs = [baseline_dirs[i] for i in [1,2,0,3]]

    network_list = ['Freq Conv', 'Image Conv', 'Alternating', 'Interleaved', 'U-Net', 'CascadeNet', 'KIKI']

    stats = pd.DataFrame(columns=['Value','Architecture','Statistic'])
    
    base_dir = '/data/vision/polina/users/nmsingh/dev/fouriernetworks/training/melba/'
    unet_ssim = load_ssims(os.path.join(base_dir,sota_str,'unet_brain_undersample_test_output.h5'))
    stats = append_list(stats,unet_ssim,'U-Net','SSIM')
    
    cascade_ssim = load_ssims(os.path.join(base_dir,sota_str,'cascade_brain_undersample_test_output.h5'))
    stats = append_list(stats,cascade_ssim,'Cascade','SSIM')
    
    kiki_ssim = load_ssims(os.path.join(base_dir,sota_str,'kiki_brain_undersample_test_output.h5'))
    stats = append_list(stats,kiki_ssim,'KIKI','SSIM')    
    
    for exp_i in range(2,4):
        exp = baseline_dirs[exp_i]
        model_dir = os.path.join(exp_dir,exp)

        config_path = visualization_lib.get_config_path(model_dir)
        config = training_config.TrainingConfig(os.path.join(model_dir, config_path))
        config.read_config()
        title = remap_title(config.architecture+' '+config.input_domain)
        h5_path = os.path.join(model_dir,abl_str)

        ssim = load_ssims(h5_path)
        stats = append_list(stats,ssim,network_list[exp_i],'SSIM')
    
    return stats

In [6]:
base_dir = '/data/vision/polina/users/nmsingh/dev/fouriernetworks/training/melba/'
stats_rand = load_allnetwork_stats(os.path.join(base_dir,'undersample_ablations'),'brain_undersample_test_output.h5','baselines_matched')
stats_4x = load_allnetwork_stats(os.path.join(base_dir,'uniform_undersample_ablations'),'uniform_undersample_ablations_test_output.h5','baselines_matched_uniform')
stats_8x = load_allnetwork_stats(os.path.join(base_dir,'uniform_undersample_ablations_8x'),'uniform_undersample_ablations_8x_test_output.h5','baselines_matched_uniform_8x')


In [7]:
plt.figure()
colors = ['#69f97a','#95a4ee','#fcb335','#d499ff','#5900b3']
hue_order = ['U-Net', 'CascadeNet', 'KIKI','Alternating', 'Interleaved']
this_cmap = sns.color_palette(colors,len(colors))
sns.set_palette(this_cmap)  

fig, axes = plt.subplots(1,3, figsize=(12,4), gridspec_kw = {'wspace':0.2, 'hspace':0.05})
sns.boxplot(x="Architecture", y="Value", ax=axes[2], width= 0.8, dodge=False, hue="Architecture",
                 data=stats_rand)

sns.boxplot(x="Architecture", y="Value", ax=axes[0], width= 0.8, dodge=False, hue="Architecture",
                 data=stats_4x)

sns.boxplot(x="Architecture", y="Value", ax=axes[1], width= 0.8, dodge=False, hue="Architecture",
                 data=stats_8x)

for i in range(3):
    ax = axes[i]
    ax.set_ylim(0.75,1)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.set_xticks([])
    if(i==0):
        ax.set_ylabel('SSIM')
    else:
        ax.set_ylabel('')
        ax.set_yticks([])
        
    if(i!=2):
        ax.get_legend().remove()
    else:
        ax.legend(framealpha=0.0,frameon=False,handletextpad=0.3,markerscale=3,bbox_to_anchor=(0.6, -0.1, 0.5, 0.5))
        
    xlabels = ['4x Uniform', '8x Uniform', 'Random']
    ax.set_xlabel(xlabels[i])
plt.subplots_adjust(wspace=100)
plt.savefig('undersample_bar.pdf',bbox_inches='tight')

## Ablations

In [9]:
def plot_ablations(exp_dir, ref_h5_path, h5_results_fn, ind, x, y, n):
    fig, axes = plt.subplots(4, 6, figsize=(12,8), gridspec_kw = {'wspace':0.05, 'hspace':0.05})
    fig.subplots_adjust(hspace=1, wspace=1)

    baseline_dirs = os.listdir(exp_dir)
    baseline_dirs.sort()
    baseline_dirs = [baseline_dirs[i] for i in [1,2,0,3]]
    
    gt_img, gt_k, gt_patch = load_ref_imgs(ref_h5_path, inputs=False, x=x, y=y, n=n)
    plot_imgs(axes, 0, gt_img, gt_k, gt_patch, gt_patch, ind, title='Ground Truth')

    img, k, patch = load_ref_imgs(ref_h5_path, inputs=True, x=x, y=y, n=n)
    plot_imgs(axes, 1, img, k, patch, gt_patch, ind, title='Input', rect=(x,y,n))
    input_k = k

    for exp_i in range(2,len(baseline_dirs)+2):
        exp = baseline_dirs[exp_i-2]
        model_dir = os.path.join(exp_dir,exp)

        config_path = visualization_lib.get_config_path(model_dir)
        config = training_config.TrainingConfig(os.path.join(model_dir, config_path))
        config.read_config()
        title = remap_title(config.architecture+' '+config.input_domain)
        
        h5_path = os.path.join(model_dir,h5_results_fn)

        img, k, patch = load_imgs(h5_path, x=x, y=y, n=n)
        plot_imgs(axes, exp_i, img, k, patch, gt_patch, ind, title=title, rect=(x,y,n))

    plt.subplots_adjust(wspace=None, hspace=None)

In [10]:
colors = ['#ff9999','#99d6ff','#d499ff','#5900b3']
this_cmap = sns.color_palette(colors,len(colors))
sns.set_palette(this_cmap)

In [11]:
def plot_ablations_scatter(exp_dir, h5_results_fn,ax=None,ylabel=False,show_legend=False,exp_title=None):
    baseline_dirs = os.listdir(exp_dir)
    baseline_dirs.sort()
    baseline_dirs = [baseline_dirs[i] for i in [1,2,0,3]]

    ssims = []
    for exp_i in range(len(baseline_dirs)):
        exp = baseline_dirs[exp_i]
        model_dir = os.path.join(exp_dir,exp)

        config_path = visualization_lib.get_config_path(model_dir)
        config = training_config.TrainingConfig(os.path.join(model_dir, config_path))
        config.read_config()
        title = remap_title(config.architecture+' '+config.input_domain)
        h5_path = os.path.join(model_dir, h5_results_fn)

        ssim = load_ssims(h5_path)
        ssims.append(ssim)


    hue_order = ['Frequency','Image','Alternating','Interleaved']

        
    legend = ['Freq Conv', 'Image Conv', 'Alternating', 'Interleaved']
    sort_inds = np.argsort(ssims[3])
        
    if(ax is None):
        plt.figure(figsize=(20,12))

        for i in range(4):
            plt.scatter(np.arange(0,100),ssims[i][sort_inds])
        plt.legend([legend[i] for i in range(4)])

        plt.xlabel('Subject')
        plt.ylabel('SSIM')
        
    else:
        for i in range(4):
            if(i<3):
                sns.scatterplot(range(100),ssims[i][sort_inds],label=hue_order[i],ax=ax,edgecolor=None,s=10)
            else:
                sp = sns.scatterplot(range(100),ssims[i][sort_inds],label=hue_order[i],ax=ax, palette=this_cmap,hue_order=hue_order,edgecolor=None,s=10)

        ax.set_xlabel('Subject')
        
        if(ylabel):
            ax.set_ylabel('SSIM')
        else:
            ax.set_ylabel('')
            
        if(show_legend):
            handles, labels = ax.get_legend_handles_labels()
            ax.get_legend().set_title('') 

            ax.legend(handles,labels,framealpha=0.0,handletextpad=0.,markerscale=3,bbox_to_anchor=(0.6, 0, 0.5, 0.5))
        else:
            sp.legend_.remove()


        ax.set_ylim(bottom=0.5)
        ax.set_title(exp_title)
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        #ax.tick_params(labelsize=16)

        

In [12]:
undersample_exp_dir = '/data/vision/polina/users/nmsingh/dev/fouriernetworks/training/melba/undersample_ablations'
undersample_ref_h5_path = '/data/vision/polina/users/nmsingh/dev/fouriernetworks/testsets/brain_undersample_test.h5'
undersample_h5_results_fn = 'brain_undersample_test_output.h5'

ind = 78
x=110
y=110
n=84

plot_ablations(undersample_exp_dir, undersample_ref_h5_path, undersample_h5_results_fn, ind, x, y, n)
plt.savefig('undersample_imgs.pdf',bbox_inches='tight')

In [13]:
motion_exp_dir = '/data/vision/polina/users/nmsingh/dev/fouriernetworks/training/melba/motion_ablations'
motion_ref_h5_path = '/data/vision/polina/users/nmsingh/dev/fouriernetworks/testsets/brain_motion_test.h5'
motion_h5_results_fn = 'brain_motion_test_output.h5'

ind = 1
x=110
y=110
n=84

plot_ablations(motion_exp_dir, motion_ref_h5_path, motion_h5_results_fn, ind, x, y, n)
plt.savefig('motion_imgs.pdf',bbox_inches='tight')

In [14]:
noise_exp_dir = '/data/vision/polina/users/nmsingh/dev/fouriernetworks/training/melba/noise_ablations'
noise_ref_h5_path = '/data/vision/polina/users/nmsingh/dev/fouriernetworks/testsets/brain_noise_test.h5'
noise_h5_results_fn = 'brain_noise_test_output.h5'


ind = 1
x=80
y=50
n=84

plot_ablations(noise_exp_dir, noise_ref_h5_path, noise_h5_results_fn, ind, x, y, n)
plt.savefig('noise_imgs.pdf',bbox_inches='tight')

In [15]:
us_motion_exp_dir = '/data/vision/polina/users/nmsingh/dev/fouriernetworks/training/melba/undersample_motion_ablations'
us_motion_ref_h5_path = '/data/vision/polina/users/nmsingh/dev/fouriernetworks/testsets/brain_undersample_motion_test.h5'
us_motion_h5_results_fn = 'brain_undersample_motion_test_output.h5'


ind = 3
x=110
y=110
n=84

plot_ablations(us_motion_exp_dir, us_motion_ref_h5_path, us_motion_h5_results_fn, ind, x, y, n)
plt.savefig('undersample_motion_imgs.pdf',bbox_inches='tight')

In [16]:
fig, axes = plt.subplots(1,4, figsize=(16,4), gridspec_kw = {'wspace':0.2, 'hspace':0.2})
plot_ablations_scatter(motion_exp_dir, motion_h5_results_fn,ax=axes[0],ylabel=True, exp_title='Motion')
plot_ablations_scatter(noise_exp_dir, noise_h5_results_fn,ax=axes[1], exp_title='Noise')
plot_ablations_scatter(undersample_exp_dir, undersample_h5_results_fn,ax=axes[2], exp_title='Undersampling')
plot_ablations_scatter(us_motion_exp_dir, us_motion_h5_results_fn,ax=axes[3], show_legend=True, exp_title='Undersampled Motion')
plt.savefig('scatterplots.pdf',bbox_inches='tight')