Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom autograd fails with torchdeq in eval mode #4

Open
BurgerAndreas opened this issue Mar 31, 2024 · 1 comment
Open

Custom autograd fails with torchdeq in eval mode #4

BurgerAndreas opened this issue Mar 31, 2024 · 1 comment

Comments

@BurgerAndreas
Copy link

BurgerAndreas commented Mar 31, 2024

It's a very nieche problem, but tripped me over big time :')

Issue

For model.eval() , z_pred will not have tracked gradients (z_pred.requires_gradient==False).
For custom torch.autograd this will lead to an error: RuntimeError: One of the differentiated Tensors does not require grad.

Minimal example


import torch

import torchdeq
from torchdeq import get_deq
from torchdeq.norm import apply_norm, reset_norm

class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.layer = torch.nn.Linear(10, 10)

        # deq
        self.deq = get_deq()
        apply_norm(self.layer, 'weight_norm')

    def implicit_layer(self, x):
        return self.layer(x)
    
    def forward(self, x, pos):

        z = torch.zeros_like(x)

        reset_norm(self.layer)

        f = lambda z: self.f(z, pos)

        z_pred, info = self.deq(self.implicit_layer, z)
        
        # if model.eval() -> z_pred[-1].requires_grad is False!
        energy = z_pred[-1]
        forces = -1 * (
            torch.autograd.grad(
                energy,
                # diff with respect to pos
                # if you get 'One of the differentiated Tensors appears to not have been used in the graph'
                # then because pos is not 'used' to calculate the energy
                pos, 
                grad_outputs=torch.ones_like(energy),
                create_graph=True,
                # allow_unused=True, 
            )[0]
        )

        return energy, forces


def run(model, eval=False):

    if eval:
        model.eval()
    else:
        model.train()

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    for step in range(10):
        x = torch.randn(10, 10)
        pos = torch.randn(10, 3)
        energy, forces = model(x, pos)
        
        # loss
        optimizer.zero_grad()
        energy_target = torch.randn(10, 1)
        energy_loss = torch.nn.functional.mse_loss(energy, energy_target)
        force_target = torch.randn(10, 3)
        force_loss = torch.nn.functional.mse_loss(forces, force_target)
        loss = energy_loss + force_loss

        if not eval:
            loss.backward()
            optimizer.step()
    
    return True

if __name__ == '__main__':
    model = MyModel()
    success = run(model, eval=False)
    print(f'train success: {success}')
    success = run(model, eval=True)
    print(f'eval success: {success}')

While model.train() it will work perfectly well. For model.eval() we get the error: RuntimeError: One of the differentiated Tensors does not require grad.

Desired behaviour

A flag to set such that z_pred[-1].requires_grad is always True, even when model.eval().
self.deq = get_deq(grad_in_eval=True)

@Gsunshine
Copy link
Member

Hi Andreas @BurgerAndreas ,

Thanks a lot for your interest! I think a quick fix is to enable self.deq to be in the train mode while other components of the model are in eval mode.

I appreciate the suggestion! I think we can implement such a feature into the lib. Feel free to submit a PR.
I'll be back to close this issue soon.

Thanks,
Zhengyang

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants