Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation #9

Closed
udonuser opened this issue May 8, 2018 · 5 comments

Comments

@udonuser
Copy link

udonuser commented May 8, 2018

Hi, I tried to use your s2conv/so3conv in multi model like following.
(Model includes your s2conv/so3conv)

def train(epoch):
    model.train()
    for batch_idx, (image,target) in enumerate(train_loader):
        image = image.to(device)
        optimizer.zero_grad()
       
        # multi model
        re_image1 = model(image)
        re_image2 = model(image)
        loss = re_image1.abs().mean() + re_image2.abs().mean()

        loss.backward()
        optimizer.step()

Then I got following error.

  File "main.py", line 66, in <module>
    main()
  File "main.py", line 62, in main
    train(epoch)
  File "main.py", line 53, in train
    loss.backward()
  File "/home/hayashi/.python-venv/lib/python3.5/site-packages/torch/tensor.py", line 93, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/hayashi/.python-venv/lib/python3.5/site-packages/torch/autograd/__init__.py", line 89, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

There are no error when I use mono-model like following

def train(epoch):
    model.train()
    for batch_idx, (image,target) in enumerate(train_loader):
        image = image.to(device)
        optimizer.zero_grad()
       
        # mono model
        image1 = model(image)
        loss = image1.abs().mean() 

        loss.backward()
        optimizer.step()

So I think this error is not caused from inplace operation.
Do you know this error's detail?

P.S.
I found this error doesn't occur when I use past version of your s2conv/so3conv.
(maybe this is for Pytorch v0.3.1)
If you can, please republish past version of s2cnn (for Pytorch v0.3.1).

@mariogeiger
Copy link
Collaborator

Hi,
No I never observed this error, we always used a mono-model.
Did you try to simplify the model to see if the error still occur ? For instance using only s2conv or only so3conv ?

@udonuser
Copy link
Author

udonuser commented May 8, 2018

Yes, I got same error using only s2conv in following code.

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST

from s2cnn import s2_near_identity_grid, S2Convolution


def S2conv2d(in_c, out_c, in_b, out_b):
    grid = s2_near_identity_grid(n_alpha=2 * in_b)
    return S2Convolution(in_c, out_c, in_b, out_b, grid)

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = S2conv2d(1, 5, 14, 7)

    def forward(self, x):
        return self.conv1(x)


def main():
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    
    WORKERS=1

    img_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ])

    train_loader = DataLoader(MNIST('./data', train=True, transform=img_transform, download=True),
                              batch_size=256, num_workers=WORKERS,  pin_memory=True, shuffle=True)

    model = Model().to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3,
                                momentum=0.9, weight_decay=5e-4)

    def train():
        model.train()
        for batch_idx, (image,target) in enumerate(train_loader):
            image = image.to(device)
            optimizer.zero_grad()

            # multi model
            output1 = model(image)
            output2 = model(image)
            loss = (output1 + output2).mean()

            loss.backward()
            optimizer.step()
            print("OK")
            break
    train()

@mariogeiger
Copy link
Collaborator

The problem comes from s2_rft when we use torch.einsum. The problem can be reproduced by the following code:

x = torch.randn(3, 3, requires_grad=True)
z1 = torch.einsum("ij,jk->ik", (x, torch.randn(3, 3)))
z2 = torch.einsum("ij,jk->ik", (x, torch.randn(3, 3)))
z1.sum().backward()

@mariogeiger
Copy link
Collaborator

I can fix it with torch.einsum("ij,jk->ik", (x.clone(), torch.randn(3, 3)))

@udonuser
Copy link
Author

udonuser commented May 8, 2018

Thank you so much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants