## Generalisation performance

This notebook loads csv files that contain the logits and predictions for the images in the testset.
These csv files were generated by: CHAM/code/network/cc_generalisation_csv.py

author: Christina Funke

In [None]:
import pandas as pd
import seaborn as sns #  pip3 install seaborn==0.9.0
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from scipy.optimize import minimize
import matplotlib
from matplotlib.patches import Patch
from matplotlib.lines import Line2D

In [None]:
# define directories
res_dir = '../results/'
fig_dir = '../figures/'
main_stim_folder = '/gpfs01/bethge/share/christina_judy_share/cc'


# resnets
exp_resnet_contrastrandom = 'resnet50_lr0.0003_numtrain14000_augment1_unique_batchsize64_optimizerAdam_contrastrandom_1292019_v0'
exp_resnet_contrast0_cropmargin = 'resnet50_lr0.0003_numtrain14000_augment1_unique_batchsize64_optimizerAdam_contrast0_reg0_otf0_cropmargin1_5152019_v0'
exp_bagnet32_cropmargin = 'bagnet32_lr0.0001_numtrain280000_augment1_pairs_batchsize8_optimizeradabound_contrast0_reg0_otf1_5142019_v3'

In [None]:
# labellist assigns a description of the number of the set number. 
# For example labellist[1] returns description of set1
labellist_ = ['0','i.i.d. to \ntraining','Line width \n1.25 px', 'Line width \n4.5 px', 'White',
            'Black-White\n-Black','No flankers', 'Noise, 2 lines', 'Line width \n7.5 px', 'Noise, 1 line',
            '3 edges', '6 edges', '9 edges', 'More edges', 'Curved (wrong lw)', 
            'Dashed (wrong lw)', 'Dashed flanker (wrong lw)', 'Curved', 'Dashed', 'w/ dashed \n flanker',
            'Diameter \n50 px', 'Diameter \n100 px', 'Diameter \n150 px', 'w/ curvy \nflankers', 'Asymmetric \nflankers',
            'Binarized']

# add the same numbers as used in the methods
labelnumberlist = ['','','(1)', '(2)', '(4)',
            '(5)','(7)', '', '(3)', '',
            '', '', '', '(6)', '', 
            '', '', '', '(13)', '(14)',
            '(10)', '(11)', '(12)', '(15)', '(8)',
            '(9)']
labellist = [x + ' ' + y for x, y in zip(labellist_, labelnumberlist)]

highlight_sets_flanker = {6, 17, 18, 20, 21, 22}
highlight_sets_curved = {17, 18, 19, 20, 21, 22, 23}

### Define colors

In [None]:
b = [0.65, 0.65, 1]
r = (1, 0.65, 0.65)
r = sns.hls_palette(8, l=0.7)[0]
b = sns.hls_palette(8, l=0.7)[5]

custom_col = [r] * 6
custom_col2 = []
for i in range(6):
    custom_col2.append([b[0] - i / 12, b[1] - i / 12, b[2] - i / 8])

In [None]:
def plot_bar(
    exp_name, df0, sets, df0_opt=pd.DataFrame(), highlight="star", title="", contrast="all", cropmargin=False,
):
    """
    show generalisation performance
   
    example: plot_bar(df, sets=[1, 4, 5], contrast='all')
    
    ARGS:
        df0: generalisation performance with threshold == 0
        sets: list of sets shown
        df0_opt: generalisation performance with optimal threshold
        highlight: how are the flankers and curvy stimuli highlighted: 
                'star' (symbols at x axis) or 'bar' (bars above the plot)
        title (str): title of the plot
        contrast: which contrast levels will be shown (for example 0, 0.2,... or 'all')
        cropmargin: crop the 16 px margin of the stimuli shown under the plot
    """
    plt.figure(figsize=(len(sets), 3))

    # which contrast levels?
    if contrast != "all":
        if not df0_opt.empty:
            df0_opt = df0_opt.loc[df0_opt["contrast"] == contrast]
        df0 = df0.loc[df0["contrast"] == contrast]

    # plot data
    if not df0_opt.empty:
        df1_opt = df0_opt.loc[df0_opt["set"].isin(sets)]
        h = sns.barplot(x="set", y="pc", data=df1_opt, hue="contrast", order=sets, palette=custom_col,)

    df1 = df0.loc[df0["set"].isin(sets)]
    g = sns.barplot(x="set", y="pc", data=df1, hue="contrast", order=sets, palette=custom_col2)

    # make bars thinner
    if contrast != "all":
        for bar in g.patches:
            x = bar.get_x()
            bar.set_x(x + 0.5 / 4)
            bar.set_width(0.5)

    # mark training set with hatches
    if contrast != "all":
        if not df0_opt.empty:
            g.patches[len(sets)].set_hatch("//")
        else:
            g.patches[0].set_hatch("//")

    # add legend
    net_string = "ResNet50"

    if not df0_opt.empty:
        legend_elements = [
            Patch(facecolor=custom_col2[0], label=net_string),
            Patch(facecolor=custom_col[0], label=net_string + " w/ optimised decision criterion",),
            Line2D([0], [0], marker="o", markersize=8, color="k", fillstyle="none", markeredgewidth=2, linestyle="", label="curvy lines",),
            Line2D([0], [0], marker="x", markersize=8, color="k", markeredgewidth=2, label="no flankers", linestyle="",),
        ]
    else:
        legend_elements = [
            Line2D([0], [0], marker="o", markersize=8, color="k", fillstyle="none", markeredgewidth=2, linestyle="", label="curvy lines",),
            Line2D([0], [0], marker="x", markersize=8, color="k", markeredgewidth=2, label="no flankers", linestyle="",),
        ]

    plt.legend(handles=legend_elements, loc="upper left", ncol=4, bbox_to_anchor=(0, 1.2))
    if highlight == "":
        plt.legend().set_visible(False)

    # labels, title, axis
    labels = []
    # label sorted by generation of variation
    #for x in sets:
    #    #labels.append(labellist_[x])
    
    # labels sorted by acc
    set_c = 0
    for x in sets:
        if set_c == 0:
            labels.append(labellist_[x])
        else:
            labels.append(labellist_[x] + ' (' + str(set_c) + ')')
        set_c += 1
        
    plt.plot([-0.5, len(sets)], [0.5, 0.5], "--", color="k")  # line showing chance performance
    plt.ylim(0.45, 1)
    plt.xlabel(" ")
    if len(sets) <= 10:
        pos_label = -1.55
    else:
        pos_label = -0.87  # -0.85
    g.set_xticklabels(labels, rotation=45, va="center", ha="center", position=(0, pos_label))
    g.tick_params(axis="x", which="both", length=0)
    plt.ylabel("Accuracy")
    g.spines["top"].set_visible(False)
    g.spines["right"].set_visible(False)
    # plt.title(title, fontsize = 15, y=1.25)

    # highlight curvy and no flanker condition
    if highlight == "bar":
        # highlight all plots that are in the list highlight_sets_flanker
        for highlight_set in highlight_sets_flanker.intersection(set(sets)):
            highlight_index = sets.index(highlight_set)
            plt.axvspan(
                highlight_index - 0.5, highlight_index + 0.5, ymin=1, ymax=1.1, lw=0, color=custom_col_flanker, clip_on=False,
            )
        plt.text(8.5, 1.015, "no flankers", fontsize=10)

        for highlight_set in highlight_sets_curved.intersection(set(sets)):
            highlight_index = sets.index(highlight_set)
            plt.axvspan(
                highlight_index - 0.5, highlight_index + 0.5, ymin=1.1, ymax=1.2, lw=0, color=custom_col_curved, clip_on=False,
            )
        plt.text(10.2, 1.07, "curvy", fontsize=10)

    if highlight == "star":
        for highlight_set in highlight_sets_flanker.intersection(set(sets)):
            highlight_index = sets.index(highlight_set)
            plt.plot(
                [highlight_index + 0.15], [0.45], "x", markersize=8, color="k", markeredgewidth=2, clip_on=False,
            )

        for highlight_set in highlight_sets_curved.intersection(set(sets)):
            highlight_index = sets.index(highlight_set)
            plt.plot(
                [highlight_index - 0.15], [0.45], "o", markersize=8, color="k", fillstyle="none", markeredgewidth=2, clip_on=False,
            )

    # add example images under bars
    n = 0
    for set_num in sets:
        n += 1
        stim_folder = main_stim_folder + "/set" + str(set_num) + "/contrast0/"
        img_open = Image.open(stim_folder + "/test/open/test4.png")
        img_closed = Image.open(stim_folder + "/test/closed/test1.png")

        xl, yl, xh, yh = np.array(g.get_position()).ravel()
        w = xh - xl
        h = yh - yl
        xp = xl + (xh - xl) * n / len(sets) - (xh - xl) / (2 * len(sets))

        if len(sets) <= 10:
            size = 0.45
            pos_closed = -0.4
            pos_open = -0.9
        else:
            size = 0.23  # 0.2 to fit exactly to size of bar
            pos_closed = -0.14
            pos_open = -0.4

        # closed images
        a = plt.axes([xp - size * 0.5, pos_closed, size, size], facecolor="k", frameon=0)
        if cropmargin:
            plt.imshow(np.array(img_closed)[16:-16, 16:-16, :], interpolation="none")
        else:
            plt.imshow(img_closed, interpolation="none")

        plt.xticks([])
        plt.yticks([])
        if n == 1:
            plt.ylabel("closed \ncontour", rotation=0, va="center", ha="right")

        # open images
        a = plt.axes([xp - size * 0.5, pos_open, size, size], facecolor="k", frameon=0)
        if cropmargin:
            plt.imshow(np.array(img_open)[16:-16, 16:-16, :], interpolation="none")
        else:
            plt.imshow(img_open, interpolation="none")

        plt.xticks([])
        plt.yticks([])
        if n == 1:
            plt.ylabel("open \ncontour", rotation=0, va="center", ha="right")

    # plt.savefig(fig_dir + 'bar_' + exp_name + '_sets' + str(sets).replace(' ','') + '.svg', bbox_inches = "tight", dpi=1000)


