Skip to content

Update vae.py#3761

Open
mzamini92 wants to merge 4 commits intodeepspeedai:masterfrom
mzamini92:patch-1
Open

Update vae.py#3761
mzamini92 wants to merge 4 commits intodeepspeedai:masterfrom
mzamini92:patch-1

Conversation

@mzamini92
Copy link

Since the DSVAE class is already inheriting from torch.nn.Module, there is no need to inherit from CUDAGraph as well. You can remove the CUDAGraph inheritance. Instead of using self.vae.requires_grad_(requires_grad=False), you can use torch.no_grad() context manager during initialization to disable gradient computation for the self.vae module. The _graph_replay_decoder, _graph_replay_encoder, and _graph_replay methods can benefit from the @torch.no_grad() decorator.

Since the `DSVAE` class is already inheriting from `torch.nn.Module`, there is no need to inherit from `CUDAGraph` as well. You can remove the `CUDAGraph` inheritance.
Instead of using `self.vae.requires_grad_(requires_grad=False)`, you can use `torch.no_grad()` context manager during initialization to disable gradient computation for the `self.vae` module. 
The `_graph_replay_decoder`, `_graph_replay_encoder`, and `_graph_replay` methods can benefit from the `@torch.no_grad()` decorator.
@mzamini92
Copy link
Author

@microsoft-github-policy-service agree

@microsoft-github-policy-service agree

@loadams
Copy link
Collaborator

loadams commented Oct 31, 2024

@mzamini92 - since this PR made it quite a while before being lost, do you think it is still worth merging, if so could you resolve the merge conflicts? Or should we close this PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants