Skip to content

Commit

Permalink
Merge pull request #12 from nickderobertis/labels
Browse files Browse the repository at this point in the history
Add ability to pass labels for inputs and have them show up in styled dfs and plots
  • Loading branch information
github-actions[bot] committed Oct 13, 2020
2 parents e511dd9 + 5dcd19b commit 6546ec8
Show file tree
Hide file tree
Showing 9 changed files with 344 additions and 256 deletions.
2 changes: 1 addition & 1 deletion conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
]

# Package version in the format (major, minor, release)
PACKAGE_VERSION_TUPLE = (0, 2, 4)
PACKAGE_VERSION_TUPLE = (0, 2, 5)

# Short description of the package
PACKAGE_SHORT_DESCRIPTION = "Python Sensitivity Analysis - Gradient DataFrames and Hex-Bin Plots"
Expand Down
499 changes: 253 additions & 246 deletions nbexamples/Sensitivity Analysis.ipynb

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion sensitivity/df.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@


def sensitivity_df(sensitivity_values: Dict[str, Any], func: Callable,
result_name: str = 'Result', **func_kwargs) -> pd.DataFrame:
result_name: str = 'Result',
labels: Optional[Dict[str, str]] = None,
**func_kwargs) -> pd.DataFrame:
"""
Creates a DataFrame containing the results of sensitivity analysis.
Expand All @@ -26,6 +28,8 @@ def sensitivity_df(sensitivity_values: Dict[str, Any], func: Callable,
:param func: Function that accepts arguments with names matching the keys of sensitivity_values, and outputs a
scalar value.
:param result_name: Name for result shown in graph color bar label
:param labels: Optional dictionary where keys are arguments of the function and values are the displayed names
for these arguments in the styled DataFrames and plots
:param func_kwargs: Additional arguments to pass to func, regardless of the sensitivity values picked
:return: a DataFrame containing the results from sensitivity analysis on func
"""
Expand All @@ -41,6 +45,8 @@ def sensitivity_df(sensitivity_values: Dict[str, Any], func: Callable,
df = df.append(pd.DataFrame(pd.Series(base_param_dict)).T)
df.reset_index(drop=True, inplace=True)
df = df.convert_dtypes()
if labels:
df.rename(columns=labels, inplace=True)

return df

Expand Down
21 changes: 19 additions & 2 deletions sensitivity/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class SensitivityAnalyzer:
same style as would be passed to df.style.format, e.g. '${:,.2f}' for USD formatting
:param color_map: matplotlib color map, default is RdYlGn (red, yellow, green). See
https://matplotlib.org/3.3.2/tutorials/colors/colormaps.html
:param labels: Optional dictionary where keys are arguments of the function and values are the displayed names
for these arguments in the styled DataFrames and plots
:return: Sensitivity analysis hex bin sub plot figure
Examples:
Expand Down Expand Up @@ -75,6 +77,7 @@ class SensitivityAnalyzer:
func_kwargs_dict: Optional[Dict[str, Any]] = None
num_fmt: Optional[str] = None
color_map: str = 'RdYlGn'
labels: Optional[Dict[str, str]] = None

def __post_init__(self):
if self.func_kwargs_dict is None:
Expand All @@ -83,6 +86,7 @@ def __post_init__(self):
self.sensitivity_values,
self.func,
result_name=self.result_name,
labels=self.labels,
**self.func_kwargs_dict
)

Expand All @@ -100,7 +104,7 @@ def plot(self, **kwargs) -> plt.Figure:
color_map=self.color_map,
)
config_dict.update(**kwargs)
sensitivity_cols = list(self.sensitivity_values.keys())
sensitivity_cols = self.sensitivity_cols
return _hex_figure_from_sensitivity_df(
self.df,
sensitivity_cols,
Expand All @@ -125,7 +129,7 @@ def styled_dfs(self, disp: bool = True, **kwargs) -> Union[Styler, Dict[Sequence
)
config_dict.update(**kwargs)
# Output a single Styler if only one or two variables
sensitivity_cols = list(self.sensitivity_values.keys())
sensitivity_cols = self.sensitivity_cols
if len(sensitivity_cols) == 1:
output[tuple(sensitivity_cols)] = _style_sensitivity_df(
self.df,
Expand Down Expand Up @@ -188,6 +192,19 @@ def styled_dfs(self, disp: bool = True, **kwargs) -> Union[Styler, Dict[Sequence

return output

@property
def sensitivity_cols(self) -> List[str]:
sensitivity_cols = list(self.sensitivity_values.keys())
if self.labels:
new_sensitivity_cols: List[str] = []
for col in sensitivity_cols:
if col in self.labels:
new_sensitivity_cols.append(self.labels[col])
else:
new_sensitivity_cols.append(col)
sensitivity_cols = new_sensitivity_cols
return sensitivity_cols


def _display_header(text: str):
html_str = f'<h2>{text}</h2>'
Expand Down
17 changes: 15 additions & 2 deletions tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,19 @@

INPUT_FILES_FOLDER = Path(os.path.join('tests', 'input_data'))
DF_STYLED_PATH = os.path.join(INPUT_FILES_FOLDER, 'df_styled.html')
DF_LABELED_PATH = os.path.join(INPUT_FILES_FOLDER, 'df_labeled.html')
DF_STYLED_NUM_FMT_PATH = os.path.join(INPUT_FILES_FOLDER, 'df_styled_num_fmt.html')
DF_STYLE_UUID = '1ee5ad65-4cac-42e3-8133-7ae800cb23ad'
DEFAULT_PLOT_PATH = INPUT_FILES_FOLDER / 'default_plot.png'
PLOT_THREE_PATH = INPUT_FILES_FOLDER / 'plot_three.png'
PLOT_OPTIONS_PATH = INPUT_FILES_FOLDER / 'plot_options.png'
RESULT_NAME = 'my_res'
TWO_VALUE_LABELS = {
'value1': 'Formatted 1',
'value2': 'Formatted 2'
}
THREE_VALUE_LABELS = deepcopy(TWO_VALUE_LABELS)
THREE_VALUE_LABELS['value3'] = 'Formatted 3'
EXPECT_DF_TWO_VALUE = pd.DataFrame(
[
(1, 4, 10),
Expand All @@ -24,6 +31,7 @@
],
columns=['value1', 'value2', RESULT_NAME]
)
EXPECT_DF_TWO_VALUE_LABELS = EXPECT_DF_TWO_VALUE.rename(columns=TWO_VALUE_LABELS)

EXPECT_DF_THREE_VALUE = pd.DataFrame(
[
Expand All @@ -47,6 +55,7 @@
SENSITIVITY_VALUES_THREE_VALUE['value3'] = [6, 7]



def add_5_to_values(value1, value2):
return value1 + value2 + 5

Expand All @@ -55,11 +64,15 @@ def add_10_to_values(value1, value2, value3=5):
return value1 + value2 + value3 + 10


def assert_styled_matches(styler: Styler, file_path: str = DF_STYLED_PATH):
def assert_styled_matches(styler: Styler, file_path: str = DF_STYLED_PATH, generate: bool = False):
compare_html = styler.set_uuid(DF_STYLE_UUID).render()

if generate:
Path(file_path).write_text(compare_html)

with open(file_path, 'r') as f:
expect_html = f.read()

compare_html = styler.set_uuid(DF_STYLE_UUID).render()
assert compare_html == expect_html


Expand Down
28 changes: 28 additions & 0 deletions tests/input_data/df_labeled.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
<style type="text/css" >
#T_1ee5ad65-4cac-42e3-8133-7ae800cb23adrow0_col0{
background-color: #a50026;
color: #f1f1f1;
background-color: #a50026;
color: #f1f1f1;
}#T_1ee5ad65-4cac-42e3-8133-7ae800cb23adrow0_col1,#T_1ee5ad65-4cac-42e3-8133-7ae800cb23adrow1_col0{
background-color: #feffbe;
color: #000000;
background-color: #feffbe;
color: #000000;
}#T_1ee5ad65-4cac-42e3-8133-7ae800cb23adrow1_col1{
background-color: #006837;
color: #f1f1f1;
background-color: #006837;
color: #f1f1f1;
}</style><table id="T_1ee5ad65-4cac-42e3-8133-7ae800cb23ad" ><caption>my_res - Formatted 1 vs. Formatted 2</caption><thead> <tr> <th class="blank level0" ></th> <th class="col_heading level0 col0" >4</th> <th class="col_heading level0 col1" >5</th> </tr> <tr> <th class="index_name level0" >Formatted 1</th> <th class="blank" ></th> <th class="blank" ></th> </tr></thead><tbody>
<tr>
<th id="T_1ee5ad65-4cac-42e3-8133-7ae800cb23adlevel0_row0" class="row_heading level0 row0" >1</th>
<td id="T_1ee5ad65-4cac-42e3-8133-7ae800cb23adrow0_col0" class="data row0 col0" >10.000000</td>
<td id="T_1ee5ad65-4cac-42e3-8133-7ae800cb23adrow0_col1" class="data row0 col1" >11.000000</td>
</tr>
<tr>
<th id="T_1ee5ad65-4cac-42e3-8133-7ae800cb23adlevel0_row1" class="row_heading level0 row1" >2</th>
<td id="T_1ee5ad65-4cac-42e3-8133-7ae800cb23adrow1_col0" class="data row1 col0" >11.000000</td>
<td id="T_1ee5ad65-4cac-42e3-8133-7ae800cb23adrow1_col1" class="data row1 col1" >12.000000</td>
</tr>
</tbody></table>
Binary file modified tests/input_data/plot_options.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 13 additions & 1 deletion tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from pandas.testing import assert_frame_equal

from sensitivity.df import sensitivity_df
from tests.base import EXPECT_DF_TWO_VALUE, SENSITIVITY_VALUES_TWO_VALUE, add_5_to_values, RESULT_NAME
from tests.base import EXPECT_DF_TWO_VALUE, SENSITIVITY_VALUES_TWO_VALUE, add_5_to_values, RESULT_NAME, \
TWO_VALUE_LABELS, EXPECT_DF_TWO_VALUE_LABELS


def test_create_sensitivity_df():
Expand All @@ -13,3 +14,14 @@ def test_create_sensitivity_df():
)

assert_frame_equal(df, EXPECT_DF_TWO_VALUE, check_dtype=False)


def test_labeled_sensitivity_df():
df = sensitivity_df(
SENSITIVITY_VALUES_TWO_VALUE,
add_5_to_values,
result_name=RESULT_NAME,
labels=TWO_VALUE_LABELS
)

assert_frame_equal(df, EXPECT_DF_TWO_VALUE_LABELS, check_dtype=False)
11 changes: 8 additions & 3 deletions tests/test_sensitivity_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sensitivity import SensitivityAnalyzer
from tests.base import EXPECT_DF_TWO_VALUE, SENSITIVITY_VALUES_TWO_VALUE, add_5_to_values, RESULT_NAME, \
SENSITIVITY_VALUES_THREE_VALUE, add_10_to_values, EXPECT_DF_THREE_VALUE, assert_styled_matches, \
DF_STYLED_NUM_FMT_PATH, assert_graph_matches, PLOT_THREE_PATH, PLOT_OPTIONS_PATH
DF_STYLED_NUM_FMT_PATH, assert_graph_matches, PLOT_THREE_PATH, PLOT_OPTIONS_PATH, TWO_VALUE_LABELS, DF_LABELED_PATH


class TestSensitivityAnalyzer:
Expand Down Expand Up @@ -47,6 +47,11 @@ def test_create_styled_dfs_with_num_fmt(self):
assert_styled_matches(result, DF_STYLED_NUM_FMT_PATH)
assert_styled_matches(result2, DF_STYLED_NUM_FMT_PATH)

def test_create_styled_dfs_with_labels(self):
sa = self.create_sa(labels=TWO_VALUE_LABELS)
result = sa.styled_dfs()
assert_styled_matches(result, DF_LABELED_PATH)

def test_create_styled_dfs_three_values(self):
sa = self.create_sa(
sensitivity_values=SENSITIVITY_VALUES_THREE_VALUE,
Expand All @@ -71,9 +76,9 @@ def test_create_plot_with_options(self):
options = dict(
grid_size=2, color_map='viridis', reverse_colors=True
)
sa = self.create_sa(**options)
sa = self.create_sa(labels=TWO_VALUE_LABELS, **options)
result = sa.plot()
assert_graph_matches(result, file_path=PLOT_OPTIONS_PATH)
sa = self.create_sa()
sa = self.create_sa(labels=TWO_VALUE_LABELS)
result = sa.plot(**options)
assert_graph_matches(result, file_path=PLOT_OPTIONS_PATH)

0 comments on commit 6546ec8

Please sign in to comment.