## Functions to optimize threshold

In [None]:
def get_error(x, logits, labels):
    """
    get the prediction error (= 1 - accuracy) depending on the bias/threshold for the logits
    """
    return 1 - sum(((logits > x).astype(int) == labels).astype(int)) / len(logits)


def linesearch_threshold(exp_name, set_num, contrast):
    """
    Find the optimal threshold to reach the optimal performance by a simple line search. 
    This function returns optimal accuracy and threshold.
    example: linesearch_threshold(exp_name, 5, 0)
    """
    try:
        df_csv = pd.read_csv(res_dir + "/imagelevel/" + exp_name + "/set" + str(set_num) + "_contrast" + str(contrast) + ".csv")
        logits = df_csv["logits"].values
        labels = df_csv["label"].values

        # take the interval in which 95% of the the data lies and select 100 points in this interval
        searchspace = np.linspace(np.percentile(logits, 2.5), np.percentile(logits, 97.5), 100)
        maxacc = 0
        maxth = 0
        for th in searchspace:
            acc = 1 - get_error(th, logits, labels)
            if acc > maxacc:
                maxacc = acc
                maxth = th
    # skip nonexistent files
    except:
        print("file not found")
        df_csv = pd.DataFrame(columns=["label", "logits"])
        maxacc, maxth = 0, 0

    return maxacc, maxth  # return accuracy, threshold


def get_dataframe_optimized(exp_name, set_nums, contrasts):
    """
    Summarize the results. The datasets and contrast levels can be specified
    example: get_dataframe_optimized(exp_name, [1, 2], [0,1])
    """
    df = pd.DataFrame(columns=["set", "contrast", "pc", "threshold"])

    for set_num in set_nums:
        for contrast in contrasts:
            acc, th = linesearch_threshold(exp_name, set_num, contrast)
            df = df.append(pd.DataFrame({"set": [set_num], "contrast": [contrast], "pc": [acc], "threshold": [th],}))
    return df


### Define which sets to show

In [None]:
# less sets
sets_reduced = [1, 25, 4, 5, 2, 3, 8, 13, 24, 6, 20, 21, 22, 18, 19, 23]

# sort by performance of contrast0
df_opt = get_dataframe_optimized(exp_resnet_contrast0_cropmargin, sets_reduced, [0])
sets_sorted = list(df_opt.sort_values("pc", ascending=False)["set"])
sets_sorted.insert(0, sets_sorted.pop(sets_sorted.index(1))) # make sure that iid to training is at the left 

# Make figures

### Main figure (part A and B)

In [None]:
sets = sets_sorted
df_opt = get_dataframe_optimized(exp_resnet_contrast0_cropmargin, sets, [0])
plot_bar(exp_resnet_contrast0_cropmargin, df_opt, sets=sets, highlight="star", title="", contrast=0, cropmargin=True)
plt.show()

### Appendix (different contrast levels)

In [None]:
sets = sets_sorted
df_opt = get_dataframe_optimized(exp_resnet_contrastrandom, sets, [0, 0.2, 0.4, 0.6, 0.8, 1])
plot_bar(exp_resnet_contrastrandom, df_opt, sets=sets, highlight="star", title="", contrast="all")
plt.show()

### Appendix (BagNet)

In [None]:
sets = sets_sorted
df_opt = get_dataframe_optimized(exp_bagnet32_cropmargin, sets, [0, 0.2, 0.4, 0.6, 0.8, 1])
plot_bar(exp_bagnet32_cropmargin, df_opt, sets=sets, highlight='star', title="", contrast=0, cropmargin=True)