diff --git a/README.rst b/README.rst index af6a3ba..f0afe3a 100644 --- a/README.rst +++ b/README.rst @@ -258,6 +258,8 @@ Executes a plotting function and saves the resulting plot to specified formats u +----------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | ``teeplot_dpi`` | Resolution for rasterized components of saved plots, default is publication-quality 300 dpi. | +----------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| ``teeplot_figsize`` | Optional ``(width, height)`` tuple in inches; resizes the current figure via ``set_size_inches`` after the plotter runs. | ++----------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | ``teeplot_oncollision`` | Strategy for handling filename collisions: "error", "fix", "ignore", or "warn", default "warn"; inferred from environment if not specified. | +----------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | ``teeplot_outattrs`` | Dict with additional key-value attributes to include in the output filename. | @@ -270,6 +272,8 @@ Executes a plotting function and saves the resulting plot to specified formats u +----------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | ``teeplot_postprocess`` | Actions to perform after plotting but before saving. Can be a string of code to ``exec`` or a callable function. If a string, it's executed with access to ``plt`` and ``sns`` (if installed), and the plotter return value as ``teed``. | +----------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| ``teeplot_rc_context`` | Mapping of matplotlib rcParams applied via ``matplotlib.rc_context`` around the plotter, postprocess, and save steps. | ++----------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | ``teeplot_save`` | File formats to save the plots in. Defaults to global settings if ``True``, all output suppressed if ``False``. Default global setting is ``{" .png", ".pdf"}``. Supported: ".eps", ".png", ".pdf", ".pgf", ".ps", ".svg". | +----------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | ``teeplot_show`` | Dictates whether ``plt.show()`` should be called after plot is saved. If True, the plot is displayed using ``plt.show()``. Default behavior is to display if an interactive environment is detected (e.g., a notebook). | diff --git a/teeplot/teeplot.py b/teeplot/teeplot.py index 2409148..6bc57ed 100644 --- a/teeplot/teeplot.py +++ b/teeplot/teeplot.py @@ -4,6 +4,7 @@ import functools import os import pathlib +import types import typing import warnings import sys @@ -62,13 +63,15 @@ def tee( *args: typing.Any, teeplot_callback: bool = False, teeplot_dpi: int = 300, + teeplot_figsize: typing.Optional[typing.Tuple[float, float]] = None, teeplot_oncollision: typing.Optional[ typext.Literal["error", "fix", "ignore", "warn"]] = None, - teeplot_outattrs: typing.Dict[str, str] = {}, + teeplot_outattrs: typing.Mapping[str, str] = types.MappingProxyType({}), teeplot_outdir: str = "teeplots", teeplot_outinclude: typing.Iterable[str] = tuple(), teeplot_outexclude: typing.Iterable[str] = tuple(), teeplot_postprocess: typing.Union[str, typing.Callable] = "", + teeplot_rc_context: typing.Mapping[str, typing.Any] = types.MappingProxyType({}), teeplot_save: typing.Union[typing.Iterable[str], bool] = True, teeplot_show: typing.Optional[bool] = None, teeplot_subdir: str = '', @@ -93,11 +96,15 @@ def tee( Resolution for rasterized components of the saved plot in dots per inch. Default is publication-quality 300 dpi. + teeplot_figsize : Tuple[float, float], optional + Size of the saved plot in inches as (width, height). + + If provided, the current figure is resized after the plotter runs. teeplot_oncollision : Literal["error", "fix", "ignore", "warn"], optional Strategy for handling collisions between generated filenames. Default "ignore" if executing in interactive mode, else default "warn". - teeplot_outattrs : Dict[str, str], optional + teeplot_outattrs : Mapping[str, str], optional Additional attributes to include in the output filename. teeplot_outdir : str, default "teeplots" Base directory for saving plots. @@ -122,6 +129,9 @@ def tee( return value as the `teed` kwarg, second with the plotter return value as the `ax` kwarg, third with no args, and last with the plotter return value as a positional arg. + teeplot_rc_context : Mapping[str, Any], optional + Mapping of matplotlib rcParams to apply via `matplotlib.rc_context` + around the plotter, postprocess, and save steps. teeplot_save : Union[str, Iterable[str], bool], default True File formats to save the plots in. @@ -224,46 +234,50 @@ def tee( # ----- end argument parsing # ----- begin plotting - teed = plotter(*args, **{k: v for k, v in kwargs.items()}) + with matplotlib.rc_context(teeplot_rc_context): + teed = plotter(*args, **{k: v for k, v in kwargs.items()}) - if isinstance(teeplot_postprocess, abc.Callable): - while "make breakable": - try: - teeplot_postprocess(teed=teed) # first attempt - break - except TypeError: - pass - try: - teeplot_postprocess(ax=teed) # second attempt - break - except TypeError: - pass - try: - teeplot_postprocess() # third attempt - break - except TypeError: - pass + if teeplot_figsize is not None: + plt.gcf().set_size_inches(*teeplot_figsize) + + if isinstance(teeplot_postprocess, abc.Callable): + while "make breakable": + try: + teeplot_postprocess(teed=teed) # first attempt + break + except TypeError: + pass + try: + teeplot_postprocess(ax=teed) # second attempt + break + except TypeError: + pass + try: + teeplot_postprocess() # third attempt + break + except TypeError: + pass + try: + teeplot_postprocess(teed) # fourth attempt + break + except TypeError: + pass + raise TypeError( # give up + f"teeplot_postprocess={teeplot_postprocess} threw TypeError " + "or call signature incompatible with attempted invocations", + ) + elif teeplot_postprocess: + if not isinstance(teeplot_postprocess, str): + raise TypeError( + "teeplot_postprocess must be str or Callable, " + f"not {type(teeplot_postprocess)} {teeplot_postprocess}" + ) try: - teeplot_postprocess(teed) # fourth attempt - break - except TypeError: + import seaborn as sns + import seaborn + except ModuleNotFoundError: pass - raise TypeError( # give up - f"teeplot_postprocess={teeplot_postprocess} threw TypeError " - "or call signature incompatible with attempted invocations", - ) - elif teeplot_postprocess: - if not isinstance(teeplot_postprocess, str): - raise TypeError( - "teeplot_postprocess must be str or Callable, " - f"not {type(teeplot_postprocess)} {teeplot_postprocess}" - ) - try: - import seaborn as sns - import seaborn - except ModuleNotFoundError: - pass - exec(teeplot_postprocess) + exec(teeplot_postprocess) incl = [*teeplot_outinclude] attr_maker = lambda ext: { @@ -296,65 +310,66 @@ def tee( out_folder.mkdir(parents=True, exist_ok=True) def save_callback(): - for ext in save: - - if ext not in teeplot_save: - if teeplot_verbose > 1: - print(f"skipping {out_path}") - continue - - out_path = pathlib.Path( - kn.chop( - str(out_folder / out_filenamer(ext)), - mkdir=True, - ), - ) - - if out_path in _history: - if teeplot_oncollision == "error": - raise RuntimeError(f"teeplot already created file {out_path}") - elif teeplot_oncollision == "fix": - count = _history[out_path] - suffix = f"ext={ext}" - assert str(out_path).endswith(suffix) - out_path = str(out_path)[:-len(suffix)] + f"#={count}+" + suffix - elif teeplot_oncollision == "ignore": - pass - elif teeplot_oncollision == "warn": - warnings.warn( - f"teeplot already created file {out_path}, overwriting it", - ) - else: - raise ValueError( - "teeplot_oncollision must be one of 'error', 'fix', " - f"'ignore', or 'warn', not {teeplot_oncollision}", - ) - _history[out_path] += 1 - - if teeplot_verbose: - print(out_path) - plt.savefig( - str(out_path), - bbox_inches='tight', - transparent=teeplot_transparent, - dpi=teeplot_dpi, - # see https://matplotlib.org/2.1.1/users/whats_new.html#reproducible-ps-pdf-and-svg-output - **dict( - metadata={ - key: None - for key in { - ".png": [], - ".pdf": ["CreationDate"], - ".svg": ["Date"], - }.get(ext, []) - }, - ) if ext != ".pgf" else {}, - ) - - if teeplot_show or (teeplot_show is None and hasattr(sys, 'ps1')): - plt.show() - - return teed + with matplotlib.rc_context(teeplot_rc_context): + for ext in save: + + if ext not in teeplot_save: + if teeplot_verbose > 1: + print(f"skipping {out_path}") + continue + + out_path = pathlib.Path( + kn.chop( + str(out_folder / out_filenamer(ext)), + mkdir=True, + ), + ) + + if out_path in _history: + if teeplot_oncollision == "error": + raise RuntimeError(f"teeplot already created file {out_path}") + elif teeplot_oncollision == "fix": + count = _history[out_path] + suffix = f"ext={ext}" + assert str(out_path).endswith(suffix) + out_path = str(out_path)[:-len(suffix)] + f"#={count}+" + suffix + elif teeplot_oncollision == "ignore": + pass + elif teeplot_oncollision == "warn": + warnings.warn( + f"teeplot already created file {out_path}, overwriting it", + ) + else: + raise ValueError( + "teeplot_oncollision must be one of 'error', 'fix', " + f"'ignore', or 'warn', not {teeplot_oncollision}", + ) + _history[out_path] += 1 + + if teeplot_verbose: + print(out_path) + plt.savefig( + str(out_path), + bbox_inches='tight', + transparent=teeplot_transparent, + dpi=teeplot_dpi, + # see https://matplotlib.org/2.1.1/users/whats_new.html#reproducible-ps-pdf-and-svg-output + **dict( + metadata={ + key: None + for key in { + ".png": [], + ".pdf": ["CreationDate"], + ".svg": ["Date"], + }.get(ext, []) + }, + ) if ext != ".pgf" else {}, + ) + + if teeplot_show or (teeplot_show is None and hasattr(sys, 'ps1')): + plt.show() + + return teed if teeplot_callback: return save_callback, teed diff --git a/tests/test_tee.py b/tests/test_tee.py index 1f42527..af6e568 100644 --- a/tests/test_tee.py +++ b/tests/test_tee.py @@ -350,6 +350,118 @@ def test_outexclude(): ) +def test_figsize(): + + np.random.seed(1) + x, y = np.random.normal(size=(2, 5000)).cumsum(axis=1) + + tp.tee( + sns.lineplot, + x=x, + y=y, + sort=False, + lw=1, + teeplot_outattrs={ + 'figsize' : 'metadata', + }, + teeplot_subdir='mydirectory', + teeplot_figsize=(10, 6), + ) + + assert tuple(plt.gcf().get_size_inches()) == (10, 6) + + for ext in '.pdf', '.png': + assert os.path.exists( + os.path.join('teeplots', 'mydirectory', f'figsize=metadata+viz=lineplot+ext={ext}'), + ) + + +def test_figsize_none(): + + np.random.seed(1) + x, y = np.random.normal(size=(2, 5000)).cumsum(axis=1) + + plt.figure(figsize=(4, 3)) + tp.tee( + sns.lineplot, + x=x, + y=y, + sort=False, + lw=1, + teeplot_outattrs={ + 'figsizenone' : 'metadata', + }, + teeplot_subdir='mydirectory', + teeplot_figsize=None, + ) + + assert tuple(plt.gcf().get_size_inches()) == (4, 3) + + for ext in '.pdf', '.png': + assert os.path.exists( + os.path.join('teeplots', 'mydirectory', f'figsizenone=metadata+viz=lineplot+ext={ext}'), + ) + + +def test_rc_context(): + + captured = {} + + def lineplot(**kwargs): + captured['lines.linewidth'] = plt.rcParams['lines.linewidth'] + return sns.lineplot(**kwargs) + + np.random.seed(1) + x, y = np.random.normal(size=(2, 5000)).cumsum(axis=1) + + before = plt.rcParams['lines.linewidth'] + tp.tee( + lineplot, + x=x, + y=y, + sort=False, + teeplot_outattrs={ + 'rccontext' : 'metadata', + }, + teeplot_subdir='mydirectory', + teeplot_rc_context={'lines.linewidth': 7.5}, + ) + + assert captured['lines.linewidth'] == 7.5 + assert plt.rcParams['lines.linewidth'] == before + + for ext in '.pdf', '.png': + assert os.path.exists( + os.path.join('teeplots', 'mydirectory', f'rccontext=metadata+viz=lineplot+ext={ext}'), + ) + + +def test_rc_context_default(): + + np.random.seed(1) + x, y = np.random.normal(size=(2, 5000)).cumsum(axis=1) + + before = plt.rcParams['lines.linewidth'] + + tp.tee( + sns.lineplot, + x=x, + y=y, + sort=False, + teeplot_outattrs={ + 'rccontextdefault' : 'metadata', + }, + teeplot_subdir='mydirectory', + ) + + assert plt.rcParams['lines.linewidth'] == before + + for ext in '.pdf', '.png': + assert os.path.exists( + os.path.join('teeplots', 'mydirectory', f'rccontextdefault=metadata+viz=lineplot+ext={ext}'), + ) + + def test_callback(): saveit, ax = tp.tee(