In [None]:
import logging
import os

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import offsetbox
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from statsmodels.stats.multicomp import MultiComparison

import divisivenormalization.analysis as analysis
import divisivenormalization.utils as helpers
from divisivenormalization.data import Dataset, MonkeySubDataset

helpers.config_ipython()

logging.basicConfig(level=logging.INFO)

sns.set()
sns.set_style("ticks")
# adjust sns paper context rc parameters
font_size = 8
rc_dict = {
    "font.size": font_size,
    "axes.titlesize": font_size,
    "axes.labelsize": font_size,
    "xtick.labelsize": font_size,
    "ytick.labelsize": font_size,
    "legend.fontsize": font_size,
    "figure.figsize": (helpers.cm2inch(8), helpers.cm2inch(8)),
    "figure.dpi": 300,
    "pdf.fonttype": 42,
    "savefig.transparent": True,
    "savefig.bbox_inches": "tight",
}
sns.set_context("paper", rc=rc_dict)


class args:
    num_best = 10
    num_val = 10
    fname_best_csv = "df_best.csv"
    fname_val_csv = "df_val.csv"
    weights_path = "weights"
    train_logs_path = "train_logs"
    orientation_binsize = np.deg2rad(10)
    stim_full_size = 140  # full size of stimulus w/o subsampling and cropping
    stim_subsample = 2
    nonspecific_path = (
        "/projects/burg2021_learning-divisive-normalization/nonspecific_divisive_net"
    )
    subunit_path = "/projects/burg2021_learning-divisive-normalization/subunit_net"
    cnn3_path = "/projects/burg2021_learning-divisive-normalization/cnn3"
    surround_path_dict = {
        3: "/projects/burg2021_learning-divisive-normalization/divisive_3x3_surround_net",
        5: "/projects/burg2021_learning-divisive-normalization/divisive_5x5_surround_net",
        7: "/projects/burg2021_learning-divisive-normalization/divisive_7x7_surround_net",
    }
    oriented_threshold = 0.125



 ### Load data

In [None]:
results_df = pd.read_csv("results.csv")
# Save a simplified version of the csv file, sorted by validation set performance
df_plain = helpers.simplify_df(results_df)
df_plain.to_csv("results_plain.csv")

data_dict = Dataset.get_clean_data()
data = MonkeySubDataset(data_dict, seed=1000, train_frac=0.8, subsample=2, crop=30)



 ### Split into set of best models and validation models
 Use the 10 best models for analysis. Use the best 11-20 models to tune analysis algorithms.
 Split the csv files accordingly. Also, extract some weights to be used for later analysis and save
 them as pickle. As this operation requires model loading, we do it only if it was not done before.

In [None]:
try:
    df_best = pd.read_csv(args.fname_best_csv)
    logging.info("loaded data from " + args.fname_best_csv)

except FileNotFoundError:
    df_best = df_plain[0 : args.num_best].copy()

    fev_lst = []
    for i in range(args.num_best):
        run_no = df_best.iloc[i]["run_no"]
        logging.info("load run no " + str(run_no))
        model = helpers.load_dn_model(run_no, results_df, data, args.train_logs_path)

        fev = model.evaluate_fev_testset()
        fev_lst.append(fev)

        feve = model.evaluate_fev_testset_per_neuron()
        var_explained, explainable_var = model.evaluate_ve_testset_per_neuron()
        helpers.pkl_dump(feve, run_no, "feve.pkl", args.weights_path)

        # get weights and normalization input
        (
            features_chanfirst,
            p,
            pooled,
            readout_feat,
            u,
            v,
            dn_exponent,
        ) = helpers.get_weights(model)

        norm_input = analysis.norm_input(pooled, p)

        helpers.pkl_dump(
            features_chanfirst, run_no, "features_chanfirst.pkl", args.weights_path
        )
        helpers.pkl_dump(p, run_no, "p.pkl", args.weights_path)
        helpers.pkl_dump(pooled, run_no, "pooled.pkl", args.weights_path)
        helpers.pkl_dump(norm_input, run_no, "norm_input.pkl", args.weights_path)
        helpers.pkl_dump(readout_feat, run_no, "readout_feat_w.pkl", args.weights_path)
        helpers.pkl_dump(u, run_no, "u.pkl", args.weights_path)
        helpers.pkl_dump(v, run_no, "v.pkl", args.weights_path)
        helpers.pkl_dump(dn_exponent, run_no, "dn_exponent.pkl", args.weights_path)

    df_best["fev"] = fev_lst
    df_best.to_csv(args.fname_best_csv)


try:
    df_val = pd.read_csv("df_val.csv")
    logging.info("loaded data from " + args.fname_val_csv)

except FileNotFoundError:
    df_val = df_plain[args.num_best : args.num_best + args.num_val].copy()

    fev_lst = []
    for i in range(args.num_val):
        run_no = df_val.iloc[i]["run_no"]
        logging.info("load run no " + str(run_no))
        model = helpers.load_dn_model(run_no, results_df, data, args.train_logs_path)

        fev = model.evaluate_fev_testset()
        fev_lst.append(fev)

        features_chanfirst = helpers.get_weights(model)[0]
        helpers.pkl_dump(
            features_chanfirst, run_no, "features_chanfirst.pkl", args.weights_path
        )

    df_val["fev"] = fev_lst
    df_val.to_csv(args.fname_val_csv)



 ### Compare model performance
 *For this to work, you first have to run the cell "Get and save FEV performance on test set"
 in the cnn3, nonspecific_divisive_net, and subunit_net analysis jupyter notebooks.*

In [None]:
# Plot

fev_dn = df_best.fev.values
fev_nonspecific = pd.read_csv(
    os.path.join(args.nonspecific_path, "df_best.csv")
).fev.values
fev_subunit = pd.read_csv(os.path.join(args.subunit_path, "df_best.csv")).fev.values
fev_cnn3 = pd.read_csv(os.path.join(args.cnn3_path, "df_best.csv")).fev.values

fev_lst = 100 * np.array([fev_subunit, fev_nonspecific, fev_dn, fev_cnn3])
fev_stats = analysis.compute_fev_summary_stats(fev_lst)

plt.figure(figsize=(helpers.cm2inch(6), helpers.cm2inch(8)))
x = np.arange(len(fev_stats["mean"]))
plt.scatter(
    x,
    fev_stats["mean"],
    color=["grey", "xkcd:Bluegreen", "xkcd:Blue", "grey"],
    marker="_",
    linewidths=[0.01] * 4,
)
plt.errorbar(
    x, fev_stats["mean"], yerr=fev_stats["sem"], fmt="none", color="xkcd:black"
)
plt.xticks(
    x,
    ["Subunit", "Nonspecific DN", "Specific DN", "Black-box CNN"],
    rotation=45,
    horizontalalignment="right",
)
plt.yticks(ticks=fev_stats["mean"], labels=np.round(fev_stats["mean"], 1))
plt.ylabel("Absolute accuracy (% FEV)")
sns.despine(trim=True, offset=5)
plt.tight_layout()
plt.show()


print("Mean FEV", np.round(fev_stats["mean"], 1))
print("SEM", np.round(fev_stats["sem"], 1))

fev = fev_stats["mean"]
percent = (fev - fev[0]) / (fev[-1] - fev[0]) * 100
print("Percentage scale", np.round(percent, 0))

sem = fev_stats["sem"]
percent_sem = sem / (fev[-1] - fev[0]) * 100
print("Percentage scale SEM", np.round(percent_sem, 1))
print()

ci = fev_stats["conf_int"]
print("Confidence intervals:")
for c, shapiro_reject, name in zip(
    ci,
    fev_stats["shapiro_reject"],
    ["fev_subunit", "fev_nonspecific", "fev_dn", "fev_cnn3"],
):
    percent = (c - fev[0]) / (fev[-1] - fev[0]) * 100
    percent = percent.squeeze()
    p_pm = (percent[1] - percent[0]) / 2
    c = np.array(c).squeeze()
    c_pm = (c[1] - c[0]) / 2
    print()
    print(name)
    print(
        "Confidence interval:",
        np.round(c, 1),
        "Plus/minus:",
        np.round(c_pm, 1),
        "Percentage scale:",
        np.round(percent, 1),
        "Plus/minus (percentage):",
        np.round(p_pm, 0),
    )



In [None]:
# Statistical tests

top_idx = 0
feve = {}
feve_surr = {}
var_explained = {}
explainable_var = {}

run_no = df_best.iloc[top_idx].run_no
feve["dn"] = helpers.pkl_load(run_no, "feve.pkl", args.weights_path)

run_no = (
    pd.read_csv(os.path.join(args.nonspecific_path, "df_best.csv")).iloc[top_idx].run_no
)
feve["dn-non-specific"] = helpers.pkl_load(
    run_no, "feve.pkl", os.path.join(args.nonspecific_path, args.weights_path)
)

run_no = (
    pd.read_csv(os.path.join(args.subunit_path, "df_best.csv")).iloc[top_idx].run_no
)
feve["subunit"] = helpers.pkl_load(
    run_no, "feve.pkl", os.path.join(args.subunit_path, args.weights_path)
)

run_no = pd.read_csv(os.path.join(args.cnn3_path, "df_best.csv")).iloc[top_idx].run_no
feve["cnn3"] = helpers.pkl_load(
    run_no, "feve.pkl", os.path.join(args.cnn3_path, args.weights_path)
)


feve_surr["dn"] = feve["dn"]

path = args.surround_path_dict[3]
run_no = pd.read_csv(os.path.join(path, "df_best.csv")).iloc[top_idx].run_no
feve_surr["dn3"] = helpers.pkl_load(
    run_no, "feve.pkl", os.path.join(path, args.weights_path)
)

path = args.surround_path_dict[5]
run_no = pd.read_csv(os.path.join(path, "df_best.csv")).iloc[top_idx].run_no
feve_surr["dn5"] = helpers.pkl_load(
    run_no, "feve.pkl", os.path.join(path, args.weights_path)
)

path = args.surround_path_dict[7]
run_no = pd.read_csv(os.path.join(path, "df_best.csv")).iloc[top_idx].run_no
feve_surr["dn7"] = helpers.pkl_load(
    run_no, "feve.pkl", os.path.join(path, args.weights_path)
)


# Compare center models

num_models = len(feve.keys())
num_neurons = len(feve["dn"])
neuron_id_lst = [i for i in range(num_neurons)] * num_models

feve_lst = []
model_lst = []
for k, v in feve.items():
    feve_lst.extend(v)
    model_lst.extend(k for i in range(num_neurons))
feve_df = pd.DataFrame(dict(neuron_id=neuron_id_lst, feve=feve_lst, model=model_lst))

mod = MultiComparison(feve_df.feve, feve_df.model)
res = mod.allpairtest(stats.wilcoxon, method="holm")
print('Center models')
print(res[0])
print("Corrected p values")
print(res[1][2])


# Compare surround models

num_neurons = len(feve_surr["dn3"])
num_models = len(feve_surr.keys())
neuron_id_lst = [i for i in range(num_neurons)] * num_models

feve_lst = []
model_lst = []
for k, v in feve_surr.items():
    feve_lst.extend(v)
    model_lst.extend(k for i in range(num_neurons))
feve_surr_df = pd.DataFrame(
    dict(neuron_id=neuron_id_lst, feve=feve_lst, model=model_lst)
)

mod = MultiComparison(feve_surr_df.feve, feve_surr_df.model)
res = mod.allpairtest(stats.wilcoxon, method="holm")
print('\n\nSurround models')
print(res[0])
print("Corrected p values")
print(res[1][2])



 ### Plot distribution of divisive normalization exponents

In [None]:
ori = []
unori = []
for i in range(args.num_best):
    run_no = df_best.iloc[i].run_no
    features = helpers.pkl_load(run_no, "features_chanfirst.pkl", args.weights_path)
    angles = analysis.angles_circ_var(features, threshold=args.oriented_threshold)
    n = helpers.pkl_load(run_no, "dn_exponent.pkl", args.weights_path)
    n = n.squeeze()
    unori_mask = np.isnan(angles)
    ori.extend(list(n[~unori_mask].flatten()))
    unori.extend(list(n[unori_mask].flatten()))

plt.figure(figsize=(helpers.cm2inch(8), helpers.cm2inch(8 / 8 * 6)))
plt.hist([ori, unori], bins=15, lw=0, color=["xkcd:blue", "xkcd:lightblue"])
plt.legend(["Ori.", "Unori."], loc="upper left")
plt.xlim(left=0)
plt.xlabel("Values of exponents $n_l$")
plt.ylabel("No. of exponents $n_l$")
sns.despine(trim=True, offset=5)
plt.tight_layout()
plt.show()

n = []
n.extend(ori)
n.extend(unori)
print("mean", np.round(np.mean(n), 1))



 ### Validate feature orientation sorting
 Use the 11-20 best models for this (in terms of validation set performance). Then perform the
 actual analysis on the best 10 models. Filters marked by red axis are considered unoriented by our algorithm.

In [None]:
fig, axes = plt.subplots(args.num_val, 32, figsize=(32, args.num_val))
for i, axrow in zip(range(args.num_val), axes):
    run_no = df_val.iloc[i]["run_no"]
    features = helpers.pkl_load(run_no, "features_chanfirst.pkl", args.weights_path)
    angles = analysis.angles_circ_var(features, threshold=args.oriented_threshold)
    idx_pref = np.argsort(angles)
    features = features[idx_pref]
    angles = angles[idx_pref]

    for ax, feat, angle in zip(axrow, features, angles):
        vmax = np.max(np.abs(feat))
        vmin = -vmax
        ax.imshow(feat, vmax=vmax, vmin=vmin, cmap="gray")
        ax.tick_params(
            which="both", bottom=False, labelbottom=False, left=False, labelleft=False
        )

        if np.isnan(angle):
            color = "red"
            ax.spines["bottom"].set_color(color)
            ax.spines["top"].set_color(color)
            ax.spines["right"].set_color(color)
            ax.spines["left"].set_color(color)

plt.suptitle("treshold " + str(args.oriented_threshold))
plt.tight_layout()
plt.show()



 ### Matrix plot showing the structure of DN for best model

In [None]:
# best model
top_idx = 0
run_no = df_best.iloc[top_idx].run_no
norm_input = helpers.pkl_load(run_no, "norm_input.pkl", args.weights_path)
features = helpers.pkl_load(run_no, "features_chanfirst.pkl", args.weights_path)

angles = analysis.angles_circ_var(features, args.oriented_threshold)
idx_pref = np.argsort(angles)
# Put 0st filter to the end of the oriented ones
idx_pref = np.concatenate((idx_pref[1:19], idx_pref[0:1], idx_pref[19:]))

# sort
features = features[idx_pref]
angles = angles[idx_pref]
norm_input = norm_input[idx_pref][:, idx_pref]

# sort unoriented according to norm input (dark color to the right)
unor_mask = np.isnan(angles)
norm_input_unor = norm_input[unor_mask][:, unor_mask]
norm_input_unor = np.sum(norm_input_unor, axis=1)
idx_unor = np.argsort(norm_input_unor)
features[unor_mask] = features[unor_mask][idx_unor]
angles[unor_mask] = angles[unor_mask][idx_unor]
norm_input[unor_mask] = norm_input[unor_mask][idx_unor]
norm_input[:, unor_mask] = norm_input[:, unor_mask][:, idx_unor]

oriented_bools = np.logical_not(unor_mask)
angle_diff = analysis.angle_diff(angles)

# matrix plot
figsize = (helpers.cm2inch(5), helpers.cm2inch(5))
fig = analysis.plot_contribution_matrix_chan_first(
    norm_input,
    features,
    index_permutation_lst=np.arange(32),
    angle_difference=angle_diff,
    oriented_bools=oriented_bools,
    figsize=figsize,
)
plt.show()

# color bar
vmax = np.max(norm_input[oriented_bools][:, oriented_bools])
vmin = 0
fig, ax = plt.subplots(figsize=(6, 1))
fig.subplots_adjust(bottom=0.5)
cmap = matplotlib.cm.Blues
norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
ticks = np.round(np.linspace(vmin, vmax - 0.01, 4), 2)
cb1 = matplotlib.colorbar.ColorbarBase(
    ax, cmap=cmap, norm=norm, orientation="horizontal", ticks=ticks
)
cb1.set_label("Normalization input")
fig.show()



 ### Similarly oriented features contribute stronger

In [None]:
sim_input_lst, dissim_input_lst = [], []
for i in range(args.num_best):
    run_no = df_best.iloc[i].run_no
    features = helpers.pkl_load(run_no, "features_chanfirst.pkl", args.weights_path)
    norm_input = helpers.pkl_load(run_no, "norm_input.pkl", args.weights_path)

    angles = analysis.angles_circ_var(features, args.oriented_threshold)
    angles_diff = analysis.angle_diff(angles)
    unor_mask, sim_mask, dissim_mask = analysis.orientation_masks(angles_diff)
    sim_input = np.sum(norm_input[sim_mask])
    dissim_input = np.sum(norm_input[dissim_mask])

    sim_input_lst.append(sim_input)
    dissim_input_lst.append(dissim_input)

sim_err = stats.sem(sim_input_lst, ddof=0)
dissim_err = stats.sem(dissim_input_lst, ddof=0)

fractions = [s / d for s, d in zip(sim_input_lst, dissim_input_lst)]
fraction_err = stats.sem(fractions, ddof=0)
mean = np.average(fractions)
conf_int = analysis.compute_confidence_interval(fractions)

print("Similar normalization input divided by dissimilar input", np.round(mean, 2))
print("Confidence interval", np.round(conf_int, 2))
print("Plus/minus", np.round(mean - conf_int[0], 2))
print(stats.wilcoxon(sim_input_lst, dissim_input_lst))
print("Cohen's d", np.round(analysis.cohens_d(sim_input_lst, dissim_input_lst), 1))



 ### Cosine similarity

In [None]:
sim_input_lst, dissim_input_lst = [], []
for i in range(args.num_best):
    run_no = df_best.iloc[i].run_no
    features = helpers.pkl_load(run_no, "features_chanfirst.pkl", args.weights_path)
    norm_input = helpers.pkl_load(run_no, "norm_input.pkl", args.weights_path)

    cos_sim = np.zeros((features.shape[0], features.shape[0]))
    for i, a in enumerate(features):
        for j, b in enumerate(features):
            cos_sim[i, j] = analysis.cosine_similarity(a, b)
    
    crit = 0
    sim_mask = cos_sim > crit
    dissim_mask = cos_sim <= crit
    sim_input = np.sum(norm_input[sim_mask])
    dissim_input = np.sum(norm_input[dissim_mask])
    sim_input_lst.append(sim_input)
    dissim_input_lst.append(dissim_input)

sim_err = stats.sem(sim_input_lst, ddof=0)
dissim_err = stats.sem(dissim_input_lst, ddof=0)

plt.figure(figsize=(helpers.cm2inch(4), helpers.cm2inch(6)))
ax = sns.barplot(
    np.arange(2),
    [np.average(sim_input_lst), np.average(dissim_input_lst)],
    yerr=[sim_err, dissim_err],
    palette=["xkcd:blue", "grey"],
)
plt.xlabel("Cosine similarity")
plt.ylabel("Normalization input (a.u.)")
plt.xticks(plt.xticks()[0], [r"$\geq$ 0", r"< 0"])
sns.despine(trim=True, offset=5)
plt.tight_layout()
plt.show()

fractions = [s / d for s, d in zip(sim_input_lst, dissim_input_lst)]
mean = np.average(fractions)
conf_int = analysis.compute_confidence_interval(fractions)
print("Similar normalization input divided by dissimilar input", np.round(mean, 2))
print("Confidence interval", np.round(conf_int, 2))
print("Plus/minus", np.round(mean - conf_int[0], 2))
print(stats.wilcoxon(sim_input_lst, dissim_input_lst))
print("Cohens'd", np.round(analysis.cohens_d(sim_input_lst, dissim_input_lst), 1))



 ### Plot normalization input vs. orientation difference (binned)

In [None]:
# iterate over models, collect activations and according angle differences
contrib_model_lst = []
num_featpair_model_lst = []

for i in range(args.num_best):
    model_df = df_best.iloc[i]
    run_no = model_df.run_no
    norm_input = helpers.pkl_load(run_no, "norm_input.pkl", args.weights_path)
    features = helpers.pkl_load(run_no, "features_chanfirst.pkl", args.weights_path)
    a = analysis.angles_circ_var(features, args.oriented_threshold)
    a_diff = analysis.angle_diff(a)
    unor_mask = np.isnan(a_diff)
    unor_contrib = np.sum(norm_input[unor_mask])

    mask_lst = []
    contrib_lst = []
    num_featpair_lst = []
    nbins = int(np.ceil(np.deg2rad(90) / args.orientation_binsize))
    crit_angles = args.orientation_binsize * np.arange(nbins)
    for idx, angle_crit in enumerate(crit_angles):
        # first bins (last bin follows as it is cornercase):
        if idx != (crit_angles.shape[0] - 1):
            mask = np.logical_and(
                (a_diff < angle_crit + args.orientation_binsize),
                np.logical_not(unor_mask),
            )
        else:
            # last bin: include also the maximal difference
            mask = np.logical_and(
                (a_diff <= angle_crit + 2 * args.orientation_binsize),
                np.logical_not(unor_mask),
            )

        mask = np.logical_and(mask, (a_diff >= angle_crit))
        mask_lst.append(mask)
        contrib = np.sum(norm_input[mask])
        contrib_lst.append(contrib)
        num_featurepairs = np.sum(mask)
        num_featpair_lst.append(num_featurepairs)

    contrib_model_lst.append(contrib_lst)
    num_featpair_model_lst.append(num_featpair_lst)

contrib_model_arr = np.array(contrib_model_lst)
contrib_avg = np.average(contrib_model_arr, 0)
contrib_std = np.std(contrib_model_arr, 0)
num_feat_pair_avg = np.average(num_featpair_model_lst, 0)

a_diff_bins = np.arange(contrib_avg.shape[0]) * args.orientation_binsize
a_diff_bins = np.rad2deg(a_diff_bins)
a_diff_bins = np.round(a_diff_bins, 0)
xval = np.rad2deg(args.orientation_binsize) * np.arange(nbins) + (
    np.rad2deg(args.orientation_binsize) / 2
)

plt.figure(figsize=(helpers.cm2inch(8), helpers.cm2inch(8)))
plt.plot(xval, contrib_avg, "-ok")
plt.fill_between(
    xval, contrib_avg - contrib_std, contrib_avg + contrib_std, color="grey", alpha=0.4
)
xticks = np.rad2deg(args.orientation_binsize) * np.arange(nbins + 1)
plt.xticks(ticks=xticks)
plt.xlabel("Orientation difference (deg)")
plt.ylabel("Normalization input (a.u.)")
sns.despine(trim=True, offset=5)
plt.tight_layout()
plt.show()

fractions = contrib_model_arr[:, 0] / contrib_model_arr[:, -1]
mean = np.average(fractions)
conf_int = analysis.compute_confidence_interval(fractions)
print("First bin divided by last bin", np.round(mean, 2))
print("Confidence interval", np.round(conf_int, 2))
print("Plus/minus", np.round(mean - conf_int[0], 2))
print(stats.wilcoxon(contrib_model_arr[:, 0], contrib_model_arr[:, -1]))
print(
    "Cohens'd",
    np.round(analysis.cohens_d(contrib_model_arr[:, 0], contrib_model_arr[:, -1]), 1),
)



 ### Linear regression analysis
 on all normalization input vs. orientation difference pairs of best model

In [None]:
model_df = df_best.iloc[0]
run_no = model_df.run_no
norm_input = helpers.pkl_load(run_no, "norm_input.pkl", args.weights_path)
features = helpers.pkl_load(run_no, "features_chanfirst.pkl", args.weights_path)
a = analysis.angles_circ_var(features, args.oriented_threshold)
a_diff = analysis.angle_diff(a)
norm_input = norm_input.flatten()
a_diff = a_diff.flatten()

# remove nan's (unoriented filters)
norm_input = norm_input[~np.isnan(a_diff)]
a_diff = a_diff[~np.isnan(a_diff)]
a_diff = np.rad2deg(a_diff)

reg = stats.linregress(a_diff, norm_input)
print(reg)

plt.figure(figsize=(helpers.cm2inch(8), helpers.cm2inch(8)))
plt.plot(a_diff, norm_input, ".", color="xkcd:grey", label="Data")
x = np.linspace(np.min(a_diff), np.max(a_diff))
plt.plot(x, x * reg[0] + reg[1], color="xkcd:blue", label="Linear fit")
plt.legend()
plt.xlabel("Orientation difference (deg)")
plt.ylabel("Normalization input (a.u.)")
sns.despine(trim=True, offset=5)
plt.tight_layout()
plt.show()



 ### Plot normalization input for each feature

In [None]:
# best model
run_no = df_best.iloc[0]["run_no"]
contrib = helpers.pkl_load(
    run_no, "norm_input.pkl", args.weights_path
)  # shape: out-chan, in-chan
features = helpers.pkl_load(run_no, "features_chanfirst.pkl", args.weights_path)
angles = analysis.angles_circ_var(features, args.oriented_threshold)

# use only oriented features
mask = np.logical_not(np.isnan(angles))
contrib = contrib[mask][:, mask]
features = features[mask]
angles = angles[mask]

a_diff = analysis.angle_diff(angles)
_, sim_mask, dissim_mask = analysis.orientation_masks(a_diff, angle_crit=45)
sim_contrib = contrib * sim_mask
dissim_contrib = contrib * dissim_mask

# Sum contributions over in channel, for each out channel
same_val = np.sum(sim_contrib, -1)
diff_val = np.sum(dissim_contrib, -1)

plt.figure(figsize=(helpers.cm2inch(8), helpers.cm2inch(8)))
plt.scatter(diff_val, same_val, color="black", s=5)  # dots for figure post-processing
# Plot identity
max_val = np.max([diff_val, same_val])
plt.plot(np.linspace(0, max_val), np.linspace(0, max_val), "Grey")

# Plot features
X = [(diff_val[i], same_val[i]) for i in range(len(diff_val))]  # get coordinates
for i in range(len(diff_val)):
    image = features[i]
    # normalize features to symmetric color scale
    vmax = np.max(np.abs(image))
    vmin = -vmax
    norm = matplotlib.colors.Normalize(vmin, vmax)
    imagebox = offsetbox.AnnotationBbox(
        offsetbox.OffsetImage(image, cmap=plt.cm.gray_r, norm=norm, zoom=0.8),
        X[i],
        frameon=False,
    )
    plt.gca().add_artist(imagebox)

plt.xlabel("Norm. input from dissimilarly oriented features")
plt.ylabel("Norm. input from similarly oriented features")
plt.xticks(np.arange(0, 1.21, 0.3))
plt.yticks(np.arange(0, 1.21, 0.3))
sns.despine(trim=True, offset=5)
plt.show()



 ### Histogram of feature readout weights
 For ten best performing models on validation set

In [None]:
readout_feat_lst = []
for i in range(args.num_best):
    run_no = df_best.iloc[i].run_no
    readout_feat_weight = helpers.pkl_load(
        run_no, "readout_feat_w.pkl", args.weights_path
    )
    readout_feat_lst.append(readout_feat_weight)

readout_features = np.array(readout_feat_lst)
rfs_chan_norm = np.linalg.norm(
    readout_features, axis=1, keepdims=True
)  # normalize over channels
readout_features = readout_features / rfs_chan_norm
readout_features = np.average(readout_features, axis=-1)  # avg over neurons
readout_features = readout_features.flatten()

