Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to pass labels for inputs and have them show up in styled dfs and plots #12

Merged
merged 1 commit into from
Oct 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
]

# Package version in the format (major, minor, release)
PACKAGE_VERSION_TUPLE = (0, 2, 4)
PACKAGE_VERSION_TUPLE = (0, 2, 5)

# Short description of the package
PACKAGE_SHORT_DESCRIPTION = "Python Sensitivity Analysis - Gradient DataFrames and Hex-Bin Plots"
Expand Down
499 changes: 253 additions & 246 deletions nbexamples/Sensitivity Analysis.ipynb

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion sensitivity/df.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@


def sensitivity_df(sensitivity_values: Dict[str, Any], func: Callable,
result_name: str = 'Result', **func_kwargs) -> pd.DataFrame:
result_name: str = 'Result',
labels: Optional[Dict[str, str]] = None,
**func_kwargs) -> pd.DataFrame:
"""
Creates a DataFrame containing the results of sensitivity analysis.
Expand All @@ -26,6 +28,8 @@ def sensitivity_df(sensitivity_values: Dict[str, Any], func: Callable,
:param func: Function that accepts arguments with names matching the keys of sensitivity_values, and outputs a
scalar value.
:param result_name: Name for result shown in graph color bar label
:param labels: Optional dictionary where keys are arguments of the function and values are the displayed names
for these arguments in the styled DataFrames and plots
:param func_kwargs: Additional arguments to pass to func, regardless of the sensitivity values picked
:return: a DataFrame containing the results from sensitivity analysis on func
"""
Expand All @@ -41,6 +45,8 @@ def sensitivity_df(sensitivity_values: Dict[str, Any], func: Callable,
df = df.append(pd.DataFrame(pd.Series(base_param_dict)).T)
df.reset_index(drop=True, inplace=True)
df = df.convert_dtypes()
if labels:
df.rename(columns=labels, inplace=True)

return df

Expand Down
21 changes: 19 additions & 2 deletions sensitivity/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class SensitivityAnalyzer:
same style as would be passed to df.style.format, e.g. '${:,.2f}' for USD formatting
:param color_map: matplotlib color map, default is RdYlGn (red, yellow, green). See
https://matplotlib.org/3.3.2/tutorials/colors/colormaps.html
:param labels: Optional dictionary where keys are arguments of the function and values are the displayed names
for these arguments in the styled DataFrames and plots
:return: Sensitivity analysis hex bin sub plot figure
Examples:
Expand Down Expand Up @@ -75,6 +77,7 @@ class SensitivityAnalyzer:
func_kwargs_dict: Optional[Dict[str, Any]] = None
num_fmt: Optional[str] = None
color_map: str = 'RdYlGn'
labels: Optional[Dict[str, str]] = None

def __post_init__(self):
if self.func_kwargs_dict is None:
Expand All @@ -83,6 +86,7 @@ def __post_init__(self):
self.sensitivity_values,
self.func,
result_name=self.result_name,
labels=self.labels,
**self.func_kwargs_dict
)

