Skip to content

Commit

Permalink
Polish plot module.
Browse files Browse the repository at this point in the history
  • Loading branch information
dohlee committed Nov 25, 2019
1 parent 53701ea commit a907c23
Showing 1 changed file with 77 additions and 22 deletions.
99 changes: 77 additions & 22 deletions src/dohlee/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,13 @@ def wrapper(*args, **kwargs):
if 'ylim' in kwargs:
ax.set_ylim(kwargs['ylim'])
del kwargs['ylim']

xlabel = kwargs.get('xlabel', None)
if 'xlabel' in kwargs:
ax.set_xlabel(kwargs['xlabel'])
del kwargs['xlabel']

ylabel = kwargs.get('ylabel', None)
if 'ylabel' in kwargs:
ax.set_ylabel(kwargs['ylabel'])
del kwargs['ylabel']

file_path = kwargs.get('file', None)
Expand All @@ -108,14 +110,29 @@ def wrapper(*args, **kwargs):
if 'legend_title' in kwargs:
del kwargs['legend_title']

ax_result = func(*args, ax=ax, **kwargs)
despine = kwargs.get('despine', None)
if 'despine' in kwargs:
del kwargs['despine']

grid = kwargs.get('grid', None)
if 'grid' in kwargs:
del kwargs['grid']

# Post-process plotting results.
ax = func(*args, ax=ax, **kwargs)
if rotate_xticklabels:
ax_result.set_xticklabels(
ax_result.get_xticklabels(),
ax.set_xticklabels(
ax.get_xticklabels(),
ha='right',
rotation=rotate_xticklabels,
)

if xlabel:
ax.set_xlabel(xlabel)

if ylabel:
ax.set_ylabel(ylabel)

if not xticklabels:
ax.set_xticks([])
ax.set_xlabel(ax.get_xlabel(), labelpad=7.0)
Expand All @@ -124,12 +141,26 @@ def wrapper(*args, **kwargs):
ax.set_yticks([])
ax.set_ylabel(ax.get_ylabel(), labelpad=7.0)

if legend_size is not None or legend_title is not None:
if legend_title is not None:
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles, labels, prop={'size': legend_size}, title=legend_title)
ax.legend(handles, labels, title=legend_title)

if legend_size is not None:
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles, labels, prop={'size': legend_size})

if (grid is not None) and (grid != False):
if isinstance(grid, dict):
ax.grid(**grid)
else:
ax.grid()

if despine is not None:
for d in despine:
ax.spines[d].set_visible(False)

_try_save(file_path)
return ax_result
return ax
return wrapper


Expand All @@ -140,7 +171,7 @@ def set_suptitle(title):


# Set plot preference which looks good to me.
def set_style(style='white', palette='deep', context='talk', font='FreeSans', font_scale=1.00):
def set_style(style='white', palette='deep', context='talk', font='Helvetica Neue', scale=1.0, font_scale=1.00):
"""Set plot preference in a way that looks good to me.
"""
import matplotlib.font_manager as font_manager
Expand All @@ -159,7 +190,6 @@ def set_style(style='white', palette='deep', context='talk', font='FreeSans', fo
font=font,
font_scale=font_scale,
)
scale = 1.3
plt.rc('axes', linewidth=1.33, labelsize=14)
plt.rc('xtick', labelsize=10 * scale)
plt.rc('ytick', labelsize=10 * scale)
Expand All @@ -172,9 +202,9 @@ def set_style(style='white', palette='deep', context='talk', font='FreeSans', fo
plt.rc('ytick.major', size=5 * scale, width=1.33)
plt.rc('ytick.minor', size=5 * scale, width=1.33)

plt.rc('legend', fontsize=7 * scale)
plt.rc('legend', fontsize=12 * scale, frameon=False)
plt.rc('grid', color='grey', linewidth=0.5, alpha=0.33)
plt.rc('font', family='Helvetica Neue')
plt.rc('font', family=font)

color_palette = [
'#005AC8',
Expand All @@ -193,24 +223,39 @@ def set_style(style='white', palette='deep', context='talk', font='FreeSans', fo

mpl.rcParams['axes.prop_cycle'] = cycler.cycler(color=color_palette)

def set_paper():
"""
"""
set_style(scale=1.0)

def set_talk():
"""
"""
set_style(scale=1.3)

def set_presentation():
"""
"""
set_talk()

def get_axis(figsize=None, transpose=False, dpi=300):
def set_poster():
"""
"""
set_style(scale=1.5)

def get_axis(figsize=None, dpi=300):
"""Get plot axis with predefined/user-defined width and height.
>>> ax = get_axis()
>>> ax = get_axis(figsize=(7.2, 4.45))
:param float scale: Figure size scale. Width and height will be scale with this value.
:param tuple figsize: Use user-defined width and height. If this is given, `scale` parameter will be ignored.
:param bool transpose: Swap width and height.
"""
w, h = 7.2, 4.45 # Nature double-column preset inches.
w, h = 5, 5
if figsize is not None:
w, h = figsize

if transpose:
w, h = h, w

fig = plt.figure(figsize=(w, h))
ax = fig.add_subplot(111)
return ax
Expand Down Expand Up @@ -298,7 +343,7 @@ def histogram(data, ax=None, **kwargs):


@_my_plot
def boxplot(data, x, y, hue=None, ax=None, strip=False, **kwargs):
def boxplot(data, x, y, hue=None, ax=None, strip=False, box_kwargs={}, strip_kwargs={}):
"""Draw a boxplot.
>>> boxplot(data, x='species', y='sepal_length', strip=True)
Expand All @@ -309,10 +354,20 @@ def boxplot(data, x, y, hue=None, ax=None, strip=False, **kwargs):
:param axis ax: (Optional) Matplotlib axis to draw the plot on.
:param bool strip: (default=False) Draw overlapped stripplot.
"""
fliersize = 0 if strip else 5
sns.boxplot(data=data, x=x, y=y, hue=hue, linewidth=1.33, flierprops={'marker': '.'}, fliersize=fliersize, ax=ax, **kwargs)

# Set default values for box keyword arguments.
box_kwargs['linewidth'] = box_kwargs.get('linewidth', 1.33)
box_kwargs['flierprops'] = box_kwargs.get('flierprops', {'marker': '.'})
box_kwargs['fliersize'] = box_kwargs.get('fliersize', 0 if strip else 5)
box_kwargs['saturation'] = box_kwargs.get('saturation', 1.0)

# Set default values for strip keyword arguments.
strip_kwargs['color'] = strip_kwargs.get('color', 'k')
strip_kwargs['size'] = strip_kwargs.get('size', 5)

sns.boxplot(data=data, x=x, y=y, hue=hue, ax=ax, **box_kwargs)
if strip:
sns.stripplot(data=data, x=x, y=y, hue=hue, jitter=.03, color='k', size=5, ax=ax)
sns.stripplot(data=data, x=x, y=y, hue=hue, ax=ax, **strip_kwargs)

return ax

Expand Down

0 comments on commit a907c23

Please sign in to comment.