# Plot filters

In [None]:
%matplotlib inline
import os
import keras
from keras import backend as K
from keras.preprocessing import image
import tensorflow as tf
from matplotlib import pyplot as plt
import cv2
import numpy as np


def calc_bandwidth(lambd, sigma):
    r = np.pi*sigma/lambd
    c = np.sqrt(np.log(2)/2)
    return np.log2((r + c)/(r - c))

def calc_sigma(lambd, bandwidth):
    p = 2**bandwidth
    c = np.sqrt(np.log(2)/2)
    return lambd * c / np.pi  * (p + 1) / (p - 1)

def calc_lambda(sigma, bandwidth):
    p = 2**bandwidth
    c = np.sqrt(np.log(2)/2)
    return sigma * np.pi / c  * (p - 1) / (p + 1)

# Set parameters
project_root = os.path.realpath(os.pardir)
fig_dir = os.path.join(project_root, "results", "figs")
if not os.path.isdir(fig_dir):
    os.makedirs(fig_dir)

In [None]:
# image_path = "/work/data/Lenna.png"
image_path = "/work/data/example_aeroplane_s_000021.png"
# image_path = "/work/data/4.2.03.tiff"  # Mandrill
# image_path = "/work/data/4.2.07.tiff"  # Peppers
# img = image.load_img(image_path)
# img = plt.imread(image_path)
img = cv2.imread(image_path)
# img = np.asarray(img, dtype=np.uint8)

print(np.amin(img), np.amax(img))
print(img.shape)

img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img = np.copy(img.astype('float32')) / 255
# img = img.astype('float32')
print(type(img))
img = K.expand_dims(img, 0)
img = K.expand_dims(img, -1)
print(img.shape)
print(np.amin(img), np.amax(img))

lambdas = [3, 4, 5, 6, 7, 8]  # 3 <= lambd <= W/2
# sigmas = [1, 2, 3] # 4]  
# bandwidths = np.linspace(0.4, 2.6, num=3)  # ~1.5 <= bw <= ~3
bandwidths = np.linspace(1, 1.8, num=3)

lambdas = [5]
bandwidths = [1.4]

n_thetas = 8
n_psis = 4  # 1, 2, 4
n_gammas = 2

thetas = np.linspace(0, np.pi, n_thetas, endpoint=False)
psis = np.linspace(0, 2*np.pi, n_psis, endpoint=False)
gammas = np.linspace(1, 0, n_gammas, endpoint=False)

# Fix sigma and bw
# sigmas = [2, 3, 4, 5]
# bandwidths = np.linspace(1, 1.8, num=5)


convolve = True
size = (31, 31)
# gabor = {'sigmas': [sigma],
#          'lambdas': [lambd],
#          'thetas': thetas,        
#          'psis': psis,
#          'gammas': gammas,
#          'ksize': (31, 31),
#         }

ncols = len(thetas)
nrows = int(np.ceil(len(gammas)*len(psis)))

fontsize = 20
space = 0.15
width = 12
print(f"Total Gabor filters: {ncols*nrows}")

img_scale = 6

if convolve:
    nrows *= 2

i = 0
# for sg, sigma in enumerate(sigmas):
for bw in bandwidths:
    for lm, lambd in enumerate(lambdas):
        sigma = calc_sigma(lambd, bw)
#         lambd = calc_lambda(sigma, bw)
        fig, axes = plt.subplots(nrows=nrows, ncols=ncols, sharex='row', sharey='row', figsize=(width, width*nrows/ncols))
        
        for gm, gamma in enumerate(gammas):
            for ps, psi in enumerate(psis):
                for th, theta in enumerate(thetas):

                    params = {'ksize': size, 'sigma': sigma, 'theta': theta, 'lambd': lambd, 'gamma': gamma, 'psi': psi}
                    gf = cv2.getGaborKernel(**params, ktype=cv2.CV_32F)

                    row, col = (i//ncols), i%ncols
                    if convolve:
                        row *= 2
                    axes[row, col].imshow(gf, cmap='gray', vmin=-1, vmax=1)
                    axes[row, col].set_xticks([])
                    axes[row, col].set_yticks([])
                    # print(np.amin(gf), np.amax(gf))
#                     simplify(th*np.pi/n_thetas)
                    if i//ncols == 0:
#                         axes[row, col].set_title(r"$\theta = {:.3}\pi$".format(theta/np.pi))

                        if th == 0:
                            axes[row, col].set_title(r"$\theta = 0$", fontsize=fontsize)
                        else:
                            axes[row, col].set_title(r"$\theta = \frac{{{}}}{{{}}}\pi$".format(th, n_thetas), fontsize=fontsize)
                    if i%ncols == 0:
#                         axes[row, col].set_ylabel(r"$\psi = {:.3}\pi, \gamma = {}$".format(psi/np.pi, gamma))  #lambd, sigma))
                        if ps == 0:
                            axes[row, col].set_ylabel(r"$\psi = 0, \gamma = {}$".format(gamma), fontsize=fontsize)
                        else:
                            axes[row, col].set_ylabel(r"$\psi = \frac{{{}}}{{{}}}\pi, \gamma = {}$".format(ps, n_psis, gamma), fontsize=fontsize)

                    if convolve:
                        gf = K.expand_dims(gf, -1)
                        gf = K.expand_dims(gf, -1)
                        # https://stackoverflow.com/questions/34619177/what-does-tf-nn-conv2d-do-in-tensorflow
                        # K.conv2d(image.img_to_array(img), gf)
                        fimg = K.conv2d(img, gf, padding='same')
                        fimg = tf.Session().run(fimg[0,:,:,0])
                        axes[row+1, col].imshow(fimg, cmap='gray', vmin=-img_scale, vmax=img_scale)
                        # axes[row+1, col].imshow(fimg[0,:,:,0].eval(), cmap='gray')
                        axes[row+1, col].set_xticks([])
                        axes[row+1, col].set_yticks([])
                        print(np.amin(fimg), np.amax(fimg))
                    i += 1
plt.tight_layout()
plt.subplots_adjust(wspace=space, hspace=space)
plt.savefig(os.path.join(fig_dir, f"gabor_kernels.pdf"), bbox_inches="tight")  # , additional_artists=[lgd])

In [None]:
# import os
# import numpy as np

# # Set parameters
# project_root = os.path.realpath(os.pardir)
# fig_dir = os.path.join(project_root, "results", "figs")
# if not os.path.isdir(fig_dir):
#     os.makedirs(fig_dir)
data_set = 'pixel'
# stimulus_sets = ['static', 'jitter']
stimulus_set = 'jitter'
start_trial = 1
num_trials = 7
epochs = 20
save_loss = 0
data_augmentation = 0
fresh_data = 0
n_gpus = 1

lambdas = [3, 4, 5, 6, 7, 8]
bandwidths = np.linspace(1, 1.8, num=3)

In [None]:
import os
from tqdm import tqdm_notebook

# for stimulus_set in tqdm_notebook(stimulus_sets, desc="Set"):
# for bandwidth in tqdm_notebook(bandwidths, desc='$b$', leave=True):
for trial in tqdm_notebook(range(start_trial, start_trial+num_trials), desc='Trial'):
    for bandwidth in tqdm_notebook(bandwidths, desc='$b$', leave=True):
#         for sigma in tqdm_notebook(sigmas, desc='$\sigma$', leave=True):
        for lambd in tqdm_notebook(lambdas, desc='$\lambda$', leave=True):
            sigma = calc_sigma(lambd, bandwidth)
            trial_label = f"{trial}_sigma={sigma:.2}_lambd={lambd}"
            args = (f"--data_set {data_set} --stimulus_set {stimulus_set} "
                    f"--trial_label {trial_label} --sigma {sigma} --lambd {lambd} "
                    f"--data_augmentation {data_augmentation} --fresh_data {fresh_data} "
                    f"--n_gpus {n_gpus} --epochs {epochs} --save_loss {save_loss}")  #  --model_name {model_name}
            os.system(f'python3 gabornet.py {args}')

In [None]:
# Load accuracy scores and plot
%matplotlib inline
import json
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns


def calc_bandwidth(lambd, sigma):
    r = np.pi*sigma/lambd
    c = np.sqrt(np.log(2)/2)
    return np.log2((r + c)/(r - c))

def calc_sigma(lambd, bandwidth):
    p = 2**bandwidth
    c = np.sqrt(np.log(2)/2)
    return lambd * c / np.pi  * (p + 1) / (p - 1)

# data_set = 'pixel'
stimulus_set = 'jitter'  # 'static'  # 'jitter'
noise_types = ['Original', 'Salt-and-pepper', 'Additive', 'Single-pixel']  # 'Original'
test_conditions = ['Same', 'Diff', 'NoPix']
results_dir = os.path.join(project_root, 'results', data_set, stimulus_set)

rows = []
test_rows = []

# for stimulus_set in stimulus_sets:
for trial in range(1, 1+num_trials):
    for noise_type in noise_types:
        for lambd in lambdas:
            for bandwidth in bandwidths:
#             for sigma in sigmas:
                sigma = calc_sigma(lambd, bandwidth)

                trial_label = f"{trial}_sigma={sigma:.2}_lambd={lambd}"
                model_name = f"{noise_type}_{trial_label}"
                # print(model_name)

                acc_scores = np.load(os.path.join(results_dir, f'{model_name}_ACC.npy'))
                valacc_scores = np.load(os.path.join(results_dir, f'{model_name}_VALACC.npy'))
                if save_loss:
                    loss = np.load(os.path.join(results_dir, f'{model_name}_LOSS.npy'))
                    valloss = np.load(os.path.join(results_dir, f'{model_name}_VALLOSS.npy'))
                else:
                    loss = np.zeros(epochs)
                    valloss = np.zeros(epochs)

                with open(os.path.join(results_dir, f'{model_name}_CONDVALACC.json'), "r") as jf:
                    cond_acc = json.load(jf)
                if save_loss:
                    with open(os.path.join(results_dir, f'{model_name}_CONDVALLOSS.json'), "r") as jf:
                        cond_loss = json.load(jf)
                else:
                    cond_loss = {condition: 0 for condition in test_conditions}

                for condition in test_conditions:
                    test_rows.append({'Trial': trial, 'Noise Type': noise_type,
                                     'Condition': condition, 'Sigma': sigma, 'Lambda': lambd, 'Bandwidth': bandwidth,
                                      'Loss': cond_loss[condition], 'Accuracy': cond_acc[condition]})
                for epoch in range(epochs):
                    rows.append({'Trial': trial, 'Noise Type': noise_type, 'Sigma': sigma, 'Lambda': lambd, 'Bandwidth': bandwidth,
                                 'Evaluation': 'Testing', 'Epoch': epoch+1, 'Loss': valloss[epoch], 
                                 'Accuracy': valacc_scores[epoch]})

                    rows.append({'Trial': trial, 'Noise Type': noise_type, 'Sigma': sigma, 'Lambda': lambd, 'Bandwidth': bandwidth,
                                 'Evaluation': 'Training', 'Epoch': epoch+1, 'Loss': loss[epoch], 
                                 'Accuracy': acc_scores[epoch]})

scores = pd.DataFrame(rows, columns=['Trial', 'Noise Type', 'Evaluation', 'Sigma', 'Lambda', 'Bandwidth', 'Epoch', 'Loss', 'Accuracy'])
test_scores = pd.DataFrame(test_rows, columns=['Trial', 'Noise Type', 'Condition', 'Sigma', 'Lambda', 'Bandwidth', 'Loss', 'Accuracy'])
# scores
# test_scores

test_scores.rename(columns={'Noise Type': 'Noise'}, inplace=True)
test_scores.loc[:, 'Accuracy'] *= 100  # Convert to percentage

In [None]:
test_scores.head()

In [None]:
print(test_scores['Sigma'].min(), test_scores['Sigma'].max())

# Plot full data set

In [None]:
display_lam = 5
display_bw = 1.4

subfig_size = (7, 6)
sns.set_context('paper')
# sns.set(context='paper', style="white", font_scale=1.45)  #, font='serif') 
#         rc={"font.family": "serif", "font.serif": ["Times", "Palatino", "serif"]})
sns.set(style="white")
sns.set(font='serif')
sns.set(font_scale=2)  # 2.5
sns.set_style("white",
              {"font.family": "serif",
               "font.serif": ["Times", "Palatino", "serif"]})
palette = plt.rcParams["axes.prop_cycle"].by_key()["color"]

orig_scores = test_scores.query('Noise == "Original" and Condition == "Same"')
noise_scores = test_scores.drop(test_scores.query('Noise == "Original"').index)
new_scores = [orig_scores.replace("Original", noise).replace("Same", "None") for noise in noise_types[1:]]
recomb_scores = pd.concat([noise_scores, *new_scores]) #.query(f"Bandwidth == {display_bw}")
# recomb_scores = pd.concat([noise_scores, orig_scores.replace("Same", "None")]) #.query(f"Bandwidth == {display_bw}")

## Plot training end points

In [None]:
g = sns.catplot(x="Noise", y="Accuracy", row='Bandwidth', col='Lambda', hue="Condition", kind="bar", data=test_scores)
# g = sns.catplot(x="Noise Type", y="Accuracy", col='Lambda', hue="Condition", kind="bar", data=test_scores, col_wrap=3)
g.set(ylim=(0,100))
for row in g.axes:
    for ax in row:
        ax.axhline(y=10, linestyle='--', color='#e74c3c')

## Plot performance trends

In [None]:
g = sns.relplot(x='Lambda', y='Accuracy', row='Bandwidth', col='Noise', hue='Condition', kind='line', data=test_scores)  # .query("Bandwidth == 1.4")
g.set(ylim=(0,100))
for row in g.axes:
    for ax in row:
        ax.axhline(y=10, linestyle='--', color='#e74c3c')
g.set_xlabels(r'$\lambda$ [pixels]')

# Remove bars for original condition and add as lines

In [None]:
g = sns.catplot(x="Noise", y="Accuracy", row='Bandwidth', col='Lambda', hue="Condition", kind="bar", data=noise_scores)
g.set(ylim=(0,100))
for bi, row in enumerate(g.axes):
    for li, ax in enumerate(row):
        ax.axhline(y=10, linestyle='--', color='#e74c3c')
        mean_scores = orig_scores.query('Lambda == "{}" and Bandwidth == "{}"'.format(lambdas[li], bandwidths[bi])).mean()
        ax.axhline(y=mean_scores['Accuracy'], linestyle=':', color=palette[4])

## Plot trends across $\lambda$

In [None]:
# with sns.axes_style(style='white', rc={'title.size': 48, "font.family": "serif", "font.serif": ["Times", "Palatino", "serif"]}):
# with plt.rc_context(dict(sns.axes_style("white"),
#                          **sns.plotting_context("paper", font_scale=1.25))):
# with sns.plotting_context(context='paper', font_scale=1.25, rc={'font.size': 24, "font.family": "serif", "font.serif": ["Times", "Palatino", "serif"]}):
# sns.set(font_scale=1.4)
aspect = 1.3
width = 5.5
height = width / aspect
g = sns.relplot(x='Lambda', y='Accuracy', row='Bandwidth', col='Noise', hue='Condition', 
                kind='line', data=recomb_scores, palette=[*palette[:3], palette[7]], aspect=aspect, height=5.2) #, height=height)  # .query("Bandwidth == 1.4")
g.set(ylim=(0,100))
g.set_xlabels(r'$\lambda$ [pixels]')
g.set(xticks=lambdas)
g.set_titles(template=r'$b$ = {row_name} | {col_name}')  # {col_var} = 
for bi, row in enumerate(g.axes):
    for li, ax in enumerate(row):
        ax.axhline(y=10, linestyle='--', linewidth=3, color='#e74c3c')  # e74c3c
#         sns.lineplot(x='Lambda', y='Accuracy', hue='Noise', linewidth=3, palette=[palette[7]], ax=ax, legend=False,
#                      data=orig_scores.query('Bandwidth == "{}"'.format(bandwidths[bi])))
        ax.axvline(x=display_lam, linestyle=':', linewidth=3, color=palette[8])

# plt.tight_layout()
g.savefig(os.path.join(fig_dir, f"sup_fig_lambda.pdf"), bbox_inches="tight")#, additional_artists=[lgd])

## Plot trends across $b$

In [None]:
g = sns.relplot(x='Bandwidth', y='Accuracy', row='Lambda', col='Noise', hue='Condition', 
                kind='line', data=recomb_scores, palette=[*palette[:3], palette[7]], aspect=1.3)  #, height=height)
g.set(ylim=(0,100))
g.set(xticks=bandwidths)
g.set_xlabels(r'$b$ [octaves]')
g.set_titles(template=r'$\lambda$ = {row_name} | {col_name}')  # {col_var} =
# g.set(font_scale=3)
# g.set(margin_titles=True)

for li, row in enumerate(g.axes):
    for bi, ax in enumerate(row):
        ax.axhline(y=10, linestyle='--', linewidth=3, color='#e74c3c')
#         sns.lineplot(x='Bandwidth', y='Accuracy', hue='Noise', linewidth=3, palette=[palette[7]], ax=ax, legend=False,
#                      data=orig_scores.query('Lambda == "{}"'.format(lambdas[li])))
        ax.axvline(x=display_bw, linestyle=':', linewidth=3, color=palette[8])

# plt.tight_layout()
g.savefig(os.path.join(fig_dir, f"sup_fig_bandwidth.pdf"), bbox_inches="tight")

## Plot $\lambda$ trends for $b=1.4$

In [None]:
for noise in noise_types[1:]:
    fig, ax = plt.subplots(figsize=subfig_size)  # figsize=(12, 8)
    sns.lineplot(x='Lambda', y='Accuracy', hue='Condition', linewidth=3, ax=ax, 
                 data=recomb_scores.query(f"(Noise == '{noise}' or Noise == 'Original') and Bandwidth == {display_bw}"), 
                 legend='brief', palette=[*palette[:3], palette[7]])
    ax.set_ylim(0,100)
    ax.axhline(y=10, linestyle='--', linewidth=3, color='#e74c3c')
    ax.axvline(x=display_lam, linestyle=':', linewidth=3, color=palette[8])
    ax.set_xlabel(r'$\lambda$ [pixels]')
    ax.set(xticks=lambdas)
    handles, labels = ax.get_legend_handles_labels()
    lgd = ax.legend(loc=4, framealpha=0.5, fontsize=18, frameon=True, fancybox=True, 
                    handles=handles[1:], labels=labels[1:], bbox_to_anchor=(1.0, 0.6))
    plt.tight_layout()
    plt.savefig(os.path.join(fig_dir, f"{noise}_trend.pdf"), bbox_inches="tight", additional_artists=[lgd])

## Plot bar charts for $\lambda=5$

In [None]:
# query_string = 'Lambda == "{}" and Bandwidth == "{}"'.format(display_lam, display_bw)
query_string = 'Lambda == "{}"'.format(display_lam)
display_data = noise_scores.query(query_string)
for noise in noise_types[1:]:
    fig, ax = plt.subplots(figsize=subfig_size)
#     sns.barplot(x="Condition", y="Accuracy", data=display_data.query(f"Noise == '{noise}'"), capsize=0.2, color='grey', edgecolor=".1", ax=ax)
    bp = sns.barplot(x="Condition", y="Accuracy", hue='Bandwidth', 
                     data=display_data.query(f"Noise == '{noise}'"), 
                     capsize=0.2, color='grey', edgecolor=".1", ax=ax)
    # palette=("Greys")
    ax.set_ylim(0,100)
#     ax.axhline(y=10, linestyle='--', linewidth=3, color='#e74c3c')
    ax.hlines(10, -0.5, 2.5, linestyle='--', linewidth=3, colors='#e74c3c')
    mean_scores = orig_scores.query(query_string).mean()
#     ax.axhline(y=mean_scores['Accuracy'], linestyle=':', linewidth=3, color=palette[7])
    ax.hlines(mean_scores['Accuracy'], -0.5, 2.5, linestyle=':', linewidth=3, colors=palette[7])
    ax.set_xlabel('')
    lgd = bp.legend(loc=1, framealpha=0.5, fontsize=18, ncol=1, frameon=True, fancybox=True, title_fontsize=18, title="Bandwidth"), 
#                     handles=handles[1:], labels=labels[1:], bbox_to_anchor=(1.0, 0.65))
    plt.tight_layout()
    plt.savefig(os.path.join(fig_dir, f"{noise}_display_bar.pdf"), bbox_inches="tight", additional_artists=[lgd])

In [None]:
query_string = 'Lambda == "{}" and Bandwidth == "{}"'.format(display_lam, display_bw)
display_data = noise_scores.query(query_string)
ax = sns.barplot(x="Noise", y="Accuracy", hue="Condition", 
                 data=display_data, capsize=0.2, palette=("Greys"), edgecolor=".1",)
ax.set_ylim(0,100)
ax.axhline(y=10, xmin=0.02, xmax=0.98, linestyle='--', linewidth=3, color='#e74c3c')
# ax.hlines(0.1, -0.5, 2.5, linestyle='--', linewidth=3, colors='#e74c3c')
mean_scores = orig_scores.query(query_string).mean()
# ax.axhline(y=mean_scores['Accuracy'], linestyle=':', linewidth=3, color=palette[7])
ax.hlines(mean_scores['Accuracy'], -0.5, 2.5, linestyle=':', linewidth=3, colors=palette[7])

In [None]:
orig_scores.query(query_string)