From 4d58ae79b3d709f49bcf87155c42cbff2721202f Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Mon, 29 Aug 2022 09:53:05 -0700 Subject: [PATCH] Improvements to `State` (#2100) * state * state fix * variable -> state * fix * added state tests * formatting * fix test * formatting * fix test * added tests for bakcward compatibility * formatting * config fix * additional doc * doc fix * formatting --- demo/blocks_flashcards/run.py | 4 +- demo/blocks_simple_squares/run.py | 4 +- demo/components_demos/run.py | 4 +- demo/hangman/run.py | 2 +- demo/kitchen_sink_random/run.py | 8 ++-- demo/stream_audio/run.py | 2 +- gradio/__init__.py | 1 + gradio/components.py | 14 +++++- gradio/inputs.py | 4 +- gradio/interface.py | 45 ++++++++++++------- gradio/outputs.py | 4 +- gradio/utils.py | 2 +- .../1)interface_state.md | 2 +- .../3)state_in_blocks.md | 10 ++--- test/test_blocks.py | 4 +- test/test_components.py | 43 ++++++++++++++++++ .../Variable.svelte => State/State.svelte} | 0 ui/packages/app/src/components/State/index.ts | 2 + .../app/src/components/Variable/index.ts | 2 - ui/packages/app/src/components/directory.ts | 2 +- 20 files changed, 113 insertions(+), 46 deletions(-) rename ui/packages/app/src/components/{Variable/Variable.svelte => State/State.svelte} (100%) create mode 100644 ui/packages/app/src/components/State/index.ts delete mode 100644 ui/packages/app/src/components/Variable/index.ts diff --git a/demo/blocks_flashcards/run.py b/demo/blocks_flashcards/run.py index a378f271962c..4e54281e1797 100644 --- a/demo/blocks_flashcards/run.py +++ b/demo/blocks_flashcards/run.py @@ -21,7 +21,7 @@ flip_btn = gr.Button("Flip Card").style(full_width=True) with gr.Column(visible=False) as answer_col: back = gr.Textbox(label="Answer") - selected_card = gr.Variable() + selected_card = gr.State() with gr.Row(): correct_btn = gr.Button( "Correct", @@ -29,7 +29,7 @@ incorrect_btn = gr.Button("Incorrect").style(full_width=True) with gr.TabItem("Results"): - results = gr.Variable(value={}) + results = gr.State(value={}) correct_field = gr.Markdown("# Correct: 0") incorrect_field = gr.Markdown("# Incorrect: 0") gr.Markdown("Card Statistics: ") diff --git a/demo/blocks_simple_squares/run.py b/demo/blocks_simple_squares/run.py index 33519c743be1..ba335cbb7611 100644 --- a/demo/blocks_simple_squares/run.py +++ b/demo/blocks_simple_squares/run.py @@ -5,11 +5,11 @@ with demo: default_json = {"a": "a"} - num = gr.Variable(value=0) + num = gr.State(value=0) squared = gr.Number(value=0) btn = gr.Button("Next Square", elem_id="btn").style(rounded=False) - stats = gr.Variable(value=default_json) + stats = gr.State(value=default_json) table = gr.JSON() def increase(var, stats_history): diff --git a/demo/components_demos/run.py b/demo/components_demos/run.py index 08fa8866dd4f..65a65a23dce3 100644 --- a/demo/components_demos/run.py +++ b/demo/components_demos/run.py @@ -42,8 +42,8 @@ with gr.Blocks() as Timeseries_demo: gr.Timeseries() -with gr.Blocks() as Variable_demo: - gr.Variable() +with gr.Blocks() as State_demo: + gr.State() with gr.Blocks() as Button_demo: gr.Button() diff --git a/demo/hangman/run.py b/demo/hangman/run.py index ddd638f60bc6..83fea7f578f9 100644 --- a/demo/hangman/run.py +++ b/demo/hangman/run.py @@ -4,7 +4,7 @@ secret_word = "gradio" with gr.Blocks() as demo: - used_letters_var = gr.Variable([]) + used_letters_var = gr.State([]) with gr.Row() as row: with gr.Column(): input_letter = gr.Textbox(label="Enter letter") diff --git a/demo/kitchen_sink_random/run.py b/demo/kitchen_sink_random/run.py index 42718218133c..41c92d8a8db7 100644 --- a/demo/kitchen_sink_random/run.py +++ b/demo/kitchen_sink_random/run.py @@ -17,7 +17,7 @@ demo = gr.Interface( - lambda x: x, + lambda *args: args[0], inputs=[ gr.Textbox(value=lambda: datetime.now(), label="Current Time"), gr.Number(value=lambda: random.random(), label="Ranom Percentage"), @@ -60,7 +60,7 @@ ) ), gr.Timeseries(value=lambda: os.path.join(file_dir, "time.csv")), - gr.Variable(value=lambda: random.choice(string.ascii_lowercase)), + gr.State(value=lambda: random.choice(string.ascii_lowercase)), gr.Button(value=lambda: random.choice(["Run", "Go", "predict"])), gr.ColorPicker(value=lambda: random.choice(["#000000", "#ff0000", "#0000FF"])), gr.Label(value=lambda: random.choice(["Pedestrian", "Car", "Cyclist"])), @@ -91,7 +91,9 @@ gr.Plot(value=random_plot), gr.Markdown(value=lambda: f"### {random.choice(['Hello', 'Hi', 'Goodbye!'])}"), ], - outputs=None, + outputs=[ + gr.State(value=lambda: random.choice(string.ascii_lowercase)) + ], ) if __name__ == "__main__": diff --git a/demo/stream_audio/run.py b/demo/stream_audio/run.py index 8fcd3c2affc1..e0d29f8d8a14 100644 --- a/demo/stream_audio/run.py +++ b/demo/stream_audio/run.py @@ -4,7 +4,7 @@ with gr.Blocks() as demo: inp = gr.Audio(source="microphone") out = gr.Audio() - stream = gr.Variable() + stream = gr.State() def add_to_stream(audio, instream): if audio is None: diff --git a/gradio/__init__.py b/gradio/__init__.py index 9cb437e67373..74ad5190e232 100644 --- a/gradio/__init__.py +++ b/gradio/__init__.py @@ -35,6 +35,7 @@ Plot, Radio, Slider, + State, StatusTracker, Textbox, TimeSeries, diff --git a/gradio/components.py b/gradio/components.py index fa5e6f6c67cf..038856d3c4f7 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -2541,10 +2541,10 @@ def style( @document() -class Variable(IOComponent, SimpleSerializable): +class State(IOComponent, SimpleSerializable): """ Special hidden component that stores session state across runs of the demo by the - same user. The value of the Variable is cleared when the user refreshes the page. + same user. The value of the State variable is cleared when the user refreshes the page. Preprocessing: No preprocessing is performed Postprocessing: No postprocessing is performed @@ -2570,6 +2570,16 @@ def style(self): return self +class Variable(State): + """Variable was renamed to State. This class is kept for backwards compatibility.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def get_block_name(self): + return "state" + + @document("click", "style") class Button(Clickable, IOComponent, SimpleSerializable): """ diff --git a/gradio/inputs.py b/gradio/inputs.py index ac710899148f..eb8b3402808c 100644 --- a/gradio/inputs.py +++ b/gradio/inputs.py @@ -427,7 +427,7 @@ def __init__( super().__init__(x=x, y=y, label=label, optional=optional) -class State(components.Variable): +class State(components.State): """ Special hidden component that stores state across runs of the interface. Input type: Any @@ -445,7 +445,7 @@ def __init__( optional (bool): this parameter is ignored. """ warnings.warn( - "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import this component as gr.Variable from gradio.components", + "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import this component as gr.State() from gradio.components", ) super().__init__(value=default, label=label) diff --git a/gradio/interface.py b/gradio/interface.py index 673d8a99b44b..08d0d75ea00d 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -27,8 +27,8 @@ Interpretation, IOComponent, Markdown, + State, StatusTracker, - Variable, get_component_instance, ) from gradio.documentation import document, set_documentation_group @@ -213,17 +213,30 @@ def __init__( else: self.cache_examples = cache_examples or False - if "state" in inputs or "state" in outputs: - state_input_count = len([i for i in inputs if i == "state"]) - state_output_count = len([o for o in outputs if o == "state"]) - if state_input_count != 1 or state_output_count != 1: - raise ValueError( - "If using 'state', there must be exactly one state input and one state output." - ) - default = utils.get_default_args(fn)[inputs.index("state")] - state_variable = Variable(value=default) - inputs[inputs.index("state")] = state_variable - outputs[outputs.index("state")] = state_variable + state_input_indexes = [ + idx for idx, i in enumerate(inputs) if i == "state" or isinstance(i, State) + ] + state_output_indexes = [ + idx for idx, o in enumerate(outputs) if o == "state" or isinstance(o, State) + ] + + if len(state_input_indexes) == 0 and len(state_output_indexes) == 0: + pass + elif len(state_input_indexes) != 1 or len(state_output_indexes) != 1: + raise ValueError( + "If using 'state', there must be exactly one state input and one state output." + ) + else: + state_input_index = state_input_indexes[0] + state_output_index = state_output_indexes[0] + if inputs[state_input_index] == "state": + default = utils.get_default_args(fn)[state_input_index] + state_variable = State(value=default) + else: + state_variable = inputs[state_input_index] + + inputs[state_input_index] = state_variable + outputs[state_output_index] = state_variable if cache_examples: warnings.warn( @@ -240,9 +253,7 @@ def __init__( ] for component in self.input_components + self.output_components: - if not ( - isinstance(component, IOComponent) or isinstance(component, Variable) - ): + if not (isinstance(component, IOComponent)): raise ValueError( f"{component} is not a valid input/output component for Interface." ) @@ -607,10 +618,10 @@ def __call__(self, *flag_data): if self.examples: non_state_inputs = [ - c for c in self.input_components if not isinstance(c, Variable) + c for c in self.input_components if not isinstance(c, State) ] non_state_outputs = [ - c for c in self.output_components if not isinstance(c, Variable) + c for c in self.output_components if not isinstance(c, State) ] self.examples_handler = Examples( examples=examples, diff --git a/gradio/outputs.py b/gradio/outputs.py index 0bdd0b77a6d7..ef9d4c6ed34a 100644 --- a/gradio/outputs.py +++ b/gradio/outputs.py @@ -158,7 +158,7 @@ def __init__( super().__init__(x=x, y=y, label=label) -class State(components.Variable): +class State(components.State): """ Special hidden component that stores state across runs of the interface. Output type: Any @@ -170,7 +170,7 @@ def __init__(self, label: Optional[str] = None): label (str): component name in interface (not used). """ warnings.warn( - "Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components", + "Usage of gradio.outputs is deprecated, and will not be supported in the future, please import this component as gr.State() from gradio.components", ) super().__init__(label=label) diff --git a/gradio/utils.py b/gradio/utils.py index 56771042c732..b0fc1e6500e8 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -198,7 +198,7 @@ def launch_counter() -> None: pass -def get_default_args(func: Callable) -> Dict[str, Any]: +def get_default_args(func: Callable) -> List[Any]: signature = inspect.signature(func) return [ v.default if v.default is not inspect.Parameter.empty else None diff --git a/guides/2)building_interfaces/1)interface_state.md b/guides/2)building_interfaces/1)interface_state.md index 9b8ecb5cb9e2..d02ad0d649c5 100644 --- a/guides/2)building_interfaces/1)interface_state.md +++ b/guides/2)building_interfaces/1)interface_state.md @@ -23,4 +23,4 @@ $demo_chatbot_demo Notice how the state persists across submits within each page, but if you load this demo in another tab (or refresh the page), the demos will not share chat history. -The default value of `state` is None. If you pass a default value to the state parameter of the function, it is used as the default value of the state instead. \ No newline at end of file +The default value of `state` is None. If you pass a default value to the state parameter of the function, it is used as the default value of the state instead. The `Interface` class only supports a single input and outputs state variable, though it can be a list with multiple elements. For more complex use cases, you can use Blocks, [which supports multiple `State` variables](/state_in_blocks/). \ No newline at end of file diff --git a/guides/3)building_with_blocks/3)state_in_blocks.md b/guides/3)building_with_blocks/3)state_in_blocks.md index c58216c7f914..32342c29fe81 100644 --- a/guides/3)building_with_blocks/3)state_in_blocks.md +++ b/guides/3)building_with_blocks/3)state_in_blocks.md @@ -8,8 +8,8 @@ Global state in Blocks works the same as in Interface. Any variable created outs Gradio supports session **state**, where data persists across multiple submits within a page session, in Blocks apps as well. To reiterate, session data is *not* shared between different users of your model. To store data in a session state, you need to do three things: -1. Create a `gr.Variable()` object. If there is a default value to this stateful object, pass that into the constructor. -2. In the event listener, put the `Variable` object as an input and output. +1. Create a `gr.State()` object. If there is a default value to this stateful object, pass that into the constructor. +2. In the event listener, put the `State` object as an input and output. 3. In the event listener function, add the variable to the input parameters and the return value. Let's take a look at a game of hangman. @@ -19,11 +19,11 @@ $demo_hangman Let's see how we do each of the 3 steps listed above in this game: -1. We store the used letters in `used_letters_var`. In the constructor of `Variable`, we set the initial value of this to `[]`, an empty list. +1. We store the used letters in `used_letters_var`. In the constructor of `State`, we set the initial value of this to `[]`, an empty list. 2. In `btn.click()`, we have a reference to `used_letters_var` in both the inputs and outputs. -3. In `guess_letter`, we pass the value of this `Variable` to `used_letters`, and then return an updated value of this `Variable` in the return statement. +3. In `guess_letter`, we pass the value of this `State` to `used_letters`, and then return an updated value of this `State` in the return statement. -With more complex apps, you will likely have many Variables storing session state in a single Blocks app. +With more complex apps, you will likely have many State variables storing session state in a single Blocks app. diff --git a/test/test_blocks.py b/test/test_blocks.py index 7e8b762a804e..c9aba9449b64 100644 --- a/test/test_blocks.py +++ b/test/test_blocks.py @@ -232,7 +232,7 @@ def test_slider_random_value_config(): def test_io_components_attach_load_events_when_value_is_fn(io_components): - + io_components = [comp for comp in io_components if not (comp == gr.State)] interface = gr.Interface( lambda *args: None, inputs=[comp(value=lambda: None) for comp in io_components], @@ -247,7 +247,7 @@ def test_io_components_attach_load_events_when_value_is_fn(io_components): def test_blocks_do_not_filter_none_values_from_updates(io_components): - io_components = [c() for c in io_components if c not in [gr.Variable, gr.Button]] + io_components = [c() for c in io_components if c not in [gr.State, gr.Button]] with gr.Blocks() as demo: for component in io_components: component.render() diff --git a/test/test_components.py b/test/test_components.py index e1e16fccf04d..ca9383e0311b 100644 --- a/test/test_components.py +++ b/test/test_components.py @@ -1736,5 +1736,48 @@ def test_video_postprocess_converts_to_playable_format(test_file_dir): assert processing_utils.video_is_playable(str(full_path_to_output)) +class TestState: + def test_as_component(self): + state = gr.State(value=5) + assert state.preprocess(10) == 10 + assert state.preprocess("abc") == "abc" + assert state.stateful + + @pytest.mark.asyncio + async def test_in_interface(self): + def test(x, y=" def"): + return (x + y, x + y) + + io = gr.Interface(test, ["text", "state"], ["text", "state"]) + result = await io.call_function(0, ["abc"]) + assert result[0][0] == "abc def" + result = await io.call_function(0, ["abc", result[0][0]]) + assert result[0][0] == "abcabc def" + + @pytest.mark.asyncio + async def test_in_blocks(self): + with gr.Blocks() as demo: + score = gr.State() + btn = gr.Button() + btn.click(lambda x: x + 1, score, score) + + result = await demo.call_function(0, [0]) + assert result[0] == 1 + result = await demo.call_function(0, [result[0]]) + assert result[0] == 2 + + @pytest.mark.asyncio + async def test_variable_for_backwards_compatibility(self): + with gr.Blocks() as demo: + score = gr.Variable() + btn = gr.Button() + btn.click(lambda x: x + 1, score, score) + + result = await demo.call_function(0, [0]) + assert result[0] == 1 + result = await demo.call_function(0, [result[0]]) + assert result[0] == 2 + + if __name__ == "__main__": unittest.main() diff --git a/ui/packages/app/src/components/Variable/Variable.svelte b/ui/packages/app/src/components/State/State.svelte similarity index 100% rename from ui/packages/app/src/components/Variable/Variable.svelte rename to ui/packages/app/src/components/State/State.svelte diff --git a/ui/packages/app/src/components/State/index.ts b/ui/packages/app/src/components/State/index.ts new file mode 100644 index 000000000000..b40b2cd14fcb --- /dev/null +++ b/ui/packages/app/src/components/State/index.ts @@ -0,0 +1,2 @@ +export { default as Component } from "./State.svelte"; +export const modes = ["static"]; diff --git a/ui/packages/app/src/components/Variable/index.ts b/ui/packages/app/src/components/Variable/index.ts deleted file mode 100644 index 259f81f19185..000000000000 --- a/ui/packages/app/src/components/Variable/index.ts +++ /dev/null @@ -1,2 +0,0 @@ -export { default as Component } from "./Variable.svelte"; -export const modes = ["static"]; diff --git a/ui/packages/app/src/components/directory.ts b/ui/packages/app/src/components/directory.ts index 24e7132d592f..21754bf4751f 100644 --- a/ui/packages/app/src/components/directory.ts +++ b/ui/packages/app/src/components/directory.ts @@ -28,11 +28,11 @@ export const component_map = { radio: () => import("./Radio"), row: () => import("./Row"), slider: () => import("./Slider"), + state: () => import("./State"), statustracker: () => import("./StatusTracker"), tabs: () => import("./Tabs"), tabitem: () => import("./TabItem"), textbox: () => import("./Textbox"), timeseries: () => import("./TimeSeries"), - variable: () => import("./Variable"), video: () => import("./Video") };