We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
(h2ollm) ubuntu@cloudvm:~/h2o-llm$ CUDA_VISIBLE_DEVICES=1 python finetune.py --base_model=decapoda-research/llama-13b-hf --llama_flash_attn=True
3%|██▉ | 10/377 [08:44<5:20:03, 52.33s/it]
(h2ollm) ubuntu@cloudvm:~/h2o-llm$ CUDA_VISIBLE_DEVICES=1 python finetune.py --base_model=decapoda-research/llama-13b-hf --llama_flash_attn=False --train_8bit=False 0%|▎ | 1/377 [00:21<2:16:32, 21.79s/it]
(h2ollm) ubuntu@cloudvm:~/h2o-llm$ CUDA_VISIBLE_DEVICES=1 python finetune.py --base_model=decapoda-research/llama-13b-hf --llama_flash_attn=True 1%|█▏ | 4/377 [03:33<5:29:20, 52.98s/it]
(h2ollm) ubuntu@cloudvm:~/h2o-llm$ CUDA_VISIBLE_DEVICES=1 python finetune.py --base_model=decapoda-research/llama-13b-hf --llama_flash_attn=True --train_8bit=False 0%|▎ | 1/377 [00:19<2:04:22, 19.85s/it]
(h2ollm) ubuntu@cloudvm:~/h2o-llm$ CUDA_VISIBLE_DEVICES=1 python finetune.py --base_model=decapoda-research/llama-13b-hf --llama_flash_attn=False --train_8bit=False 2%|██ | 7/377 [02:31<2:14:38, 21.83s/it]
(h2ollm) ubuntu@cloudvm:~/h2o-llm$ git diff diff --git a/finetune.py b/finetune.py index e112138..4477fb2 100644 --- a/finetune.py +++ b/finetune.py @@ -643,7 +643,7 @@ def train( model = torch.compile(model) # WIP (not generally replacing layers until pytorch 2.1) if not llama_flash_attn: - torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_flash_sdp(False) if gpus > 1 and not ddp: assert trainer.is_model_parallel
1%|█▍ | 5/377 [01:50<2:15:45, 21.90s/it]
(h2ollm) ubuntu@cloudvm:~/h2o-llm$ git diff diff --git a/finetune.py b/finetune.py index e112138..2ea48d4 100644 --- a/finetune.py +++ b/finetune.py @@ -640,10 +640,10 @@ def train( ).__get__(model, type(model)) if torch.__version__ >= "2" and sys.platform != "win32": - model = torch.compile(model) + # model = torch.compile(model) # WIP (not generally replacing layers until pytorch 2.1) if not llama_flash_attn: - torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_flash_sdp(False) if gpus > 1 and not ddp: assert trainer.is_model_parallel
1%|▉ | 3/377 [01:06<2:16:33, 21.91s/it]
The text was updated successfully, but these errors were encountered:
#31 #86
Sorry, something went wrong.
so nothing really matters, if can do 16-bit, then do 16-bit. Otherwise do 8-bit. No Flash attention specifics needed.
No branches or pull requests
Torch==2.0.0
8-bit with flash-attn package:
(h2ollm) ubuntu@cloudvm:~/h2o-llm$ CUDA_VISIBLE_DEVICES=1 python finetune.py --base_model=decapoda-research/llama-13b-hf --llama_flash_attn=True
3%|██▉ | 10/377 [08:44<5:20:03, 52.33s/it]
16-bit with defaults:
(h2ollm) ubuntu@cloudvm:~/h2o-llm$ CUDA_VISIBLE_DEVICES=1 python finetune.py --base_model=decapoda-research/llama-13b-hf --llama_flash_attn=False --train_8bit=False
0%|▎ | 1/377 [00:21<2:16:32, 21.79s/it]
Torch==2.0.1
8-bit with flash-attn package:
(h2ollm) ubuntu@cloudvm:~/h2o-llm$ CUDA_VISIBLE_DEVICES=1 python finetune.py --base_model=decapoda-research/llama-13b-hf --llama_flash_attn=True
1%|█▏ | 4/377 [03:33<5:29:20, 52.98s/it]
16-bit with flash-attn package:
(h2ollm) ubuntu@cloudvm:~/h2o-llm$ CUDA_VISIBLE_DEVICES=1 python finetune.py --base_model=decapoda-research/llama-13b-hf --llama_flash_attn=True --train_8bit=False
0%|▎ | 1/377 [00:19<2:04:22, 19.85s/it]
16-bit with default (flash attention):
(h2ollm) ubuntu@cloudvm:~/h2o-llm$ CUDA_VISIBLE_DEVICES=1 python finetune.py --base_model=decapoda-research/llama-13b-hf --llama_flash_attn=False --train_8bit=False
2%|██ | 7/377 [02:31<2:14:38, 21.83s/it]
16-bit with default, but disabled flash attention:
1%|█▍ | 5/377 [01:50<2:15:45, 21.90s/it]
16-bit with default, but disabled flash attention and no torch.compile():
1%|▉ | 3/377 [01:06<2:16:33, 21.91s/it]
The text was updated successfully, but these errors were encountered: