Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/forge/data_models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
175 changes: 175 additions & 0 deletions src/forge/data_models/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from typing import Any, Dict, Tuple

from forge.data_models.completion import Completion

from forge.data_models.loss import LossOutput
from forge.data_models.minibatch import Minibatch
from forge.data_models.prompt import Prompt


# TODO: This file needs should NOT be in the data_models folder/package


class Store(ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really close to the KVStore we're building in https://github.com/meta-pytorch/forge/pull/147

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack. Will rebase and remove this once the other PR gets merged.

"""
Abstract base class for a generic key-value store.

This class defines the interface for a storage backend that can save and retrieve
values using string keys. Subclasses should implement the actual storage logic,
which could be in-memory, on disk, remote (e.g., RDMA, Redis), or any other backend.

Example use cases include storing model weights, configuration objects, or any
other data that needs to be accessed by key.

Methods:
put(key: str, value: Any) -> None
Store a value under the specified key.

get(key: str) -> Any
Retrieve the value associated with the specified key.

Subclasses must implement both methods.
"""

@abstractmethod
def put(self, key: str, value: Any) -> None:
"""Store a value under a key."""
pass

@abstractmethod
def get(self, key: str) -> Any:
"""Retrieve a value by key."""
pass


class WeightsBuffer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a reasonable abstraction in forge, although I do wonder if this is something torchstore should support to begin with?
e.g. (e.g., in-memory, RDMA, file system, torchstore etc.)

These are all things torchstore should support OOB.

"""
Concrete class for managing model weights using a generic key-value Store backend.
This class provides a simple interface to store and retrieve model weights
(or references to them) by delegating the actual storage logic to a Store instance.
The Store abstraction allows for flexible backends (e.g., in-memory, RDMA, file system, torchstore etc.)
without changing the WeightBuffer interface.
Example usage:
store = MyCustomStoreBackend()
buffer = WeightBuffer(store)
buffer.put("model_weights", weights)
latest_weights = buffer.get("model_weights")
Args:
store (Store): An instance of a Store backend to use for storage.
"""

def __init__(self, store):
"""
Initialize the WeightBuffer with a given Store backend.
Args:
store (Store): The storage backend to use.
"""
self.store = store

def put(self, key: str, weights):
"""
Store the given weights under the specified key.
Args:
key (str): The key under which to store the weights.
weights: The weights object or reference to store.
"""
self.store.put(key, weights)

def get(self, key: str):
"""
Retrieve the weights stored under the specified key.
Args:
key (str): The key for which to retrieve the weights.
Returns:
The weights object or reference associated with the key.
"""
return self.store.get(key)


class Trainer(ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

"""
Abstract base class for a reinforcement learning (RL) trainer.
This class defines the interface for any RL trainer implementation.
It standardizes the methods required for gradient accumulation, applying updates,
and snapshotting model weights. Subclasses should implement the actual logic
for these operations, which may vary depending on the underlying model,
framework, or distributed setup.
"""

@abstractmethod
def accummulate_gradients(self, minibatch: Minibatch) -> LossOutput:
"""
Accumulate gradients for the given minibatch.
This method is called once per minibatch during training. It should compute
the gradients for the minibatch and accumulate them (without applying them yet).

Args:
minibatch (Minibatch): The minibatch of data to use for gradient computation.
Returns:
LossOutput: The computed loss and any additional outputs needed for logging or analysis.
"""
pass

@abstractmethod
def apply_gradients(self) -> None:
"""
Apply accumulated gradients to the model parameters.
This method should update the model's parameters using the gradients that have
been accumulated so far (e.g., by calling an optimizer step). After this call,
the accumulated gradients should be cleared/reset.
"""
pass

@abstractmethod
def snapshot_weights(self) -> WeightsBuffer:
"""
Save the current model weights and return a buffer handle.
This method should capture the current state of the model's weights and store
them in a WeightBuffer (which may be local or remote, depending on the implementation).
The returned buffer can be used to transfer weights between components or for checkpointing.
Returns:
WeightsBuffer: A handle or reference to the stored weights buffer.
"""
pass


class Generator(ABC):
"""
Abstract base class for a model generator in RL or sequence modeling workflows.
This class defines the interface for any generator implementation, which is responsible
for producing completions (e.g., text, actions) given a prompt, and for updating its
internal model weights. Subclasses should implement the actual logic for generation
and weight updates, which may vary depending on the underlying model or framework.
"""

@abstractmethod
def generate(self, prompt: Prompt) -> list[Completion]:
"""
Generate completions given a prompt.
This method should use the current model to produce one or more completions
(e.g., text outputs, actions) based on the provided prompt.
Args:
prompt (Prompt): The input prompt or context for generation.
Returns:
list[Completion]: A list of generated completions corresponding to the prompt.
"""
pass

@abstractmethod
def update_weights(self, weights_handle: WeightsBuffer) -> None:
"""
Update the weights of the model using the provided weights buffer.
This method should update the generator's internal model parameters using
the weights stored in the given WeightsBuffer (which may be local or remote).
Args:
weights_handle (WeightsBuffer): A handle or reference to the weights buffer
containing the new model weights.
"""
pass
31 changes: 31 additions & 0 deletions src/forge/data_models/completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Optional

import torch
from forge.data_models.prompt import Prompt


@dataclass
class Completion:
"""A model-generated completion for a given prompt."""
Comment on lines +15 to +16
Copy link
Contributor

@felipemello1 felipemello1 Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we already have a dataclass for this. Not sure if its "Episode". But it makes sense to have a completion class.


# The original prompt.
prompt: Prompt

# the decoded text returned by the model
text: str

# the encoded text (token ids) that were fed into the model
prompt_ids: torch.Tensor

# the encoded text (token ids) that were generated by the model
token_ids: torch.Tensor

# the log probabilities of the target tokens
log_probs: Optional[torch.Tensor] = None
64 changes: 64 additions & 0 deletions src/forge/data_models/distributed_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from dataclasses import dataclass

import torch
import torch.distributed as dist


@dataclass
class DistributedMetric(ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should use https://github.com/meta-pytorch/forge/tree/main/src/forge/data/dataset_metrics

But lets connect and see if you think it makes sense!

"""Metrics that are calculated in distributed fashion.

Metrics computed in each rank are going to be wrapped in DistributedMetric
according to how they are going to be aggregated. For example, average log prob
can be wrapped as `Fraction(Sum((logp * mask).sum()), Sum(mask.sum()))` where
`mask` indicates which token is valid.
"""

# We need to pass a context argument for distribution setup in the future.
@abstractmethod
def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor:
pass

@abstractmethod
def local(self) -> torch.Tensor:
pass


@dataclass
class SumDistributedMetric(DistributedMetric):
def __init__(self, tensor: torch.Tensor) -> None:
self.tensor = tensor

def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor:
return _try_clone_and_reduce(self.tensor, op=dist.ReduceOp.SUM, group=group)

def local(self) -> torch.Tensor:
return self.tensor


@dataclass
class Fraction:
numerator: DistributedMetric
denominator: DistributedMetric

def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor:
return self.numerator.reduce(group) / self.denominator.reduce(group)

def local(self) -> torch.Tensor:
return self.numerator.local() / self.denominator.local()


def _try_clone_and_reduce(
tensor: torch.Tensor, op: dist.ReduceOp, group: dist.ProcessGroup | None
) -> torch.Tensor:
cloned = tensor.detach().clone()
if dist.is_initialized():
dist.all_reduce(cloned, op=op, group=group)
return cloned
68 changes: 68 additions & 0 deletions src/forge/data_models/experience.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Optional, Sequence

import torch
from forge.data_models.scored_completion import ScoredCompletion


@dataclass
class Experience:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very similar to Completion. I wonder if we should just keep one. But its also a bit similar to "Episode".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to see how this plays with data packing. It could be that the Packer takes list[Experience]. But i dont think it makes sense for this class to have concat logic if we are having packing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is modeled in a way that, Completion is what we receive from the generator and the experience is what we feed into trainer (in-between it goes through scoring, post-processing, etc.).

"""
The Experience data class to be used by the trainer.

Experiences are usually generated from a scored completion and running various post processing steps.
"""

# Concatenated prompt and sample token ids.
ids: torch.Tensor

# The mask for the target ids, 0 for prompt tokens, 1 for sample tokens.
mask: torch.Tensor

# The weight to apply to the loss of each target token. It's normally computed
# from the advantage and the reward.
weights: torch.Tensor

# The log probabilities of the target tokens, for prompt part it's set to 0,
# for generation part it's computed from the Generator/Sampler.
log_probs: Optional[torch.Tensor] = None

# TODO: add more fields as required
state: str = ""


def from_scored_completion(scored_completion: ScoredCompletion) -> Experience:
"""Converts a ScoredCompletion to an Experience."""
prompt_ids = scored_completion.completion.prompt_ids
token_ids = scored_completion.completion.token_ids
log_probs = scored_completion.completion.log_probs
ids = torch.cat([prompt_ids, token_ids])
mask = torch.cat(
[
torch.zeros(prompt_ids.shape, dtype=torch.float32),
torch.ones_like(token_ids, dtype=torch.float32),
]
)
advantage = scored_completion.score
weights = mask * advantage
log_probs = torch.cat(
[
torch.zeros(prompt_ids.shape, dtype=torch.float32),
# TODO: this only works if sample.log_probs is 1
log_probs,
]
)
return Experience(ids=ids, mask=mask, weights=weights, log_probs=log_probs)


def from_scored_completions(
scored_completions: Sequence[ScoredCompletion],
) -> Sequence[Experience]:
"""Converts a sequence of ScoredCompletion to a sequence of Experiences."""
return [from_scored_completion(sc) for sc in scored_completions]
22 changes: 22 additions & 0 deletions src/forge/data_models/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass

import torch
from forge.data_models.distributed_metric import Fraction
from forge.data_models.minibatch import Minibatch


@dataclass
class LossInput:
minibatch: Minibatch
trainer_logits: torch.Tensor


@dataclass
class LossOutput:
loss: Fraction
Comment on lines +14 to +22
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its interesting to have a LossInput and LossOutput class. I am just afraid that these abstractions would add too much hierarchy. i.e. instead of loss(logits, targets, mask), we have loss(LossInput(MiniBatch(something_else))).

But the LossOutput makes more sense to me because we may be outputting multiple numbers. We should double check, because if its only a couple, then having a dataclass might be an overkill

Loading
Loading