From f4b4eb9ccfd604f124f25e97123bcdf8a48f04bd Mon Sep 17 00:00:00 2001 From: "Jonathan C. McKinney" Date: Wed, 14 Feb 2024 17:10:30 -0800 Subject: [PATCH] Fixes #1324 --- docs/FAQ.md | 10 ++++-- reqs_optional/reqs_constraints.txt | 2 +- requirements.txt | 5 ++- src/gradio_runner.py | 56 +++++++++++++++++++++--------- 4 files changed, 53 insertions(+), 20 deletions(-) diff --git a/docs/FAQ.md b/docs/FAQ.md index 7f3bae35f..5c8d18ab4 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -1,8 +1,14 @@ ## Frequently asked questions -### nginx and k8 multi-pod support +### Gradio clean-up of states + +While Streamlit handles [callbacks to state clean-up)[https://github.com/streamlit/streamlit/issues/6166], Gradio does [not](https://github.com/gradio-app/gradio/issues/4016) without h2oGPT-driven changes. So if you want browser/tab closure to trigger clean-up, `https://h2o-release.s3.amazonaws.com/h2ogpt/gradio-4.19.0-py3-none-any.whl` is required instead of PyPi version. This also helps if have many users using your app and want to ensure databases are cleaned up. + +To use, uncomment `https://h2o-release.s3.amazonaws.com/h2ogpt/gradio-4.19.0-py3-none-any.whl` in `requirements.txt`. -Gradio 4.18.0 fails to support nginx or other proxies, so we use 4.17.0 for now. For more information, see: https://github.com/gradio-app/gradio/issues/7391. +This will clean up model states if use UI to load/unload models when not using `--base_model` on CLI like in windows, so don't have to worry about memory leaks when browser tab is closed. It will also clean up Chroma database states. + +### nginx and k8 multi-pod support Gradio 4.x.y fails to support k8 multi-pod use, so for that case please use gradio 3.50.2 and gradio_client 0.6.1 by commenting-in/out relevant lines in `requirements.txt`, `reqs_optional/reqs_constraints.txt`, and comment-out `gradio_pdf` in `reqs_optional/requirements_optional_langchain.txt`. For more information, see: https://github.com/gradio-app/gradio/issues/6920. diff --git a/reqs_optional/reqs_constraints.txt b/reqs_optional/reqs_constraints.txt index df0ee0692..551caded2 100644 --- a/reqs_optional/reqs_constraints.txt +++ b/reqs_optional/reqs_constraints.txt @@ -1,6 +1,6 @@ # ensure doesn't drift, e.g. https://github.com/h2oai/h2ogpt/issues/1348 torch==2.1.2 -gradio==4.18.0 +gradio==4.19.0 gradio_client==0.10.0 #gradio==3.50.2 #gradio_client==0.6.1 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 6011009fb..75a39b22a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,10 @@ # for generate (gradio server) and finetune datasets==2.16.1 sentencepiece==0.1.99 -https://gradio-builds.s3.amazonaws.com/9b8810ff9af4d9a50032752af09cefcf2ef7a7ac/gradio-4.18.0-py3-none-any.whl +# no websockets, more cloud friendly +gradio==4.19.0 +# able to make gradio clean-up states +# https://h2o-release.s3.amazonaws.com/h2ogpt/gradio-4.19.0-py3-none-any.whl #gradio==3.50.2 sse_starlette==1.8.2 huggingface_hub==0.19.4 diff --git a/src/gradio_runner.py b/src/gradio_runner.py index 1de945306..e9fc9e864 100644 --- a/src/gradio_runner.py +++ b/src/gradio_runner.py @@ -804,19 +804,29 @@ def click_stop(): have_vision_models = kwargs['inference_server'].startswith('http') and is_vision_model(kwargs['base_model']) with demo: + support_state_callbacks = hasattr(gr.State(), 'callback') + # avoid actual model/tokenizer here or anything that would be bad to deepcopy # https://github.com/gradio-app/gradio/issues/3558 + def model_state_done(state): + if isinstance(state, dict) and 'model' in state and hasattr(state['model'], 'cpu'): + state['model'].cpu() + state['model'] = None + clear_torch_cache() + + model_state_cb = dict(callback=model_state_done) if support_state_callbacks else {} model_state = gr.State( - dict(model='model', tokenizer='tokenizer', device=kwargs['device'], - base_model=kwargs['base_model'], - tokenizer_base_model=kwargs['tokenizer_base_model'], - lora_weights=kwargs['lora_weights'], - inference_server=kwargs['inference_server'], - prompt_type=kwargs['prompt_type'], - prompt_dict=kwargs['prompt_dict'], - visible_models=visible_models_to_model_choice(kwargs['visible_models']), - h2ogpt_key=None, # only apply at runtime when doing API call with gradio inference server - ) + value=dict(model='model', tokenizer='tokenizer', device=kwargs['device'], + base_model=kwargs['base_model'], + tokenizer_base_model=kwargs['tokenizer_base_model'], + lora_weights=kwargs['lora_weights'], + inference_server=kwargs['inference_server'], + prompt_type=kwargs['prompt_type'], + prompt_dict=kwargs['prompt_dict'], + visible_models=visible_models_to_model_choice(kwargs['visible_models']), + h2ogpt_key=None, # only apply at runtime when doing API call with gradio inference server + ), + **model_state_cb, ) def update_langchain_mode_paths(selection_docs_state1): @@ -831,11 +841,25 @@ def update_langchain_mode_paths(selection_docs_state1): return selection_docs_state1 # Setup some gradio states for per-user dynamic state + def my_db_state_done(state): + if isinstance(state, dict): + for langchain_mode_db, db_state in state.items(): + scratch_data = state[langchain_mode_db] + if langchain_mode_db in langchain_modes_intrinsic: + if len(scratch_data) == length_db1() and hasattr(scratch_data[0], 'delete_collection') and scratch_data[1] == scratch_data[2]: + # scratch if not logged in + scratch_data[0].delete_collection() + # try to free from memory + scratch_data[0] = None + del scratch_data[0] + + my_db_state_cb = dict(callback=my_db_state_done) if support_state_callbacks else {} + model_state2 = gr.State(kwargs['model_state_none'].copy()) - model_options_state = gr.State([model_options0]) + model_options_state = gr.State([model_options0], **model_state_cb) lora_options_state = gr.State([lora_options]) server_options_state = gr.State([server_options]) - my_db_state = gr.State(my_db_state0) + my_db_state = gr.State(my_db_state0, **my_db_state_cb) chat_state = gr.State({}) if kwargs['enable_tts'] and kwargs['tts_model'].startswith('tts_models/'): from src.tts_coqui import get_role_to_wave_map @@ -1629,8 +1653,8 @@ def show_llava(x): info="Whether to pass JSON to and get JSON back from LLM", visible=True) metadata_in_context = gr.components.Textbox(value='[]', - label="Metadata keys to include in LLM context (all, auto, or [key1, key2, ...] where strings are quoted)", - visible=True) + label="Metadata keys to include in LLM context (all, auto, or [key1, key2, ...] where strings are quoted)", + visible=True) embed = gr.components.Checkbox(value=True, label="Embed text", @@ -6010,8 +6034,8 @@ def stop_audio_func(): load_func = user_state_setup load_inputs = [my_db_state, requests_state, login_btn, login_btn] load_outputs = [my_db_state, requests_state, login_btn] - #auth = None - #load_func, load_inputs, load_outputs = None, None, None + # auth = None + # load_func, load_inputs, load_outputs = None, None, None app_js = wrap_js_to_lambda( len(load_inputs) if load_inputs else 0,