Skip to content

Commit

Permalink
Fixes #1324
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudotensor committed Feb 15, 2024
1 parent bd5721b commit f4b4eb9
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 20 deletions.
10 changes: 8 additions & 2 deletions docs/FAQ.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
## Frequently asked questions

### nginx and k8 multi-pod support
### Gradio clean-up of states

While Streamlit handles [callbacks to state clean-up)[https://github.com/streamlit/streamlit/issues/6166], Gradio does [not](https://github.com/gradio-app/gradio/issues/4016) without h2oGPT-driven changes. So if you want browser/tab closure to trigger clean-up, `https://h2o-release.s3.amazonaws.com/h2ogpt/gradio-4.19.0-py3-none-any.whl` is required instead of PyPi version. This also helps if have many users using your app and want to ensure databases are cleaned up.

To use, uncomment `https://h2o-release.s3.amazonaws.com/h2ogpt/gradio-4.19.0-py3-none-any.whl` in `requirements.txt`.

Gradio 4.18.0 fails to support nginx or other proxies, so we use 4.17.0 for now. For more information, see: https://github.com/gradio-app/gradio/issues/7391.
This will clean up model states if use UI to load/unload models when not using `--base_model` on CLI like in windows, so don't have to worry about memory leaks when browser tab is closed. It will also clean up Chroma database states.

### nginx and k8 multi-pod support

Gradio 4.x.y fails to support k8 multi-pod use, so for that case please use gradio 3.50.2 and gradio_client 0.6.1 by commenting-in/out relevant lines in `requirements.txt`, `reqs_optional/reqs_constraints.txt`, and comment-out `gradio_pdf` in `reqs_optional/requirements_optional_langchain.txt`. For more information, see: https://github.com/gradio-app/gradio/issues/6920.

Expand Down
2 changes: 1 addition & 1 deletion reqs_optional/reqs_constraints.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# ensure doesn't drift, e.g. https://github.com/h2oai/h2ogpt/issues/1348
torch==2.1.2
gradio==4.18.0
gradio==4.19.0
gradio_client==0.10.0
#gradio==3.50.2
#gradio_client==0.6.1
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# for generate (gradio server) and finetune
datasets==2.16.1
sentencepiece==0.1.99
https://gradio-builds.s3.amazonaws.com/9b8810ff9af4d9a50032752af09cefcf2ef7a7ac/gradio-4.18.0-py3-none-any.whl
# no websockets, more cloud friendly
gradio==4.19.0
# able to make gradio clean-up states
# https://h2o-release.s3.amazonaws.com/h2ogpt/gradio-4.19.0-py3-none-any.whl
#gradio==3.50.2
sse_starlette==1.8.2
huggingface_hub==0.19.4
Expand Down
56 changes: 40 additions & 16 deletions src/gradio_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,19 +804,29 @@ def click_stop():
have_vision_models = kwargs['inference_server'].startswith('http') and is_vision_model(kwargs['base_model'])

with demo:
support_state_callbacks = hasattr(gr.State(), 'callback')

# avoid actual model/tokenizer here or anything that would be bad to deepcopy
# https://github.com/gradio-app/gradio/issues/3558
def model_state_done(state):
if isinstance(state, dict) and 'model' in state and hasattr(state['model'], 'cpu'):
state['model'].cpu()
state['model'] = None
clear_torch_cache()

model_state_cb = dict(callback=model_state_done) if support_state_callbacks else {}
model_state = gr.State(
dict(model='model', tokenizer='tokenizer', device=kwargs['device'],
base_model=kwargs['base_model'],
tokenizer_base_model=kwargs['tokenizer_base_model'],
lora_weights=kwargs['lora_weights'],
inference_server=kwargs['inference_server'],
prompt_type=kwargs['prompt_type'],
prompt_dict=kwargs['prompt_dict'],
visible_models=visible_models_to_model_choice(kwargs['visible_models']),
h2ogpt_key=None, # only apply at runtime when doing API call with gradio inference server
)
value=dict(model='model', tokenizer='tokenizer', device=kwargs['device'],
base_model=kwargs['base_model'],
tokenizer_base_model=kwargs['tokenizer_base_model'],
lora_weights=kwargs['lora_weights'],
inference_server=kwargs['inference_server'],
prompt_type=kwargs['prompt_type'],
prompt_dict=kwargs['prompt_dict'],
visible_models=visible_models_to_model_choice(kwargs['visible_models']),
h2ogpt_key=None, # only apply at runtime when doing API call with gradio inference server
),
**model_state_cb,
)

def update_langchain_mode_paths(selection_docs_state1):
Expand All @@ -831,11 +841,25 @@ def update_langchain_mode_paths(selection_docs_state1):
return selection_docs_state1

# Setup some gradio states for per-user dynamic state
def my_db_state_done(state):
if isinstance(state, dict):
for langchain_mode_db, db_state in state.items():
scratch_data = state[langchain_mode_db]
if langchain_mode_db in langchain_modes_intrinsic:
if len(scratch_data) == length_db1() and hasattr(scratch_data[0], 'delete_collection') and scratch_data[1] == scratch_data[2]:
# scratch if not logged in
scratch_data[0].delete_collection()
# try to free from memory
scratch_data[0] = None
del scratch_data[0]

my_db_state_cb = dict(callback=my_db_state_done) if support_state_callbacks else {}

model_state2 = gr.State(kwargs['model_state_none'].copy())
model_options_state = gr.State([model_options0])
model_options_state = gr.State([model_options0], **model_state_cb)
lora_options_state = gr.State([lora_options])
server_options_state = gr.State([server_options])
my_db_state = gr.State(my_db_state0)
my_db_state = gr.State(my_db_state0, **my_db_state_cb)
chat_state = gr.State({})
if kwargs['enable_tts'] and kwargs['tts_model'].startswith('tts_models/'):
from src.tts_coqui import get_role_to_wave_map
Expand Down Expand Up @@ -1629,8 +1653,8 @@ def show_llava(x):
info="Whether to pass JSON to and get JSON back from LLM",
visible=True)
metadata_in_context = gr.components.Textbox(value='[]',
label="Metadata keys to include in LLM context (all, auto, or [key1, key2, ...] where strings are quoted)",
visible=True)
label="Metadata keys to include in LLM context (all, auto, or [key1, key2, ...] where strings are quoted)",
visible=True)

embed = gr.components.Checkbox(value=True,
label="Embed text",
Expand Down Expand Up @@ -6010,8 +6034,8 @@ def stop_audio_func():
load_func = user_state_setup
load_inputs = [my_db_state, requests_state, login_btn, login_btn]
load_outputs = [my_db_state, requests_state, login_btn]
#auth = None
#load_func, load_inputs, load_outputs = None, None, None
# auth = None
# load_func, load_inputs, load_outputs = None, None, None

app_js = wrap_js_to_lambda(
len(load_inputs) if load_inputs else 0,
Expand Down

0 comments on commit f4b4eb9

Please sign in to comment.