Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow gr.Templates to take in arguments #2600

Merged
merged 16 commits into from
Nov 4, 2022
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ No changes to highlight.

## Full Changelog:
* Add `api_name` to `Blocks.__call__` by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2593](https://github.com/gradio-app/gradio/pull/2593)
* Allow `gr.Templates` to accept parameters to override the defaults by [@abidlabs](https://github.com/abidlabs) in [PR 2600](https://github.com/gradio-app/gradio/pull/2600)


## Contributors Shoutout:
No changes to highlight.
Expand Down
2 changes: 1 addition & 1 deletion gradio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Dropdown,
File,
Gallery,
Highlight,
Highlightedtext,
HighlightedText,
Image,
Expand Down Expand Up @@ -60,7 +61,6 @@
from gradio.routes import mount_gradio_app
from gradio.templates import (
Files,
Highlight,
ImageMask,
ImagePaint,
List,
Expand Down
1 change: 1 addition & 0 deletions gradio/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -3982,6 +3982,7 @@ def get_component_instance(comp: str | dict | Component, render=True) -> Compone

DataFrame = Dataframe
Highlightedtext = HighlightedText
Highlight = HighlightedText
Checkboxgroup = CheckboxGroup
TimeSeries = Timeseries
Json = JSON
88 changes: 50 additions & 38 deletions gradio/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ class Text(components.Textbox):
is_template = True

def __init__(self, **kwargs):
super().__init__(lines=1, **kwargs)
defaults = dict(lines=1)
defaults.update(kwargs)
super().__init__(**defaults)


class TextArea(components.Textbox):
Expand All @@ -20,73 +22,80 @@ class TextArea(components.Textbox):
is_template = True

def __init__(self, **kwargs):
super().__init__(lines=7, **kwargs)
defaults = dict(lines=7)
defaults.update(kwargs)
super().__init__(**defaults)


class Webcam(components.Image):
"""
Sets: source="webcam"
Sets: source="webcam", interactive=True
"""

is_template = True

def __init__(self, **kwargs):
super().__init__(source="webcam", interactive=True, **kwargs)
defaults = dict(source="webcam", interactive=True)
defaults.update(kwargs)
super().__init__(**defaults)


class Sketchpad(components.Image):
"""
Sets: image_mode="L", source="canvas", shape=(28, 28), invert_colors=True
Sets: image_mode="L", source="canvas", shape=(28, 28), invert_colors=True, interactive=True
"""

is_template = True

def __init__(self, **kwargs):
super().__init__(
defaults = dict(
image_mode="L",
source="canvas",
shape=(28, 28),
invert_colors=True,
interactive=True,
**kwargs
)
defaults.update(kwargs)
super().__init__(**defaults)


class Paint(components.Image):
"""
Sets: source="canvas", tool="color-sketch"
Sets: source="canvas", tool="color-sketch", interactive=True
"""

is_template = True

def __init__(self, **kwargs):
super().__init__(
source="canvas", tool="color-sketch", interactive=True, **kwargs
)
defaults = dict(source="canvas", tool="color-sketch", interactive=True)
defaults.update(kwargs)
super().__init__(**defaults)


class ImageMask(components.Image):
"""
Sets: source="canvas", tool="sketch"
Sets: source="upload", tool="sketch", interactive=True
"""

is_template = True

def __init__(self, **kwargs):
super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)
defaults = dict(source="upload", tool="sketch", interactive=True)
defaults.update(kwargs)
super().__init__(**defaults)


class ImagePaint(components.Image):
"""
Sets: source="upload", tool="color-sketch"
Sets: source="upload", tool="color-sketch", interactive=True
"""

is_template = True

def __init__(self, **kwargs):
super().__init__(
source="upload", tool="color-sketch", interactive=True, **kwargs
)
defaults = dict(source="upload", tool="color-sketch", interactive=True)
defaults.update(kwargs)
super().__init__(**defaults)


class Pil(components.Image):
Expand All @@ -97,7 +106,9 @@ class Pil(components.Image):
is_template = True

def __init__(self, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be better if we specified the kwargs the class supports as opposed to accepting arbitrary kwargs?

Right now you could do gr.Pil(type="numpy") which is confusing/defeats the purpose of the template. Also means devs won't get hints from their editors and makes it easy to make a typo and only be aware at runtime.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah you're right. I think it's okay to override the template params but you're right it would be good to specify them explicitly

super().__init__(type="pil", **kwargs)
defaults = dict(type="pil")
defaults.update(kwargs)
super().__init__(**defaults)


class PlayableVideo(components.Video):
Expand All @@ -108,7 +119,9 @@ class PlayableVideo(components.Video):
is_template = True

def __init__(self, **kwargs):
super().__init__(format="mp4", **kwargs)
defaults = dict(format="mp4")
defaults.update(kwargs)
super().__init__(**defaults)


class Microphone(components.Audio):
Expand All @@ -119,7 +132,9 @@ class Microphone(components.Audio):
is_template = True

def __init__(self, **kwargs):
super().__init__(source="microphone", **kwargs)
defaults = dict(source="microphone")
defaults.update(kwargs)
super().__init__(**defaults)


class Mic(components.Audio):
Expand All @@ -130,7 +145,9 @@ class Mic(components.Audio):
is_template = True

def __init__(self, **kwargs):
super().__init__(source="microphone", **kwargs)
defaults = dict(source="microphone")
defaults.update(kwargs)
super().__init__(**defaults)


class Files(components.File):
Expand All @@ -141,7 +158,9 @@ class Files(components.File):
is_template = True

def __init__(self, **kwargs):
super().__init__(file_count="multiple", **kwargs)
defaults = dict(file_count="multiple")
defaults.update(kwargs)
super().__init__(**defaults)


class Numpy(components.Dataframe):
Expand All @@ -152,7 +171,9 @@ class Numpy(components.Dataframe):
is_template = True

def __init__(self, **kwargs):
super().__init__(type="numpy", **kwargs)
defaults = dict(type="numpy")
defaults.update(kwargs)
super().__init__(**defaults)


class Matrix(components.Dataframe):
Expand All @@ -165,9 +186,10 @@ class Matrix(components.Dataframe):
def __init__(self, **kwargs):
"""
Custom component
@param kwargs:
"""
super().__init__(type="array", **kwargs)
defaults = dict(type="array")
defaults.update(kwargs)
super().__init__(**defaults)


class List(components.Dataframe):
Expand All @@ -180,17 +202,7 @@ class List(components.Dataframe):
def __init__(self, **kwargs):
"""
Custom component
@param kwargs:
"""
super().__init__(type="array", col_count=1, **kwargs)


class Highlight(components.HighlightedText):
is_template = True

def __init__(self, **kwargs):
"""
Custom component
@param kwargs:
"""
super().__init__(**kwargs)
defaults = dict(type="array", col_count=1)
defaults.update(kwargs)
super().__init__(**defaults)
11 changes: 11 additions & 0 deletions test/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,17 @@ def test_static(self):
component = gr.Textbox("abc")
self.assertEqual(component.get_config().get("value"), "abc")

def test_override_template(self):
"""
override template
"""
component = gr.TextArea(value="abc")
self.assertEqual(component.get_config().get("value"), "abc")
self.assertEqual(component.get_config().get("lines"), 7)
component = gr.TextArea(value="abc", lines=4)
self.assertEqual(component.get_config().get("value"), "abc")
self.assertEqual(component.get_config().get("lines"), 4)


class TestNumber(unittest.TestCase):
def test_component_functions(self):
Expand Down