Skip to content

Two stage training experiments#1297

Merged
kothasuhas merged 9 commits into
mainfrom
suhas/two-stage-pr-2
May 19, 2025
Merged

Two stage training experiments#1297
kothasuhas merged 9 commits into
mainfrom
suhas/two-stage-pr-2

Conversation

@kothasuhas
Copy link
Copy Markdown
Contributor

@kothasuhas kothasuhas commented May 18, 2025

Description

Fixes #702 and #1110 (and removes the need for #1124)

Defines a two-stage config, a stylized controlled experimental setting which

  1. covers a large space of pre-training, mid-training, and post-training data schedules
  2. is simple to search over to answer questions of interest

We provide examples of using the framework in three ways

  1. fine-tuning: varying the replay ratio with no target data in pre-training
  2. mid-training: setting the mixture for two stages of training. there are a couple of different parameterizations, explained in the two-stage config documentation
  3. continual pre-training: training with replay ratio initialized from a pre-trained language model, where the goal is to measure accuracy instead of loss

@kothasuhas
Copy link
Copy Markdown
Contributor Author

Super unimportant, but I'm pretty proud of this comment :)) Unfortunately removed the greek characters because the linter was unhappy ://
Screenshot 2025-05-18 at 12 59 54 AM

@kothasuhas
Copy link
Copy Markdown
Contributor Author

Looks like I'll need to fix how I'm building the train config (this code was written for Marin two weeks ago). Code is still ready for review while i fix this

@dlwh
Copy link
Copy Markdown
Member

dlwh commented May 18, 2025

i am 100% on board with disabling whatever lint rule made you do that

Copy link
Copy Markdown
Member

@dlwh dlwh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

code looks good. if you could add some docs I'd be happy!

BASQUE_TASKS = (EvalTaskConfig("xcopa_eu", num_fewshot=0, task_alias="xcopa_eu"),)

if __name__ == "__main__":
NUM_RARE_STEPS = 400.0 # 200M tokens
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you define "rare" somewhere and link to it wherever it's mentioned

Two stage training: pre-training on only C4, followed by fine-tuning on rare + wC4.
We cooldown/rewarmup the learning rate and reset the optimizer state in between checkpoints.
For a fair comparison, we keep the total number of training steps fixed across replay ratio.
When we increase the replay ratio, the second stage has more steps so we decrease the length of pre-training.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

link to some documentation on replay ratio


@dataclass
class TwoStageConfig:
"""Configuration for two-stage training."""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more on what this means

"""Configuration for two-stage training."""

### Data quantity
rare_data_name: str
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

define what we mean here (probably in docstring above)

Comment on lines +89 to +90
self.rare_data = data_dict[self.rare_data_name]
self.common_data = data_dict[self.common_data_name]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these could just be properties?

if self.steps_per_eval is None:
self.steps_per_eval = self.num_train_steps // 20

self.steps_per_eval = min(self.steps_per_eval, self.num_train_steps // 2)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe do a warning if you are changing it


def set_data_schedule_params(self):
"""
┌──────────────────────────┬──────────┐
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great ascii figure, but somewhere you need to give a high level explanation of what you're doign

Comment on lines +233 to +243
"""Format total number of tokens in B/M/K notation."""
tokens = self.total_tokens
if tokens >= 1_000_000_000:
return f"{tokens/1_000_000_000:.1f}B"
elif tokens >= 10_000_000:
return f"{int(tokens/1_000_000)}M"
elif tokens >= 1_000_000:
return f"{tokens/1_000_000:.1f}M"
elif tokens >= 1_000:
return f"{tokens/1_000:.1f}K"
return str(tokens)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should make a helper (and probably it should use humanfriendly.format_size and munge out the B/bytes)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if its okay, will stick to this for now for backwards compatibility w experiments?

return hash(self) == hash(other)


def two_stage_train_step(two_stage_config: TwoStageConfig) -> ExecutorStep:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc string. either this one or the config one should be pretty detailed

@kothasuhas kothasuhas merged commit bfbc449 into main May 19, 2025
5 checks passed
This was referenced May 19, 2025
@rjpower rjpower deleted the suhas/two-stage-pr-2 branch January 21, 2026 00:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Studying how data ordering can improve pretraining

2 participants