In [15]:
import os
import sys
sys.path.append(os.path.abspath('.'))
sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../run'))

import copy
import glob
import typing

import numpy as np
import pandas as pd
import scipy
from scipy import stats

from mpl_toolkits.mplot3d import Axes3D
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import patches
from matplotlib import path as mpath
import matplotlib.gridspec as gridspec

from simple_relational_reasoning.embeddings.visualizations import *


In [4]:
IGNORE_LIST = ('saycam_split_text_test.csv', 'saycam_nan_test.csv')

def load_and_join_dataframes(folder, ext='.csv', ignore_list=IGNORE_LIST) -> pd.DataFrame:
    files = glob.glob(os.path.join(folder, '*' + ext))
    df = pd.concat([pd.read_csv(f) for f in files if os.path.basename(f) not in ignore_list])
    df.reset_index(drop=True, inplace=True)
    return df
    
combined_df = load_and_join_dataframes('../embedding_outputs')
combined_df.drop(columns=[combined_df.columns[0]], inplace=True)
combined_df.rotate_angle.fillna(0, inplace=True)
combined_df.rotate_angle = combined_df.rotate_angle.astype(int)

# TODO: break down the modle name field into architecture, training, flipping, dino, etc.

In [5]:
combined_df.head()

Unnamed: 0,model_name,condition,acc_mean,acc_std,acc_sem,relation,two_reference_objects,adjacent_reference_objects,n_target_types,transpose_stimuli,n_habituation_stimuli,rotate_angle,seed,n_examples,extra_diagonal_margin
0,mobilenet-random,different_shapes,0.507812,0.499939,0.015623,above_below,0,0,1,0,1,0,34,1024,0
1,mobilenet-random,split_text,0.503906,0.499985,0.015625,above_below,0,0,1,0,1,0,34,1024,0
2,mobilenet-random,random_color,0.515625,0.499756,0.015617,above_below,0,0,1,0,1,0,34,1024,0
3,resnext-random,different_shapes,0.641602,0.47953,0.014985,above_below,0,0,1,0,1,0,34,1024,0
4,resnext-random,split_text,0.605469,0.48875,0.015273,above_below,0,0,1,0,1,0,34,1024,0


In [None]:
# nested dict, where the first key is the style (name, hatch, etc,)
# the second key is the field name (model name, relation, etc)
# the third key is the field value (resnext, mobilenet, etc.)
# and the value is the value for that style keyword
DEFAULT_PLOT_STYLES = dict(
    color=dict(
        
    ),
    hatch=dict(

    )
    # TODO: fill values in the above
)


def plot_single_bar(
    ax: matplotlib.axes.Axes,
    x: float,
    key: typing.List[str],
    mean: pd.core.groupby.SeriesGroupBy,
    std: pd.core.groupby.SeriesGroupBy,
    plot_std: bool,
    bar_width: float, 
    bar_kwargs: typing.Dict[str, typing.Any],
    global_bar_kwargs: typing.Dict[str, typing.Union[str, int]],
):
    m = mean.loc[key]
    if plot_std:
        s = std.loc[key]
    else:
        s = None

    ax.bar(x, m, yerr=s, width=bar_width, **bar_kwargs, **global_bar_kwargs) 
    return x + bar_width 
    # TODO: above bar texts would go here, if they exist


