Skip to content

Kondo Gate backward skip, and some other changes.#1

Open
plugyawn wants to merge 5 commits intogoogle-deepmind:mainfrom
plugyawn:main
Open

Kondo Gate backward skip, and some other changes.#1
plugyawn wants to merge 5 commits intogoogle-deepmind:mainfrom
plugyawn:main

Conversation

@plugyawn
Copy link
Copy Markdown

@plugyawn plugyawn commented Apr 16, 2026

I noticed the implementation of the Kondo Gate doesn't actually skip the backward, as mentioned in the paper (it instead masks the loss and still pays the dense backward). In addition, the actual benefits on large-scale training wasn't becoming apparent due to lack of caching.

So, this PR adds a few changes:

  • A backward-skipping Kondo Gate implementation, that skips backward cost by screening the batch, compacting, and then only diffing through the kept subset. This incurs a small screening cost; I have attached the timings below. I'm sure it can be further amortized.
  • In the case that the learner and actor policies are the same, and hence training is on-policy, a second forward pass is unnecessary and can be skipped. This reduces the otherwise $F + k ( F + B)$ cost to just $F + kB$ (for cases where we're (at least nearly) on-policy, from what I understand).
  • The wallclock savings from the Kondo Gate alone were hard to notice at longer prompt_lengths because caching wasn't implemented. With caching, the trainer is a much larger chunk of the per-step timing, and hence skipping the backward is even more rewarding, Amdahl-wise.
image This is on a 4-vocab reversal task with prompt_length 12. Averaged across 3 seeds, 5000 steps each. Base Kondo is the current implementation, with k=1.0; note that the wallclock includes logging-at-every-step overhead, the actual difference will probably be bigger.
  • Base Kondo 100% / base PG: ~364s, 19.2M backward tokens
  • Backward-skip Kondo 50%: ~287s, 9.6M backward tokens
  • Backward-skip Kondo 70%: ~324s, 13.45M backward tokens

Base Kondo at 70%/50% goes through all of the backward tokens, but algorithmically roughly does the same.

Edit: I'll add the ablations. The row-compaction does lead to an approximation to the "true" gradient of the original egg implementation, but I think it's closer to the paper's spirit?

Edit 2: On second thoughts, I might write this out as a MoE-like router over the training items. That should be cleaner.

The plots look a little too good, but they seem reproducible. I'll try with more configs to check.


The total step-time across 5000 step-runs drops by from ~54ms to ~38ms on average on my M3 Pro, for the default transformer config, due to reduced backward cost. However, across the run, this amortizes to ~21% reduction wallclock for 50% Kondo go over the same amount data (including logging costs; estimated logging step is 19ms per step, excluding which the 50% gate speedup goes to ~27%).

Across a 5000 step run, timings were:

  • Base Kondo 100% / base PG
    • total step: 53.93 ms
    • sample: 17.93 ms
    • screen: 0 ms
    • compact: 0 ms
    • train: 36.00 ms
    • backward tokens / step: 3840
    • total wall clock: 364.32 s
  • Backward-skip Kondo 70%
    • total step: 46.17 ms
    • sample: 18.42 ms
    • screen: 0.489 ms
    • compact: 0.078 ms
    • train: 27.19 ms
    • backward tokens / step: 2690
    • backward fraction: 0.7005
    • total wall clock: 323.89 s
  • Backward-skip Kondo 50%
    • total step: 37.76 ms
    • sample: 17.90 ms
    • screen: 0.467 ms
    • compact: 0.063 ms
    • train: 19.33 ms
    • backward tokens / step: 1920
    • backward fraction: 0.5
    • total wall clock: 287.17 s

@google-cla
Copy link
Copy Markdown

google-cla bot commented Apr 16, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@plugyawn
Copy link
Copy Markdown
Author

plugyawn commented Apr 16, 2026

Note that the compacted version is an approximation to the "true" gradient of the original implementation.

On second thoughts... I think it's maybe better to think of it as a delight-based MoE-like router, systems-wise?

Edit: Yep, that's definitely cleaner.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant