Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dtype- and device-related client issues #98

Merged
merged 8 commits into from
Nov 29, 2022
Merged

Conversation

borzunov
Copy link
Collaborator

@borzunov borzunov commented Nov 29, 2022

This PR:

  1. Makes inference/forward/backward calls on client remember the dtype and device of source tensors, then move/cast the outputs to the same dtype/device. This way:

    • Users don't need to make changes in the code launching RemoteSequential to make it run on a different device.
    • model.generate() also starts to support both CPU and GPU.
    • See the draft of the GPU-based Colab notebook for running inference/generate/forward/backward through the public swarm with BLOOM-176B.
  2. Sets default low_cpu_mem_usage=True, client's request timeout to 20 sec.

  3. Removes excess casts to float32 left in Dmitry's code.

  4. (minor) Improves error messages.

@borzunov borzunov merged commit ab41223 into main Nov 29, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant