Skip to content

Commit

Permalink
Merge pull request #55 from h2oai/fixes48
Browse files Browse the repository at this point in the history
Fixes #48
  • Loading branch information
pseudotensor committed Apr 20, 2023
2 parents 39a46bc + af5d5ac commit 0915f72
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 29 deletions.
70 changes: 41 additions & 29 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import inspect
import sys
import os
import traceback
import typing

from utils import set_seed, flatten_list
from utils import set_seed, flatten_list, clear_torch_cache

SEED = 1236
set_seed(SEED)

Expand Down Expand Up @@ -176,10 +178,7 @@ def get_response(*args, exi=0):
import matplotlib.pyplot as plt

for exi, ex in enumerate(examples):
if torch.cuda.is_available:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
clear_torch_cache()
print("")
print("START" + "=" * 100)
print("Question: %s %s" % (ex[0], ('input=%s' % ex[1] if ex[1] else '')))
Expand All @@ -204,7 +203,13 @@ def get_response(*args, exi=0):
return_tensors="pt",
truncation=True,
max_length=cutoff_len)
score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
try:
score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
except torch.cuda.OutOfMemoryError as e:
print("GPU OOM: question: %s answer: %s exception: %s" % (prompt, res, str(e)), flush=True)
traceback.print_exc()
score = 0.0
clear_torch_cache()
print("SCORE %s: %s" % (exi, score), flush=True)
score_dump.append(ex + [prompt, res, score])
# dump every score in case abort
Expand Down Expand Up @@ -761,7 +766,14 @@ def score_last_response(*args):
return_tensors="pt",
truncation=True,
max_length=max_length_tokenize).to(smodel.device)
score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
try:
score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
except torch.cuda.OutOfMemoryError as e:
print("GPU OOM: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True)
del inputs
traceback.print_exc()
clear_torch_cache()
return 'Response Score: GPU OOM'
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
return 'Response Score: {:.1%}'.format(score)
else:
Expand Down Expand Up @@ -872,24 +884,24 @@ def bot(*args, retry=False):
if kwargs['auto_score']:
submit_event = instruction.submit(**user_args, queue=stream_output, api_name='instruction').then(
**bot_args, api_name='instruction_bot',
).then(**score_args, api_name='instruction_bot_score')
).then(**score_args, api_name='instruction_bot_score').then(clear_torch_cache)
submit_event2 = submit.click(**user_args, queue=stream_output, api_name='submit').then(
**bot_args, api_name='submit_bot',
).then(**score_args, api_name='submit_bot_score')
).then(**score_args, api_name='submit_bot_score').then(clear_torch_cache)
submit_event3 = retry.click(**user_args, queue=stream_output, api_name='retry').then(
**retry_bot_args, api_name='retry_bot',
).then(**score_args, api_name='retry_bot_score')
).then(**score_args, api_name='retry_bot_score').then(clear_torch_cache)
submit_event4 = undo.click(**undo_user_args, queue=stream_output, api_name='undo').then(**score_args, api_name='undo_score')
else:
submit_event = instruction.submit(**user_args, queue=stream_output, api_name='instruction').then(
**bot_args, api_name='instruction_bot',
)
).then(clear_torch_cache)
submit_event2 = submit.click(**user_args, queue=stream_output, api_name='submit').then(
**bot_args, api_name='submit_bot',
)
).then(clear_torch_cache)
submit_event3 = retry.click(**user_args, queue=stream_output, api_name='retry').then(
**retry_bot_args, api_name='retry_bot',
)
).then(clear_torch_cache)
submit_event4 = undo.click(**undo_user_args, queue=stream_output, api_name='undo')
clear.click(lambda: None, None, text_output, queue=False, api_name='clear')

Expand All @@ -915,11 +927,7 @@ def load_model(model_name, lora_weights, model_state_old, prompt_type_old):
del model_state_old[1]
model_state_old[1] = None

if torch.cuda.is_available:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()

clear_torch_cache()
if kwargs['debug']:
print("Pre-switch post-del GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
all_kwargs['base_model'] = model_name.strip()
Expand All @@ -931,10 +939,7 @@ def load_model(model_name, lora_weights, model_state_old, prompt_type_old):

all_kwargs['lora_weights'] = lora_weights.strip()
model1, tokenizer1, device1 = get_model(**all_kwargs)
if torch.cuda.is_available:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
clear_torch_cache()

if kwargs['debug']:
print("Post-switch GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
Expand All @@ -955,7 +960,7 @@ def chatbot_list(x, model_used_in):
prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type)
chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output)
if not os.environ.get("HUGGINGFACE_SPACES"):
load_model_event = load_model_button.click(**load_model_args).then(**prompt_update_args).then(**chatbot_update_args)
load_model_event = load_model_button.click(**load_model_args).then(**prompt_update_args).then(**chatbot_update_args).then(clear_torch_cache)

def dropdown_model_list(list0, x):
new_state = [list0[0] + [x]]
Expand Down Expand Up @@ -984,10 +989,11 @@ def dropdown_lora_list(list0, x):
flag_btn.click(lambda *args: callback.flag(args), inputs_list + [text_output], None, preprocess=False,
api_name='flag')
if kwargs['chat']:

# don't pass text_output, don't want to clear output, just stop it
# FIXME: have to click once to stop output and second time to stop GPUs going
stop_btn.click(lambda: None, None, None, cancels=[submit_event, submit_event2, submit_event3],
queue=False, api_name='stop')
queue=False, api_name='stop').then(clear_torch_cache)

demo.queue(concurrency_count=1)
favicon_path = "h2o-logo.svg"
Expand Down Expand Up @@ -1082,10 +1088,7 @@ def evaluate(
# try to free-up original tokenizer (i.e. list was passed as reference)
if model_state0 is not None and model_state0[1] is not None:
model_state0[1] = None
if torch.cuda.is_available:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
clear_torch_cache()
model, tokenizer, device, base_model = model_state
elif model_state0 is not None and len(model_state0) == 4 and model_state0[0] is not None:
assert isinstance(model_state[0], str)
Expand Down Expand Up @@ -1242,7 +1245,16 @@ def generate(callback=None, **kwargs):
for stopping_criteria1 in stopping_criteria0:
kwargs['stopping_criteria'].append(stopping_criteria1)

model.generate(**kwargs)
try:
model.generate(**kwargs)
except torch.cuda.OutOfMemoryError as e:
print("GPU OOM: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)), flush=True)
if kwargs['input_ids'] is not None:
kwargs['input_ids'].cpu()
kwargs['input_ids'] = None
traceback.print_exc()
clear_torch_cache()
return

for output in CallbackToGenerator(generate, callback=None, **gen_kwargs):
decoded_output = decoder(output)
Expand Down
8 changes: 8 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import gc
import random
import numpy as np
import torch
Expand Down Expand Up @@ -29,3 +30,10 @@ def flatten_list(lis):
else:
new_lis.append(item)
return new_lis


def clear_torch_cache():
if torch.cuda.is_available:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()

0 comments on commit 0915f72

Please sign in to comment.