diff --git a/fooof/core/reports.py b/fooof/core/reports.py index 40bc8aa3a..0c75c3a28 100644 --- a/fooof/core/reports.py +++ b/fooof/core/reports.py @@ -95,7 +95,7 @@ def save_report_fg(fg, file_name, file_path=None, add_settings=True): # Initialize figure _ = plt.figure(figsize=REPORT_FIGSIZE) - grid = gridspec.GridSpec(n_rows, 2, wspace=0.4, hspace=0.25, height_ratios=height_ratios) + grid = gridspec.GridSpec(n_rows, 2, wspace=0.35, hspace=0.25, height_ratios=height_ratios) # First / top: text results ax0 = plt.subplot(grid[0, :]) @@ -108,15 +108,15 @@ def save_report_fg(fg, file_name, file_path=None, add_settings=True): # Aperiodic parameters plot ax1 = plt.subplot(grid[1, 0]) - plot_fg_ap(fg, ax1) + plot_fg_ap(fg, ax1, custom_styler=None) # Goodness of fit plot ax2 = plt.subplot(grid[1, 1]) - plot_fg_gf(fg, ax2) + plot_fg_gf(fg, ax2, custom_styler=None) # Peak center frequencies plot ax3 = plt.subplot(grid[2, :]) - plot_fg_peak_cens(fg, ax3) + plot_fg_peak_cens(fg, ax3, custom_styler=None) # Third - Model settings if add_settings: diff --git a/fooof/plts/fg.py b/fooof/plts/fg.py index edd950e18..310d95d1c 100644 --- a/fooof/plts/fg.py +++ b/fooof/plts/fg.py @@ -44,7 +44,7 @@ def plot_fg(fg, save_fig=False, file_name=None, file_path=None, **plot_kwargs): raise NoModelError("No model fit results are available, can not proceed.") fig = plt.figure(figsize=plot_kwargs.pop('figsize', PLT_FIGSIZES['group'])) - gs = gridspec.GridSpec(2, 2, wspace=0.4, hspace=0.25, height_ratios=[1, 1.2]) + gs = gridspec.GridSpec(2, 2, wspace=0.35, hspace=0.35, height_ratios=[1, 1.2]) # Apply scatter kwargs to all subplots scatter_kwargs = plot_kwargs @@ -52,15 +52,15 @@ def plot_fg(fg, save_fig=False, file_name=None, file_path=None, **plot_kwargs): # Aperiodic parameters plot ax0 = plt.subplot(gs[0, 0]) - plot_fg_ap(fg, ax0, **scatter_kwargs) + plot_fg_ap(fg, ax0, **scatter_kwargs, custom_styler=None) # Goodness of fit plot ax1 = plt.subplot(gs[0, 1]) - plot_fg_gf(fg, ax1, **scatter_kwargs) + plot_fg_gf(fg, ax1, **scatter_kwargs, custom_styler=None) # Center frequencies plot ax2 = plt.subplot(gs[1, :]) - plot_fg_peak_cens(fg, ax2, **plot_kwargs) + plot_fg_peak_cens(fg, ax2, **plot_kwargs, custom_styler=None) @savefig diff --git a/fooof/plts/settings.py b/fooof/plts/settings.py index 4b7f1050e..c6d82c138 100644 --- a/fooof/plts/settings.py +++ b/fooof/plts/settings.py @@ -6,9 +6,9 @@ ################################################################################################### # Define default figure sizes -PLT_FIGSIZES = {'spectral' : (10, 8), +PLT_FIGSIZES = {'spectral' : (8.5, 6.5), 'params' : (7, 6), - 'group' : (12, 10)} + 'group' : (9, 7)} # Define defaults for colors for plots, based on what is plotted PLT_COLORS = {'data' : 'black', @@ -45,8 +45,8 @@ ## Define default values for plot aesthetics # These are all custom style arguments -TITLE_FONTSIZE = 20 -LABEL_SIZE = 16 -TICK_LABELSIZE = 16 +TITLE_FONTSIZE = 18 +LABEL_SIZE = 14 +TICK_LABELSIZE = 12 LEGEND_SIZE = 12 LEGEND_LOC = 'best' diff --git a/fooof/plts/style.py b/fooof/plts/style.py index dc72a1142..7f9566ffb 100644 --- a/fooof/plts/style.py +++ b/fooof/plts/style.py @@ -169,7 +169,14 @@ def apply_custom_style(ax, **kwargs): ax.legend(prop={'size': kwargs.pop('legend_size', LEGEND_SIZE)}, loc=kwargs.pop('legend_loc', LEGEND_LOC)) - plt.tight_layout() + # Apply tight layout to the figure object, if matplotlib is new enough + # If available, `.set_layout_engine` should be equivalent to + # `plt.tight_layout()`, but seems to raise fewer warnings... + try: + fig = plt.gcf() + fig.set_layout_engine('tight') + except: + plt.tight_layout() def apply_style(ax, axis_styler=apply_axis_style, line_styler=apply_line_style, @@ -192,10 +199,10 @@ def apply_style(ax, axis_styler=apply_axis_style, line_styler=apply_line_style, Each of these sub-functions can be replaced by passing in replacement callables. """ - axis_styler(ax, **kwargs) - line_styler(ax, **kwargs) - collection_styler(ax, **kwargs) - custom_styler(ax, **kwargs) + axis_styler(ax, **kwargs) if axis_styler is not None else None + line_styler(ax, **kwargs) if line_styler is not None else None + collection_styler(ax, **kwargs) if collection_styler is not None else None + custom_styler(ax, **kwargs) if custom_styler is not None else None def style_plot(func, *args, **kwargs): diff --git a/fooof/plts/templates.py b/fooof/plts/templates.py index 9b1e341b6..b637e5958 100644 --- a/fooof/plts/templates.py +++ b/fooof/plts/templates.py @@ -10,6 +10,7 @@ from fooof.core.modutils import safe_import, check_dependency from fooof.plts.utils import check_ax, set_alpha +from fooof.plts.settings import TITLE_FONTSIZE, LABEL_SIZE, TICK_LABELSIZE plt = safe_import('.pyplot', 'matplotlib') @@ -46,14 +47,14 @@ def plot_scatter_1(data, label=None, title=None, x_val=0, ax=None): ax.scatter(x_data, data, s=36, alpha=set_alpha(len(data))) if label: - ax.set_ylabel(label, fontsize=16) + ax.set_ylabel(label, fontsize=LABEL_SIZE) ax.set(xticks=[x_val], xticklabels=[label]) if title: - ax.set_title(title, fontsize=20) + ax.set_title(title, fontsize=TITLE_FONTSIZE) - ax.tick_params(axis='x', labelsize=16) - ax.tick_params(axis='y', labelsize=12) + ax.tick_params(axis='x', labelsize=TICK_LABELSIZE) + ax.tick_params(axis='y', labelsize=TICK_LABELSIZE) ax.set_xlim([-0.5, 0.5]) @@ -89,12 +90,12 @@ def plot_scatter_2(data_0, label_0, data_1, label_1, title=None, ax=None): plot_scatter_1(data_1, label_1, x_val=1, ax=ax1) if title: - ax.set_title(title, fontsize=20) + ax.set_title(title, fontsize=TITLE_FONTSIZE) ax.set(xlim=[-0.5, 1.5], xticks=[0, 1], xticklabels=[label_0, label_1]) - ax.tick_params(axis='x', labelsize=16) + ax.tick_params(axis='x', labelsize=TICK_LABELSIZE) @check_dependency(plt, 'matplotlib') @@ -121,13 +122,13 @@ def plot_hist(data, label, title=None, n_bins=25, x_lims=None, ax=None): ax.hist(data[~np.isnan(data)], n_bins, range=x_lims, alpha=0.8) - ax.set_xlabel(label, fontsize=16) - ax.set_ylabel('Count', fontsize=16) + ax.set_xlabel(label, fontsize=LABEL_SIZE) + ax.set_ylabel('Count', fontsize=LABEL_SIZE) if x_lims: ax.set_xlim(x_lims) if title: - ax.set_title(title, fontsize=20) + ax.set_title(title, fontsize=TITLE_FONTSIZE) - ax.tick_params(axis='both', labelsize=12) + ax.tick_params(axis='both', labelsize=TICK_LABELSIZE)