plt.figure(figsize=(helpers.cm2inch(8), helpers.cm2inch(8 / 8 * 6)))
plt.hist(readout_features, bins=15, color="Grey", edgecolor="Grey", linewidth=1)
plt.xlim(left=0)
sns.despine(trim=True, offset=5)
plt.xlabel("Avg. feature readout weight (a.u.)")
plt.ylabel("No. of features")
plt.tight_layout()
plt.show()

coeff_of_variation = np.std(readout_features) / np.mean(readout_features)
print("Coefficient of variation", np.round(coeff_of_variation, 1))



 ### Model performance vs. size of normalization pool
 *For this to work, you first have to run the cell "Get and save FEV performance on test set"
 in the divisive_3x3_surround_net, divisive_5x5_surround_net, and divisive_7x7_surround_net analysis jupyter notebooks.*

In [None]:
fev_dict = {}
max_valset_fev = np.empty((len(args.surround_path_dict) + 1,))

# models with surround
for i, (surround_size, path) in enumerate(args.surround_path_dict.items()):
    fev_vals = pd.read_csv(os.path.join(path, "df_best.csv")).fev.values
    fev_dict[surround_size] = fev_vals
    max_valset_fev[i + 1] = fev_vals[0]

# model w/o surround
fev_vals = pd.read_csv(os.path.join("df_best.csv")).fev.values
fev_dict[1] = fev_vals
max_valset_fev[0] = fev_vals[0]

fig, ax = plt.subplots(figsize=(8/2.54, 8/2.54))
y = np.array(fev_dict[1][1:])
ax.scatter(
    np.ones(9), y * 100, color="k", marker="o", s=5, label="Top 10 runs\nval. set"
)
for i in range(3, 8, 2):
    y = np.array(fev_dict[i][1:10])
    ax.scatter(i * np.ones(9), y * 100, color="k", marker="o", s=5, label=None)

# best runs on val set: larger dot
ax.scatter(
    np.arange(1, 8, 2),
    max_valset_fev * 100,
    color="k",
    marker="o",
    s=25,
    label="Best run\nval. set",
)

# used for analysis; blue dot
ax.scatter(
    1,
    fev_dict[1][0] * 100,
    color="xkcd:Blue",
    marker="o",
    s=50,
    label="Used for\nanalysis",
)

ax.set_ylabel("Accuracy (% FEV on test set)")
ax.set_yticks(np.arange(47, 50.1))

ax.set_xticks(np.arange(1, 8, 2))
norm_pool_size_deg = (np.arange(1, 8, 2) * 5 + 12) / 35  # px / 35ppd
xlabels = np.round(norm_pool_size_deg, 2)
ax.set_xticklabels(xlabels.astype(str))
ax.set_xlabel("Size of normalization pool (deg)")

ax.legend(frameon=False, loc="lower left")

sns.despine(trim=True, offset=5)
fig.tight_layout()
plt.show()



 ### Spatial normalization pool
 *For this to work, you first have to run the cell "Get and save FEV performance on test set"
 in the divisive_3x3_surround_net, divisive_5x5_surround_net, and divisive_7x7_surround_net analysis jupyter notebooks.*

In [None]:
for surround_size, path in args.surround_path_dict.items():
    run_no = pd.read_csv(os.path.join(path, "df_best.csv")).run_no.values[0]
    u = helpers.pkl_load(run_no, "u.pkl", os.path.join(path, args.weights_path))
    u = np.abs(u)

    no_rows = 2
    no_columns = 32
    fig, axes = plt.subplots(no_rows, no_columns, figsize=(no_columns * 1, no_rows * 1))
    for c in range(no_columns):
        for r in range(no_rows):
            ax = axes[r, c]
            _ = ax.imshow(u[:, :, 0, c, r], cmap="Greys", vmin=0)
            ax.tick_params(
                which="both",
                bottom=False,
                labelbottom=False,
                left=False,
                labelleft=False,
            )
    plt.show()

    uavg = np.average(u, axis=(-2, -1))
    plt.imshow(uavg[:, :, 0], cmap="Greys", vmin=0)
    plt.tick_params(
                which="both",
                bottom=False,
                labelbottom=False,
                left=False,
                labelleft=False,
            )
    plt.show()

