Skip to content

Commit

Permalink
Improvements to State (#2100)
Browse files Browse the repository at this point in the history
* state

* state fix

* variable -> state

* fix

* added state tests

* formatting

* fix test

* formatting

* fix test

* added tests for bakcward compatibility

* formatting

* config fix

* additional doc

* doc fix

* formatting
  • Loading branch information
abidlabs committed Aug 29, 2022
1 parent 1ad5878 commit 4d58ae7
Show file tree
Hide file tree
Showing 20 changed files with 113 additions and 46 deletions.
4 changes: 2 additions & 2 deletions demo/blocks_flashcards/run.py
Expand Up @@ -21,15 +21,15 @@
flip_btn = gr.Button("Flip Card").style(full_width=True)
with gr.Column(visible=False) as answer_col:
back = gr.Textbox(label="Answer")
selected_card = gr.Variable()
selected_card = gr.State()
with gr.Row():
correct_btn = gr.Button(
"Correct",
).style(full_width=True)
incorrect_btn = gr.Button("Incorrect").style(full_width=True)

with gr.TabItem("Results"):
results = gr.Variable(value={})
results = gr.State(value={})
correct_field = gr.Markdown("# Correct: 0")
incorrect_field = gr.Markdown("# Incorrect: 0")
gr.Markdown("Card Statistics: ")
Expand Down
4 changes: 2 additions & 2 deletions demo/blocks_simple_squares/run.py
Expand Up @@ -5,11 +5,11 @@
with demo:
default_json = {"a": "a"}

num = gr.Variable(value=0)
num = gr.State(value=0)
squared = gr.Number(value=0)
btn = gr.Button("Next Square", elem_id="btn").style(rounded=False)

stats = gr.Variable(value=default_json)
stats = gr.State(value=default_json)
table = gr.JSON()

def increase(var, stats_history):
Expand Down
4 changes: 2 additions & 2 deletions demo/components_demos/run.py
Expand Up @@ -42,8 +42,8 @@
with gr.Blocks() as Timeseries_demo:
gr.Timeseries()

with gr.Blocks() as Variable_demo:
gr.Variable()
with gr.Blocks() as State_demo:
gr.State()

with gr.Blocks() as Button_demo:
gr.Button()
Expand Down
2 changes: 1 addition & 1 deletion demo/hangman/run.py
Expand Up @@ -4,7 +4,7 @@
secret_word = "gradio"

with gr.Blocks() as demo:
used_letters_var = gr.Variable([])
used_letters_var = gr.State([])
with gr.Row() as row:
with gr.Column():
input_letter = gr.Textbox(label="Enter letter")
Expand Down
8 changes: 5 additions & 3 deletions demo/kitchen_sink_random/run.py
Expand Up @@ -17,7 +17,7 @@


demo = gr.Interface(
lambda x: x,
lambda *args: args[0],
inputs=[
gr.Textbox(value=lambda: datetime.now(), label="Current Time"),
gr.Number(value=lambda: random.random(), label="Ranom Percentage"),
Expand Down Expand Up @@ -60,7 +60,7 @@
)
),
gr.Timeseries(value=lambda: os.path.join(file_dir, "time.csv")),
gr.Variable(value=lambda: random.choice(string.ascii_lowercase)),
gr.State(value=lambda: random.choice(string.ascii_lowercase)),
gr.Button(value=lambda: random.choice(["Run", "Go", "predict"])),
gr.ColorPicker(value=lambda: random.choice(["#000000", "#ff0000", "#0000FF"])),
gr.Label(value=lambda: random.choice(["Pedestrian", "Car", "Cyclist"])),
Expand Down Expand Up @@ -91,7 +91,9 @@
gr.Plot(value=random_plot),
gr.Markdown(value=lambda: f"### {random.choice(['Hello', 'Hi', 'Goodbye!'])}"),
],
outputs=None,
outputs=[
gr.State(value=lambda: random.choice(string.ascii_lowercase))
],
)

if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion demo/stream_audio/run.py
Expand Up @@ -4,7 +4,7 @@
with gr.Blocks() as demo:
inp = gr.Audio(source="microphone")
out = gr.Audio()
stream = gr.Variable()
stream = gr.State()

def add_to_stream(audio, instream):
if audio is None:
Expand Down
1 change: 1 addition & 0 deletions gradio/__init__.py
Expand Up @@ -35,6 +35,7 @@
Plot,
Radio,
Slider,
State,
StatusTracker,
Textbox,
TimeSeries,
Expand Down
14 changes: 12 additions & 2 deletions gradio/components.py
Expand Up @@ -2541,10 +2541,10 @@ def style(


@document()
class Variable(IOComponent, SimpleSerializable):
class State(IOComponent, SimpleSerializable):
"""
Special hidden component that stores session state across runs of the demo by the
same user. The value of the Variable is cleared when the user refreshes the page.
same user. The value of the State variable is cleared when the user refreshes the page.
Preprocessing: No preprocessing is performed
Postprocessing: No postprocessing is performed
Expand All @@ -2570,6 +2570,16 @@ def style(self):
return self


class Variable(State):
"""Variable was renamed to State. This class is kept for backwards compatibility."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def get_block_name(self):
return "state"


@document("click", "style")
class Button(Clickable, IOComponent, SimpleSerializable):
"""
Expand Down
4 changes: 2 additions & 2 deletions gradio/inputs.py
Expand Up @@ -427,7 +427,7 @@ def __init__(
super().__init__(x=x, y=y, label=label, optional=optional)


class State(components.Variable):
class State(components.State):
"""
Special hidden component that stores state across runs of the interface.
Input type: Any
Expand All @@ -445,7 +445,7 @@ def __init__(
optional (bool): this parameter is ignored.
"""
warnings.warn(
"Usage of gradio.inputs is deprecated, and will not be supported in the future, please import this component as gr.Variable from gradio.components",
"Usage of gradio.inputs is deprecated, and will not be supported in the future, please import this component as gr.State() from gradio.components",
)
super().__init__(value=default, label=label)

Expand Down
45 changes: 28 additions & 17 deletions gradio/interface.py
Expand Up @@ -27,8 +27,8 @@
Interpretation,
IOComponent,
Markdown,
State,
StatusTracker,
Variable,
get_component_instance,
)
from gradio.documentation import document, set_documentation_group
Expand Down Expand Up @@ -213,17 +213,30 @@ def __init__(
else:
self.cache_examples = cache_examples or False

if "state" in inputs or "state" in outputs:
state_input_count = len([i for i in inputs if i == "state"])
state_output_count = len([o for o in outputs if o == "state"])
if state_input_count != 1 or state_output_count != 1:
raise ValueError(
"If using 'state', there must be exactly one state input and one state output."
)
default = utils.get_default_args(fn)[inputs.index("state")]
state_variable = Variable(value=default)
inputs[inputs.index("state")] = state_variable
outputs[outputs.index("state")] = state_variable
state_input_indexes = [
idx for idx, i in enumerate(inputs) if i == "state" or isinstance(i, State)
]
state_output_indexes = [
idx for idx, o in enumerate(outputs) if o == "state" or isinstance(o, State)
]

if len(state_input_indexes) == 0 and len(state_output_indexes) == 0:
pass
elif len(state_input_indexes) != 1 or len(state_output_indexes) != 1:
raise ValueError(
"If using 'state', there must be exactly one state input and one state output."
)
else:
state_input_index = state_input_indexes[0]
state_output_index = state_output_indexes[0]
if inputs[state_input_index] == "state":
default = utils.get_default_args(fn)[state_input_index]
state_variable = State(value=default)
else:
state_variable = inputs[state_input_index]

inputs[state_input_index] = state_variable
outputs[state_output_index] = state_variable

if cache_examples:
warnings.warn(
Expand All @@ -240,9 +253,7 @@ def __init__(
]

for component in self.input_components + self.output_components:
if not (
isinstance(component, IOComponent) or isinstance(component, Variable)
):
if not (isinstance(component, IOComponent)):
raise ValueError(
f"{component} is not a valid input/output component for Interface."
)
Expand Down Expand Up @@ -607,10 +618,10 @@ def __call__(self, *flag_data):

if self.examples:
non_state_inputs = [
c for c in self.input_components if not isinstance(c, Variable)
c for c in self.input_components if not isinstance(c, State)
]
non_state_outputs = [
c for c in self.output_components if not isinstance(c, Variable)
c for c in self.output_components if not isinstance(c, State)
]
self.examples_handler = Examples(
examples=examples,
Expand Down
4 changes: 2 additions & 2 deletions gradio/outputs.py
Expand Up @@ -158,7 +158,7 @@ def __init__(
super().__init__(x=x, y=y, label=label)


class State(components.Variable):
class State(components.State):
"""
Special hidden component that stores state across runs of the interface.
Output type: Any
Expand All @@ -170,7 +170,7 @@ def __init__(self, label: Optional[str] = None):
label (str): component name in interface (not used).
"""
warnings.warn(
"Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
"Usage of gradio.outputs is deprecated, and will not be supported in the future, please import this component as gr.State() from gradio.components",
)
super().__init__(label=label)

Expand Down
2 changes: 1 addition & 1 deletion gradio/utils.py
Expand Up @@ -198,7 +198,7 @@ def launch_counter() -> None:
pass


def get_default_args(func: Callable) -> Dict[str, Any]:
def get_default_args(func: Callable) -> List[Any]:
signature = inspect.signature(func)
return [
v.default if v.default is not inspect.Parameter.empty else None
Expand Down
2 changes: 1 addition & 1 deletion guides/2)building_interfaces/1)interface_state.md
Expand Up @@ -23,4 +23,4 @@ $demo_chatbot_demo

Notice how the state persists across submits within each page, but if you load this demo in another tab (or refresh the page), the demos will not share chat history.

The default value of `state` is None. If you pass a default value to the state parameter of the function, it is used as the default value of the state instead.
The default value of `state` is None. If you pass a default value to the state parameter of the function, it is used as the default value of the state instead. The `Interface` class only supports a single input and outputs state variable, though it can be a list with multiple elements. For more complex use cases, you can use Blocks, [which supports multiple `State` variables](/state_in_blocks/).
10 changes: 5 additions & 5 deletions guides/3)building_with_blocks/3)state_in_blocks.md
Expand Up @@ -8,8 +8,8 @@ Global state in Blocks works the same as in Interface. Any variable created outs

Gradio supports session **state**, where data persists across multiple submits within a page session, in Blocks apps as well. To reiterate, session data is *not* shared between different users of your model. To store data in a session state, you need to do three things:

1. Create a `gr.Variable()` object. If there is a default value to this stateful object, pass that into the constructor.
2. In the event listener, put the `Variable` object as an input and output.
1. Create a `gr.State()` object. If there is a default value to this stateful object, pass that into the constructor.
2. In the event listener, put the `State` object as an input and output.
3. In the event listener function, add the variable to the input parameters and the return value.

Let's take a look at a game of hangman.
Expand All @@ -19,11 +19,11 @@ $demo_hangman

Let's see how we do each of the 3 steps listed above in this game:

1. We store the used letters in `used_letters_var`. In the constructor of `Variable`, we set the initial value of this to `[]`, an empty list.
1. We store the used letters in `used_letters_var`. In the constructor of `State`, we set the initial value of this to `[]`, an empty list.
2. In `btn.click()`, we have a reference to `used_letters_var` in both the inputs and outputs.
3. In `guess_letter`, we pass the value of this `Variable` to `used_letters`, and then return an updated value of this `Variable` in the return statement.
3. In `guess_letter`, we pass the value of this `State` to `used_letters`, and then return an updated value of this `State` in the return statement.

With more complex apps, you will likely have many Variables storing session state in a single Blocks app.
With more complex apps, you will likely have many State variables storing session state in a single Blocks app.



4 changes: 2 additions & 2 deletions test/test_blocks.py
Expand Up @@ -232,7 +232,7 @@ def test_slider_random_value_config():


def test_io_components_attach_load_events_when_value_is_fn(io_components):

io_components = [comp for comp in io_components if not (comp == gr.State)]
interface = gr.Interface(
lambda *args: None,
inputs=[comp(value=lambda: None) for comp in io_components],
Expand All @@ -247,7 +247,7 @@ def test_io_components_attach_load_events_when_value_is_fn(io_components):

def test_blocks_do_not_filter_none_values_from_updates(io_components):

io_components = [c() for c in io_components if c not in [gr.Variable, gr.Button]]
io_components = [c() for c in io_components if c not in [gr.State, gr.Button]]
with gr.Blocks() as demo:
for component in io_components:
component.render()
Expand Down
43 changes: 43 additions & 0 deletions test/test_components.py
Expand Up @@ -1736,5 +1736,48 @@ def test_video_postprocess_converts_to_playable_format(test_file_dir):
assert processing_utils.video_is_playable(str(full_path_to_output))


class TestState:
def test_as_component(self):
state = gr.State(value=5)
assert state.preprocess(10) == 10
assert state.preprocess("abc") == "abc"
assert state.stateful

@pytest.mark.asyncio
async def test_in_interface(self):
def test(x, y=" def"):
return (x + y, x + y)

io = gr.Interface(test, ["text", "state"], ["text", "state"])
result = await io.call_function(0, ["abc"])
assert result[0][0] == "abc def"
result = await io.call_function(0, ["abc", result[0][0]])
assert result[0][0] == "abcabc def"

@pytest.mark.asyncio
async def test_in_blocks(self):
with gr.Blocks() as demo:
score = gr.State()
btn = gr.Button()
btn.click(lambda x: x + 1, score, score)

result = await demo.call_function(0, [0])
assert result[0] == 1
result = await demo.call_function(0, [result[0]])
assert result[0] == 2

@pytest.mark.asyncio
async def test_variable_for_backwards_compatibility(self):
with gr.Blocks() as demo:
score = gr.Variable()
btn = gr.Button()
btn.click(lambda x: x + 1, score, score)

result = await demo.call_function(0, [0])
assert result[0] == 1
result = await demo.call_function(0, [result[0]])
assert result[0] == 2


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions ui/packages/app/src/components/State/index.ts
@@ -0,0 +1,2 @@
export { default as Component } from "./State.svelte";
export const modes = ["static"];
2 changes: 0 additions & 2 deletions ui/packages/app/src/components/Variable/index.ts

This file was deleted.

2 changes: 1 addition & 1 deletion ui/packages/app/src/components/directory.ts
Expand Up @@ -28,11 +28,11 @@ export const component_map = {
radio: () => import("./Radio"),
row: () => import("./Row"),
slider: () => import("./Slider"),
state: () => import("./State"),
statustracker: () => import("./StatusTracker"),
tabs: () => import("./Tabs"),
tabitem: () => import("./TabItem"),
textbox: () => import("./Textbox"),
timeseries: () => import("./TimeSeries"),
variable: () => import("./Variable"),
video: () => import("./Video")
};

0 comments on commit 4d58ae7

Please sign in to comment.