-
Notifications
You must be signed in to change notification settings - Fork 49
Garbage collect on every train / ref model step #209
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,7 +37,7 @@ class ReferenceModel(ForgeActor): | |
| compile: Compile = field(default_factory=Compile) | ||
| training: Training = field( | ||
| default_factory=Training | ||
| ) # Only needed in order to correctly set a lower dtype | ||
| ) # Needed in order to set attrs like dtype, garbage collection freq, etc. | ||
|
|
||
| # Populated in setup | ||
| # TODO: Commented out since engine_config parsing extracts from class members | ||
|
|
@@ -61,6 +61,7 @@ def __post_init__(self): | |
| """ | ||
| self.rank = current_rank().rank | ||
| self.size = math.prod(current_size().values()) | ||
| self.step = 0 | ||
|
|
||
| env = { | ||
| "RANK": str(self.rank), | ||
|
|
@@ -83,6 +84,7 @@ async def setup(self): | |
|
|
||
| @endpoint | ||
| async def forward(self, input_ids: torch.Tensor) -> torch.Tensor: | ||
| self.engine.gc_handler.run(self.step) | ||
|
||
| model_parts = self.engine.model_parts | ||
| parallel_dims = self.engine.parallel_dims | ||
| input_ids = input_ids.to("cuda") | ||
|
|
@@ -106,6 +108,7 @@ async def forward(self, input_ids: torch.Tensor) -> torch.Tensor: | |
| with self.engine.maybe_enable_amp: | ||
| with torch.inference_mode(): | ||
| logits = model_parts[0](input_ids) | ||
| self.step += 1 | ||
| if isinstance(logits, DTensor): | ||
| logits = logits.full_tensor() | ||
| return logits | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -73,7 +73,7 @@ def __post_init__(self): | |
| f"{f.name} should be a {f.type} type or a dict like object" | ||
| ) | ||
|
|
||
| self.current_step = 1 # fragile contract. | ||
| self.step = 1 # fragile contract. | ||
| self.num_training_steps = self.training.steps | ||
| self.gradient_accumulation_steps = 1 | ||
| self.rank = current_rank().rank | ||
|
|
@@ -100,7 +100,7 @@ async def setup(self): | |
| for key in {"loss", "state_dict_key", "use_dcp"}: | ||
| engine_config.pop(key) # Not part of job config | ||
| self.engine = ForgeEngine(ForgeJobConfig(**engine_config)) | ||
| self.engine.checkpointer.load(step=self.current_step) | ||
| self.engine.checkpointer.load(step=self.step) | ||
| self.engine.optimizers.zero_grad() | ||
|
|
||
| def forward_backward( | ||
|
|
@@ -173,6 +173,7 @@ def forward_backward( | |
| def train_step( | ||
| self, inputs: list[dict[str, Tensor]], targets: list[dict[str, Tensor]] | ||
| ) -> float: | ||
| self.engine.gc_handler.run(self.step) | ||
| local_inputs = inputs[self.engine.dp_rank] | ||
| local_targets = targets[self.engine.dp_rank] | ||
| batch_to_device(local_inputs, self.engine.device) | ||
|
|
@@ -192,10 +193,10 @@ def train_step( | |
| self.engine.optimizers.zero_grad() | ||
| self.engine.lr_schedulers.step() | ||
|
|
||
| self.current_step += 1 | ||
| self.step += 1 | ||
| self.engine.checkpointer.save( | ||
|
||
| curr_step=self.current_step, | ||
| last_step=self.current_step == self.num_training_steps, | ||
| curr_step=self.step, | ||
| last_step=self.step == self.num_training_steps, | ||
| ) | ||
|
|
||
| return loss.item() | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.