In [1]:
import sys
sys.path.append('..')
import os
import json

import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt

import seaborn as sns
sns.set()

import numpy as np

import torch
import torch.nn.functional as F

from models import get_net
from utils.sr_utils import *
from utils.common_utils import np_to_pil, pil_to_np, get_fname

from BayTorch.freq_to_bayes import MCDropoutVI, MeanFieldVI
from BayTorch.inference.losses import uceloss
from BayTorch.inference.utils import uncert_regression_gal
import BayTorch.visualize as V

matplotlib.rcParams['text.usetex'] = True
matplotlib.rcParams['text.latex.preamble'] = [r'\usepackage{bm}']

In [22]:
#path_log_dir = '/home/toelle/logs'
path_log_dir = '/media/fastdata/toelle/logs_sr'

imsize = -1
factor = 4
enforce_div32 = 'CROP'

img_name = 'xray'
to_compare = 'gauss1sq2'

if img_name == 'xray':
    fname = 'data/bayesian/BACTERIA-1351146-0006.jpg'
elif img_name == 'oct':
    fname = 'data/bayesian/CNV-9997680-30.png'
elif img_name == 'us':
    fname = 'data/bayesian/081_HC.jpg'
elif img_name == 'ct':
    fname = 'data/bayesian/gt_ct.png'
elif img_name == 'mri':
    fname = 'data/bayesian/gt_mri.png'

#if to_compare in ['box', 'lanczos2', 'lanczos3', 'gauss12', 'gauss1sq2']:
runs = ['none_sr_%s_' % img_name, 'mean_field_sr_%s_gp_' % img_name, 'mean_field_sr_%s_smp_' % img_name,
        'mc_dropout_sr_%s_2d_' % img_name, 'mc_dropout_sr_%s_g2d_' % img_name, 'sgld_sr_%s_' % img_name, 'sgld_sr_%s_nll_' % img_name,
        'sgld_paper_sr_%s_' % img_name, 'sgld_paper_sr_%s_nll_' % img_name]
runs = [r + to_compare for r in runs]
labels = [r'DIP', r'FFG GP', r'FFG SMP', 
          r'MCD 2d', r'MCD g2d', r'SGLD+MSE $\sigma^2=\eta$', 
          r'SGLD+NLL $\sigma^2=\eta$', r'SGLD+MSE $\sigma=\eta$', r'SGLD+NLL $\sigma=\eta$']
imgs = load_LR_HR_imgs_sr(fname, imsize, factor, enforce_div32)

HR and LR resolutions: (256, 256), (64, 64)


In [23]:
files = os.listdir(path_log_dir)

In [24]:
from skimage.metrics import structural_similarity
imgs['bicubic_np'], imgs['sharp_np'], imgs['nearest_np'] = get_baselines(imgs['LR_pil'], imgs['HR_pil'])
if imgs['bicubic_np'].shape[0] > 1:
    ssim_bicubic = structural_similarity(np.moveaxis(imgs['HR_np'], 0, -1), np.moveaxis(imgs['bicubic_np'], 0, -1), multichannel=True)
else:
    ssim_bicubic = structural_similarity(imgs['HR_np'][0], imgs['bicubic_np'][0])
print(img_name)
print(ssim_bicubic)

xray
0.9302735008840118


In [25]:
for run, label in zip(runs, labels):
    # check wether this and extensions with _1, _2, ... exists
    dir_names = [file for file in files if file[:len(run)] == run and file[len(run):] in ['', '_1', '_2', '_3']]
    print(dir_names)
    psnrs = []
    ssims = []
    
    mses_LR = []
    mses_HR = []
    print(label)
    try:
        for dir_name in dir_names:
            train_data = torch.load('%s/%s/train_vals.pt' % (path_log_dir, dir_name), map_location='cpu')

            psnrs.append(train_data['psnr_HR_gt_sm'])
            ssims.append(train_data['ssim_HR_gt_sm'])
            mses_LR.append(train_data['mse_LR'])
            mses_HR.append(train_data['mse_HR'])

        psnrs = np.array(psnrs)
        mses_LR = np.array(mses_LR)
        mses_HR = np.array(mses_HR)
        ssims = np.array(ssims)
    
        print('max psnr: %.2f' % psnrs.max())
        #print(np.argmax(psnrs))
        #print(psnrs.mean(axis=0)[1000])
        print('max psnr last 100: %.2f' % psnrs[:,-100:].max())
        print('mean psnr: %.2f' % psnrs[:,-100:].mean())
        print('std psnr: %.4f' % np.std(psnrs[:,-100:].mean(axis=0)))
        print('\n')

        print('max ssim: %.2f' % ssims.max())
        print('max ssim last 100: %.2f' % ssims[:,-100:].max())
        print('mean ssim: %.2f' % ssims[:,-100:].mean())
        print('std ssim: %.4f' % np.std(ssims[:,-100:].mean(axis=0)))
        print('\n')

        print('mean mse LR: %.5f' % mses_LR[:,-100:].mean())
        print('std: %.4f' % np.std(mses_LR[:,-100:].mean(axis=0)))
        print('\n')

        print('mean mse HR: %.5f' % mses_HR[:,-100:].mean())
        print('std: %.4f' % np.std(mses_HR[:,-100:].mean(axis=0)))
        print('\n')

        print('\n\n')
    
    except:
        print('no exp yet')

['none_sr_xray_gauss1sq2']
DIP
max psnr: 30.64
max psnr last 100: 30.62
mean psnr: 30.62
std psnr: 0.0001


max ssim: 0.93
max ssim last 100: 0.93
mean ssim: 0.93
std ssim: 0.0000


mean mse LR: 0.00000
std: 0.0000


mean mse HR: 0.00087
std: 0.0000





['mean_field_sr_xray_gp_gauss1sq2']
FFG GP
max psnr: 30.80
max psnr last 100: 30.65
mean psnr: 30.65
std psnr: 0.0013


max ssim: 0.93
max ssim last 100: 0.93
mean ssim: 0.93
std ssim: 0.0000


mean mse LR: 0.00001
std: 0.0000


mean mse HR: 0.00090
std: 0.0000





['mean_field_sr_xray_smp_gauss1sq2']
FFG SMP
max psnr: 30.92
max psnr last 100: 30.13
mean psnr: 30.11
std psnr: 0.0108


max ssim: 0.93
max ssim last 100: 0.93
mean ssim: 0.92
std ssim: 0.0001


mean mse LR: 0.00019
std: 0.0000


mean mse HR: 0.00109
std: 0.0000





['mc_dropout_sr_xray_2d_gauss1sq2']
MCD 2d
max psnr: 30.75
max psnr last 100: 30.74
mean psnr: 30.72
std psnr: 0.0071


max ssim: 0.93
max ssim last 100: 0.93
mean ssim: 0.93
std ssim: 0.0001


mean mse LR: 0.