Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #48 #55

Merged
merged 1 commit into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 41 additions & 29 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import inspect
import sys
import os
import traceback
import typing

from utils import set_seed, flatten_list
from utils import set_seed, flatten_list, clear_torch_cache

SEED = 1236
set_seed(SEED)

Expand Down Expand Up @@ -176,10 +178,7 @@ def get_response(*args, exi=0):
import matplotlib.pyplot as plt

for exi, ex in enumerate(examples):
if torch.cuda.is_available:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
clear_torch_cache()
print("")
print("START" + "=" * 100)
print("Question: %s %s" % (ex[0], ('input=%s' % ex[1] if ex[1] else '')))
Expand All @@ -204,7 +203,13 @@ def get_response(*args, exi=0):
return_tensors="pt",
truncation=True,
max_length=cutoff_len)
score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
try:
score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
except torch.cuda.OutOfMemoryError as e:
print("GPU OOM: question: %s answer: %s exception: %s" % (prompt, res, str(e)), flush=True)
traceback.print_exc()
score = 0.0
clear_torch_cache()
print("SCORE %s: %s" % (exi, score), flush=True)
score_dump.append(ex + [prompt, res, score])
# dump every score in case abort
Expand Down Expand Up @@ -761,7 +766,14 @@ def score_last_response(*args):
return_tensors="pt",
truncation=True,
max_length=max_length_tokenize).to(smodel.device)
score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
try:
score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
except torch.cuda.OutOfMemoryError as e:
print("GPU OOM: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True)
del inputs
traceback.print_exc()
clear_torch_cache()
return 'Response Score: GPU OOM'
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
return 'Response Score: {:.1%}'.format(score)
else:
Expand Down Expand Up @@ -872,24 +884,24 @@ def bot(*args, retry=False):
if kwargs['auto_score']:
submit_event = instruction.submit(**user_args, queue=stream_output, api_name='instruction').then(
**bot_args, api_name='instruction_bot',
).then(**score_args, api_name='instruction_bot_score')
).then(**score_args, api_name='instruction_bot_score').then(clear_torch_cache)
submit_event2 = submit.click(**user_args, queue=stream_output, api_name='submit').then(
**bot_args, api_name='submit_bot',
).then(**score_args, api_name='submit_bot_score')
).then(**score_args, api_name='submit_bot_score').then(clear_torch_cache)
submit_event3 = retry.click(**user_args, queue=stream_output, api_name='retry').then(
**retry_bot_args, api_name='retry_bot',
).then(**score_args, api_name='retry_bot_score')
).then(**score_args, api_name='retry_bot_score').then(clear_torch_cache)
submit_event4 = undo.click(**undo_user_args, queue=stream_output, api_name='undo').then(**score_args, api_name='undo_score')
else:
submit_event = instruction.submit(**user_args, queue=stream_output, api_name='instruction').then(
**bot_args, api_name='instruction_bot',
)
).then(clear_torch_cache)
submit_event2 = submit.click(**user_args, queue=stream_output, api_name='submit').then(
**bot_args, api_name='submit_bot',
)
).then(clear_torch_cache)
submit_event3 = retry.click(**user_args, queue=stream_output, api_name='retry').then(
**retry_bot_args, api_name='retry_bot',
)
).then(clear_torch_cache)
submit_event4 = undo.click(**undo_user_args, queue=stream_output, api_name='undo')
clear.click(lambda: None, None, text_output, queue=False, api_name='clear')

Expand All @@ -915,11 +927,7 @@ def load_model(model_name, lora_weights, model_state_old, prompt_type_old):
del model_state_old[1]
model_state_old[1] = None

if torch.cuda.is_available:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()

clear_torch_cache()
if kwargs['debug']:
print("Pre-switch post-del GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
all_kwargs['base_model'] = model_name.strip()
Expand All @@ -931,10 +939,7 @@ def load_model(model_name, lora_weights, model_state_old, prompt_type_old):

all_kwargs['lora_weights'] = lora_weights.strip()
model1, tokenizer1, device1 = get_model(**all_kwargs)
if torch.cuda.is_available:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
clear_torch_cache()

if kwargs['debug']:
print("Post-switch GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
Expand All @@ -955,7 +960,7 @@ def chatbot_list(x, model_used_in):
prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type)
chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output)
if not os.environ.get("HUGGINGFACE_SPACES"):
load_model_event = load_model_button.click(**load_model_args).then(**prompt_update_args).then(**chatbot_update_args)
load_model_event = load_model_button.click(**load_model_args).then(**prompt_update_args).then(**chatbot_update_args).then(clear_torch_cache)

def dropdown_model_list(list0, x):
new_state = [list0[0] + [x]]
Expand Down Expand Up @@ -984,10 +989,11 @@ def dropdown_lora_list(list0, x):
flag_btn.click(lambda *args: callback.flag(args), inputs_list + [text_output], None, preprocess=False,
api_name='flag')
if kwargs['chat']:

# don't pass text_output, don't want to clear output, just stop it
# FIXME: have to click once to stop output and second time to stop GPUs going
stop_btn.click(lambda: None, None, None, cancels=[submit_event, submit_event2, submit_event3],
queue=False, api_name='stop')
queue=False, api_name='stop').then(clear_torch_cache)

demo.queue(concurrency_count=1)
favicon_path = "h2o-logo.svg"
Expand Down Expand Up @@ -1082,10 +1088,7 @@ def evaluate(
# try to free-up original tokenizer (i.e. list was passed as reference)
if model_state0 is not None and model_state0[1] is not None:
model_state0[1] = None
if torch.cuda.is_available:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
clear_torch_cache()
model, tokenizer, device, base_model = model_state
elif model_state0 is not None and len(model_state0) == 4 and model_state0[0] is not None:
assert isinstance(model_state[0], str)
Expand Down Expand Up @@ -1242,7 +1245,16 @@ def generate(callback=None, **kwargs):
for stopping_criteria1 in stopping_criteria0:
kwargs['stopping_criteria'].append(stopping_criteria1)

model.generate(**kwargs)
try:
model.generate(**kwargs)
except torch.cuda.OutOfMemoryError as e:
print("GPU OOM: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)), flush=True)
if kwargs['input_ids'] is not None:
kwargs['input_ids'].cpu()
kwargs['input_ids'] = None
traceback.print_exc()
clear_torch_cache()
return

for output in CallbackToGenerator(generate, callback=None, **gen_kwargs):
decoded_output = decoder(output)
Expand Down
8 changes: 8 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import gc
import random
import numpy as np
import torch
Expand Down Expand Up @@ -29,3 +30,10 @@ def flatten_list(lis):
else:
new_lis.append(item)
return new_lis


def clear_torch_cache():
if torch.cuda.is_available:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()