Skip to content

Commit

Permalink
Merge pull request #36 from sgwhat/win-perf-opt
Browse files Browse the repository at this point in the history
Windows performance optimization
  • Loading branch information
sgwhat committed Apr 11, 2024
2 parents af95b6c + 8a220d5 commit 5d9ad12
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
4 changes: 0 additions & 4 deletions modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,10 +350,6 @@ def ipex_llm_loader(model_name):
use_cache=shared.args.use_cache,
)

if shared.args.device == "GPU":
import intel_extension_for_pytorch
model = model.half().to("xpu")

tokenizer = AutoTokenizer.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)

return model, tokenizer
Expand Down
7 changes: 4 additions & 3 deletions modules/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
is_stream = state['stream']
if len(all_stop_strings) > 0 and not state['stream']:
state = copy.deepcopy(state)
state['stream'] = True

min_update_interval = 0
if state.get('max_updates_second', 0) > 0:
Expand Down Expand Up @@ -375,6 +374,10 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(filtered_params)
print()

if shared.args.device == "GPU":
import intel_extension_for_pytorch
shared.model = shared.model.to("xpu")

t0 = time.time()
try:
if not is_chat and not shared.is_seq2seq:
Expand All @@ -384,8 +387,6 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
if not state['stream']:
with torch.no_grad():
output = shared.model.generate(**generate_params)[0]
if cuda:
output = output.cuda()

starting_from = 0 if shared.is_seq2seq else len(input_ids[0])
yield get_reply_from_output_ids(output, state, starting_from=starting_from)
Expand Down

0 comments on commit 5d9ad12

Please sign in to comment.