Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch >= 2.2.0 inference issues on MPS #458

Open
davmacario opened this issue Mar 20, 2024 · 3 comments
Open

Torch >= 2.2.0 inference issues on MPS #458

davmacario opened this issue Mar 20, 2024 · 3 comments

Comments

@davmacario
Copy link

davmacario commented Mar 20, 2024

When running

python sample.py --init_from=gpt2 --num_samples=2 --max_new_tokens=100

having set device = 'mps' on my M1 Pro MacBook (MacOS 14.4), with Torch 2.2.1 and 2.2.0, I get this output:

Overriding: init_from = gpt2
Overriding: num_samples = 2
Overriding: max_new_tokens = 100
loading weights from pretrained gpt: gpt2
forcing vocab_size=50257, block_size=1024, bias=True
overriding dropout rate to 0.0
number of parameters: 123.65M
No meta.pkl found, assuming GPT-2 encodings...

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
---------------

!!!!!!。!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
---------------

The character ! corresponds to token 0, meaning that the model only generates 0's at the output.
This does not happen when using Torch 2.1.x.

I know this is probably a Torch bug, but I could use the help trying to pinpoint the actual cause of this issue (and possibly submit a bug report to Torch).
Let me know if someone has had the same issue.

@davmacario davmacario changed the title Torch >= 2.0.0 inference issues on MPS Torch >= 2.2.0 inference issues on MPS Mar 20, 2024
davmacario added a commit to davmacario/MDI-LLM that referenced this issue Mar 22, 2024
@adriankobras
Copy link

I had the same issue on a M1 Pro Macbook with Torch 2.2.0

@sun1638650145
Copy link

I encountered a similar issue as well. I reproduced a Transformer on mps and encountered a similar error, but it worked fine on cpu. I upgraded to torch 2.3.0 and it seems to have fixed it(Although I didn't see any mention of bug fixes in the logs).

@davmacario
Copy link
Author

Alright, thanks @sun1638650145, I'll give it a try and possibly update the issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants