You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thanks for the amazing, clean library. I had a few issues running it on MPS, I thought I'd share how I got it to work. May not be suitable for a PR given some of it is hacky, but could evolve into one if there is interest and additional input from the community.
This required reimplementing the functions precompute_freqs_cis, reshape_for_broadcast, apply_rotary_emb in train.py to split out the complex operations:
Separate the computation of the cosine and sine parts in precompute_freqs_cis().
Split the real and imaginary parts of the input tensors in apply_rotary_emb() before applying the transformations.
Unfortunately, this results in a ~10X slowdown in the apply_rotary_emb function, and I haven't had much luck in optimizing it further. Maybe someone else has some ideas!
8. NB. Must use PyTorch nightly
With the above modifications, the code runs as is for training and inference in this conda environment:
Hi, awesome research!
I actually have been training on m2 (pro) mac mini using this codebase by just adding
from torch._dynamo import config
config.suppress_errors=True
I still get a lot of error logs printed at startup (2.5 screen heights of them) but then the training goes on.
I have been wondering if this impacts performance a lot. It has too
Thanks for the amazing, clean library. I had a few issues running it on MPS, I thought I'd share how I got it to work. May not be suitable for a PR given some of it is hacky, but could evolve into one if there is interest and additional input from the community.
Minimal diff to get it to run: master...louismullie:llama2.c:master
1. Disable AutoCast
AutoCast not yet supported on MPS. See: pytorch/pytorch#88415
2. Disable fused optimizer
Fused=True not yet supported on MPS. No GitHub issue found.
3. Enable CPU fallback for unsupported Pytorch ops
Set environment:
PYTORCH_ENABLE_MPS_FALLBACK=1
and expect slowdown.See: pytorch/pytorch#77754, pytorch/pytorch#77764
4. Suppress Dynamo warnings.
Set
torch._logging.set_logs(dynamo=logging.ERROR)
to suppress.5. Try to detect MPS everywhere CUDA can't be detected.
Self-explanatory.
7. Remove use of complex operations in rotary embedding
See: pytorch/pytorch#95976
This required reimplementing the functions
precompute_freqs_cis
,reshape_for_broadcast
,apply_rotary_emb
intrain.py
to split out the complex operations:Essentially:
Unfortunately, this results in a ~10X slowdown in the
apply_rotary_emb
function, and I haven't had much luck in optimizing it further. Maybe someone else has some ideas!8. NB. Must use PyTorch nightly
With the above modifications, the code runs as is for training and inference in this
conda
environment:The text was updated successfully, but these errors were encountered: