From 422c6d7d62bbb392b9e3390ac98171ffabe63458 Mon Sep 17 00:00:00 2001 From: swap357 Date: Tue, 4 Apr 2023 23:33:35 -0700 Subject: [PATCH] add cpu support --- main.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/main.py b/main.py index 410319e..c68f967 100644 --- a/main.py +++ b/main.py @@ -12,14 +12,22 @@ tokenizer = None current_peft_model = None +# Add an argument parser for --use_cpu +parser = argparse.ArgumentParser(description="Simple LLaMA Finetuner") +parser.add_argument("-s", "--share", action="store_true", help="Enable sharing of the Gradio interface") +parser.add_argument("--use_cpu", action="store_true", help="Use CPU for training and inference") +args = parser.parse_args() +use_cpu = args.use_cpu + + def load_base_model(): global model print('Loading base model...') model = transformers.LlamaForCausalLM.from_pretrained( 'decapoda-research/llama-7b-hf', - load_in_8bit=True, - torch_dtype=torch.float16, - device_map={'':0} + load_in_8bit=(not use_cpu), + torch_dtype=(torch.float32 if use_cpu else torch.float16), + device_map={'':0} if not use_cpu else None ) def load_tokenizer(): @@ -105,7 +113,9 @@ def generate_text( input_ids=input_ids, attention_mask=torch.ones_like(input_ids), generation_config=generation_config - )[0].cuda() + )[0] + if not use_cpu: + output = output.cuda() return tokenizer.decode(output, skip_special_tokens=True).strip() @@ -205,7 +215,7 @@ def to_dict(text): # Enables mixed precision training using 16-bit floating point numbers (FP16). # This can speed up training and reduce GPU memory consumption without # sacrificing too much model accuracy. - fp16=True, + fp16=(not use_cpu), # The frequency (in terms of steps) of logging training metrics and statistics # like loss, learning rate, etc. In this case, it logs after every 20 steps. @@ -444,8 +454,5 @@ def update_models_list(): """) if __name__ == '__main__': - parser = argparse.ArgumentParser(description="Simple LLaMA Finetuner") - parser.add_argument("-s", "--share", action="store_true", help="Enable sharing of the Gradio interface") - args = parser.parse_args() demo.queue().launch(share=args.share)