Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZeroDivisionError in backward #26

Closed
furybubu opened this issue Jul 6, 2018 · 6 comments
Closed

ZeroDivisionError in backward #26

furybubu opened this issue Jul 6, 2018 · 6 comments

Comments

@furybubu
Copy link

furybubu commented Jul 6, 2018

Hi, I am having an error when I implement the amp procedure on a working CNN like this:

self.optimizer.zero_grad()

        outputs = self.model(maps)

        loss = self.criterion(outputs,labels.float())

        #add automatic mixed precision support from apex
        with self.amp_handle.scale_loss(loss, self.optimizer) as scaled_loss:

            scaled_loss.backward()

        self.optimizer.step()

`
And here is the error I get:

scaled_loss.backward() File "/usr/lib/python3.5/contextlib.py", line 66, in __exit__ next(self.gen) File "/usr/local/lib/python3.5/dist-packages/apex-0.1-py3.5-linux-x86_64.egg/apex/amp/handle.py", line 53, in scale_loss optimizer.param_groups, loss_scale) File "/usr/local/lib/python3.5/dist-packages/apex-0.1-py3.5-linux-x86_64.egg/apex/amp/scaler.py", line 21, in unscale_and_update 1. / scale, ZeroDivisionError: float division by zero

Any suggestion would be appreciated.

@mcarilli
Copy link
Contributor

mcarilli commented Jul 6, 2018

@carlc-nv is the primary developer of Amp, I've let him know.

@cbcase
Copy link
Contributor

cbcase commented Jul 6, 2018

Hi @furybubu, thanks for reporting this issue.

A couple questions:

  • When you say "working CNN," you mean that the model trains acceptably in fp32?
  • Could you share a little more about the details of the model, optimizer, and dataset?
  • If you are OK sharing more about the model, could you do the following:
    • add verbose=True to the amp init call (ie: amp_handle = amp.init(verbose=True, ...))
    • Run one iteration of the model
    • Share the output

That last step will log exactly what casts amp is inserting into the model.

The specific issue you are observing is that the "fp16 loss scale" is becoming increasingly small until it becomes zero. This suggests to me there is a different fp16-related issue, since the loss scale decreases only when there is an inf or a NaN in the gradient -- and that should not happen for many iterations in a row (which it would have to for the loss scale to get all the way to zero).

@furybubu
Copy link
Author

furybubu commented Jul 6, 2018

HI @cbcase ,
Yes, by "working CNN" I meant a CNN that uses fp32 and not mixed precision data. I cannot share much info about my model but I will try to give you as much as I can:
My model has 5 conv layers interspersed with maxpool layers and a couple of fully connected layers at the end. I use the Adam optimizer, nothing too fancy.
The dataset is pretty large, about a million occurrences of volumetric data, I have a batch size of 20 that I split among 2 gpus (10 each) with DataParrallel. My model trains beautifully when I do not enable the mixed precision training.

I will try to rerun it with the verbose flag to see if I get more clues in the output.
Thanks!

@furybubu
Copy link
Author

furybubu commented Jul 6, 2018

So I basically get things like this:
Float->Half (conv3d) Float->Half (conv3d) Float->Half (conv3d) Float->Half (conv3d) Float->Half (conv3d) Float->Half (conv3d) Float->Half (conv3d) Float->Half (conv3d) Float->Half (conv3d) Float->Half (conv3d) Float->Half (conv3d) Float->Half (conv3d) Float->Half (conv3d) Float->Half (conv3d) Float->Half (conv3d) Float->Half (conv3d) Float->Half (conv3d) Float->Half (conv3d) Float->Half (conv3d) Float->Half (linear) Float->Half (conv3d) Float->Half (linear) Float->Half (conv3d) Float->Half (conv3d) Float->Half (linear) Float->Half (linear) Float->Half (linear) Float->Half (linear) Float->Half (linear) Float->Half (linear) Float->Half (linear) Float->Half (linear) Float->Half (linear) Float->Half (linear) Half->Float (mse_loss)
And then same error.

@cbcase
Copy link
Contributor

cbcase commented Jul 11, 2018

Hi @furybubu,

I'm looking at adding better debugging support for when there are mixed precision issues. If you're interested in being a guinea pig, I've pushed work-in-progress changes to this branch: https://github.com/NVIDIA/apex/tree/amp_debug. You can check it out and install it in the usual way.

Right now, there's just one function handle.run_debug(model, loss_fn) that will print out a "debug report" of sorts. The input arguments are:

  • model: your PyTorch model Module
  • loss_fn: a function that, when invoked, will return the loss on a fixed input / output pair

Here's what that looks like in practice:

Sample original code:

data, target = load_data() # However you load data
output = model(data)
loss = criterion(output, model)
...

To run debug:

data, target = load_data()
def loss_fn():
    output = model(data)
    return criterion(output, model)
handle.run_debug(model, loss_fn)

The debug script will do three things:

  1. Run forward / backward in mixed precision (without any loss scale) and print out any observed inf / nan values and in which module they occur. I believe this can help us diagnose the issue you are seeing.
  2. Print the gradient norm and absolute max value for each model parameter. This is probably not so useful in your case, though it may make it easier to interpret where the overflow values occur.
  3. Find the largest possible loss scale without overflow and compare the gradients computed in fp32 and with mixed precision. This can help identify bugs in mixed precision code (ie, apex).

Let us know if you're able to try this out and what you learn! In particular, I would be interested to hear:

  • What modules / parameter names do you see overflows on the first iteration?
  • Same thing, but take a model that has been trained in fp32 for a bit (so the parameters are no longer at their initial values)

@yuribd
Copy link

yuribd commented Jan 28, 2019

HI all!

Wonder if some other ppl reported similar issue and what was the solution?
I observe the same issue in my case (see below) . At the same time using FP16_Optimizer with dynamic_loss_scale=True works just fine

     36             if p.grad is not None:
     37                 self._has_overflow = scale_check_overflow(p.grad.data,
---> 38                                                           1. / scale)
     39             if self._has_overflow:
     40                 break

ZeroDivisionError: float division by zero

that's using approach suggested here

it reduces scale gradually from 2^15 to 8 and then breaks

Overflowed with loss scale 16384.0.  Reducing loss scale and replaying
Overflowed with loss scale 8192.0.  Reducing loss scale and replaying
Overflowed with loss scale 4096.0.  Reducing loss scale and replaying
Overflowed with loss scale 2048.0.  Reducing loss scale and replaying
Overflowed with loss scale 1024.0.  Reducing loss scale and replaying
Overflowed with loss scale 512.0.  Reducing loss scale and replaying
Overflowed with loss scale 256.0.  Reducing loss scale and replaying
Overflowed with loss scale 128.0.  Reducing loss scale and replaying
Overflowed with loss scale 64.0.  Reducing loss scale and replaying
Overflowed with loss scale 32.0.  Reducing loss scale and replaying
Overflowed with loss scale 16.0.  Reducing loss scale and replaying
Overflowed with loss scale 8.0.  Reducing loss scale and replaying```

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants