Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] get vizro ai customized text output #488

Merged
merged 29 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
993ad1b
update helper
Anna-Xiong May 16, 2024
dee20fb
add get all output method
Anna-Xiong May 16, 2024
2dd6ddc
update get_outputs
Anna-Xiong May 16, 2024
b1c0929
update get_outputs
Anna-Xiong May 16, 2024
c7289b8
update get_plot_outputs
Anna-Xiong May 16, 2024
e6debe9
changelog
Anna-Xiong May 16, 2024
3b43d9b
Merge branch 'main' into feat/get_vizro_ai_customized_output
Anna-Xiong May 21, 2024
99354d0
Merge branch 'main' into feat/get_vizro_ai_customized_output
Anna-Xiong May 23, 2024
4b40f38
fix
Anna-Xiong May 23, 2024
a6c811d
changelog
Anna-Xiong May 23, 2024
db450e4
fix
Anna-Xiong May 23, 2024
cc6fea6
docstring
Anna-Xiong May 23, 2024
8a7e4ef
Merge branch 'main' into feat/get_vizro_ai_customized_output
Anna-Xiong May 27, 2024
1a708d3
Merge branch 'main' into feat/get_vizro_ai_customized_output
Anna-Xiong Jun 3, 2024
20ba6a5
add dataclass and update show fig utils
Anna-Xiong Jun 3, 2024
cbb2a64
refactor return to Plot dataclass
Anna-Xiong Jun 3, 2024
1c0fc50
fix
Anna-Xiong Jun 3, 2024
3ec92b9
changelog
Anna-Xiong Jun 3, 2024
8b2e0ca
Merge branch 'main' into feat/get_vizro_ai_customized_output
Anna-Xiong Jun 3, 2024
99133ca
update show fig rule
Anna-Xiong Jun 3, 2024
52638cc
update explain info msg
Anna-Xiong Jun 3, 2024
f5ef747
address pr comments
Anna-Xiong Jun 4, 2024
6a64951
renaming Plot dataclass and change show fig
Anna-Xiong Jun 4, 2024
4c69844
rename arguments
Anna-Xiong Jun 4, 2024
477ad82
fix
Anna-Xiong Jun 4, 2024
e5eaf7f
Merge branch 'main' into feat/get_vizro_ai_customized_output
Anna-Xiong Jun 4, 2024
3152209
address pr comments
Anna-Xiong Jun 4, 2024
6f25fd0
Merge branch 'main' into feat/get_vizro_ai_customized_output
Anna-Xiong Jun 6, 2024
fcc3304
Merge branch 'main' into feat/get_vizro_ai_customized_output
Anna-Xiong Jun 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
<!--
A new scriv changelog fragment.

Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨

- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Removed

- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->

### Added

- Enable feature to get all possible outputs from VizroAI.plot() using by specifying 'return_elements = True' ([#488](https://github.com/mckinsey/vizro/pull/488)). It returns a datalcass that contains code string, figure object, business insights, and code explanation.
Anna-Xiong marked this conversation as resolved.
Show resolved Hide resolved

<!--
### Changed

- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Deprecated

- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Fixed

- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Security

- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
89 changes: 56 additions & 33 deletions vizro-ai/src/vizro_ai/_vizro_ai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from typing import Any, Dict, Optional, Union
from dataclasses import asdict
from typing import Any, Optional, Union

import pandas as pd
import plotly.graph_objects as go
Expand All @@ -10,9 +11,10 @@
from vizro_ai.task_pipeline._pipeline_manager import PipelineManager
from vizro_ai.utils.helper import (
DebugFailure,
Plot,
_debug_helper,
_display_markdown,
_exec_code_and_retrieve_fig,
_exec_fig_code_display_markdown,
_is_jupyter,
)

Expand All @@ -23,7 +25,7 @@ class VizroAI:
"""Vizro-AI main class."""

pipeline_manager: PipelineManager = PipelineManager()
_return_all_text: bool = False
_return_all_text: bool = False # TODO deleted after adding new integration test

def __init__(self, model: Optional[Union[ChatOpenAI, str]] = None):
"""Initialization of VizroAI.
Expand Down Expand Up @@ -57,7 +59,7 @@ def _lazy_get_component(self, component_class: Any) -> Any: # TODO configure co

def _run_plot_tasks(
self, df: pd.DataFrame, user_input: str, max_debug_retry: int = 3, explain: bool = False
) -> Dict[str, Any]:
) -> Plot:
"""Task execution."""
chart_type_pipeline = self.pipeline_manager.chart_type_pipeline
chart_types = chart_type_pipeline.run(initial_args={"chain_input": user_input, "df": df})
Expand All @@ -76,18 +78,29 @@ def _run_plot_tasks(

pass_validation = validated_code_dict.get("debug_status")
code_string = validated_code_dict.get("code_string")
business_insights, code_explanation = None, None

if explain and pass_validation:
if not pass_validation:
raise DebugFailure(
"Chart creation failed. Retry debugging has reached maximum limit. Try to rephrase the prompt, "
"or try to select a different model. Fallout response is provided: \n\n" + code_string
)

fig_object = _exec_code_and_retrieve_fig(
code=code_string, local_args={"df": df}, show_fig=_is_jupyter(), is_notebook_env=_is_jupyter()
)
if explain:
business_insights, code_explanation = self._lazy_get_component(GetCodeExplanation).run(
chain_input=user_input, code_snippet=code_string
)

return {
"business_insights": business_insights,
"code_explanation": code_explanation,
"code_string": code_string,
}
return Plot(
code=code_string,
figure=fig_object,
business_insights=business_insights,
code_explanation=code_explanation,
)

return Plot(code=code_string, figure=fig_object)

def _get_chart_code(self, df: pd.DataFrame, user_input: str) -> str:
"""Get Chart code of vizro via english descriptions, English to chart translation.
Expand All @@ -99,41 +112,51 @@ def _get_chart_code(self, df: pd.DataFrame, user_input: str) -> str:
user_input: User questions or descriptions of the desired visual

"""
# TODO refine and update error handling
return self._run_plot_tasks(df, user_input, explain=False).get("code_string")

def plot(
self, df: pd.DataFrame, user_input: str, explain: bool = False, max_debug_retry: int = 3
) -> Union[go.Figure, Dict[str, Any]]:
# TODO retained for some chat application integration, need deprecation handling
return self._run_plot_tasks(df, user_input, explain=False).code

def plot( # pylint: disable=too-many-arguments # noqa: PLR0913
self,
df: pd.DataFrame,
user_input: str,
explain: bool = False,
max_debug_retry: int = 3,
return_elements: bool = False,
) -> Union[go.Figure, Plot]:
"""Plot visuals using vizro via english descriptions, english to chart translation.

Args:
df: The dataframe to be analyzed.
user_input: User questions or descriptions of the desired visual.
explain: Flag to include explanation in response.
max_debug_retry: Maximum number of retries to debug errors. Defaults to `3`.
return_elements: Flag to return plot dataclass that includes all components.
Anna-Xiong marked this conversation as resolved.
Show resolved Hide resolved

Returns:
Plotly Figure object or a dictionary containing data
go.Figure or Plot dataclass

"""
output_dict = self._run_plot_tasks(df, user_input, explain=explain, max_debug_retry=max_debug_retry)
code_string = output_dict.get("code_string")
business_insights = output_dict.get("business_insights")
code_explanation = output_dict.get("code_explanation")
vizro_plot = self._run_plot_tasks(
df=df, user_input=user_input, explain=explain, max_debug_retry=max_debug_retry
)

if code_string.startswith("Failed to debug code"):
raise DebugFailure(
"Chart creation failed. Retry debugging has reached maximum limit. Try to rephrase the prompt, "
"or try to select a different model. Fallout response is provided: \n\n" + code_string
if not explain:
logger.info(
"Flag explain is set to False. business_insights and code_explanation will not be included in "
"the output dataclass."
)

else:
_display_markdown(
code_snippet=vizro_plot.code,
biz_insights=vizro_plot.business_insights,
code_explain=vizro_plot.code_explanation,
)

# TODO Tentative for integration test
# TODO Tentative for integration test, will be updated/removed for new tests
if self._return_all_text:
output_dict = asdict(vizro_plot)
output_dict["code_string"] = vizro_plot.code
return output_dict
if not explain:
return _exec_code_and_retrieve_fig(code=code_string, local_args={"df": df}, is_notebook_env=_is_jupyter())
if explain:
return _exec_fig_code_display_markdown(
df=df, code_snippet=code_string, biz_insights=business_insights, code_explain=code_explanation
)

return vizro_plot if return_elements else vizro_plot.figure
32 changes: 21 additions & 11 deletions vizro-ai/src/vizro_ai/utils/helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Helper Functions For Vizro AI."""

import traceback
from dataclasses import dataclass, field
from typing import Callable, Dict, Optional

import pandas as pd
Expand All @@ -9,6 +10,16 @@
from .safeguard import _safeguard_check


@dataclass
class Plot:
Anna-Xiong marked this conversation as resolved.
Show resolved Hide resolved
"""Data class about a vizro ai plot."""
Anna-Xiong marked this conversation as resolved.
Show resolved Hide resolved

code: str
figure: go.Figure
business_insights: Optional[str] = field(default=None)
code_explanation: Optional[str] = field(default=None)


# Taken from rich.console. See https://github.com/Textualize/rich.
def _is_jupyter() -> bool: # pragma: no cover
"""Checks if we're running in a Jupyter notebook."""
Expand Down Expand Up @@ -49,22 +60,29 @@ def _debug_helper(


def _exec_code_and_retrieve_fig(
code: str, local_args: Optional[Dict] = None, is_notebook_env: bool = True
code: str, local_args: Optional[Dict] = None, show_fig: bool = False, is_notebook_env: bool = True
) -> go.Figure:
"""Execute code in notebook with correct namespace and return fig object.

Args:
code: code string to be executed
local_args: additional local arguments
show_fig: boolean flag indicating if fig will be rendered automatically
is_notebook_env: boolean flag indicating if code is run in Jupyter notebook

Returns:
go.Figure
Plotly go figure
Anna-Xiong marked this conversation as resolved.
Show resolved Hide resolved

"""
from IPython import get_ipython

if show_fig and "\nfig.show()" not in code:
code += "\nfig.show()"
elif not show_fig:
code = code.replace("fig.show()", "")

namespace = get_ipython().user_ns if is_notebook_env else globals()

if local_args:
namespace.update(local_args)
_safeguard_check(code)
Expand All @@ -75,21 +93,14 @@ def _exec_code_and_retrieve_fig(
return dashboard_ready_fig


def _exec_fig_code_display_markdown(
df: pd.DataFrame, code_snippet: str, biz_insights: str, code_explain: str
) -> go.Figure:
# TODO change default test str to other
def _display_markdown(code_snippet: str, biz_insights: str, code_explain: str) -> None:
Anna-Xiong marked this conversation as resolved.
Show resolved Hide resolved
"""Display chart and Markdown format description in jupyter and returns fig object.

Args:
df: The dataframe to be analyzed.
code_snippet: code string to be executed
biz_insights: business insights to be displayed in markdown cell
code_explain: code explanation to be displayed in markdown cell

Returns:
go.Figure

"""
try:
# pylint: disable=import-outside-toplevel
Expand All @@ -100,7 +111,6 @@ def _exec_fig_code_display_markdown(
markdown_code = f"```\n{code_snippet}\n```"
output_text = f"<h4>Insights:</h4>\n\n{biz_insights}\n<br><br><h4>Code:</h4>\n\n{code_explain}\n{markdown_code}"
display(Markdown(output_text))
return _exec_code_and_retrieve_fig(code_snippet, local_args={"df": df}, is_notebook_env=_is_jupyter())


class DebugFailure(Exception):
Expand Down
Loading