Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why "no required computing gradients"? #4

Closed
HRLTY opened this issue May 19, 2017 · 7 comments
Closed

Why "no required computing gradients"? #4

HRLTY opened this issue May 19, 2017 · 7 comments

Comments

@HRLTY
Copy link

HRLTY commented May 19, 2017

I used the same calc_GradientPenalty method as yours and the latest master branch of pytorch('0.1.12+625850c'). But it stuck at penalty.backward() with an error

"RuntimeError: there are no graph nodes that require computing gradients"

. I used requires_gradient = True for the interpolates variable.
Thanks!

@caogang
Copy link
Owner

caogang commented May 19, 2017

Can you show your code totally? This problem may because of the type of the network. So can you paste your code, please?

@HRLTY
Copy link
Author

HRLTY commented May 19, 2017

I pasted below. I am checking variables' requires_grad attribute. The returned variable from autograd.grad has a FALSE attribute( fyi, retriving the creator of the gradients var will raise an attribute error). I believe that's the cause of the later backward error. However, why it returned a variable with False?

from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable


parser = argparse.ArgumentParser()
parser.add_argument('--dataset', required=True, help='cifar10 | lsun | imagenet | folder | lfw ')
parser.add_argument('--dataroot', required=True, help='path to dataset')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network')
parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
parser.add_argument('--ngf', type=int, default=64)
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--netG', default='', help="path to netG (to continue training)")
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints')
parser.add_argument('--manualSeed', type=int, help='manual seed')
parser.add_argument('--lamda',type=float, default=10, help='wgan-improved param')
opt = parser.parse_args()
print(opt)

try:
    os.makedirs(opt.outf)
except OSError:
    pass

if opt.manualSeed is None:
    opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
if opt.cuda:
    torch.cuda.manual_seed_all(opt.manualSeed)

cudnn.benchmark = True

if torch.cuda.is_available() and not opt.cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")

if opt.dataset in ['imagenet', 'folder', 'lfw']:
    # folder dataset
    dataset = dset.ImageFolder(root=opt.dataroot,
                               transform=transforms.Compose([
                                   transforms.Scale(opt.imageSize),
                                   transforms.CenterCrop(opt.imageSize),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                               ]))
elif opt.dataset == 'lsun':
    dataset = dset.LSUN(db_path=opt.dataroot, classes=['bedroom_train'],
                        transform=transforms.Compose([
                            transforms.Scale(opt.imageSize),
                            transforms.CenterCrop(opt.imageSize),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                        ]))
elif opt.dataset == 'cifar10':
    dataset = dset.CIFAR10(root=opt.dataroot, download=True,
                           transform=transforms.Compose([
                               transforms.Scale(opt.imageSize),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ])
    )
assert dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
                                         shuffle=True, num_workers=int(opt.workers))

ngpu = int(opt.ngpu)
nz = int(opt.nz)
ngf = int(opt.ngf)
ndf = int(opt.ndf)
nc = 3


# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

# def calGradientPenalty(real, fake, netD):
#     alpha = torch.rand(real.size(0), 1)
#     alpha = alpha.expand(real.size())
#     if opt.cuda:
#         alpha = alpha.cuda()
#     interpolates = alpha * real + (1-alpha) * fake
#     if opt.cuda:
#         interpolates = interpolates.cuda()
#     interpolates = Variable(interpolates, requires_grad = True)
#     disc_inter = netD(interpolates)
#     grad = torch.ones(disc_inter.size())
#     gradients = torch.autograd.grad(outputs=disc_inter, inputs = interpolates,
#     grad_outputs=grad.cuda() if opt.cuda else grad, create_graph = True,
#     retain_graph = True, only_inputs = True)[0]
#
#     grad_penalty = ((gradients.norm(2, dim=1) - 1) **2).mean() * opt.lamda
#
#     return grad_penalty

def calGradientPenalty(real_data, fake_data, netD):
    alpha = torch.rand(real_data.size(0), 1)
    alpha = alpha.expand(real_data.size())
    alpha = alpha.cuda() if opt.cuda else alpha

    interpolates = alpha * real_data + ((1 - alpha) * fake_data)

    if opt.cuda:
        interpolates = interpolates.cuda()
    interpolates = torch.autograd.Variable(interpolates, requires_grad=True)
    print(interpolates.requires_grad)
    disc_interpolates = netD(interpolates)
    print(disc_interpolates.creator)
    gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size()).cuda() if opt.cuda else torch.ones(
                                  disc_interpolates.size()),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]
    #gradients.requires_grad = True
    print(gradients.creator)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * opt.lamda
    print(gradient_penalty.requires_grad)
    return gradient_penalty

class _netG(nn.Module):
    def __init__(self, ngpu):
        super(_netG, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(     nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2,     ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(    ngf,      nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)
        return output


netG = _netG(ngpu)
netG.apply(weights_init)
if opt.netG != '':
    netG.load_state_dict(torch.load(opt.netG))
print(netG)


class _netD(nn.Module):
    def __init__(self, ngpu):
        super(_netD, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            #nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            #nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            #nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            #nn.Sigmoid()
        )

    def forward(self, input):
        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)

        return output.view(-1, 1)


netD = _netD(ngpu)
netD.apply(weights_init)
if opt.netD != '':
    netD.load_state_dict(torch.load(opt.netD))
print(netD)

criterion = nn.BCELoss()

input = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)
noise = torch.FloatTensor(opt.batchSize, nz, 1, 1)
fixed_noise = torch.FloatTensor(opt.batchSize, nz, 1, 1).normal_(0, 1)
label = torch.FloatTensor(opt.batchSize)
real_label = 1
fake_label = 0

if opt.cuda:
    netD.cuda()
    netG.cuda()
    criterion.cuda()
    input, label = input.cuda(), label.cuda()
    noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

input = Variable(input)
label = Variable(label)
noise = Variable(noise)
fixed_noise = Variable(fixed_noise)

# setup optimizer
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.9))
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.9))

for epoch in range(opt.niter):
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
        netD.zero_grad()
        real_cpu, _ = data
        batch_size = real_cpu.size(0)
        input.data.resize_(real_cpu.size()).copy_(real_cpu)
        label.data.resize_(batch_size).fill_(real_label)

        output = netD(input)
        errD_real = -1 * torch.mean(output) #criterion(output, label)
        errD_real.backward()
        D_x = output.data.mean()

        # train with fake
        noise.data.resize_(batch_size, nz, 1, 1)
        noise.data.normal_(0, 1)
        fake = netG(noise)
        label.data.fill_(fake_label)
        output = netD(fake.detach())
        errD_fake = torch.mean(output)#criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.data.mean()

        penalty = calGradientPenalty(input.data, fake.data, netD)
        penalty.backward()
        errD = errD_real + errD_fake + opt.lamda * penalty
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.data.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake)
        errG = -1 * torch.mean(output) #criterion(output, label)
        errG.backward()
        D_G_z2 = output.data.mean()
        optimizerG.step()

        print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
              % (epoch, opt.niter, i, len(dataloader),
                 errD.data[0], errG.data[0], D_x, D_G_z1, D_G_z2))
        if i % 100 == 0:
            vutils.save_image(real_cpu,
                    '%s/real_samples.png' % opt.outf,
                    normalize=True)
            fake = netG(fixed_noise)
            vutils.save_image(fake.data,
                    '%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch),
                    normalize=True)

    # do checkpointing
    torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch))
    torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch))

@caogang
Copy link
Owner

caogang commented May 19, 2017

Now, the ConvNd doesn't support calculating high order gradient(This problem is on progress in pytorch). So this code shouldn't work. You can change the ConvNd in Discriminator to Linear to test the code.

@HRLTY
Copy link
Author

HRLTY commented May 19, 2017

OK, thanks. Do you have any estimated time for convNd to be supported in master branch?
I was trying to run the gan_toy.py. However, there is also an error for the backward(). Do you know why?

