diff --git a/python/mlc_chat/gradio.py b/python/mlc_chat/gradio.py index 91c3bd6d2b..5975a8681d 100644 --- a/python/mlc_chat/gradio.py +++ b/python/mlc_chat/gradio.py @@ -30,6 +30,7 @@ def _parse_args(): is provided, it will be set to 0 by default.", ) args.add_argument("--port", type=int, default=7860, help="The port number to run gradio.") + args.add_argument("--host", type=str, default="127.0.0.1", help="The local host to run gradio.") args.add_argument( "--share", action="store_true", @@ -147,8 +148,7 @@ def gradio_stats(self): def launch_gradio( - artifact_path: str = "dist", device: str = "auto", port: int = 7860, share: bool = False -): + artifact_path: str = "dist", device: str = "auto", port: int = 7860, share: bool = False, host: str = "127.0.0.1"): r"""Launch the gradio interface with a given port, creating a publically sharable link if specified.""" # create a gradio module @@ -230,9 +230,9 @@ def launch_gradio( stats_button.click(mod.gradio_stats, [], [stats_output]) # launch to the web - demo.launch(share=share, enable_queue=True, server_port=port) + demo.launch(share=share, enable_queue=True, server_port=port,server_name=host) if __name__ == "__main__": ARGS = _parse_args() - launch_gradio(ARGS.artifact_path, ARGS.device, ARGS.port, ARGS.share) + launch_gradio(ARGS.artifact_path, ARGS.device, ARGS.port, ARGS.share, ARGS.host)