Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ trainer:
max_norm: 1.0
steps: 1000000
dtype: bfloat16
gc_freq: 1
compile:
enable: false
parallelism:
Expand Down Expand Up @@ -83,6 +84,7 @@ ref_model:
hf_assets_path: hf://${model}
training:
dtype: bfloat16
gc_freq: 1
compile:
enable: false
parallelism:
Expand Down
5 changes: 4 additions & 1 deletion src/forge/actors/reference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ReferenceModel(ForgeActor):
compile: Compile = field(default_factory=Compile)
training: Training = field(
default_factory=Training
) # Only needed in order to correctly set a lower dtype
) # Needed in order to set attrs like dtype, garbage collection freq, etc.

# Populated in setup
# TODO: Commented out since engine_config parsing extracts from class members
Expand All @@ -61,6 +61,7 @@ def __post_init__(self):
"""
self.rank = current_rank().rank
self.size = math.prod(current_size().values())
self.step = 0

env = {
"RANK": str(self.rank),
Expand All @@ -83,6 +84,7 @@ async def setup(self):

@endpoint
async def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
self.engine.gc_handler.run(self.step)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain what is this doing internally?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model_parts = self.engine.model_parts
parallel_dims = self.engine.parallel_dims
input_ids = input_ids.to("cuda")
Expand All @@ -106,6 +108,7 @@ async def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
with self.engine.maybe_enable_amp:
with torch.inference_mode():
logits = model_parts[0](input_ids)
self.step += 1
if isinstance(logits, DTensor):
logits = logits.full_tensor()
return logits
11 changes: 6 additions & 5 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __post_init__(self):
f"{f.name} should be a {f.type} type or a dict like object"
)

self.current_step = 1 # fragile contract.
self.step = 1 # fragile contract.
self.num_training_steps = self.training.steps
self.gradient_accumulation_steps = 1
self.rank = current_rank().rank
Expand All @@ -100,7 +100,7 @@ async def setup(self):
for key in {"loss", "state_dict_key", "use_dcp"}:
engine_config.pop(key) # Not part of job config
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
self.engine.checkpointer.load(step=self.current_step)
self.engine.checkpointer.load(step=self.step)
self.engine.optimizers.zero_grad()

def forward_backward(
Expand Down Expand Up @@ -173,6 +173,7 @@ def forward_backward(
def train_step(
self, inputs: list[dict[str, Tensor]], targets: list[dict[str, Tensor]]
) -> float:
self.engine.gc_handler.run(self.step)
local_inputs = inputs[self.engine.dp_rank]
local_targets = targets[self.engine.dp_rank]
batch_to_device(local_inputs, self.engine.device)
Expand All @@ -192,10 +193,10 @@ def train_step(
self.engine.optimizers.zero_grad()
self.engine.lr_schedulers.step()

self.current_step += 1
self.step += 1
self.engine.checkpointer.save(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: Why do we do CP save on evert step?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a Titan checkpointer impl detail. What actually happens is that it checks if it should save, which is determined by the checkpoint frequency attr found in our config. If it shouldn't checkpoint it just returns. See here.

A much much much better name would be maybe_save IMO

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why isn't it a "finally" step that's done by default?

curr_step=self.current_step,
last_step=self.current_step == self.num_training_steps,
curr_step=self.step,
last_step=self.step == self.num_training_steps,
)

return loss.item()
Expand Down
Loading