Traceback (most recent call last):                                                                                                                 │···
  File "gan_toy.py", line 270, in <module>                                                                                                         │···
    gradient_penalty.backward()                                                                                                                    │···
  File "/home/shu.zhang/ruihuang/Dev/pytorchEnv/lib/python2.7/site-packages/torch/autograd/variable.py", line 145, in backward                     │···
    torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)                                                          │···
  File "/home/shu.zhang/ruihuang/Dev/pytorchEnv/lib/python2.7/site-packages/torch/autograd/__init__.py", line 98, in backward                      │···
    variables, grad_variables, retain_graph)                                                                                                       │···
  File "/home/shu.zhang/ruihuang/Dev/pytorchEnv/lib/python2.7/site-packages/torch/autograd/function.py", line 90, in apply                         │···
    return self._forward_cls.backward(self, *args)                                                                                                 │···
  File "/home/shu.zhang/ruihuang/Dev/pytorchEnv/lib/python2.7/site-packages/torch/nn/_functions/linear.py", line 23, in backward                   │···
    grad_input = torch.mm(grad_output, weight)                                                                                                     │···
  File "/home/shu.zhang/ruihuang/Dev/pytorchEnv/lib/python2.7/site-packages/torch/autograd/variable.py", line 531, in mm                           │···
    return self._static_blas(Addmm, (output, 0, 1, self, matrix), False)                                                                           │···
  File "/home/shu.zhang/ruihuang/Dev/pytorchEnv/lib/python2.7/site-packages/torch/autograd/variable.py", line 524, in _static_blas                 │···
    return cls.apply(*(args[:1] + args[-2:] + (alpha, beta, inplace)))                                                                             │···
  File "/home/shu.zhang/ruihuang/Dev/pytorchEnv/lib/python2.7/site-packages/torch/autograd/_functions/blas.py", line 24, in forward                │···
    matrix1, matrix2, out=output)                                                                                                                  │···
TypeError: torch.addmm received an invalid combination of arguments - got (int, torch.cuda.ByteTensor, int, torch.cuda.ByteTensor, torch.cuda.Float│···
Tensor, out=torch.cuda.ByteTensor), but expected one of:                                                                                           │···
 * (torch.cuda.ByteTensor source, torch.cuda.ByteTensor mat1, torch.cuda.ByteTensor mat2, *, torch.cuda.ByteTensor out)                            │···
 * (torch.cuda.ByteTensor source, torch.cuda.sparse.ByteTensor mat1, torch.cuda.ByteTensor mat2, *, torch.cuda.ByteTensor out)                     │···
 * (int beta, torch.cuda.ByteTensor source, torch.cuda.ByteTensor mat1, torch.cuda.ByteTensor mat2, *, torch.cuda.ByteTensor out)                  │···
 * (torch.cuda.ByteTensor source, int alpha, torch.cuda.ByteTensor mat1, torch.cuda.ByteTensor mat2, *, torch.cuda.ByteTensor out)                 │···
 * (int beta, torch.cuda.ByteTensor source, torch.cuda.sparse.ByteTensor mat1, torch.cuda.ByteTensor mat2, *, torch.cuda.ByteTensor out)           │···
 * (torch.cuda.ByteTensor source, int alpha, torch.cuda.sparse.ByteTensor mat1, torch.cuda.ByteTensor mat2, *, torch.cuda.ByteTensor out)          │···
 * (int beta, torch.cuda.ByteTensor source, int alpha, torch.cuda.ByteTensor mat1, torch.cuda.ByteTensor mat2, *, torch.cuda.ByteTensor out)       │···
      didn't match because some of the arguments have invalid types: (int, torch.cuda.ByteTensor, int, torch.cuda.ByteTensor, torch.cuda.FloatTenso│···
r, out=torch.cuda.ByteTensor)                                                                                                                      │···
 * (int beta, torch.cuda.ByteTensor source, int alpha, torch.cuda.sparse.ByteTensor mat1, torch.cuda.ByteTensor mat2, *, torch.cuda.ByteTensor out)│···
      didn't match because some of the arguments have invalid types: (int, torch.cuda.ByteTensor, int, torch.cuda.ByteTensor, torch.cuda.FloatTenso│···
r, out=torch.cuda.ByteTensor)

@caogang
Copy link
Owner

caogang commented May 19, 2017

Sorry, this is one of bug existing, but my methods of fixing this bug is approved in pytorch. So I can give you an fix to make this error clear.

torch/nn/_functions/thnn/activation.py

         else:
+            mask = input > ctx.threshold
+            grad_input = mask.type_as(grad_output) * grad_output
-            grad_input = grad_output.masked_fill(input > ctx.threshold, 0)
         return grad_input, None, None, None

@caogang
Copy link
Owner

caogang commented May 19, 2017

this is a bug in pytorch pytorch/pytorch#1517

@HRLTY HRLTY closed this as completed May 20, 2017
@clu5
Copy link

clu5 commented Jun 21, 2017

I think I have this same bug. When I try to run the MNIST GAN, I get this stack trace:

Generator (
  (block_1): Sequential (
    (0): ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU (inplace)
  )
  (block_2): Sequential (
    (0): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU (inplace)
  )
  (deconv_out): ConvTranspose2d(64, 1, kernel_size=(8, 8), stride=(2, 2))
  (preprocess): Sequential (
    (0): Linear (128 -> 4096)
    (1): ReLU (inplace)
  )
  (sigmoid): Sigmoid ()
)
Discriminator (
  (main): Sequential (
    (0): Linear (784 -> 4096)
    (1): ReLU (inplace)
    (2): Linear (4096 -> 4096)
    (3): ReLU (inplace)
    (4): Linear (4096 -> 4096)
    (5): ReLU (inplace)
    (6): Linear (4096 -> 4096)
    (7): ReLU (inplace)
    (8): Linear (4096 -> 4096)
    (9): ReLU (inplace)
    (10): Linear (4096 -> 1)
  )
)
torch.Size([50, 4096])
torch.Size([50, 256, 4, 4])
torch.Size([50, 128, 8, 8])
torch.Size([50, 128, 7, 7])
torch.Size([50, 64, 11, 11])
torch.Size([50, 1, 28, 28])
torch.Size([50, 1, 28, 28])
torch.Size([50, 4096])
torch.Size([50, 256, 4, 4])
torch.Size([50, 128, 8, 8])
torch.Size([50, 128, 7, 7])
torch.Size([50, 64, 11, 11])
torch.Size([50, 1, 28, 28])
torch.Size([50, 1, 28, 28])
torch.Size([50, 4096])
torch.Size([50, 256, 4, 4])
torch.Size([50, 128, 8, 8])
torch.Size([50, 128, 7, 7])
torch.Size([50, 64, 11, 11])
torch.Size([50, 1, 28, 28])
torch.Size([50, 1, 28, 28])
torch.Size([50, 4096])
torch.Size([50, 256, 4, 4])
torch.Size([50, 128, 8, 8])
torch.Size([50, 128, 7, 7])
torch.Size([50, 64, 11, 11])
torch.Size([50, 1, 28, 28])
torch.Size([50, 1, 28, 28])
torch.Size([50, 4096])
torch.Size([50, 256, 4, 4])
torch.Size([50, 128, 8, 8])
torch.Size([50, 128, 7, 7])
torch.Size([50, 64, 11, 11])
torch.Size([50, 1, 28, 28])
torch.Size([50, 1, 28, 28])
torch.Size([50, 4096])
torch.Size([50, 256, 4, 4])
torch.Size([50, 128, 8, 8])
torch.Size([50, 128, 7, 7])
torch.Size([50, 64, 11, 11])
torch.Size([50, 1, 28, 28])
torch.Size([50, 1, 28, 28])
2.076378107070923
Discriminator cost: [ nan]
Generator cost: [ nan]
Wasserstein distance: [ nan]
Traceback (most recent call last):
  File "gan_mnist.py", line 145, in <module>
    D_real.backward(m_one)
  File "/home/clu/anaconda3/lib/python3.5/site-packages/torch/autograd/variable.py", line 152, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
  File "/home/clu/anaconda3/lib/python3.5/site-packages/torch/autograd/__init__.py", line 98, in backward
    variables, grad_variables, retain_graph)
RuntimeError: there are no graph nodes that require computing gradients

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants