Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accelerated training with floating point fp16 #7

Closed
milliema opened this issue Jan 10, 2021 · 17 comments
Closed

Accelerated training with floating point fp16 #7

milliema opened this issue Jan 10, 2021 · 17 comments
Labels
question Further information is requested stale

Comments

@milliema
Copy link

Thanks for the work!
I'd like to know if the the original code is also applicable to accelerated training, i.e. using automatic mixed precision like fp16. I tried to adopt SAM in my own training codes with apex fp16, but Nan loss happens and the computed grad norm is Nan. When I switch to fp32, it goes on well. Is it incompatible with fp16? What are the suggestions to make the code work with fp16? Thanks!

@rohitsingh02
Copy link

Same thing is happening with me, unable to use it with AUTOMATIC MIXED PRECISION (Pytorch).

@alexriedel1
Copy link

Hey, I just implemnted AMP in this way and it seems to be working:

#first forward-backward pass
with torch.cuda.amp.autocast():
        preds_first = model(images)
        loss = criterion(preds_first, labels)  # use this loss for any training statistics
        
loss.mean().backward()
optimizer.first_step(zero_grad=True)
        
#second forward-backward pass
with torch.cuda.amp.autocast():
        preds_second = model(images)
        loss_second = criterion(preds_second, labels)
            
loss_second.mean().backward()
optimizer.second_step(zero_grad=True)

@milliema
Copy link
Author

milliema commented Jan 15, 2021

Hey, I just implemnted AMP in this way and it seems to be working:

#first forward-backward pass
with torch.cuda.amp.autocast():
        preds_first = model(images)
        loss = criterion(preds_first, labels)  # use this loss for any training statistics
        
loss.mean().backward()
optimizer.first_step(zero_grad=True)
        
#second forward-backward pass
with torch.cuda.amp.autocast():
        preds_second = model(images)
        loss_second = criterion(preds_second, labels)
            
loss_second.mean().backward()
optimizer.second_step(zero_grad=True)

Thanks for the reply!
I'm using apex fp16 and it's a little different with torch.amp.
For my case, I only included the base optimizer in the apex initialization.

model, optimizer.base_optimizer = amp.initialize(model, optimizer.base_optimizer, opt_level="O1")

As for the backward, I keep 1st step to be as same as before and only use scaled_loss for the 2nd backward.

loss = cal_loss(xx)
loss.backward()
optimizer.first_step(zero_grad=True)
loss = cal_loss(xx)
with amp.scale_loss(loss, optimizer.base_optimizer) as scaled_loss:
     scaled_loss.backward()
optimizer.second_step(zero_grad=True)

It's able to work but I'm not sure whether it's the best solution. If I use scaled loss for the 1st backward, Nan loss always happens.

@alexriedel1
Copy link

Is it also possible to initialize the full optimizer?
Are your model forward outputs fp16 now?

@milliema
Copy link
Author

@alexriedel1 I've tried to initialize amp with optimizer, but it doesn't work.
Amp should affect the step function within optimizer. However, in SAM the optimizer doesn't use step but use first_step/second_step instead. So I guess it's better to init the base optimizer.
I didn't check the forward output. The training speed is almost half of regular training, e.g. 1000 img/s for regular training and 500 img/s for SAM. Is it the case same for you?

@alexriedel1
Copy link

@milliema Yes that's absolutely explainable as SAM needs two backward passes through the network instead of one with a simpel SGD, so it should take double the time to train

@davda54 davda54 added the question Further information is requested label Jan 15, 2021
@stale
Copy link

stale bot commented Feb 5, 2021

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the stale label Feb 5, 2021
@stale stale bot closed this as completed Feb 12, 2021
@jeongHwarr
Copy link

@alexriedel1 Hello. I found your comment while looking for a way to apply amp for the sam optimizer. In the original amp method, I know that the loss is divided by the scale again and then backward. For example, scaler.scale(loss).backward().

Why did you do backwards using the mean of the loss? Is there any problem with this? Is it okay to not use scaler?

@davda54 davda54 reopened this Dec 6, 2021
@stale stale bot removed the stale label Dec 6, 2021
@alexriedel1
Copy link

@alexriedel1 Hello. I found your comment while looking for a way to apply amp for the sam optimizer. In the original amp method, I know that the loss is divided by the scale again and then backward. For example, scaler.scale(loss).backward().

Why did you do backwards using the mean of the loss? Is there any problem with this? Is it okay to not use scaler?

I didn't fully implement the amp method as proposed. I think using the scaler will be no problem.

Reducing the loss to mean is just dependent on your loss function. For example, pytorchs BCE Loss is already implemented with the mean reduction by default.

@jeongHwarr
Copy link

@alexriedel1 Ok, I got it. I think the problem when applying amp to sam is that I cannot use scaler.step(optimizer) or optimizer.first_step(zero_grad=True). I think that the gradient needs to be unscaled when using amp.
This is done through scaler.step(optimizer), but when I use it, I cannot use optimizer.first_step(zero_grad=True).

@maxmatical
Copy link

maxmatical commented Dec 24, 2021

yes, the original solution does not unscale the gradients, which would lead to the scaling factor interfering with the learning rate

if you take a look at scaler.step(optimizer), under the hood it is doing

  1. calling unscale_(optimizer)
  2. calling optimizer.step()

in theory you should be able to run something similar to the example here by doing the following during traning

# first pass
with torch.cuda.amp.autocast():
    out = model(input)
    loss = criterion(out, label)

scaler.scale(loss).backward()
scaler.unscale_(optimizer)
optimizer.first_step(zero_grad=True)
scaler.update()

# 2nd pass
with torch.cuda.amp.autocast():
    out_2 = model(input)
    loss_2 = criterion(out_2, labels)

scaler.scale(loss_2).backward()
scaler.unscale_(optimizer)
optimizer.second_step(zero_grad=True)
scaler.update()

however since you're not calling scaler.step(optimizer) you run the risk of inf/NaN gradients (unless the sam steps already takes care of this)

@stale
Copy link

stale bot commented Jan 15, 2022

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@twmht
Copy link

twmht commented Jul 18, 2022

@maxmatical

is model.no_sync() important in the first backward pass? it seems that you don't have that.

@ahmdtaha
Copy link

ahmdtaha commented Aug 5, 2022

Here are my two cents on this issue.

TLDR: use the following code and be ready to revert to the regular single-step optimization momentarily
I made the following changes to sam inside sam.py

    @torch.no_grad()
    def first_step(self, zero_grad=False, mixed_precision=False):
        with autocast() if mixed_precision else do_nothing_context_mgr():
            grad_norm = self._grad_norm()
            for group in self.param_groups:
                scale = group["rho"] / (grad_norm + 1e-12)

                for p in group["params"]:
                    if p.grad is None:
                        continue
                    self.state[p]["old_p"] = p.data.clone()
                    e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                    p.add_(e_w)  # climb to the local maximum "w + e(w)"

            if zero_grad:
                self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False, mixed_precision=False):
        with autocast() if mixed_precision else do_nothing_context_mgr():
            for group in self.param_groups:
                for p in group["params"]:
                    if p.grad is None:
                        continue
                    p.data = self.state[p]["old_p"]  # get back to "w" from "w + e(w)"

    @torch.no_grad()
    def step(self, closure=None):
        self.base_optimizer.step(closure)

Using this pytorch tutorial, the proposed solution goes as follows

def train(
    args, model, device, train_loader, optimizer, first_step_scaler, second_step_scaler, epoch
):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        enable_running_stats(model)
        # First forward step
        with autocast():
            output = model(data)
            loss = F.nll_loss(output, target)
        first_step_scaler.scale(loss).backward()

        # We unscale manually for two reasons: (1) SAM's first-step adds the gradient
        # to weights directly. So gradient must be unscaled; (2) unscale_ checks if any
        # gradient is inf and updates optimizer_state["found_inf_per_device"] accordingly.
        # We use optimizer_state["found_inf_per_device"] to decide whether to apply
        # SAM's first-step or not.
        first_step_scaler.unscale_(optimizer)

        optimizer_state = first_step_scaler._per_optimizer_states[id(optimizer)]

        # Check if any gradients are inf/nan
        inf_grad_cnt = sum(v.item() for v in optimizer_state["found_inf_per_device"].values())

        if inf_grad_cnt == 0:
            # if valid graident, apply sam_first_step
            optimizer.first_step(zero_grad=True, mixed_precision=True)
            sam_first_step_applied = True
        else:
            # if invalid graident, skip sam and revert to single optimization step
            optimizer.zero_grad()
            sam_first_step_applied = False

        # Update the scaler with no impact on the model (weights or gradient). This update step
        # resets the optimizer_state["found_inf_per_device"]. So, it is applied after computing
        # inf_grad_cnt. Note that zero_grad() has no impact on the update() operation,
        # because update() leverage optimizer_state["found_inf_per_device"]
        first_step_scaler.update()

        disable_running_stats(model)
        # Second forward step
        with autocast():
            output = model(data)
            loss = F.nll_loss(output, target)
        second_step_scaler.scale(loss).backward()

        if sam_first_step_applied:
            # If sam_first_step was applied, apply the 2nd step
            optimizer.second_step(mixed_precision=True)

        second_step_scaler.step(optimizer)
        second_step_scaler.update()

where

base_optimizer = torch.optim.SGD  # define an optimizer for the "sharpness-aware" update
optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9)
first_step_scaler = GradScaler(2 ** 8) # A small scaler_init acts as a warmup
second_step_scaler = GradScaler(2 ** 8)  # A small scaler_init acts as a warmup

How is this tested?

(1) a lot of debugging to make sure the code is doing what is supposed to do, (2) train my model twice: full and mixed precision; then verify both loss curves are similar -- of course, not identical.

What is the main catch?

I found that the network produces NaN predictions during inference while not crashing during training (forward and backward). While the network $f_\theta$ has finite parameters $\theta$, it produces NaN for some -- not all -- inputs. When a network reaches this state (shown in the next figure), SAM's first step (gradient-ascent) always generates NaN/inf gradient which signals instability. Then, of course, SAM's second step also generates NaN/inf gradient. This instability is never observed explicitly during training because PyTorch GradScaler skips gradient-descent whenever gradient is Nan. Accordingly, the network's parameters $\theta$ remain intact despite multiple backpropagation steps.

Mixed-Precision-SAM

To get out of this unstable state, the proposed solution reverts to the regular single-step stochastic gradient-descent momentarily. A gradient-descent step is likely to have valid -- none NaN -- gradient compared to gradient-ascent. This pushes the network's parameters outside the unstable state. It is worthnoting that loss curves are high-dimensional, i.e., my 2D drawing is for illustration purpose only.

Why the network enters this state in the first place?

I don't know. Yet, the network enters this state at an early training stage which signals poor initialization.

One thing I don't like about the proposed solution is that it is verbose. I wish someone propose a concise solution.

I found the following resources helpful while investigating this issue.
[1] https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py
[2] https://pytorch.org/docs/stable/notes/amp_examples.html

@mengjingyouling
Copy link

Hey, I just implemnted AMP in this way and it seems to be working:

#first forward-backward pass
with torch.cuda.amp.autocast():
        preds_first = model(images)
        loss = criterion(preds_first, labels)  # use this loss for any training statistics
        
loss.mean().backward()
optimizer.first_step(zero_grad=True)
        
#second forward-backward pass
with torch.cuda.amp.autocast():
        preds_second = model(images)
        loss_second = criterion(preds_second, labels)
            
loss_second.mean().backward()
optimizer.second_step(zero_grad=True)

Can it work well?

@alibalapour
Copy link

alibalapour commented Oct 22, 2022

@ahmdtaha
Do you have any idea about how to implement gradient accumulation in your code?

@rtxbae
Copy link

rtxbae commented Dec 1, 2022

@alibalapour @ahmdtaha have you found out how to implement the gradient accumulation in the code?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested stale
Projects
None yet
Development

No branches or pull requests