Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
vwxyzjn
reviewed
Jun 6, 2023
Contributor
There was a problem hiding this comment.
The PR makes sense! Would it be possible to verify that it in fact can reproduce the exact same gradient with different number of gradient accumulation steps?
Here is a snippet:
import torch
from torch.utils.data import TensorDataset, DataLoader
import copy
# seed
torch.manual_seed(0)
# define toy inputs and labels
x = torch.tensor([1., 2., 3., 4.])
y = torch.tensor([2., 4., 6., 8.])
# define dataset and dataloader
dataset = TensorDataset(x, y)
dataloader = DataLoader(dataset, batch_size=2)
# define model, optimizer and loss function
model = torch.nn.Linear(1, 1)
# clone the model
model_clone = copy.deepcopy(model)
criterion = torch.nn.MSELoss()
accumulation_steps = 2
# loop over batches
for i, (inputs, labels) in enumerate(dataloader):
# reshape inputs and labels
inputs = inputs.view(-1, 1)
labels = labels.view(-1, 1)
# forward pass
outputs = model(inputs)
loss = criterion(outputs, labels) / accumulation_steps
# backward pass
loss.backward()
# check if accumulation is done
if (i + 1) % accumulation_steps == 0:
print("w/ accumulation, the final model grad is", model.weight.grad)
break
loss = criterion(model_clone(x.view(-1, 1)), y.view(-1, 1))
loss.backward()
print("w/o accumulation, the final model grad is", model_clone.weight.grad)w/ accumulation, the final model grad is tensor([[-27.4301]])
w/o accumulation, the final model grad is tensor([[-27.4301]])
younesbelkada
commented
Jun 6, 2023
Contributor
Author
younesbelkada
left a comment
There was a problem hiding this comment.
Yes definitelty! Will add a test for that, thanks a lot for the snippet and the pointer!
vwxyzjn
approved these changes
Jun 14, 2023
Contributor
vwxyzjn
left a comment
There was a problem hiding this comment.
LGTM! Thanks so much @younesbelkada
yxliu-TAMU
pushed a commit
to mincheolseong/ECEN743-GRPO-Project-Proposal
that referenced
this pull request
Apr 20, 2025
* add correct grad acc * add some tests but they fail * test should pass * style * fix
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
This PR correctly fixes the gradient accumulation, firstly introduced in #220
Fixes #321
cc @lvwerra