### Notebook Purpose

prototype dc regularization of intermediate layers

In [1]:
import os, sys
import h5py
import numpy as np
import torch
import torchvision
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import copy
import PIL

sys.path.append('/home/vanveen/ConvDecoder/')
from utils.data_io import load_h5, get_mask, num_params
from include.decoder_conv import init_convdecoder
from include.decoder_conv import get_scale_factor, get_net_input, get_hidden_size
from include.fit import fit
from include.subsample import MaskFunc
from utils.evaluate import calc_metrics
from utils.transform import fft_2d, ifft_2d, root_sum_squares, \
                            reshape_complex_vals_to_adj_channels, \
                            reshape_adj_channels_to_complex_vals, \
                            crop_center

if torch.cuda.is_available():
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    dtype = torch.cuda.FloatTensor
    torch.cuda.set_device(0)
else:
    dtype = torch.FloatTensor

# STATUS

### done:
- done: extracted feature maps
    - actually getting collapsed channels from original feature maps, i.e. 160 --> 30 via 1x1 conv
    - we'll be computing mse w dimn [2*nc, x_, y_]
- applied a mask to the feat_map
    - ksp_m_down, shape [x_,y_], has many zero values. feat_map doesn't have zero values
    - we compute mse(feat_map, ksp_m_down) for each layer
    - do we want to penalize zero values in ksp_m_down? i don't think so
        - we just don't have a prior at those indices. the true value is surely not zero
        - if we did penalize perfectly, we'd recreate ksp_masked. want to allow the network expressive freedom to re-create ksp_orig
- normalized ksp_m_down according to feat_map so we maintain distribution of feat_map
- performed weighted mse so we apply the same weight to each pixel
    - last layer has 2000x as many pixels as first layer

### possible implementation todo's
- try only applying feat_map_loss in last x iterations
- try weighting earlier/later layers more heavily
- currently we downsample ksp_masked according to the size of each hidden layer's feat_map
    - instead, we could upsample each feat_map and compute its loss with the original-size ksp_masked
        - this would encourage network to learn upsampling according to upsample_mode='nearest' 
- could use different downsampling methods to create ksp_m_down
    - currently using most expressive method, i.e. bicubic
    - other options
        - nearest, i.e. inverse method used by dd+ for upsampling
        - bilinear

### scaling this up over a bunch of settings

- make a dataframe where indiv row is one config setting, indiv column is one sample
- rows, i.e. config attributes, according to one run_id
    - num_iter
        - total
        - at which we turn on fm_loss, i.e. 0.5 * total and 0.8 * total
    - alpha_fm = 10 ** exp
    - weighting: early, late, or all
    - downsampling method: nearest, bilinear, bicubic
- each entry will be one sample's scores as a tuple w (ssim, psnr)

#### TIME CALCS - number of hours

[3 file_ids] * [3 weighting_options] * [3 downsamp_options] * [3 iter_start_fm_loss] * [6 alpha_fm]

In [102]:
x = 81*6 # total runs
x = x / 2 # total hours if 10000 iter per run
x = x / 2 # split up on two gpu's
x = x * (2000 / 10000) # if i do 2000 iter
x

24.3

In [90]:
DIM = 320
SCALE_FAC = 0.1
NUM_ITER = 1

file_id_list = ['1000273']#, '1000325', '1000464', '1000007', '1000537', '1000818', \
#                  '1001140', '1001219', '1001338', '1001598', '1001533', '1001798']
file_id_list.sort()

exp_list = [-1]#-2, -3, -4, -5, -10] # i.e. loss = loss_ksp + 10**exp * loss_feat_map

imgs_run_list = []

path_out = '/bmrNAS/people/dvv/out_fastmri/expmt_fm_loss/'

for file_id in file_id_list:

    f, ksp_orig = load_h5(file_id)
    ksp_orig = torch.from_numpy(ksp_orig)

    mask = get_mask(ksp_orig)
    
    for exp in exp_list:

        net, net_input, ksp_orig_, hidden_size = init_convdecoder(ksp_orig, mask)

        ksp_masked = SCALE_FAC * ksp_orig_ * mask
        img_masked = ifft_2d(ksp_masked)

        net, mse_wrt_ksp, mse_wrt_img = fit(
            ksp_masked=ksp_masked, img_masked=img_masked,
            net=net, net_input=net_input, mask2d=mask, num_iter=NUM_ITER,\
            alpha_fm=10**exp)


        # use above two lines when we don't want to access feature maps
        img_out = net(net_input.type(dtype))
        img_out = img_out[0] if type(img_out) is tuple else img_out
    #     img_out, feat_maps = net(net_input.type(dtype))

        img_out = reshape_adj_channels_to_complex_vals(img_out[0])
        ksp_est = fft_2d(img_out)
        ksp_dc = torch.where(mask, ksp_masked, ksp_est)

        img_masked = crop_center(root_sum_squares(ifft_2d(ksp_masked)).detach(), DIM, DIM)
        img_est = crop_center(root_sum_squares(ifft_2d(ksp_est)).detach(), DIM, DIM)
        img_dc = crop_center(root_sum_squares(ifft_2d(ksp_dc)).detach(), DIM, DIM)
        img_gt = crop_center(root_sum_squares(ifft_2d(ksp_orig)), DIM, DIM)
        imgs_run_list.append([img_masked, img_est, img_dc, img_gt])
        
        
        _, _, ssim_est, psnr_est = calc_metrics(np.array(img_est), np.array(img_gt))
        _, _, ssim_dc, psnr_dc = calc_metrics(np.array(img_dc), np.array(img_gt))

    #     np.save('{}{}_est.npy'.format(path_out, file_id), img_est)
    #     np.save('{}{}_dc.npy'.format(path_out, file_id), img_dc)
    #     np.save('{}{}_gt.npy'.format(path_out, file_id), img_gt)