Expand All @@ -100,7 +104,7 @@ def plot(self, **kwargs) -> plt.Figure:
color_map=self.color_map,
)
config_dict.update(**kwargs)
sensitivity_cols = list(self.sensitivity_values.keys())
sensitivity_cols = self.sensitivity_cols
return _hex_figure_from_sensitivity_df(
self.df,
sensitivity_cols,
Expand All @@ -125,7 +129,7 @@ def styled_dfs(self, disp: bool = True, **kwargs) -> Union[Styler, Dict[Sequence
)
config_dict.update(**kwargs)
# Output a single Styler if only one or two variables
sensitivity_cols = list(self.sensitivity_values.keys())
sensitivity_cols = self.sensitivity_cols
if len(sensitivity_cols) == 1:
output[tuple(sensitivity_cols)] = _style_sensitivity_df(
self.df,
Expand Down Expand Up @@ -188,6 +192,19 @@ def styled_dfs(self, disp: bool = True, **kwargs) -> Union[Styler, Dict[Sequence

return output

@property
def sensitivity_cols(self) -> List[str]:
sensitivity_cols = list(self.sensitivity_values.keys())
if self.labels:
new_sensitivity_cols: List[str] = []
for col in sensitivity_cols:
if col in self.labels:
new_sensitivity_cols.append(self.labels[col])
else:
new_sensitivity_cols.append(col)
sensitivity_cols = new_sensitivity_cols
return sensitivity_cols


def _display_header(text: str):
html_str = f'<h2>{text}</h2>'
Expand Down
17 changes: 15 additions & 2 deletions tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,19 @@

INPUT_FILES_FOLDER = Path(os.path.join('tests', 'input_data'))
DF_STYLED_PATH = os.path.join(INPUT_FILES_FOLDER, 'df_styled.html')
DF_LABELED_PATH = os.path.join(INPUT_FILES_FOLDER, 'df_labeled.html')
DF_STYLED_NUM_FMT_PATH = os.path.join(INPUT_FILES_FOLDER, 'df_styled_num_fmt.html')
DF_STYLE_UUID = '1ee5ad65-4cac-42e3-8133-7ae800cb23ad'
DEFAULT_PLOT_PATH = INPUT_FILES_FOLDER / 'default_plot.png'
PLOT_THREE_PATH = INPUT_FILES_FOLDER / 'plot_three.png'
PLOT_OPTIONS_PATH = INPUT_FILES_FOLDER / 'plot_options.png'
RESULT_NAME = 'my_res'
TWO_VALUE_LABELS = {
'value1': 'Formatted 1',
'value2': 'Formatted 2'
}
THREE_VALUE_LABELS = deepcopy(TWO_VALUE_LABELS)
THREE_VALUE_LABELS['value3'] = 'Formatted 3'
EXPECT_DF_TWO_VALUE = pd.DataFrame(
[
(1, 4, 10),
Expand All @@ -24,6 +31,7 @@
],
columns=['value1', 'value2', RESULT_NAME]
)
EXPECT_DF_TWO_VALUE_LABELS = EXPECT_DF_TWO_VALUE.rename(columns=TWO_VALUE_LABELS)

EXPECT_DF_THREE_VALUE = pd.DataFrame(
[
Expand All @@ -47,6 +55,7 @@
SENSITIVITY_VALUES_THREE_VALUE['value3'] = [6, 7]



def add_5_to_values(value1, value2):
return value1 + value2 + 5

Expand All @@ -55,11 +64,15 @@ def add_10_to_values(value1, value2, value3=5):
return value1 + value2 + value3 + 10


def assert_styled_matches(styler: Styler, file_path: str = DF_STYLED_PATH):
def assert_styled_matches(styler: Styler, file_path: str = DF_STYLED_PATH, generate: bool = False):
compare_html = styler.set_uuid(DF_STYLE_UUID).render()

if generate:
Path(file_path).write_text(compare_html)

with open(file_path, 'r') as f:
expect_html = f.read()

compare_html = styler.set_uuid(DF_STYLE_UUID).render()
assert compare_html == expect_html


Expand Down
28 changes: 28 additions & 0 deletions tests/input_data/df_labeled.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
<style type="text/css" >
#T_1ee5ad65-4cac-42e3-8133-7ae800cb23adrow0_col0{
background-color: #a50026;
color: #f1f1f1;
background-color: #a50026;
color: #f1f1f1;
}#T_1ee5ad65-4cac-42e3-8133-7ae800cb23adrow0_col1,#T_1ee5ad65-4cac-42e3-8133-7ae800cb23adrow1_col0{
background-color: #feffbe;
color: #000000;
background-color: #feffbe;
color: #000000;
}#T_1ee5ad65-4cac-42e3-8133-7ae800cb23adrow1_col1{
background-color: #006837;
color: #f1f1f1;
background-color: #006837;
color: #f1f1f1;
}</style><table id="T_1ee5ad65-4cac-42e3-8133-7ae800cb23ad" ><caption>my_res - Formatted 1 vs. Formatted 2</caption><thead> <tr> <th class="blank level0" ></th> <th class="col_heading level0 col0" >4</th> <th class="col_heading level0 col1" >5</th> </tr> <tr> <th class="index_name level0" >Formatted 1</th> <th class="blank" ></th> <th class="blank" ></th> </tr></thead><tbody>
<tr>
<th id="T_1ee5ad65-4cac-42e3-8133-7ae800cb23adlevel0_row0" class="row_heading level0 row0" >1</th>
<td id="T_1ee5ad65-4cac-42e3-8133-7ae800cb23adrow0_col0" class="data row0 col0" >10.000000</td>
<td id="T_1ee5ad65-4cac-42e3-8133-7ae800cb23adrow0_col1" class="data row0 col1" >11.000000</td>
</tr>
<tr>
<th id="T_1ee5ad65-4cac-42e3-8133-7ae800cb23adlevel0_row1" class="row_heading level0 row1" >2</th>
<td id="T_1ee5ad65-4cac-42e3-8133-7ae800cb23adrow1_col0" class="data row1 col0" >11.000000</td>
<td id="T_1ee5ad65-4cac-42e3-8133-7ae800cb23adrow1_col1" class="data row1 col1" >12.000000</td>
</tr>
</tbody></table>
Binary file modified tests/input_data/plot_options.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 13 additions & 1 deletion tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from pandas.testing import assert_frame_equal

from sensitivity.df import sensitivity_df
from tests.base import EXPECT_DF_TWO_VALUE, SENSITIVITY_VALUES_TWO_VALUE, add_5_to_values, RESULT_NAME
from tests.base import EXPECT_DF_TWO_VALUE, SENSITIVITY_VALUES_TWO_VALUE, add_5_to_values, RESULT_NAME, \
TWO_VALUE_LABELS, EXPECT_DF_TWO_VALUE_LABELS


def test_create_sensitivity_df():
Expand All @@ -13,3 +14,14 @@ def test_create_sensitivity_df():
)

assert_frame_equal(df, EXPECT_DF_TWO_VALUE, check_dtype=False)


def test_labeled_sensitivity_df():
df = sensitivity_df(
SENSITIVITY_VALUES_TWO_VALUE,
add_5_to_values,
result_name=RESULT_NAME,
labels=TWO_VALUE_LABELS
)

assert_frame_equal(df, EXPECT_DF_TWO_VALUE_LABELS, check_dtype=False)
11 changes: 8 additions & 3 deletions tests/test_sensitivity_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sensitivity import SensitivityAnalyzer
from tests.base import EXPECT_DF_TWO_VALUE, SENSITIVITY_VALUES_TWO_VALUE, add_5_to_values, RESULT_NAME, \
SENSITIVITY_VALUES_THREE_VALUE, add_10_to_values, EXPECT_DF_THREE_VALUE, assert_styled_matches, \
DF_STYLED_NUM_FMT_PATH, assert_graph_matches, PLOT_THREE_PATH, PLOT_OPTIONS_PATH
DF_STYLED_NUM_FMT_PATH, assert_graph_matches, PLOT_THREE_PATH, PLOT_OPTIONS_PATH, TWO_VALUE_LABELS, DF_LABELED_PATH


class TestSensitivityAnalyzer:
Expand Down Expand Up @@ -47,6 +47,11 @@ def test_create_styled_dfs_with_num_fmt(self):
assert_styled_matches(result, DF_STYLED_NUM_FMT_PATH)
assert_styled_matches(result2, DF_STYLED_NUM_FMT_PATH)

def test_create_styled_dfs_with_labels(self):
sa = self.create_sa(labels=TWO_VALUE_LABELS)
result = sa.styled_dfs()
assert_styled_matches(result, DF_LABELED_PATH)

def test_create_styled_dfs_three_values(self):
sa = self.create_sa(
sensitivity_values=SENSITIVITY_VALUES_THREE_VALUE,
Expand All @@ -71,9 +76,9 @@ def test_create_plot_with_options(self):
options = dict(
grid_size=2, color_map='viridis', reverse_colors=True
)
sa = self.create_sa(**options)
sa = self.create_sa(labels=TWO_VALUE_LABELS, **options)
result = sa.plot()
assert_graph_matches(result, file_path=PLOT_OPTIONS_PATH)
sa = self.create_sa()
sa = self.create_sa(labels=TWO_VALUE_LABELS)
result = sa.plot(**options)
assert_graph_matches(result, file_path=PLOT_OPTIONS_PATH)