def plot_single_panel(
    ax: matplotlib.axes.Axes,
    mean: pd.core.groupby.SeriesGroupBy,
    std: pd.core.groupby.SeriesGroupBy,
    plot_std: bool,
    orders_by_field: typing.Dict[str, typing.List[str]],
    plot_style_by_field: typing.Dict[str, typing.Dict[str, typing.Dict[str, typing.Any]]],
    group_bars_by: str, 
    color_bars_by: str, 
    hatch_bars_by: typing.Optional[str] = None,
    bar_width: float = 0.2, 
    bar_group_spacing: float = 0.5, 
    add_chance_hline: bool = True,
    global_bar_kwargs: typing.Dict[str, typing.Union[str, int]] = DEFAULT_BAR_KWARGS,
    text_kwargs: typing.Dict[str, str] = DEFAULT_TEXT_KWARGS,
    ylim: typing.Tuple[float, float] = DEFAULT_YLIM,
    ylabel: str = 'Accuracy', 
):
    x = 0

    for group_by_value in orders_by_field[group_bars_by]:
        for color_by_value in orders_by_field[color_bars_by]:
            bar_kwargs = dict(facecolor=plot_style_by_field['color'][color_bars_by][color_by_value])

            group_and_color_key = [group_by_value, color_by_value]

            if hatch_bars_by is not None:
                for hatch_by_value in orders_by_field[hatch_bars_by]:
                    bar_kwargs['hatch'] = plot_style_by_field['hatch'][hatch_bars_by][hatch_by_value]

                    hatch_key = group_and_color_key[:]
                    hatch_key.append(hatch_key)

                    x = plot_single_bar(ax, x, hatch_key, mean, std, 
                        plot_std, bar_width, 
                        bar_kwargs, global_bar_kwargs)
            
            else:
                x = plot_single_bar(ax, x, group_and_color_key, mean, std, 
                    plot_std, bar_width, 
                    bar_kwargs, global_bar_kwargs)

        x += bar_group_spacing

    group_values = orders_by_field[group_bars_by]

    group_length = len(orders_by_field[color_bars_by]) * (len(orders_by_field[hatch_bars_by]) if hatch_bars_by is not None else 1)
    x_tick_locations = np.arange(len(group_values)) * (bar_group_spacing + bar_width * group_length) +\
                        bar_width * (group_length / 2 - 0.5)
    xtick_text_kwargs = text_kwargs.copy()
    if len(group_values) > 4:
        xtick_text_kwargs['fontsize'] -= 4
    ax.set_xticks(x_tick_locations)
    ax.set_xticklabels([plot_prettify(val) for val in group_values], fontdict=xtick_text_kwargs)
    ax.tick_params(axis='both', which='major', labelsize=text_kwargs['fontsize'] - 4)

    if add_chance_hline:
        xlim = plt.xlim()
        ax.hlines(0.5, *xlim, linestyle='--', alpha=0.5)
        ax.set_xlim(*xlim)
        
    if ylim is not None:
        ax.set_ylim(*ylim)

    ax.set_xlabel(plot_prettify(group_bars_by), **text_kwargs)
    ax.set_ylabel(ylabel, **text_kwargs)


def add_legend_to_ax(ax: matplotlib.axes.Axes, 
    # TODO: this needs more args here
    orders_by_field: typing.Dict[str, typing.List[str]],
    plot_style_by_field: typing.Dict[str, typing.Dict[str, typing.Dict[str, typing.Any]]],
    color_bars_by: str,
    hatch_bars_by: typing.Optional[str] = None,
    text_kwargs: typing.Dict[str, str] = DEFAULT_TEXT_KWARGS,
    legend_loc: typing.Optional[str] = 'best', 
    legend_ncol: typing.Optional[int] = None):

    patches = []

    if legend_ncol is None:
        ncol = 1

    for color_by_value in orders_by_field[color_bars_by]:
        bar_kwargs = dict(facecolor=plot_style_by_field['color'][color_bars_by][color_by_value])
        patch_kwargs = dict(facecolor='none', edgecolor='black')
        patch_kwargs.update(bar_kwargs)
        patches.append(matplotlib.patches.Patch(**patch_kwargs, label=plot_prettify(color_by_value)))
    
    if hatch_bars_by is not None:
        for hatch_by_value in orders_by_field[hatch_bars_by]:
            bar_kwargs['hatch'] = plot_style_by_field['hatch'][hatch_bars_by][hatch_by_value]
            patch_kwargs = dict(facecolor='none', edgecolor='black')
            patch_kwargs.update(bar_kwargs)
            patches.append(matplotlib.patches.Patch(**patch_kwargs, label=plot_prettify(hatch_by_value)))

    if len(patches) > 0: 
        ax.legend(handles=patches, loc=legend_loc, ncol=ncol, fontsize=text_kwargs['fontsize'] - 4)
    

def multiple_bar_plots(df: pd.DataFrame, *,
    # what and how to plot arguments
    filter_dict: typing.Dict[str, typing.Union[str, typing.Sequence[str]]],
    group_bars_by: str, 
    color_bars_by: str, 
    panel_by: typing.Optional[str] = None, 
    hatch_bars_by: typing.Optional[str] = None,
    plot_std: bool = True,
    sem: bool = True,
    orders_by_field: typing.Dict[str, typing.Sequence[str]] = DEFAULT_ORDERS,
    # plot style arguments
    plot_style_by_field: typing.Dict[str, typing.Dict[str, typing.Dict[str, typing.Any]]] = DEFAULT_PLOT_STYLES, 
    ax: typing.Union[matplotlib.axes.Axes, typing.Sequence[matplotlib.axes.Axes], None] = None,
    figsize: typing.Tuple[float, float] = None,
    layout: typing.Tuple[int, int] = None,
    bar_width: float = 0.2, 
    bar_group_spacing: float = 0.5, 
    add_chance_hline: bool = True,
    global_bar_kwargs: typing.Dict[str, typing.Union[str, int]] = DEFAULT_BAR_KWARGS,
    text_kwargs: typing.Dict[str, str] = DEFAULT_TEXT_KWARGS,
    ylim: typing.Tuple[float, float] = DEFAULT_YLIM,
    ylabel: str = 'Accuracy', 
    legend_ax_index: typing.Optional[int] = None, 
    legend_loc: typing.Optional[str] ='best', 
    legend_ncol: typing.Optional[int] = None,
    # plot saving arguments
    save_path: typing.Optional[str] = None, 
    save_should_print: bool = False, 
    ):

    if default_bar_kwargs is None:
        default_bar_kwargs = dict()

    group_by_fields = [group_bars_by, color_bars_by]
    if panel_by is not None:
        group_by_fields.insert(0, panel_by)
    if hatch_bars_by is not None:
        group_by_fields.append(hatch_bars_by)

    grouped_df = filter_and_group(df, filter_dict, group_by_fields)

    mean = grouped_df.acc_mean.mean()
    if sem:
        std = grouped_df.acc_sem.mean()
    else:
        std = grouped_df.acc_std.mean()

    orders_by_field = copy.deepcopy(orders_by_field)
    for field in group_by_fields:
        # TODO: filter out values that don't exist in this data subset
        if field not in orders_by_field:
            orders_by_field[field] = list(sorted(mean.index.unique(level=field))) 

    if ax is None:
        if panel_by is None:
            fig, ax = plt.subplots(1, 1, figsize=figsize)

        else:
            if layout is None:
                raise ValueError('layout must be specified if panel_by is specified')

            if np.prod(layout) != len(orders_by_field[panel_by]):
                raise ValueError('layout must have the same number of cells as the number of unique values of panel_by')

            fig, ax = plt.subplots(*layout)

    if panel_by is None:
        plot_single_panel(ax, mean, std, 
            plot_std, orders_by_field, plot_style_by_field,
            group_bars_by, color_bars_by, hatch_bars_by,
            bar_width, bar_group_spacing, add_chance_hline,
            global_bar_kwargs, text_kwargs, 
            ylim, ylabel)

        add_legend_to_ax(ax, orders_by_field, plot_style_by_field, color_bars_by, hatch_bars_by, 
            text_kwargs, legend_loc, legend_ncol)

    else:
        n_rows = layout[0]
        for i, panel_value in enumerate(orders_by_field[panel_by]):
            panel_ax = ax[i]
            plot_single_panel(panel_ax, mean[panel_value], std[panel_value], 
                plot_std, orders_by_field, plot_style_by_field,
                group_bars_by, color_bars_by, hatch_bars_by,
                bar_width, bar_group_spacing, add_chance_hline,
                global_bar_kwargs, text_kwargs, 
                ylim, ylabel if i % n_rows == 0 else '')

            if i == legend_ax_index:
                add_legend_to_ax(ax, orders_by_field, plot_style_by_field, color_bars_by, hatch_bars_by, 
                    text_kwargs, legend_loc, legend_ncol)

            ax.set_title(f'{panel_by} = {panel_value}')

    # TODO: deal with panel titles
    # TODO: consider if we want to do the above-bar text things again
    if save_path is not None:
        save_plot(save_path, should_print=save_should_print)
    
    plt.show()