Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GPU support #74

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open

Conversation

AyanoClarke
Copy link

To run on GPU, two parameters are added:

  1. dtype: default is float32, which is the same as the original code FloatTensor; however, the dtype can be set to bfloat16 or float16 to reduce the memory and accelerate the training steps.
  2. device: default is cpu, but it can be set to cuda (for AMD or NVIDIA), mps (for Apple).

The tutorial (data 151673) is tested on Apple M1 chips and Nvidia's A100 device.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant