In [None]:
import os
import numpy as np

def calc_bandwidth(lambd, sigma):
    r = np.pi*sigma/lambd
    c = np.sqrt(np.log(2)/2)
    return np.log2((r + c)/(r - c))

def calc_sigma(lambd, bandwidth):
    p = 2**bandwidth
    c = np.sqrt(np.log(2)/2)
    return lambd * c / np.pi  * (p + 1) / (p - 1)

def calc_lambda(sigma, bandwidth):
    p = 2**bandwidth
    c = np.sqrt(np.log(2)/2)
    return sigma * np.pi / c  * (p - 1) / (p + 1)

project_root = os.path.realpath(os.pardir)
fig_dir = os.path.join(project_root, "results", "figs")
if not os.path.isdir(fig_dir):
    os.makedirs(fig_dir)

In [None]:
# Set parameters
data_set = 'pixel'
# stimulus_sets = ['static', 'jitter']
stimulus_set = 'jitter'

start_trial = 1
num_trials = 5

# lambdas = [3, 4, 5, 6, 7, 8]
# sigmas = [1, 2, 3, 4, 5]
sigmas = [0.2, 0.4, 0.6, 0.8, 1, 1.2, 1.4, 1.6, 1.8, 2, 2.2, 2.4, 2.6, 2.8, 3, 4, 5]
bandwidths = np.linspace(1, 1.8, num=5)

epochs = 20
save_loss = 0
data_augmentation = 0
fresh_data = 0
n_gpus = 1

In [None]:
[1.2, 1.4, 1.6, 1.8, 2.2, 2.4, 2.6, 2.8]
np.linspace(1, 3, num=11)

# Run the model
import os
from tqdm import tqdm_notebook

# for stimulus_set in tqdm_notebook(stimulus_sets, desc="Set"):
# for bandwidth in tqdm_notebook(bandwidths, desc='$b$', leave=True):
for trial in tqdm_notebook(range(start_trial, start_trial+num_trials), desc='Trial'):
    for sigma in tqdm_notebook(sigmas, desc='$\sigma$', leave=True):
        for bandwidth in tqdm_notebook(bandwidths, desc='$b$', leave=True):
#         for lambd in tqdm_notebook(lambdas, desc='$\lambda$', leave=True):
#             sigma = calc_sigma(lambd, bandwidth)
            lambd = calc_lambda(sigma, bandwidth)
            trial_label = f"{trial}_sigma={float(sigma):.2}_lambd={float(lambd):.2}"
            args = (f"--data_set {data_set} --stimulus_set {stimulus_set} "
                    f"--trial_label {trial_label} --sigma {sigma} --lambd {lambd} "
                    f"--data_augmentation {data_augmentation} --fresh_data {fresh_data} "
                    f"--n_gpus {n_gpus} --epochs {epochs} --save_loss {save_loss}")  #  --model_name {model_name}
            os.system(f'python3 gabornet.py {args}')

In [None]:
# Load accuracy scores and plot
%matplotlib inline
import json
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns


# data_set = 'pixel'
stimulus_set = 'jitter'  # 'static'  # 'jitter'
noise_types = ['Original', 'Salt-and-pepper', 'Additive', 'Single-pixel']  # 'Original'
test_conditions = ['Same', 'Diff', 'NoPix']
results_dir = os.path.join(project_root, 'results', data_set, stimulus_set)

rows = []
test_rows = []

# for stimulus_set in stimulus_sets:
for trial in range(1, 1+num_trials):
    for noise_type in noise_types:
#         for lambd in lambdas:
        for sigma in sigmas:
            for bandwidth in bandwidths:
#                 sigma = calc_sigma(lambd, bandwidth)
                lambd = calc_lambda(sigma, bandwidth)

                trial_label = f"{trial}_sigma={float(sigma):.2}_lambd={float(lambd):.2}"
                model_name = f"{noise_type}_{trial_label}"
                # print(model_name)

                acc_scores = np.load(os.path.join(results_dir, f'{model_name}_ACC.npy'))
                valacc_scores = np.load(os.path.join(results_dir, f'{model_name}_VALACC.npy'))
                if save_loss:
                    loss = np.load(os.path.join(results_dir, f'{model_name}_LOSS.npy'))
                    valloss = np.load(os.path.join(results_dir, f'{model_name}_VALLOSS.npy'))
                else:
                    loss = np.zeros(epochs)
                    valloss = np.zeros(epochs)

                with open(os.path.join(results_dir, f'{model_name}_CONDVALACC.json'), "r") as jf:
                    cond_acc = json.load(jf)
                if save_loss:
                    with open(os.path.join(results_dir, f'{model_name}_CONDVALLOSS.json'), "r") as jf:
                        cond_loss = json.load(jf)
                else:
                    cond_loss = {condition: 0 for condition in test_conditions}

                for condition in test_conditions:
                    test_rows.append({'Trial': trial, 'Noise': noise_type, 'Condition': condition, 
                                      'Scale': sigma, 'Wavelength': lambd, 'Bandwidth': bandwidth,
                                      'Loss': cond_loss[condition], 'Accuracy': cond_acc[condition]})
                for epoch in range(epochs):
                    rows.append({'Trial': trial, 'Noise': noise_type, 'Epoch': epoch+1, 'Evaluation': 'Testing', 
                                 'Scale': sigma, 'Wavelength': lambd, 'Bandwidth': bandwidth, 
                                 'Loss': valloss[epoch], 'Accuracy': valacc_scores[epoch]})

                    rows.append({'Trial': trial, 'Noise': noise_type, 'Epoch': epoch+1, 'Evaluation': 'Training', 
                                 'Scale': sigma, 'Wavelength': lambd, 'Bandwidth': bandwidth, 
                                 'Loss': loss[epoch], 'Accuracy': acc_scores[epoch]})

scores = pd.DataFrame(rows, columns=['Trial', 'Noise', 'Epoch', 'Evaluation', 'Scale', 'Wavelength', 'Bandwidth', 'Loss', 'Accuracy'])
test_scores = pd.DataFrame(test_rows, columns=['Trial', 'Noise', 'Condition', 'Scale', 'Wavelength', 'Bandwidth', 'Loss', 'Accuracy'])

symbols = {'Orientation': '$\theta$',
           'Phase': '$\psi$',
           'Aspect': '$\gamma$',
           'Scale': '$\sigma$',
           'Wavelength': '$\lambda$',
           'Bandwidth': '$b$'}

units = {'Orientation': 'radians',
         'Phase': 'radians',
         'Aspect': '',
         'Scale': '',  # Check
         'Wavelength': 'pixels',  # 'pixels/cycle'
         'Bandwidth': 'octaves'}  # Check

if not save_loss:
    scores.drop(columns='Loss', inplace=True)
    test_scores.drop(columns='Loss', inplace=True)
# scores.rename(columns={'Noise': 'Mask'}, inplace=True)
scores.loc[:, 'Accuracy'] *= 100  # Convert to percentage
# test_scores.rename(columns={'Noise': 'Mask'}, inplace=True)
test_scores.loc[:, 'Accuracy'] *= 100  # Convert to percentage

In [None]:
test_scores.head()

In [None]:
print(test_scores['Scale'].min(), test_scores['Scale'].max())
print(test_scores['Wavelength'].min(), test_scores['Wavelength'].max())
print(test_scores['Bandwidth'].min(), test_scores['Bandwidth'].max())

## Rerun

In [None]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1,0"

In [None]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=1

In [None]:
test_scores.query('Scale == 2.4 and Bandwidth == 1.8 and Noise == "Original" and Trial == 2')

In [None]:
sns.lineplot(x='Epoch', y='Accuracy', hue='Evaluation', data=scores.query('Scale == 2.4 and Bandwidth == 1.8 and Noise == "Original" and Trial == 2'))

In [None]:
test_scores.query('Scale == 2.4 and Bandwidth == 1.4 and Noise == "Single-pixel" and Trial == 5')

In [None]:
sns.lineplot(x='Epoch', y='Accuracy', hue='Evaluation', data=scores.query('Scale == 2.4 and Bandwidth == 1.4 and Noise == "Single-pixel" and Trial == 5'))

In [None]:
test_scores.query('Scale == 4 and Bandwidth == 1.4 and Noise == "Single-pixel"')

In [None]:
test_scores.query('Scale < 1 and Noise == "Single-pixel"')

# Plot full data set

In [None]:
display_lam = 5
display_bw = 1.4
display_sig = 3

subfig_size = (7, 6)
sns.set_context('paper')
# sns.set(context='paper', style="white", font_scale=1.45)  #, font='serif') 
#         rc={"font.family": "serif", "font.serif": ["Times", "Palatino", "serif"]})
sns.set(style="white")
sns.set(font='serif')
sns.set(font_scale=2)  # 2.5
sns.set_style("white",
              {"font.family": "serif",
               "font.serif": ["Times", "Palatino", "serif"]})
palette = plt.rcParams["axes.prop_cycle"].by_key()["color"]

orig_scores = test_scores.query('Noise == "Original" and Condition == "Same"')
noise_scores = test_scores.drop(test_scores.query('Noise == "Original"').index)
new_scores = [orig_scores.replace("Original", noise).replace("Same", "None") for noise in noise_types[1:]]
recomb_scores = pd.concat([noise_scores, *new_scores]) #.query(f"Bandwidth == {display_bw}")
# recomb_scores = pd.concat([noise_scores, orig_scores.replace("Same", "None")]) #.query(f"Bandwidth == {display_bw}")

## Plot training end points

In [None]:
g = sns.catplot(x="Noise", y="Accuracy", row='Scale', col='Bandwidth', hue="Condition", kind="bar", data=noise_scores)
g.set(ylim=(0,100))
for bi, row in enumerate(g.axes):
    for li, ax in enumerate(row):
        ax.axhline(y=10, linestyle='--', color='#e74c3c')
        mean_scores = orig_scores.query('Scale == "{}" and Bandwidth == "{}"'.format(sigmas[bi], bandwidths[li])).mean()
        ax.axhline(y=mean_scores['Accuracy'], linestyle=':', color=palette[4])
g.set_xticklabels(labels=g.axes.flat[-1].get_xticklabels(), rotation=45)
g.set_titles(template=r'$\sigma$ = {row_name} | $b$ = {col_name}')

## Plot performance trends

In [None]:
# g = sns.relplot(x='Bandwidth', y='Accuracy', row='Sigma', col='Noise', hue='Condition', kind='line', data=test_scores)
# g.set(ylim=(0,100))
# for row in g.axes:
#     for ax in row:
#         ax.axhline(y=10, linestyle='--', color='#e74c3c')
# # g.set_xlabels(r'$\sigma$ [pixels]')
# g.set(xticks=bandwidths)
# g.set_titles(template=r'$\sigma$ = {row_name} | {col_name}')  # {col_var} = 


g = sns.relplot(x='Bandwidth', y='Accuracy', row='Scale', col='Noise', hue='Condition', 
                kind='line', data=recomb_scores, palette=[*palette[:3], palette[7]], aspect=1.3)
# g = sns.relplot(x='Bandwidth', y='Accuracy', row='Lambda', col='Noise', hue='Condition', 
#                 kind='line', data=recomb_scores, palette=[*palette[:3], palette[7]], aspect=1.3)  #, height=height)
g.set(ylim=(0,100))
g.set(xticks=bandwidths)
g.set_xlabels(r'$b$ [octaves]')
g.set_titles(template=r'$\sigma$ = {row_name} | {col_name}')  # {col_var} =
# g.set(font_scale=3)
# g.set(margin_titles=True)

for li, row in enumerate(g.axes):
    for bi, ax in enumerate(row):
        ax.axhline(y=10, linestyle='--', linewidth=3, color='#e74c3c')
        ax.axvline(x=display_bw, linestyle=':', linewidth=3, color=palette[8])

# plt.tight_layout()
# g.savefig(os.path.join(fig_dir, f"sup_fig_bandwidth.pdf"), bbox_inches="tight")

In [None]:
aspect = 1.3
width = 5.5
height = width / aspect

g = sns.relplot(x='Scale', y='Accuracy', row='Bandwidth', col='Noise', hue='Condition', 
                kind='line', data=recomb_scores, palette=[*palette[:3], palette[7]], aspect=aspect, height=5.2)  # .query("Bandwidth == 1.4")
g.set(ylim=(0,100))
g.set_xlabels(r'$\sigma$')  # [pixels]
# g.set(xticks=sigmas)
g.set_titles(template=r'$b$ = {row_name} | {col_name}')  # {col_var} = 
for bi, row in enumerate(g.axes):
    for li, ax in enumerate(row):
        ax.axhline(y=10, linestyle='--', linewidth=3, color='#e74c3c')  # e74c3c
        ax.axvline(x=display_sig, linestyle=':', linewidth=3, color=palette[8])

# g.set(xscale='log')
# plt.tight_layout()
# g.savefig(os.path.join(fig_dir, f"sup_fig_lambda.pdf"), bbox_inches="tight")#, additional_artists=[lgd])

Does normalisation need to be applied to the outputs of the convolutions?

## Gabor filters
* At which precise value of $\sigma$ does the transition occur?
* How does this relate to the size of a pixel?
* $\rightarrow$ Use a finer grid and extend $\sigma < 1$

## Difference of Gaussians filters
* What happens when the Gabor filters are replaced with DoGs?
* Does this work better for Additive noise?
* Does this generally preserve more of the network's performance?
* $\rightarrow$ Run a grid search on DoG parameters

Consider hard-coding both in the first two layers

From https://micro.magnet.fsu.edu/primer/java/digitalimaging/processing/diffgaussians/:

Blurring an image using a Gaussian kernel suppresses only high-frequency spatial information. Subtracting one image from the other preserves spatial information that lies between the range of frequencies that are preserved in the two blurred images. Thus, the difference of gaussians is equivalent to a band-pass filter that discards all but a handful of spatial frequencies that are present in the original grayscale image. 

## Apply the model to more difficult data
### GANs
* Is the model more robust to adversarial attacks (GANs)?
* When it fails, are the stimuli more noticeable?
* https://github.com/hindupuravinash/the-gan-zoo

### Sketches
* Are outline drawings enough?
* Can it cope without colour?
* $\rightarrow$ Use Ella's SketchNet stimuli

In [None]:
!nvidia-smi

## Plot trends across Wavelength, $\lambda$

In [None]:
# with sns.axes_style(style='white', rc={'title.size': 48, "font.family": "serif", "font.serif": ["Times", "Palatino", "serif"]}):
# with plt.rc_context(dict(sns.axes_style("white"),
#                          **sns.plotting_context("paper", font_scale=1.25))):
# with sns.plotting_context(context='paper', font_scale=1.25, rc={'font.size': 24, "font.family": "serif", "font.serif": ["Times", "Palatino", "serif"]}):
# sns.set(font_scale=1.4)
aspect = 1.3
width = 5.5
height = width / aspect
g = sns.relplot(x='Wavelength', y='Accuracy', row='Bandwidth', col='Noise', hue='Condition', 
                kind='line', data=recomb_scores, palette=[*palette[:3], palette[7]], aspect=aspect, height=5.2) #, height=height)  # .query("Bandwidth == 1.4")
g.set(ylim=(0,100))
g.set_xlabels(r'$\lambda$ [pixels]')
g.set(xticks=lambdas)
g.set_titles(template=r'$b$ = {row_name} | {col_name}')  # {col_var} = 
for bi, row in enumerate(g.axes):
    for li, ax in enumerate(row):
        ax.axhline(y=10, linestyle='--', linewidth=3, color='#e74c3c')  # e74c3c
#         sns.lineplot(x='Lambda', y='Accuracy', hue='Noise', linewidth=3, palette=[palette[7]], ax=ax, legend=False,
#                      data=orig_scores.query('Bandwidth == "{}"'.format(bandwidths[bi])))
        ax.axvline(x=display_lam, linestyle=':', linewidth=3, color=palette[8])

# plt.tight_layout()
g.savefig(os.path.join(fig_dir, f"sup_fig_lambda.pdf"), bbox_inches="tight")#, additional_artists=[lgd])

## Plot trends across Bandwidth, $b$

In [None]:
g = sns.relplot(x='Bandwidth', y='Accuracy', row='Wavelength', col='Noise', hue='Condition', 
                kind='line', data=recomb_scores, palette=[*palette[:3], palette[7]], aspect=1.3)  #, height=height)
g.set(ylim=(0,100))
g.set(xticks=bandwidths)
g.set_xlabels(r'$b$ [octaves]')
g.set_titles(template=r'$\lambda$ = {row_name} | {col_name}')  # {col_var} =
# g.set(font_scale=3)
# g.set(margin_titles=True)

for li, row in enumerate(g.axes):
    for bi, ax in enumerate(row):
        ax.axhline(y=10, linestyle='--', linewidth=3, color='#e74c3c')
#         sns.lineplot(x='Bandwidth', y='Accuracy', hue='Noise', linewidth=3, palette=[palette[7]], ax=ax, legend=False,
#                      data=orig_scores.query('Lambda == "{}"'.format(lambdas[li])))
        ax.axvline(x=display_bw, linestyle=':', linewidth=3, color=palette[8])

# plt.tight_layout()
g.savefig(os.path.join(fig_dir, f"sup_fig_bandwidth.pdf"), bbox_inches="tight")

## Plot $\lambda$ trends for $b=1.4$

In [None]:
for noise in noise_types[1:]:
    fig, ax = plt.subplots(figsize=subfig_size)  # figsize=(12, 8)
    sns.lineplot(x='Wavelength', y='Accuracy', hue='Condition', linewidth=3, ax=ax, 
                 data=recomb_scores.query(f"(Noise == '{noise}' or Noise == 'Original') and Bandwidth == {display_bw}"), 
                 legend='brief', palette=[*palette[:3], palette[7]])
    ax.set_ylim(0,100)
    ax.axhline(y=10, linestyle='--', linewidth=3, color='#e74c3c')
    ax.axvline(x=display_lam, linestyle=':', linewidth=3, color=palette[8])
    ax.set_xlabel(r'$\lambda$ [pixels]')
    ax.set(xticks=lambdas)
    handles, labels = ax.get_legend_handles_labels()
    lgd = ax.legend(loc=4, framealpha=0.5, fontsize=18, frameon=True, fancybox=True, 
                    handles=handles[1:], labels=labels[1:], bbox_to_anchor=(1.0, 0.6))
    plt.tight_layout()
    plt.savefig(os.path.join(fig_dir, f"{noise}_trend.pdf"), bbox_inches="tight", additional_artists=[lgd])

## Plot bar charts for $\lambda=5$

In [None]:
# query_string = 'Lambda == "{}" and Bandwidth == "{}"'.format(display_lam, display_bw)
query_string = 'Wavelength == "{}"'.format(display_lam)
display_data = noise_scores.query(query_string)
for noise in noise_types[1:]:
    fig, ax = plt.subplots(figsize=subfig_size)
#     sns.barplot(x="Condition", y="Accuracy", data=display_data.query(f"Noise == '{noise}'"), capsize=0.2, color='grey', edgecolor=".1", ax=ax)
    bp = sns.barplot(x="Condition", y="Accuracy", hue='Bandwidth', 
                     data=display_data.query(f"Noise == '{noise}'"), 
                     capsize=0.2, color='grey', edgecolor=".1", ax=ax)
    # palette=("Greys")
    ax.set_ylim(0,100)
#     ax.axhline(y=10, linestyle='--', linewidth=3, color='#e74c3c')
    ax.hlines(10, -0.5, 2.5, linestyle='--', linewidth=3, colors='#e74c3c')
    mean_scores = orig_scores.query(query_string).mean()
#     ax.axhline(y=mean_scores['Accuracy'], linestyle=':', linewidth=3, color=palette[7])
    ax.hlines(mean_scores['Accuracy'], -0.5, 2.5, linestyle=':', linewidth=3, colors=palette[7])
    ax.set_xlabel('')
    lgd = bp.legend(loc=1, framealpha=0.5, fontsize=18, ncol=1, frameon=True, fancybox=True, title_fontsize=18, title="Bandwidth"), 
#                     handles=handles[1:], labels=labels[1:], bbox_to_anchor=(1.0, 0.65))
    plt.tight_layout()
    plt.savefig(os.path.join(fig_dir, f"{noise}_display_bar.pdf"), bbox_inches="tight", additional_artists=[lgd])

In [None]:
query_string = 'Wavelength == "{}" and Bandwidth == "{}"'.format(display_lam, display_bw)
display_data = noise_scores.query(query_string)
ax = sns.barplot(x="Noise", y="Accuracy", hue="Condition", 
                 data=display_data, capsize=0.2, palette=("Greys"), edgecolor=".1",)
ax.set_ylim(0,100)
ax.axhline(y=10, xmin=0.02, xmax=0.98, linestyle='--', linewidth=3, color='#e74c3c')
# ax.hlines(0.1, -0.5, 2.5, linestyle='--', linewidth=3, colors='#e74c3c')
mean_scores = orig_scores.query(query_string).mean()
# ax.axhline(y=mean_scores['Accuracy'], linestyle=':', linewidth=3, color=palette[7])
ax.hlines(mean_scores['Accuracy'], -0.5, 2.5, linestyle=':', linewidth=3, colors=palette[7])

In [None]:
orig_scores.query(query_string)