Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions fooof/core/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def save_report_fg(fg, file_name, file_path=None, add_settings=True):

# Initialize figure
_ = plt.figure(figsize=REPORT_FIGSIZE)
grid = gridspec.GridSpec(n_rows, 2, wspace=0.4, hspace=0.25, height_ratios=height_ratios)
grid = gridspec.GridSpec(n_rows, 2, wspace=0.35, hspace=0.25, height_ratios=height_ratios)

# First / top: text results
ax0 = plt.subplot(grid[0, :])
Expand All @@ -108,15 +108,15 @@ def save_report_fg(fg, file_name, file_path=None, add_settings=True):

# Aperiodic parameters plot
ax1 = plt.subplot(grid[1, 0])
plot_fg_ap(fg, ax1)
plot_fg_ap(fg, ax1, custom_styler=None)

# Goodness of fit plot
ax2 = plt.subplot(grid[1, 1])
plot_fg_gf(fg, ax2)
plot_fg_gf(fg, ax2, custom_styler=None)

# Peak center frequencies plot
ax3 = plt.subplot(grid[2, :])
plot_fg_peak_cens(fg, ax3)
plot_fg_peak_cens(fg, ax3, custom_styler=None)

# Third - Model settings
if add_settings:
Expand Down
8 changes: 4 additions & 4 deletions fooof/plts/fg.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,23 @@ def plot_fg(fg, save_fig=False, file_name=None, file_path=None, **plot_kwargs):
raise NoModelError("No model fit results are available, can not proceed.")

fig = plt.figure(figsize=plot_kwargs.pop('figsize', PLT_FIGSIZES['group']))
gs = gridspec.GridSpec(2, 2, wspace=0.4, hspace=0.25, height_ratios=[1, 1.2])
gs = gridspec.GridSpec(2, 2, wspace=0.35, hspace=0.35, height_ratios=[1, 1.2])

# Apply scatter kwargs to all subplots
scatter_kwargs = plot_kwargs
scatter_kwargs['all_axes'] = True

# Aperiodic parameters plot
ax0 = plt.subplot(gs[0, 0])
plot_fg_ap(fg, ax0, **scatter_kwargs)
plot_fg_ap(fg, ax0, **scatter_kwargs, custom_styler=None)

# Goodness of fit plot
ax1 = plt.subplot(gs[0, 1])
plot_fg_gf(fg, ax1, **scatter_kwargs)
plot_fg_gf(fg, ax1, **scatter_kwargs, custom_styler=None)

# Center frequencies plot
ax2 = plt.subplot(gs[1, :])
plot_fg_peak_cens(fg, ax2, **plot_kwargs)
plot_fg_peak_cens(fg, ax2, **plot_kwargs, custom_styler=None)


@savefig
Expand Down
10 changes: 5 additions & 5 deletions fooof/plts/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
###################################################################################################

# Define default figure sizes
PLT_FIGSIZES = {'spectral' : (10, 8),
PLT_FIGSIZES = {'spectral' : (8.5, 6.5),
'params' : (7, 6),
'group' : (12, 10)}
'group' : (9, 7)}

# Define defaults for colors for plots, based on what is plotted
PLT_COLORS = {'data' : 'black',
Expand Down Expand Up @@ -45,8 +45,8 @@

## Define default values for plot aesthetics
# These are all custom style arguments
TITLE_FONTSIZE = 20
LABEL_SIZE = 16
TICK_LABELSIZE = 16
TITLE_FONTSIZE = 18
LABEL_SIZE = 14
TICK_LABELSIZE = 12
LEGEND_SIZE = 12
LEGEND_LOC = 'best'
17 changes: 12 additions & 5 deletions fooof/plts/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,14 @@ def apply_custom_style(ax, **kwargs):
ax.legend(prop={'size': kwargs.pop('legend_size', LEGEND_SIZE)},
loc=kwargs.pop('legend_loc', LEGEND_LOC))

plt.tight_layout()
# Apply tight layout to the figure object, if matplotlib is new enough
# If available, `.set_layout_engine` should be equivalent to
# `plt.tight_layout()`, but seems to raise fewer warnings...
try:
fig = plt.gcf()
fig.set_layout_engine('tight')
except:
plt.tight_layout()


def apply_style(ax, axis_styler=apply_axis_style, line_styler=apply_line_style,
Expand All @@ -192,10 +199,10 @@ def apply_style(ax, axis_styler=apply_axis_style, line_styler=apply_line_style,
Each of these sub-functions can be replaced by passing in replacement callables.
"""

axis_styler(ax, **kwargs)
line_styler(ax, **kwargs)
collection_styler(ax, **kwargs)
custom_styler(ax, **kwargs)
axis_styler(ax, **kwargs) if axis_styler is not None else None
line_styler(ax, **kwargs) if line_styler is not None else None
collection_styler(ax, **kwargs) if collection_styler is not None else None
custom_styler(ax, **kwargs) if custom_styler is not None else None


def style_plot(func, *args, **kwargs):
Expand Down
21 changes: 11 additions & 10 deletions fooof/plts/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from fooof.core.modutils import safe_import, check_dependency
from fooof.plts.utils import check_ax, set_alpha
from fooof.plts.settings import TITLE_FONTSIZE, LABEL_SIZE, TICK_LABELSIZE

plt = safe_import('.pyplot', 'matplotlib')

Expand Down Expand Up @@ -46,14 +47,14 @@ def plot_scatter_1(data, label=None, title=None, x_val=0, ax=None):
ax.scatter(x_data, data, s=36, alpha=set_alpha(len(data)))

if label:
ax.set_ylabel(label, fontsize=16)
ax.set_ylabel(label, fontsize=LABEL_SIZE)
ax.set(xticks=[x_val], xticklabels=[label])

if title:
ax.set_title(title, fontsize=20)
ax.set_title(title, fontsize=TITLE_FONTSIZE)

ax.tick_params(axis='x', labelsize=16)
ax.tick_params(axis='y', labelsize=12)
ax.tick_params(axis='x', labelsize=TICK_LABELSIZE)
ax.tick_params(axis='y', labelsize=TICK_LABELSIZE)

ax.set_xlim([-0.5, 0.5])

Expand Down Expand Up @@ -89,12 +90,12 @@ def plot_scatter_2(data_0, label_0, data_1, label_1, title=None, ax=None):
plot_scatter_1(data_1, label_1, x_val=1, ax=ax1)

if title:
ax.set_title(title, fontsize=20)
ax.set_title(title, fontsize=TITLE_FONTSIZE)

ax.set(xlim=[-0.5, 1.5],
xticks=[0, 1],
xticklabels=[label_0, label_1])
ax.tick_params(axis='x', labelsize=16)
ax.tick_params(axis='x', labelsize=TICK_LABELSIZE)


@check_dependency(plt, 'matplotlib')
Expand All @@ -121,13 +122,13 @@ def plot_hist(data, label, title=None, n_bins=25, x_lims=None, ax=None):

ax.hist(data[~np.isnan(data)], n_bins, range=x_lims, alpha=0.8)

ax.set_xlabel(label, fontsize=16)
ax.set_ylabel('Count', fontsize=16)
ax.set_xlabel(label, fontsize=LABEL_SIZE)
ax.set_ylabel('Count', fontsize=LABEL_SIZE)

if x_lims:
ax.set_xlim(x_lims)

if title:
ax.set_title(title, fontsize=20)
ax.set_title(title, fontsize=TITLE_FONTSIZE)

ax.tick_params(axis='both', labelsize=12)
ax.tick_params(axis='both', labelsize=TICK_LABELSIZE)