Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement_gsam_jax #4

Closed
wants to merge 0 commits into from
Closed

implement_gsam_jax #4

wants to merge 0 commits into from

Conversation

juntang-zhuang
Copy link
Contributor

Implement GSAM algorithm proposed in Surrogate gap minimization improves sharpness-aware training, ICLR 2022, which is an improvement over SAM (Sharpness-Aware Minimization)

When config.rho_max == config.rho_min and config.alpha=0.0, the GSAM algorithm reduces to SAM.

@akolesnikoff
Copy link
Collaborator

Hi,

Thank you for contribution. As stated in the readme, we normally do not accept external contributions, but we are happy to make an exception for open-source implementations of published projects developed in big_vision.

However, according to the codebase principles, project-specific code should not add complexity to the core library parts, such as the main train loop. Thus, standalone projects are expected to fork the main train loop into big_vision/trainers/proj/<project name>/... and apply necessary modifications there. We plan to submit an example of how this works soon (~2 weeks from now). Maybe you wait for the example, and then update this pull request accordingly?

@juntang-zhuang
Copy link
Contributor Author

Thanks a lot for the clarification! I will re-format and re-submit later according to the examples.

@lucasb-eyer
Copy link
Collaborator

hey, we now have an example of a project-specific trainer here: https://github.com/google-research/big_vision/tree/main/big_vision/trainers/proj/distill

If you are still interested in submitting gsam (we would like it!), could you sync to head and instead of modifying the core train.py, fork it into trainers/proj/gsam/train.py and do the modifications there?

Sorry for the delay on our side!

@juntang-zhuang
Copy link
Contributor Author

Thanks a lot for the example! I have moved all changes to big_vision/trainers/proj/gsam, please let me know if it looks good.

Copy link
Collaborator

@lucasb-eyer lucasb-eyer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your patience!

Would it be possible to add an example config? Ideally, one which produces some reference run from the paper. It would live in configs/proj/gsam/whatever.py? You would probably fork it off https://github.com/google-research/big_vision/blob/main/big_vision/configs/vit_i1k.py.

Also, in an ideal world, you would actually run this config, and show that it matches a number in the paper, and link the result here or at the top of the config, is that still possible, or you can't do that anymore?

# limitations under the License.

"""Training loop example.
This is a basic variant of a training loop, good starting point for fancy ones.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should probably update this to something like "Trainer that implements SAM/GSAM optimizers"?


if config.get("GSAM", False):
# Get the current learning rate.
learning_rate = sched_fns_cpu[0](step)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I highly doubt this is what you want. Note that you're calling a function that's been jit'ed onto the CPU from within a function that's pmap'ed onto GPU/TPU, so we have transfer at every single step happening here.

Why not call sched_fn[0](step) instead?

big_vision/trainers/proj/gsam/train.py Outdated Show resolved Hide resolved
return getattr(u, config.get("loss", "sigmoid_xent"))(
logits=logits, labels=labels)

if config.get("GSAM", False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since here we're specifically in the gsam/train.py, we can remove this config variable and if statement, and always execute the GSAM branch.


ALPHA = config.get("alpha", 0.05)
ADAPTIVE_PERTURBATION = config.get("adaptive_perturbation", False)
MINIMIZE_FP = config.get("minimize_fp", True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each of these is actually only used exactly once below, so in our code style, we would not assign them to any variable, but just inline them where they are used, see comment below.

a - g_clean_projection_norm * b, g_clean, g_robust_normalized)

# Get GSAM gradient.
g_gsam = jax.tree_multimap( lambda a, b: a - b * alpha,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an awkward space that slipped in front of lambda, please remove.

a - g_robust_projection_norm * b, g_robust, g_clean_normalized)

# Get GSAM gradient.
g_gsam = jax.tree_multimap( lambda a, b: a + b * alpha,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same awkward space here.

# Per-worker perturbation.
if adaptive_perturbation:
param_sam = jax.tree_multimap(lambda a, b: a + jnp.abs(a) * sam_rho * b / (g_clean_length + eps),
base_opt.target, g_clean)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Misaligned line continuation

base_opt.target, g_clean)
else:
param_sam = jax.tree_multimap(lambda a, b: a + sam_rho * b / (g_clean_length + eps),
base_opt.target, g_clean)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

if lr_max == lr_min:
sam_rho = rho_max
else:
sam_rho = rho_min + (rho_max - rho_min) * (lr - lr_min) / (lr_max - lr_min)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes me wonder (sorry I haven't read the GSAM paper), do you really want to linearly interpolate rho, or would you ideally want to apply the same scheduling function as the learning-rate, e.g. cosine for example?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the confusion. I want to apply the same scheduler but with a different scale / upper_lower bound.
In the paper I only used linear lr scheduler for experiments, and in theory (and proofs part of paper) the two schedules are assumed to be both of inverse sqrt.

@lucasb-eyer
Copy link
Collaborator

Also, once you're done, could you squash all the commits into just a single one?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants