From dc32c39573e0d69f36ed00b72fad2cbf1067645a Mon Sep 17 00:00:00 2001 From: sgwhat Date: Thu, 11 Apr 2024 20:05:32 +0800 Subject: [PATCH 1/2] update webui windows performance --- modules/models.py | 4 ---- modules/text_generation.py | 5 ++++- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/modules/models.py b/modules/models.py index 24a52a845a..c15e5ce437 100644 --- a/modules/models.py +++ b/modules/models.py @@ -350,10 +350,6 @@ def bigdl_llm_loader(model_name): use_cache=shared.args.use_cache, ) - if shared.args.device == "GPU": - import intel_extension_for_pytorch - model = model.to("xpu") - tokenizer = AutoTokenizer.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code) return model, tokenizer diff --git a/modules/text_generation.py b/modules/text_generation.py index 49dbe8027b..8f9aa813e1 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -78,7 +78,6 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap is_stream = state['stream'] if len(all_stop_strings) > 0 and not state['stream']: state = copy.deepcopy(state) - state['stream'] = True min_update_interval = 0 if state.get('max_updates_second', 0) > 0: @@ -375,6 +374,10 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(filtered_params) print() + if shared.args.device == "GPU": + import intel_extension_for_pytorch + shared.model = shared.model.to("xpu") + t0 = time.time() try: if not is_chat and not shared.is_seq2seq: From 964238a65dd9153c9ac13f1b40ff2278e29231f8 Mon Sep 17 00:00:00 2001 From: sgwhat Date: Thu, 11 Apr 2024 20:17:33 +0800 Subject: [PATCH 2/2] update --- modules/text_generation.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/modules/text_generation.py b/modules/text_generation.py index 8f9aa813e1..1dc8dcf499 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -387,8 +387,6 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings if not state['stream']: with torch.no_grad(): output = shared.model.generate(**generate_params)[0] - if cuda: - output = output.cuda() starting_from = 0 if shared.is_seq2seq else len(input_ids[0]) yield get_reply_from_output_ids(output, state, starting_from=starting_from)