From af5d5ac55ecd833737852b5c33a1317800092583 Mon Sep 17 00:00:00 2001 From: "Jonathan C. McKinney" Date: Wed, 19 Apr 2023 17:57:41 -0700 Subject: [PATCH] Fixes #48 --- generate.py | 70 +++++++++++++++++++++++++++++++---------------------- utils.py | 8 ++++++ 2 files changed, 49 insertions(+), 29 deletions(-) diff --git a/generate.py b/generate.py index 070939b8f..dbf67d3f1 100644 --- a/generate.py +++ b/generate.py @@ -3,9 +3,11 @@ import inspect import sys import os +import traceback import typing -from utils import set_seed, flatten_list +from utils import set_seed, flatten_list, clear_torch_cache + SEED = 1236 set_seed(SEED) @@ -176,10 +178,7 @@ def get_response(*args, exi=0): import matplotlib.pyplot as plt for exi, ex in enumerate(examples): - if torch.cuda.is_available: - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - gc.collect() + clear_torch_cache() print("") print("START" + "=" * 100) print("Question: %s %s" % (ex[0], ('input=%s' % ex[1] if ex[1] else ''))) @@ -204,7 +203,13 @@ def get_response(*args, exi=0): return_tensors="pt", truncation=True, max_length=cutoff_len) - score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0] + try: + score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0] + except torch.cuda.OutOfMemoryError as e: + print("GPU OOM: question: %s answer: %s exception: %s" % (prompt, res, str(e)), flush=True) + traceback.print_exc() + score = 0.0 + clear_torch_cache() print("SCORE %s: %s" % (exi, score), flush=True) score_dump.append(ex + [prompt, res, score]) # dump every score in case abort @@ -761,7 +766,14 @@ def score_last_response(*args): return_tensors="pt", truncation=True, max_length=max_length_tokenize).to(smodel.device) - score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0] + try: + score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0] + except torch.cuda.OutOfMemoryError as e: + print("GPU OOM: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True) + del inputs + traceback.print_exc() + clear_torch_cache() + return 'Response Score: GPU OOM' os.environ['TOKENIZERS_PARALLELISM'] = 'true' return 'Response Score: {:.1%}'.format(score) else: @@ -872,24 +884,24 @@ def bot(*args, retry=False): if kwargs['auto_score']: submit_event = instruction.submit(**user_args, queue=stream_output, api_name='instruction').then( **bot_args, api_name='instruction_bot', - ).then(**score_args, api_name='instruction_bot_score') + ).then(**score_args, api_name='instruction_bot_score').then(clear_torch_cache) submit_event2 = submit.click(**user_args, queue=stream_output, api_name='submit').then( **bot_args, api_name='submit_bot', - ).then(**score_args, api_name='submit_bot_score') + ).then(**score_args, api_name='submit_bot_score').then(clear_torch_cache) submit_event3 = retry.click(**user_args, queue=stream_output, api_name='retry').then( **retry_bot_args, api_name='retry_bot', - ).then(**score_args, api_name='retry_bot_score') + ).then(**score_args, api_name='retry_bot_score').then(clear_torch_cache) submit_event4 = undo.click(**undo_user_args, queue=stream_output, api_name='undo').then(**score_args, api_name='undo_score') else: submit_event = instruction.submit(**user_args, queue=stream_output, api_name='instruction').then( **bot_args, api_name='instruction_bot', - ) + ).then(clear_torch_cache) submit_event2 = submit.click(**user_args, queue=stream_output, api_name='submit').then( **bot_args, api_name='submit_bot', - ) + ).then(clear_torch_cache) submit_event3 = retry.click(**user_args, queue=stream_output, api_name='retry').then( **retry_bot_args, api_name='retry_bot', - ) + ).then(clear_torch_cache) submit_event4 = undo.click(**undo_user_args, queue=stream_output, api_name='undo') clear.click(lambda: None, None, text_output, queue=False, api_name='clear') @@ -915,11 +927,7 @@ def load_model(model_name, lora_weights, model_state_old, prompt_type_old): del model_state_old[1] model_state_old[1] = None - if torch.cuda.is_available: - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - gc.collect() - + clear_torch_cache() if kwargs['debug']: print("Pre-switch post-del GPU memory: %s" % torch.cuda.memory_allocated(), flush=True) all_kwargs['base_model'] = model_name.strip() @@ -931,10 +939,7 @@ def load_model(model_name, lora_weights, model_state_old, prompt_type_old): all_kwargs['lora_weights'] = lora_weights.strip() model1, tokenizer1, device1 = get_model(**all_kwargs) - if torch.cuda.is_available: - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - gc.collect() + clear_torch_cache() if kwargs['debug']: print("Post-switch GPU memory: %s" % torch.cuda.memory_allocated(), flush=True) @@ -955,7 +960,7 @@ def chatbot_list(x, model_used_in): prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type) chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output) if not os.environ.get("HUGGINGFACE_SPACES"): - load_model_event = load_model_button.click(**load_model_args).then(**prompt_update_args).then(**chatbot_update_args) + load_model_event = load_model_button.click(**load_model_args).then(**prompt_update_args).then(**chatbot_update_args).then(clear_torch_cache) def dropdown_model_list(list0, x): new_state = [list0[0] + [x]] @@ -984,10 +989,11 @@ def dropdown_lora_list(list0, x): flag_btn.click(lambda *args: callback.flag(args), inputs_list + [text_output], None, preprocess=False, api_name='flag') if kwargs['chat']: + # don't pass text_output, don't want to clear output, just stop it # FIXME: have to click once to stop output and second time to stop GPUs going stop_btn.click(lambda: None, None, None, cancels=[submit_event, submit_event2, submit_event3], - queue=False, api_name='stop') + queue=False, api_name='stop').then(clear_torch_cache) demo.queue(concurrency_count=1) favicon_path = "h2o-logo.svg" @@ -1082,10 +1088,7 @@ def evaluate( # try to free-up original tokenizer (i.e. list was passed as reference) if model_state0 is not None and model_state0[1] is not None: model_state0[1] = None - if torch.cuda.is_available: - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - gc.collect() + clear_torch_cache() model, tokenizer, device, base_model = model_state elif model_state0 is not None and len(model_state0) == 4 and model_state0[0] is not None: assert isinstance(model_state[0], str) @@ -1242,7 +1245,16 @@ def generate(callback=None, **kwargs): for stopping_criteria1 in stopping_criteria0: kwargs['stopping_criteria'].append(stopping_criteria1) - model.generate(**kwargs) + try: + model.generate(**kwargs) + except torch.cuda.OutOfMemoryError as e: + print("GPU OOM: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)), flush=True) + if kwargs['input_ids'] is not None: + kwargs['input_ids'].cpu() + kwargs['input_ids'] = None + traceback.print_exc() + clear_torch_cache() + return for output in CallbackToGenerator(generate, callback=None, **gen_kwargs): decoded_output = decoder(output) diff --git a/utils.py b/utils.py index d26304a15..04d4498e8 100644 --- a/utils.py +++ b/utils.py @@ -1,4 +1,5 @@ import os +import gc import random import numpy as np import torch @@ -29,3 +30,10 @@ def flatten_list(lis): else: new_lis.append(item) return new_lis + + +def clear_torch_cache(): + if torch.cuda.is_available: + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + gc.collect()