Skip to content

Commit

Permalink
Fix preprocess for components when type='index' (#5563)
Browse files Browse the repository at this point in the history
* fix preprocess for components when type='index'

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
abidlabs and gradio-pr-bot committed Sep 14, 2023
1 parent 50d9747 commit ba64082
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 13 deletions.
5 changes: 5 additions & 0 deletions .changeset/fluffy-sloths-help.md
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:Fix preprocess for components when type='index'
10 changes: 8 additions & 2 deletions gradio/components/checkboxgroup.py
Expand Up @@ -152,7 +152,9 @@ def update(
"__type__": "update",
}

def preprocess(self, x: list[str | int | float]) -> list[str | int | float]:
def preprocess(
self, x: list[str | int | float]
) -> list[str | int | float] | list[int | None]:
"""
Parameters:
x: list of selected choices
Expand All @@ -162,7 +164,11 @@ def preprocess(self, x: list[str | int | float]) -> list[str | int | float]:
if self.type == "value":
return x
elif self.type == "index":
return [[value for _, value in self.choices].index(choice) for choice in x]
choice_values = [value for _, value in self.choices]
return [
choice_values.index(choice) if choice in choice_values else None
for choice in x
]
else:
raise ValueError(
f"Unknown type: {self.type}. Please choose from: 'value', 'index'."
Expand Down
16 changes: 7 additions & 9 deletions gradio/components/dropdown.py
Expand Up @@ -205,8 +205,8 @@ def update(
}

def preprocess(
self, x: str | list[str]
) -> str | int | list[str] | list[int] | None:
self, x: str | int | float | list[str | int | float] | None
) -> str | int | float | list[str | int | float] | list[int | None] | None:
"""
Parameters:
x: selected choice(s)
Expand All @@ -216,19 +216,17 @@ def preprocess(
if self.type == "value":
return x
elif self.type == "index":
choice_values = [value for _, value in self.choices]
if x is None:
return None
elif self.multiselect:
assert isinstance(x, list)
return [
[value for _, value in self.choices].index(choice) for choice in x
choice_values.index(choice) if choice in choice_values else None
for choice in x
]
else:
if isinstance(x, str):
return (
[value for _, value in self.choices].index(x)
if x in self.choices
else None
)
return choice_values.index(x) if x in choice_values else None
else:
raise ValueError(
f"Unknown type: {self.type}. Please choose from: 'value', 'index'."
Expand Down
3 changes: 2 additions & 1 deletion gradio/components/radio.py
Expand Up @@ -167,7 +167,8 @@ def preprocess(self, x: str | int | float | None) -> str | int | float | None:
if x is None:
return None
else:
return [value for _, value in self.choices].index(x)
choice_values = [value for _, value in self.choices]
return choice_values.index(x) if x in choice_values else None
else:
raise ValueError(
f"Unknown type: {self.type}. Please choose from: 'value', 'index'."
Expand Down
24 changes: 23 additions & 1 deletion test/test_components.py
Expand Up @@ -496,7 +496,13 @@ def test_component_functions(self):
checkboxes_input = gr.CheckboxGroup(["a", "b", "c"])
assert checkboxes_input.preprocess(["a", "c"]) == ["a", "c"]
assert checkboxes_input.postprocess(["a", "c"]) == ["a", "c"]
assert checkboxes_input.serialize(["a", "c"], True) == ["a", "c"]
assert checkboxes_input.serialize(["a", "c"]) == ["a", "c"]

checkboxes_input = gr.CheckboxGroup(["a", "b"], type="index")
assert checkboxes_input.preprocess(["a"]) == [0]
assert checkboxes_input.preprocess(["a", "b"]) == [0, 1]
assert checkboxes_input.preprocess(["a", "b", "c"]) == [0, 1, None]

checkboxes_input = gr.CheckboxGroup(
value=["a", "c"],
choices=["a", "b", "c"],
Expand Down Expand Up @@ -565,6 +571,12 @@ def test_component_functions(self):
"interactive": None,
"root_url": None,
}

radio = gr.Radio(choices=["a", "b"], type="index")
assert radio.preprocess("a") == 0
assert radio.preprocess("b") == 1
assert radio.preprocess("c") is None

with pytest.raises(ValueError):
gr.Radio(["a", "b"], type="unknown")

Expand Down Expand Up @@ -606,6 +618,16 @@ def test_component_functions(self):
assert dropdown_input.preprocess("c full") == "c full"
assert dropdown_input.postprocess("c full") == "c full"

dropdown = gr.Dropdown(choices=["a", "b"], type="index")
assert dropdown.preprocess("a") == 0
assert dropdown.preprocess("b") == 1
assert dropdown.preprocess("c") is None

dropdown = gr.Dropdown(choices=["a", "b"], type="index", multiselect=True)
assert dropdown.preprocess(["a"]) == [0]
assert dropdown.preprocess(["a", "b"]) == [0, 1]
assert dropdown.preprocess(["a", "b", "c"]) == [0, 1, None]

dropdown_input_multiselect = gr.Dropdown(["a", "b", ("c", "c full")])
assert dropdown_input_multiselect.preprocess(["a", "c full"]) == ["a", "c full"]
assert dropdown_input_multiselect.postprocess(["a", "c full"]) == [
Expand Down

0 comments on commit ba64082

Please sign in to comment.