Two stage training experiments#1297
Conversation
|
Looks like I'll need to fix how I'm building the train config (this code was written for Marin two weeks ago). Code is still ready for review while i fix this |
|
i am 100% on board with disabling whatever lint rule made you do that |
dlwh
left a comment
There was a problem hiding this comment.
code looks good. if you could add some docs I'd be happy!
| BASQUE_TASKS = (EvalTaskConfig("xcopa_eu", num_fewshot=0, task_alias="xcopa_eu"),) | ||
|
|
||
| if __name__ == "__main__": | ||
| NUM_RARE_STEPS = 400.0 # 200M tokens |
There was a problem hiding this comment.
can you define "rare" somewhere and link to it wherever it's mentioned
| Two stage training: pre-training on only C4, followed by fine-tuning on rare + wC4. | ||
| We cooldown/rewarmup the learning rate and reset the optimizer state in between checkpoints. | ||
| For a fair comparison, we keep the total number of training steps fixed across replay ratio. | ||
| When we increase the replay ratio, the second stage has more steps so we decrease the length of pre-training. |
There was a problem hiding this comment.
link to some documentation on replay ratio
|
|
||
| @dataclass | ||
| class TwoStageConfig: | ||
| """Configuration for two-stage training.""" |
| """Configuration for two-stage training.""" | ||
|
|
||
| ### Data quantity | ||
| rare_data_name: str |
There was a problem hiding this comment.
define what we mean here (probably in docstring above)
| self.rare_data = data_dict[self.rare_data_name] | ||
| self.common_data = data_dict[self.common_data_name] |
| if self.steps_per_eval is None: | ||
| self.steps_per_eval = self.num_train_steps // 20 | ||
|
|
||
| self.steps_per_eval = min(self.steps_per_eval, self.num_train_steps // 2) |
There was a problem hiding this comment.
maybe do a warning if you are changing it
|
|
||
| def set_data_schedule_params(self): | ||
| """ | ||
| ┌──────────────────────────┬──────────┐ |
There was a problem hiding this comment.
great ascii figure, but somewhere you need to give a high level explanation of what you're doign
| """Format total number of tokens in B/M/K notation.""" | ||
| tokens = self.total_tokens | ||
| if tokens >= 1_000_000_000: | ||
| return f"{tokens/1_000_000_000:.1f}B" | ||
| elif tokens >= 10_000_000: | ||
| return f"{int(tokens/1_000_000)}M" | ||
| elif tokens >= 1_000_000: | ||
| return f"{tokens/1_000_000:.1f}M" | ||
| elif tokens >= 1_000: | ||
| return f"{tokens/1_000:.1f}K" | ||
| return str(tokens) |
There was a problem hiding this comment.
we should make a helper (and probably it should use humanfriendly.format_size and munge out the B/bytes)
There was a problem hiding this comment.
if its okay, will stick to this for now for backwards compatibility w experiments?
| return hash(self) == hash(other) | ||
|
|
||
|
|
||
| def two_stage_train_step(two_stage_config: TwoStageConfig) -> ExecutorStep: |
There was a problem hiding this comment.
doc string. either this one or the config one should be pretty detailed

Description
Fixes #702 and #1110 (and removes the need for #1124)
Defines a two-stage config, a stylized controlled experimental setting which
We provide examples of using the framework in three ways