Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TheseusLayer.to improvements: Return itself and support cuda #622

Closed
tvercaut opened this issue Nov 2, 2023 · 2 comments · Fixed by #623
Closed

TheseusLayer.to improvements: Return itself and support cuda #622

tvercaut opened this issue Nov 2, 2023 · 2 comments · Fixed by #623

Comments

@tvercaut
Copy link

tvercaut commented Nov 2, 2023

🚀 Feature

It would be great to be able to call

theseus_optim = th.TheseusLayer(optimizer).to(torch.device('cuda'))

as well as

theseus_optim = th.TheseusLayer(optimizer, device=torch.device('cuda'))

Currently this fails for two reasons:

  • TheseusLayer.to returns None and this the one-liner needs to be split in 2: theseus_optim = th.TheseusLayer(optimizer); theseus_optim.to(torch.device('cuda'))
  • Using 'cuda' fails as this is mapped to cuda:0 in PyTorch tensors and the comparison then fails in Objective.update
    if tensor.device != self.device or tensor.dtype != self.dtype:
    raise ValueError(
    f"Attempted to update variable {var_name} with a "
    f"({tensor.device},{tensor.dtype}) tensor, which is inconsistent "
    f"with objective's expected ({self.device},{self.dtype})."
    )

Motivation

This would make the use of Theseus more convenient.

Pitch

See above

Alternatives

See above

Additional context

# Applies to() with given args to all tensors in the objective
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.objective.to(*args, **kwargs)

Other issues with to have been discussed in:

@luisenp
Copy link
Contributor

luisenp commented Nov 3, 2023

This is a good suggestion, thanks for bringing it up! One question is about the expected behavior when passing torch.device("cuda) . Should we automatically pass to "cuda:0" on our side? As you say, this seems to be the case for torch tensors, but for some reason torch.device() doesn't do this automatically. Not sure if this is an oversight on Pytorch's side or a deliberate choice.

@luisenp
Copy link
Contributor

luisenp commented Nov 3, 2023

Added feature in #623.

@tvercaut tvercaut closed this as completed Nov 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants