Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Will Haste work without CUDA? #2

Closed
shamoons opened this issue Mar 10, 2020 · 2 comments
Closed

Will Haste work without CUDA? #2

shamoons opened this issue Mar 10, 2020 · 2 comments

Comments

@shamoons
Copy link

Can this be used without CUDA? Specifically, I have my laptop for local development and my HPC for actual training. Will the Haste LSTM work on my local as well? (Slower, which is fine, but that way I can maintain one codebase)

@sharvil
Copy link
Contributor

sharvil commented Mar 10, 2020

At the moment it's CUDA-only but we'll definitely get to a CPU implementation.

sharvil added a commit that referenced this issue Mar 13, 2020
These implementations are written using the PyTorch Python API which
allows them to run on any supported device (including CPU). Each
layer will automatically choose either the CUDA or generic
implementation depending on which device is currently selected for
the layer.

Issue: #2
@sharvil
Copy link
Contributor

sharvil commented Mar 13, 2020

You can now use these layers without CUDA. When the layer and input tensor are on CPU, the layer will use the CPU implementation. When the layer and input tensor are on GPU, it will use the fast CUDA implementation.

Here's an example of an LSTM on CPU:

import torch
import haste_pytorch as haste

batch_size = 128
seq_len = 256
input_size = 128
hidden_size = 256

x = torch.rand(seq_len, batch_size, input_size).cpu()
lstm = haste.LSTM(input_size, hidden_size).cpu()
y, state = lstm(x)

Note that all Haste RNN features (e.g. Zoneout, DropConnect) are supported by the CPU implementation as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants