Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement gsam in jax #8

Merged
merged 18 commits into from
Aug 19, 2022
Merged

implement gsam in jax #8

merged 18 commits into from
Aug 19, 2022

Conversation

juntang-zhuang
Copy link
Contributor

@juntang-zhuang juntang-zhuang commented Jul 16, 2022

Hi, @lucasb-eyer thanks for your review and comments. I reformated the files and squashed commits into a new PR (sorry I messed up the old PR and could not squash commits there). This PR includes:

  1. Put GSAM related configs into config.gsam and call gsam with l, grads = gsam_gradient(loss_fn=loss_fn, base_opt=opt, inputs=images, targets=labels, lr=learning_rate, **config["gsam"])
  2. Add big_vision/configs/proj/gsam/vit_1k_gsam_no_aug.py, the network used in GSAM paper used pool_type='gap' and rep_size=False, which is different from the default config.
  3. Fix format issues and squash commits.

Regarding reproducing the experiments, I wonder if it's possible for you to run the script (with 8x8 TPU cores to exactly match the paper)? I'm sorry I don't have access to TPU resources since I'm not affiliated with Google now, so I can't run experiments, though the checkpoints and the old version code that I used were kept in server. Thanks so much for your code review and help!

Copy link
Collaborator

@lucasb-eyer lucasb-eyer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing the comments. There's a few new ones regarding the config and my understanding of the schedule.

if lr_max == lr_min:
sam_rho = rho_max
else:
sam_rho = rho_min + (rho_max - rho_min) * (lr - lr_min) / (lr_max - lr_min)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From #4:

Lucas:

This makes me wonder (sorry I haven't read the GSAM paper), do you really want to linearly interpolate rho, or would you ideally want to apply the same scheduling function as the learning-rate, e.g. cosine for example?

Juntang:

Sorry for the confusion. I want to apply the same scheduler but with a different scale / upper_lower bound.
In the paper I only used linear lr scheduler for experiments, and in theory (and proofs part of paper) the two schedules are assumed to be both of inverse sqrt.

Ah this is really unfortunate, there should be a much cleaner way to implement this eg using a squashed version of sched_fns from the trainer!
But if you don't want to change the code to do this, then you should put an assert config.schedule.decay_type == "linear", "GSAM only implemented for linear lr schedule" into the train.py and add a little comment here in the code that goes something like

# Ideally, we'd use the same schedule as the lr here, just stretched to a different min/max.
# However, here we hard-code the linear scheduler only for convenience.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, sorry I did not explain this clearly. Suppose learning rate is lr(t) for step t, and there's an effective rho(t) for each step t. The code restricts rho(t) to be linear w.r.t lr(t), however rho(t) is not linear w.r.t t. If we change lr(t) to be some non-linear schedule such as cosine, the code here will generate a rho(t) also in the shape of cosine, except lr_max != rho_max and lr_min != rho_min.

I tried to use a separate sched_fn for rho(t), but it seems some schedules such as cosine does not have the option to specify a non-zero min value rho_min.

I wonder if you have any suggestions for a neater version using sched_fn with configurable min value, or we keep the schedule code here?

big_vision/trainers/proj/gsam/train.py Show resolved Hide resolved
big_vision/configs/proj/gsam/vit_1k_gsam_no_aug.py Outdated Show resolved Hide resolved
config.wd = 0.3 # default is 0.0001; paper used 0.3, effective wd=0.3*lr
config.schedule = dict(
warmup_steps=10_000,
decay_type='linear',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe append a short inline comment # only linear supported

# config.optax = dict(beta2_cap=0.95)

config.lr = 0.003
config.wd = 0.3 # default is 0.0001; paper used 0.3, effective wd=0.3*lr
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand you correctly, this is actually not correct anymore. We changed the code to always use "decoupled" values now. So you should specify here the effective wd you want, which is independent of the lr value (eg I think you want 0.001 here? as in 0.3 * 0.003 ≈ 0.001?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing it out. Since the old version code uses lr * wd as the effective wd, and lr changes with a schedule, the effective wd also has a schedule. Switching to the new configuration, is effective wd schedule available? I'm concerned if the effective wd schedule is disabled, using the same hyper-param might not be able to reproduce.

@lucasb-eyer
Copy link
Collaborator

Regarding running experiments, I could give it a try at some point, but definitely impossible to do so this week. I would run exactly the config you provide, and you need to tell me exactly which number in which table of the paper it's supposed to reproduce.

@juntang-zhuang
Copy link
Contributor Author

Regarding running experiments, I could give it a try at some point, but definitely impossible to do so this week. I would run exactly the config you provide, and you need to tell me exactly which number in which table of the paper it's supposed to reproduce.

Thanks a lot! If the effective wd schedule is not figured out, I might need to find some way to either implement the old versioned weight decay schedule, or tune the hyper-param with the new setting. I wonder if you could point Ting to the docs on how to run this repository internally, and I'll submit codes from external, so we could re-run some experiments to reproduce?

@lucasb-eyer
Copy link
Collaborator

lucasb-eyer commented Aug 5, 2022

hey, sorry I got distracted by something urgent to finish, will get back to this in one of the next two weeks and am optimistic we can get it to work well :)

edit: however, you did not yet tell me which exact number from the paper the config should be reproducing?

@juntang-zhuang
Copy link
Contributor Author

Thanks for the response. Sorry about the missing number, it's supposed to reproduce the 76.8 for ViT-B/32 in Table 1 of https://openreview.net/pdf?id=edONMAnhLu- .

I'm not fully sure about the new wdecay and lr scheduler. In the old version, lr scheduler is a single function (here lr scheduler func seems to be chained with a bunch of other schedulers); in the old version, wdecay is multiplied by lr, so wdecay is actually a scheduler rather than constant, is the new wdecay set to a constant?

Copy link
Collaborator

@lucasb-eyer lucasb-eyer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi again, I am ow giving it a try, and there were a few more issues remaining. I have written them as comments, as well as given instructions on how to fix them. I am now able to actually run the trainer and the config, and will train it over night and see if it already reproduces the result or not.

I'll try a couple weight decay values to see what's the right one, but FYI, the weight decay is still following the schedule of the lr in the new code (linear decay in this case), it's just that the base lr is not multiplied to it.


def get_config(arg=None):
"""Config for training."""
arg = bvcc.parse_arg(arg, variant='B/32', runlocal=False, aug='')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aug='' is not used anymore and should be removed.

This configuration makes use of the "arg" to get_config to select which model
to run, so a few examples are given below:

Run training of a B/16 model:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of these example commands need to be updated to this config file.

rho_min=0.1,
alpha=0.6,
adaptive_perturbation=False,
minimize_fp=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we need to add two more parameters:

lr_max=config.get_ref('lr'),
lr_min=config.schedule.get_ref('linear_end'),

opt_cpu = jax.jit(tx.init, backend="cpu")(params_cpu)
sched_fns_cpu = [jax.jit(sched_fn, backend="cpu") for sched_fn in sched_fns]

@partial(jax.pmap, axis_name="batch", donate_argnums=(0, 1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to add , static_broadcasted_argnums=(5,)) here or this will not work: step is a scalar, so we need to tell pmap that, or it expects it to be replicated. So the final line should look like:

  @partial(jax.pmap, axis_name="batch", donate_argnums=(0, 1),
           static_broadcasted_argnums=(5,))
  def update_fn...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait no that's not what we should do, or it will recompile a new function every step 😅 Instead, we should indeed replicate the step we're passing, for example by passing flax.jax_utils.replicate(step) at call-site.

However, this is creating a synchronization point, blocks prefetching, and creates a transfer at each step. Instead, we should really use the step number which is already replicated inside the optimizer. I'll find out how exactly tomorrow.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I lost track of this. What we need to do is to not pass any step at all to the function, but instead get the step like this, around line 208:

step = bv_optax.get_count(opt)
learning_rate = schd_fns[0](step) * config.lr

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, it turns out there's a minor issue with get_count so that it can't be called inside a compiled function. I have a fix for it, but let's not roll too much into this PR, you could leave this as it is currently, and I'll fix it myself after the PR is merged.

Get the GSAM gradient (https://openreview.net/pdf?id=edONMAnhLu-) of the loss function.
Args:
loss_fn: the loss function.
base_opt: the base optimizer.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1/2] This (base_opt.target used below) does not work anymore with optax. Although it looks like you really use base_opt only for getting to the params, so you can replace the argument by an actual params argument, and then use that everywhere where you currently use base_opt.target in this function.

logits=logits, labels=labels)

learning_rate = sched_fns[0](step)
l, grads = gsam_gradient(loss_fn=loss_fn, base_opt=opt, inputs=images, targets=labels,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[2/2] and then here you would pass params=params instead of base_opt=opt.

@lucasb-eyer
Copy link
Collaborator

oh, and you have a bunch of small issues like wrong indentations, trailing spaces, etc. It would be helpful if you could run pylint with this config over it, then I don't need to fix these later on.

@lucasb-eyer
Copy link
Collaborator

and another minor nitpick: could you rename the config from ...1k... to ...i1k...? Because we never call ImageNet 1k, but always i1k in the whole codebase. I assume you made a typo.


# Per-worker perturbation.
if adaptive_perturbation:
param_sam = jax.tree_multimap(lambda a, b: a + jnp.abs(a) * sam_rho * b / (g_clean_length + eps),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax.tree_multimap does not exist anymore. It's now just jax.tree_map.

@lucasb-eyer
Copy link
Collaborator

Here is training_loss of running this config, sweeping over wd=0.0009 (=0.3*0.003, should be exact same as in paper), 0.001 (nicer number close to previous one), and 0.3 (just in case). The loss is crazy, accuracy is and stays at random (not shown):
image

However, I find the fact that it starts at 693.15, roughly 100x the standard starting-loss of i1k (log1000=6.907) somewhat suspicious. I noticed the config is using sigmoid_xent loss, your paper does not mention the words "softmax" or "sigmoid" ; could it be that you trained with softmax_xent and have sigmoid_xent here in the config by mistake? I'll try a run with that instead, but please take another careful read over the config and see if you can find other sources of this.

Another thing, the config does not contain the config.init_head_bias, which we often, but not always, use. Could this also be a mistake? (I'll also schedule an experiment about this).

@juntang-zhuang
Copy link
Contributor Author

juntang-zhuang commented Aug 9, 2022

Thanks a lot for the experiments, seems the config is not correct. I'll discuss it with Ting and see if we can directly compare the config file with the one we used for experiments.

@lucasb-eyer
Copy link
Collaborator

So far, no luck with any of (sigmoid->softmax, head-bias init, ) made it any better.

Then, I also tried the follwing things:

  1. Disable weight-decay altogether, to check whether I can at least overfit. Nope, still an exploding loss, so the issue seems unrelated to wd(?)
  2. Model with cls-token and mlp-head (repr_size=True), as this was original vit. A complete disaster :)

So, I tried all the ideas I had regarding configuration, and at this point wonder if maybe there's a bug in the implementation. Could you please try on your side? Note that you don't need TPU access to run big_vision, it works great on GPUs too, we did update the README with instructions about that. Let me know when you figure out a setting/code change such that the loss does not explode in the first hundreds of steps anymore, and I can then try longer runs for you again. (I'll also ping Ting my runs internally).

@lucasb-eyer
Copy link
Collaborator

I forgot to mention, but I also tried a run with adam 1t momentum not in bfloat16, but in regular float32, and it makes no difference. Note this bfloat16 really just affects the 1st momentum buffer, nothing else.

@lucasb-eyer
Copy link
Collaborator

Ting shared with me your exact runs from the paper numbers, so I could dig in a bit more. Carefully replicating exactly the config that was run, I still get similar behaviour, though slightly less extreme ("only" going up to hundreds, not millions):
image

At this point, I feel like this must be a bug in the code. It seems to go wrong after ~500 steps, potentially you can even run that on CPUs to debug?

@juntang-zhuang
Copy link
Contributor Author

Thanks a lot for the feedback and experiments, I'll dig it out with Ting, and will post the working version here. Sorry for all the trouble with this PR.

@lucasb-eyer
Copy link
Collaborator

Sorry for all the trouble with this PR

No worries, I will be happy and thankful to have up-to-date GSAM and SAM in the codebase!

@evcu
Copy link

evcu commented Aug 18, 2022

I also tried to run this with alpha=0, and it looks slightly better at the start, but still explodes after 1-2k step.

@lucasb-eyer
Copy link
Collaborator

I just noticed in one of your changes a few days ago, you did find a bug:

    learning_rate = sched_fns[0](step)   # Wrong
    learning_rate = sched_fns[0](step) * config["lr"]   # Your fix

This looks very promising! So I patched it in and tried another run on top of the last one I mentioned here. It looks a lot better! It doesn't explode, and reaches 75.2/81.8/61.0 validation/real/v2 accuracy after 90 epochs. This not yet the expected 76.8/82.7/63.0 we're trying to reproduce, but it's getting much closer 🥳

However, the missing 1.6% are still significant, so we should find them before merging this. I carefully compared configs (already before, but once again) and didn't find a new discrepancy.
With alpha=0 I should get SAM, right? Were the SAM and Vanilla numbers in Table1 also produced by you, or copied from somewhere? If produced by you, I could also run SAM and Vanilla and see if I can reproduce them, it would give us an indication where the remaining mistake can be.

Here are a few metrics, do they all look reasonable to you?
image

@juntang-zhuang
Copy link
Contributor Author

juntang-zhuang commented Aug 19, 2022

@lucasb-eyer Thanks so much for running experiments! I'm also running an experiment on ViT-S/32, but takes much longer on my GPU machine, will also post results here after it finishes.

The results for SAM are copied from https://arxiv.org/abs/2106.01548 table 2. For the gap of 1.6%, it might come from

  • in the paper it trains for 300 epochs (here's 90) for ViT,
  • a bug related to point 2 below
  • I used 8x8 TPU cores for most experiments, for SAM-family a larger TPU core number typically increases performance.

In previous updates, I made a few changes that potentially make a difference, including the following:

  1. pass the absolute learning rate learning_rate = sched_fns[0](step) * config["lr"] instead of learning_rate = sched_fns[0] (step)
  2. in config.gsam sets absolute values to lr_max=config.get_ref('lr') and lr_min=config.schedule.get_ref('linear_end') * config.get_ref('lr')
  3. in config.schedule set linear_end=0.01 (rather than linear_end=0.00003)
  4. pass flax.jax_utils.replicate(step) when calling update_fn

(I'm not sure if 4 is necessary, just following my old code after meeting with Ting.)

For 1, it's my fault that I did not realize bv_optax defines the learning rate schedule in a relative manner, while all my code last year assumes the lr are all absolute values. This causes a bug in my previous PR, that I passed absolute lr to denominator, but relative lr to the denominator, which results in about 300x larger perturbation amplitude. Such a big perturbation would crash the network. In current version this should be fixed.

For 2 and 3, it's also caused by my mistake with lr schedule. To reproduce the paper results, the absolute learning rate is a linear decay with max_lr=0.003 and min_lr=0.00003. Switching to the relative ratio schedule, should be linear_end=0.01.

I have merged the changes above in the latest PR, let me know if you have time to take a look. I'm also reproducing a ViT-S/32 results with my machine, it's a bit slow but will post it here once I get results. Thanks again for your help with this!

@lucasb-eyer
Copy link
Collaborator

No need to blame yourself alone, I also should have noticed ALL of these during review and testing, but didn't :) Happy you found them now! Let me start some runs right away, for 300ep, and report back later today.

I actually ran all experiments on 8x8, but am curious why TPU topology would influence the results?

Copy link
Collaborator

@lucasb-eyer lucasb-eyer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have good news. Running for 300ep largely closes the remaining gap. Here are my results:

setting wd val real v2
your paper 0.0009 76.8 82.7 63.0
gsam 0.0009 77.18 82.77 63.24
gsam 0.001 77.35 83.04 64.03
gsam (a=0) 0.0009 76.02 81.56 62.31
sam (a=0, rho=0.15) 0.0009 75.56 81.12 60.97
sam for vit/mixer paper 0.0009 73.6 80.3 60.0

I am relatively sure wd=0.0009 is what you ran, but back then it was expressed differently in our configs, and the number you used was prettier. So I also ran 0.001 which is very close and a pretty number too =)

I only left a few more small comments about to code to address, and after that we can merge!

Note: we have further refactored the code a little bit since, but it is fine for you to submit the code as-is, and I will cleanup/update and test once more on my side afterwards, you've done more than enough already!
image

rho_min=0.1,
alpha=0.6,
adaptive_perturbation=False,
minimize_fp=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those two (adaptive_perturbation and minimize_fp) are set to their default values. From the doc-comment and paper, it does not seem like something a regular user would tune (contrary to rho and alpha), so let's remove them fromt he config?

perturbation is element-wise multiplied by abs(p).
minimize_fp: if True, min(f_p, h), original GSAM;
if False, min(f, h), where f is the clean loss.
f_p is the perturbed loss, h is the surrogate gap.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The doc comments of both adaptive_perturbation and minimize_fp both explain what they do in very technical terms, but it would be good to have a short high-level recommendation at the end as to when or why one would want to change them.

For example (the example is clearly wrong, because I don't understand them, but just to show the spirit of what I'm looking for):

    adaptive_perturbation: if False, same perturbation as SAM,
        treat all parameters as a single vector,
        perturbation norm is calculated as the norm of the whole vector;
        if True, for each parameter tensor p,
        perturbation is element-wise multiplied by abs(p).
        Try setting this to False when you use least-squares loss instead of KL-based ones.
    minimize_fp: if True, min(f_p, h), original GSAM;
        if False, min(f, h), where f is the clean loss.
        f_p is the perturbed loss, h is the surrogate gap.
        You probably want to leave this at its default unless you know what you're doing.

opt_cpu = jax.jit(tx.init, backend="cpu")(params_cpu)
sched_fns_cpu = [jax.jit(sched_fn, backend="cpu") for sched_fn in sched_fns]

@partial(jax.pmap, axis_name="batch", donate_argnums=(0, 1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I lost track of this. What we need to do is to not pass any step at all to the function, but instead get the step like this, around line 208:

step = bv_optax.get_count(opt)
learning_rate = schd_fns[0](step) * config.lr

return getattr(u, config.get("loss", "sigmoid_xent"))(
logits=logits, labels=labels)

learning_rate = sched_fns[0](step) * config["lr"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a ConfiDict, it can be the slightly nicer config.lr.


learning_rate = sched_fns[0](step) * config["lr"]
l, grads = gsam_gradient(loss_fn=loss_fn, params=params, inputs=images,
targets=labels, lr=learning_rate, **config["gsam"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, slightly simpler **config.gsam

opt_cpu = jax.jit(tx.init, backend="cpu")(params_cpu)
sched_fns_cpu = [jax.jit(sched_fn, backend="cpu") for sched_fn in sched_fns]

@partial(jax.pmap, axis_name="batch", donate_argnums=(0, 1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, it turns out there's a minor issue with get_count so that it can't be called inside a compiled function. I have a fix for it, but let's not roll too much into this PR, you could leave this as it is currently, and I'll fix it myself after the PR is merged.

@juntang-zhuang
Copy link
Contributor Author

juntang-zhuang commented Aug 19, 2022

Cool, I'm really excited to see the updated results, they outperform numbers in the paper!
I have updated PR according to your comments, except the step is passed to update_fn rather than read out from opt.

One minor thing is, GSAM reduces to SAM requires alpha=0 and rho_max=rho_min in the gsam_gradient function, basically SAM uses a constant perturbation rho_t, GSAM scales rho_t proportional to learning rate schedule. It might not be a good idea to set constant by setting rho_max=rho_min, maybe using a bv_optax style schedule function is a better idea for code style consistency.

For TPU number, it's because that GSAM / SAM performs per-worker perturbation based on per-worker gradient in gsam_gradient, more workers will have more different perturbations, so the model effectively see more neighbors in the parameter space.

@lucasb-eyer
Copy link
Collaborator

Thanks for your comments. My "SAM" run with rho_max=rho_min=0.15 just finished, and it's quite a bit better than the paper number too. From my reading of the code, when rho_max=rho_min then we do use a constant rho value independent of learning-rate (schedule), no?
image

And yes, making it use the actual schedule_fn from optax would be ideal, then we could simply use SAM with all kinds of schedules, and we don't need to manually specify lr_min/lr_max in the config anymore. That would be a lot better, but I thought that I already asked a lot from you, so didn't want to ask for that too :) If you want to do it, that's great, otherwise I may do it at some point, or maybe never, if we never need it. But this is the largest argument against having it in the core trainer for now.

@lucasb-eyer
Copy link
Collaborator

Regarding the perturbations per host, I noticed that the model souping paper states that not syncing may have a significant disadvantage:
image

so it may be worth implementing. Do I understand correctly that it basically means doing jax.lax.pmean(g_clean)?

Copy link
Collaborator

@lucasb-eyer lucasb-eyer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your patience overall!
I'll merge it now, and will update the trainer according to the latest refactors early next week, such that it actually works :)

@lucasb-eyer lucasb-eyer merged commit 136deda into google-research:main Aug 19, 2022
@lucasb-eyer
Copy link
Collaborator

I also just realized that we should add a pointer to this from the README. I'll do so early next week too.

@juntang-zhuang
Copy link
Contributor Author

Thanks so much for your help with the debug and PR!

Regarding the rho_t schedule, yes it is constant when rho_max=rho_min, I implemented it in a way that rho_t follows the same schedule as lr_t (except they have difference value scales). It might be better to pass rho_t as another sched_fn, but I'm not familiar with the chain style fn in bv_optax, so I'm not confident to implement correctly and matching the existing code base.

For per-worker perturbation, the model soup paper seems to contradict the original SAM paper https://arxiv.org/pdf/2010.01412.pdf section 4.1. It defines m-sharpness where m is the per-worker number of examples. A smaller m (hence a larger worker number when total batchsize is fixed) improves generalization.

I'm not quite sure about model soup implementations. In my implementation (and SAM), the process is:

  1. per-worker gradient g_clean (not synced) and per-worker perturbation param_sam
    param_sam = jax.tree_map(lambda a, b: a + \
  2. per-worker gradient g_gsam at (per-worker) perturbed model weights param_sam
    g_gsam = jax.tree_map(lambda a, b: a - b * alpha,
  3. average g_gsam across workers in
    l, grads = jax.lax.pmean((l, grads), axis_name="batch")
    note the returned grads here is g_gsam (not g_clean) in the gsam_gradient function.
  4. all workers update with the same value of globally averaged gsam in optimizer.

I'm not quite sure with model soup, but I suspect if it draws an opposite conclusion from SAM paper, it might come from a different implementation. For example, if it switches the order of 3 and 4, first performs per-worker parameter update with per-worker g_gsam, then average model weights across workers, this might harm performance compared to synced perturbation.

If want to perform synced perturbation, we can add g_clean = jax.pmean(g_clean) after

l_clean, g_clean = jax.value_and_grad(loss_fn)(params, inputs, targets)
so that param_sam is the same for all workers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants