Skip to content

Commit

Permalink
pass in kwargs to Blocks.load() (#2669)
Browse files Browse the repository at this point in the history
* pass in kwargs to Blocks.load()

* added test

* changelog
  • Loading branch information
abidlabs committed Nov 18, 2022
1 parent 36f1951 commit 1c244cc
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
No changes to highlight.

## Bug Fixes:
No changes to highlight.
* Passes kwargs into `gr.Interface.load()` by [@abidlabs](https://github.com/abidlabs) in [PR 2669](https://github.com/gradio-app/gradio/pull/2669)

## Documentation Changes:
No changes to highlight.
Expand Down
2 changes: 1 addition & 1 deletion gradio/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def load(
demo = gr.Interface.load("models/EleutherAI/gpt-neo-1.3B", description=description, examples=examples)
demo.launch()
"""
return super().load(name=name, src=src, api_key=api_key, alias=alias)
return super().load(name=name, src=src, api_key=api_key, alias=alias, **kwargs)

@classmethod
def from_pipeline(cls, pipeline: transformers.Pipeline, **kwargs) -> Interface:
Expand Down
12 changes: 10 additions & 2 deletions test/test_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,26 @@ def test_audio_to_audio(self):
def test_question_answering(self):
model_type = "image-classification"
interface = gr.Blocks.load(
name="lysandre/tiny-vit-random", src="models", alias=model_type
name="lysandre/tiny-vit-random",
src="models",
alias=model_type,
)
assert interface.__name__ == model_type
assert isinstance(interface.input_components[0], gr.components.Image)
assert isinstance(interface.output_components[0], gr.components.Label)

def test_text_generation(self):
model_type = "text_generation"
interface = gr.Interface.load("models/gpt2", alias=model_type)
interface = gr.Interface.load(
"models/gpt2", alias=model_type, description="This is a test description"
)
assert interface.__name__ == model_type
assert isinstance(interface.input_components[0], gr.components.Textbox)
assert isinstance(interface.output_components[0], gr.components.Textbox)
assert any(
"This is a test description" in d["props"].get("value", "")
for d in interface.get_config_file()["components"]
)

def test_summarization(self):
model_type = "summarization"
Expand Down

0 comments on commit 1c244cc

Please sign in to comment.