diff --git a/Pipfile b/Pipfile index b88d4c6..ee20996 100644 --- a/Pipfile +++ b/Pipfile @@ -23,6 +23,7 @@ pypandoc = "*" cruft = "*" pandas = "*" matplotlib = "*" +pd_utils = "*" [requires] python_version = "3.7" diff --git a/Pipfile.lock b/Pipfile.lock index 346baa7..3977b9c 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "5edbf4f3a77ab84b73b1ae61a285cbf7860140e7243410d65cbfdca7d8f28200" + "sha256": "a2c3862be945fc5f7ea6873243e103731e6d032e6e0d531278253782a6c742cc" }, "pipfile-spec": 6, "requires": { @@ -522,12 +522,33 @@ "index": "pypi", "version": "==1.0.1" }, + "pandasql": { + "hashes": [ + "sha256:1eb248869086435a7d85281ebd9fe525d69d9d954a0dceb854f71a8d0fd8de69" + ], + "version": "==0.7.3" + }, "pathtools": { "hashes": [ "sha256:7c35c5421a39bb82e58018febd90e3b6e5db34c5443aaaf742b3f33d4655f1c0" ], "version": "==0.1.2" }, + "patsy": { + "hashes": [ + "sha256:5465be1c0e670c3a965355ec09e9a502bf2c4cbe4875e8528b0221190a8a5d40", + "sha256:f115cec4201e1465cd58b9866b0b0e7b941caafec129869057405bfe5b5e3991" + ], + "version": "==0.5.1" + }, + "pd-utils": { + "hashes": [ + "sha256:78c349031f2d954ee00601bfb0a825c6a54efca3e44a2a449767c40c851081f4", + "sha256:a83a46304c1d9ca0a0e9b5435d5629f785727d32dade7818da5c16ee466b4b88" + ], + "index": "pypi", + "version": "==0.1.0" + }, "pillow": { "hashes": [ "sha256:0a628977ac2e01ca96aaae247ec2bd38e729631ddf2221b4b715446fd45505be", @@ -709,6 +730,38 @@ ], "version": "==0.9.1" }, + "sas7bdat": { + "hashes": [ + "sha256:484c609d962442203c15bc719a638de992a23cd13bc1971a5af6dfb0daf9f797" + ], + "version": "==2.2.3" + }, + "scipy": { + "hashes": [ + "sha256:00af72998a46c25bdb5824d2b729e7dabec0c765f9deb0b504f928591f5ff9d4", + "sha256:0902a620a381f101e184a958459b36d3ee50f5effd186db76e131cbefcbb96f7", + "sha256:1e3190466d669d658233e8a583b854f6386dd62d655539b77b3fa25bfb2abb70", + "sha256:2cce3f9847a1a51019e8c5b47620da93950e58ebc611f13e0d11f4980ca5fecb", + "sha256:3092857f36b690a321a662fe5496cb816a7f4eecd875e1d36793d92d3f884073", + "sha256:386086e2972ed2db17cebf88610aab7d7f6e2c0ca30042dc9a89cf18dcc363fa", + "sha256:71eb180f22c49066f25d6df16f8709f215723317cc951d99e54dc88020ea57be", + "sha256:770254a280d741dd3436919d47e35712fb081a6ff8bafc0f319382b954b77802", + "sha256:787cc50cab3020a865640aba3485e9fbd161d4d3b0d03a967df1a2881320512d", + "sha256:8a07760d5c7f3a92e440ad3aedcc98891e915ce857664282ae3c0220f3301eb6", + "sha256:8d3bc3993b8e4be7eade6dcc6fd59a412d96d3a33fa42b0fa45dc9e24495ede9", + "sha256:9508a7c628a165c2c835f2497837bf6ac80eb25291055f56c129df3c943cbaf8", + "sha256:a144811318853a23d32a07bc7fd5561ff0cac5da643d96ed94a4ffe967d89672", + "sha256:a1aae70d52d0b074d8121333bc807a485f9f1e6a69742010b33780df2e60cfe0", + "sha256:a2d6df9eb074af7f08866598e4ef068a2b310d98f87dc23bd1b90ec7bdcec802", + "sha256:bb517872058a1f087c4528e7429b4a44533a902644987e7b2fe35ecc223bc408", + "sha256:c5cac0c0387272ee0e789e94a570ac51deb01c796b37fb2aad1fb13f85e2f97d", + "sha256:cc971a82ea1170e677443108703a2ec9ff0f70752258d0e9f5433d00dda01f59", + "sha256:dba8306f6da99e37ea08c08fef6e274b5bf8567bb094d1dbe86a20e532aca088", + "sha256:dc60bb302f48acf6da8ca4444cfa17d52c63c5415302a9ee77b3b21618090521", + "sha256:dee1bbf3a6c8f73b6b218cb28eed8dd13347ea2f87d572ce19b289d6fd3fbc59" + ], + "version": "==1.4.1" + }, "secretstorage": { "hashes": [ "sha256:15da8a989b65498e29be338b3b279965f1b8f09b9668bd8010da183024c8bff6", @@ -845,6 +898,38 @@ ], "version": "==1.1.3" }, + "sqlalchemy": { + "hashes": [ + "sha256:64a7b71846db6423807e96820993fa12a03b89127d278290ca25c0b11ed7b4fb" + ], + "version": "==1.3.13" + }, + "statsmodels": { + "hashes": [ + "sha256:071649014680bc7cad74d323878a41099db0bb1bb0d93e7d640a0d341b467da6", + "sha256:0a8ee8fc091d9ef1db68f01e6e0079acc0f41671dfbac463131939ca573f8c71", + "sha256:13e35799cd86ccbb9e94941b7199c75f7f5194ce3b36a11cb5af8ae8b791301c", + "sha256:1840b899a4483b520531d0b731fb57e11a9251e2ed6c471dda0e77716f7b7bd6", + "sha256:1c3591b8d34240447b54936c360d1556904c81058b10b2a28092267af683bd4f", + "sha256:28b869039bb0f905f81343e3c5f1a13a58ef7d758c4a5f60b9b469921dbcda6a", + "sha256:414d423e804769bc6959ae57dc36595976fb12732e7c3ed02bdc45e970592120", + "sha256:4fb440b25dff41ee6df21e6cf83063aec669313fea799c9f2cb4b9204723e79e", + "sha256:5c135ce37036e3791c229d30a13475ba0fb868015fdbd0a1878261b48026ba76", + "sha256:64fee746d1089808cbbbcc377910e93dce21646aa0e67fa7d54ee488df545524", + "sha256:67224a71c8c5fbf994d59198c10caa28eb6436dd4518b54468901bc6e91cdbfb", + "sha256:70bba2b4ce256e5b6d9193cb9ec5e7ebafc96f6334e01248eeddaa62ba6ef60d", + "sha256:70fd072beae7403343783d9850190052a5fa83029c4c5806429d8ee0b919d7b7", + "sha256:854c0fce335fb3271fd3786b94931443bb282de77c7082d735abfa0bfed73ab1", + "sha256:9be4907a8b8ac8d0e1dd143c905faf9c28a4072ed2b0dfcc87aca50aff9bfe6f", + "sha256:b1d01224b761c2d1fae2a89afb9ef039c7a63a6882f602128652baea437188f5", + "sha256:c8319b91f5892f36debefc4c259ef52457a2ce0a0b4486f5f999ad8d45977767", + "sha256:c8fe2a8d014c130dbacf49ef2f186404699d2aeb1cb8ead92d6a5779b1dd007c", + "sha256:d0cf4680939f34898c820f9b310cad05c4d7fa3d17d078eca3928a933331abd8", + "sha256:e213c84f0f32b984305855169344d53d594a09bce159a8699967ff592ee171a7", + "sha256:eb19fd8dcf7bb2b45b0835074face22b53bdfb6cc8d778fd072ca303c8351adb" + ], + "version": "==0.11.0" + }, "tornado": { "hashes": [ "sha256:349884248c36801afa19e342a77cc4458caca694b0eda633f5878e458a44cb2c", diff --git a/conf.py b/conf.py index eb6d8ed..5add2a7 100644 --- a/conf.py +++ b/conf.py @@ -17,7 +17,7 @@ ] # Package version in the format (major, minor, release) -PACKAGE_VERSION_TUPLE = (0, 1, 0) +PACKAGE_VERSION_TUPLE = (0, 1, 1) # Short description of the package PACKAGE_SHORT_DESCRIPTION = "Python Sensitivity Analysis - Gradient DataFrames and Hex-Bin Plots" @@ -61,6 +61,7 @@ # 'otherpackage>=1,<2' 'pandas', 'matplotlib', + 'pd_utils', ] # Add any third party packages you use in requirements for optional features of your package here diff --git a/sensitivity/df.py b/sensitivity/df.py index d4a5bbc..a08c172 100644 --- a/sensitivity/df.py +++ b/sensitivity/df.py @@ -1,7 +1,8 @@ -from typing import Dict, Any, Callable +from typing import Dict, Any, Callable, Sequence, Optional import itertools from copy import deepcopy import pandas as pd +import pd_utils from pandas.io.formats.style import Styler from sensitivity.colors import _get_color_map @@ -33,11 +34,35 @@ def sensitivity_df(sensitivity_values: Dict[str, Any], func: Callable, base_param_dict.update({result_name: result}) df = df.append(pd.DataFrame(pd.Series(base_param_dict)).T) df.reset_index(drop=True, inplace=True) + df = df.convert_dtypes() return df -def _style_sensitivity_df(df: pd.DataFrame, reverse_colors: bool = False) -> Styler: +def _two_variable_sensitivity_display_df(df: pd.DataFrame, col1: str, col2: str, + result_col: str = 'Result') -> pd.DataFrame: + selected_df = df[[col1, col2, result_col]] + wide_df = pd_utils.long_to_wide( + selected_df, + col1, + result_col, + colindex=col2, + colindex_only=True + ).set_index(col1) + + return wide_df + + +def _style_sensitivity_df(df: pd.DataFrame, col1: str, col2: Optional[str] = None, result_col: str = 'Result', + reverse_colors: bool = False, + col_subset: Optional[Sequence[str]] = None) -> Styler: + if col2 is not None: + caption = f'{result_col} - {col1} vs. {col2}' + else: + caption = f'{result_col} vs. {col1}' + color_str = _get_color_map(reverse_colors=reverse_colors) - return df.style.background_gradient(cmap=color_str) + return df.style.background_gradient( + cmap=color_str, subset=col_subset, axis=None + ).set_caption(caption) diff --git a/sensitivity/main.py b/sensitivity/main.py index 623f5ab..24ea7d1 100644 --- a/sensitivity/main.py +++ b/sensitivity/main.py @@ -1,10 +1,12 @@ +import itertools from dataclasses import dataclass -from typing import Dict, Any, Callable, Optional +from typing import Dict, Any, Callable, Optional, List, Union import numpy as np +from pandas.io.formats.style import Styler import matplotlib.pyplot as plt -from sensitivity.df import sensitivity_df, _style_sensitivity_df +from sensitivity.df import sensitivity_df, _style_sensitivity_df, _two_variable_sensitivity_display_df from sensitivity.hexbin import sensitivity_hex_plots, _hex_figure_from_sensitivity_df @@ -52,7 +54,7 @@ class SensitivityAnalyzer: >>> sa.df >>> >>> # Styled DataFrame - >>> sa.styled_df + >>> sa.styled_dfs >>> >>> # Hex-Bin Plot >>> sa.plot @@ -89,8 +91,50 @@ def plot(self) -> plt.Figure: ) @property - def styled_df(self): - return _style_sensitivity_df( - self.df, - reverse_colors=self.reverse_colors - ) + def styled_dfs(self) -> Union[Styler, List[Styler]]: + # Output a single Styler if only one or two variables + sensitivity_cols = list(self.sensitivity_values.keys()) + if len(sensitivity_cols) == 1: + return _style_sensitivity_df( + self.df, + sensitivity_cols[0], + reverse_colors=self.reverse_colors, + col_subset=[self.result_name], + result_col=self.result_name + ) + elif len(sensitivity_cols) == 2: + col1 = sensitivity_cols[0] + col2 = sensitivity_cols[1] + df = _two_variable_sensitivity_display_df( + self.df, + col1, + col2, + result_col=self.result_name + ) + return _style_sensitivity_df( + df, + col1, + col2=col2, + reverse_colors=self.reverse_colors, + result_col=self.result_name, + ) + elif len(sensitivity_cols) == 0: + raise ValueError('must pass sensitivity columns') + + # Length must be greater than 2, need to output multiple, one for each pair of variables + results = [] + for col1, col2 in itertools.combinations(sensitivity_cols, 2): + df = _two_variable_sensitivity_display_df( + self.df, + col1, + col2, + result_col=self.result_name + ) + results.append(_style_sensitivity_df( + df, + col1, + col2=col2, + reverse_colors=self.reverse_colors, + result_col=self.result_name, + )) + return results \ No newline at end of file diff --git a/tests/test_sensitivity_analyzer.py b/tests/test_sensitivity_analyzer.py index 27f094f..be77dc6 100644 --- a/tests/test_sensitivity_analyzer.py +++ b/tests/test_sensitivity_analyzer.py @@ -23,9 +23,9 @@ def test_create_df(self): sa = self.create_sa() assert_frame_equal(sa.df, EXPECT_DF, check_dtype=False) - def test_create_styled_df(self): + def test_create_styled_dfs(self): sa = self.create_sa() - sa.styled_df + sa.styled_dfs # TODO [#1]: determine how to test pandas Styler object beyond creation without error def test_create_plot(self):