-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add gradient accumulation support #2049
Conversation
Run to fix issues.
|
Would it make sense to:
(relevant to #2003, which won't fit on smaller GPUs) |
I think it's better to let user decide batchsize and coresponding learning rate. |
nerfstudio/engine/trainer.py
Outdated
loss = functools.reduce(torch.add, loss_dict.values()) | ||
self.grad_scaler.scale(loss).backward() # type: ignore | ||
internal_step = 0 | ||
while True: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be clearer with a for loop, since internal_step is unnecessary except for counting iterations
I think memory budgeting is a larger PR in itself, since going from batch size->memory consumption is a pretty nontrivial thing. I'd be in favor of merging in grad accumulation and thinking about how to automatically set these parameters later. |
Here are two experients which use num rays of batch 8192 and accumulate two 4096 rays step. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks!
* add gradient accumulation support * fix 2 blank lines * fix possibly unbound variable * update gradient accumulation step with for loop * add seert and ignore pyright check --------- Co-authored-by: Zhang Jian <zhangjian49@lenovo.com>
This modification has no hurt to current repo. The train step number is same as before, so optimizer and scheduler works on old way. Users can easily increase batch size since nerf network has no batchnorm layer.