Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/forge/data_models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
137 changes: 137 additions & 0 deletions src/forge/data_models/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from typing import Any, Dict, Tuple

from forge.data_models.loss import LossOutput
from forge.data_models.minibatch import Minibatch


# TODO: This file needs should NOT be in the data_models folder/package


class Store(ABC):
"""
Abstract base class for a generic key-value store.

This class defines the interface for a storage backend that can save and retrieve
values using string keys. Subclasses should implement the actual storage logic,
which could be in-memory, on disk, remote (e.g., RDMA, Redis), or any other backend.

Example use cases include storing model weights, configuration objects, or any
other data that needs to be accessed by key.

Methods:
put(key: str, value: Any) -> None
Store a value under the specified key.

get(key: str) -> Any
Retrieve the value associated with the specified key.

Subclasses must implement both methods.
"""

@abstractmethod
def put(self, key: str, value: Any) -> None:
"""Store a value under a key."""
pass

@abstractmethod
def get(self, key: str) -> Any:
"""Retrieve a value by key."""
pass


class WeightsBuffer:
"""
Concrete class for managing model weights using a generic key-value Store backend.
This class provides a simple interface to store and retrieve model weights
(or references to them) by delegating the actual storage logic to a Store instance.
The Store abstraction allows for flexible backends (e.g., in-memory, RDMA, file system, torchstore etc.)
without changing the WeightBuffer interface.
Example usage:
store = MyCustomStoreBackend()
buffer = WeightBuffer(store)
buffer.put("model_weights", weights)
latest_weights = buffer.get("model_weights")
Args:
store (Store): An instance of a Store backend to use for storage.
"""

def __init__(self, store):
"""
Initialize the WeightBuffer with a given Store backend.
Args:
store (Store): The storage backend to use.
"""
self.store = store

def put(self, key: str, weights):
"""
Store the given weights under the specified key.
Args:
key (str): The key under which to store the weights.
weights: The weights object or reference to store.
"""
self.store.put(key, weights)

def get(self, key: str):
"""
Retrieve the weights stored under the specified key.
Args:
key (str): The key for which to retrieve the weights.
Returns:
The weights object or reference associated with the key.
"""
return self.store.get(key)


class Trainer(ABC):
"""
Abstract base class for a reinforcement learning (RL) trainer.
This class defines the interface for any RL trainer implementation.
It standardizes the methods required for gradient accumulation, applying updates,
and snapshotting model weights. Subclasses should implement the actual logic
for these operations, which may vary depending on the underlying model,
framework, or distributed setup.
"""

@abstractmethod
def accummulate_gradients(self, minibatch: Minibatch) -> LossOutput:
"""
Accumulate gradients for the given minibatch.
This method is called once per minibatch during training. It should compute
the gradients for the minibatch and accumulate them (without applying them yet).

Args:
minibatch (Minibatch): The minibatch of data to use for gradient computation.
Returns:
LossOutput: The computed loss and any additional outputs needed for logging or analysis.
"""
pass

@abstractmethod
def apply_gradients(self) -> None:
"""
Apply accumulated gradients to the model parameters.
This method should update the model's parameters using the gradients that have
been accumulated so far (e.g., by calling an optimizer step). After this call,
the accumulated gradients should be cleared/reset.
"""
pass

@abstractmethod
def snapshot_weights(self) -> WeightsBuffer:
"""
Save the current model weights and return a buffer handle.
This method should capture the current state of the model's weights and store
them in a WeightBuffer (which may be local or remote, depending on the implementation).
The returned buffer can be used to transfer weights between components or for checkpointing.
Returns:
WeightsBuffer: A handle or reference to the stored weights buffer.
"""
pass
31 changes: 31 additions & 0 deletions src/forge/data_models/completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Optional

import torch
from forge.data_models.prompt import Prompt


@dataclass
class Completion:
"""A model-generated completion for a given prompt."""

# The original prompt.
prompt: Prompt

# the decoded text returned by the model
text: str

# the encoded text (token ids) that were fed into the model
prompt_ids: torch.Tensor

# the encoded text (token ids) that were generated by the model
token_ids: torch.Tensor

# the log probabilities of the target tokens
log_probs: Optional[torch.Tensor] = None
64 changes: 64 additions & 0 deletions src/forge/data_models/distributed_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from dataclasses import dataclass

import torch
import torch.distributed as dist


@dataclass
class DistributedMetric(ABC):
"""Metrics that are calculated in distributed fashion.

Metrics computed in each rank are going to be wrapped in DistributedMetric
according to how they are going to be aggregated. For example, average log prob
can be wrapped as `Fraction(Sum((logp * mask).sum()), Sum(mask.sum()))` where
`mask` indicates which token is valid.
"""

# We need to pass a context argument for distribution setup in the future.
@abstractmethod
def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor:
pass

@abstractmethod
def local(self) -> torch.Tensor:
pass


@dataclass
class SumDistributedMetric(DistributedMetric):
def __init__(self, tensor: torch.Tensor) -> None:
self.tensor = tensor

def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor:
return _try_clone_and_reduce(self.tensor, op=dist.ReduceOp.SUM, group=group)

def local(self) -> torch.Tensor:
return self.tensor


@dataclass
class Fraction:
numerator: DistributedMetric
denominator: DistributedMetric

def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor:
return self.numerator.reduce(group) / self.denominator.reduce(group)

def local(self) -> torch.Tensor:
return self.numerator.local() / self.denominator.local()


def _try_clone_and_reduce(
tensor: torch.Tensor, op: dist.ReduceOp, group: dist.ProcessGroup | None
) -> torch.Tensor:
cloned = tensor.detach().clone()
if dist.is_initialized():
dist.all_reduce(cloned, op=op, group=group)
return cloned
68 changes: 68 additions & 0 deletions src/forge/data_models/experience.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Optional, Sequence

import torch
from forge.data_models.scored_completion import ScoredCompletion


@dataclass
class Experience:
"""
The Experience data class to be used by the trainer.

Experiences are usually generated from a scored completion and running various post processing steps.
"""

# Concatenated prompt and sample token ids.
ids: torch.Tensor

# The mask for the target ids, 0 for prompt tokens, 1 for sample tokens.
mask: torch.Tensor

# The weight to apply to the loss of each target token. It's normally computed
# from the advantage and the reward.
weights: torch.Tensor

# The log probabilities of the target tokens, for prompt part it's set to 0,
# for generation part it's computed from the Generator/Sampler.
log_probs: Optional[torch.Tensor] = None

# TODO: add more fields as required
state: str = ""


def from_scored_completion(scored_completion: ScoredCompletion) -> Experience:
"""Converts a ScoredCompletion to an Experience."""
prompt_ids = scored_completion.completion.prompt_ids
token_ids = scored_completion.completion.token_ids
log_probs = scored_completion.completion.log_probs
ids = torch.cat([prompt_ids, token_ids])
mask = torch.cat(
[
torch.zeros(prompt_ids.shape, dtype=torch.float32),
torch.ones_like(token_ids, dtype=torch.float32),
]
)
advantage = scored_completion.score
weights = mask * advantage
log_probs = torch.cat(
[
torch.zeros(prompt_ids.shape, dtype=torch.float32),
# TODO: this only works if sample.log_probs is 1
log_probs,
]
)
return Experience(ids=ids, mask=mask, weights=weights, log_probs=log_probs)


def from_scored_completions(
scored_completions: Sequence[ScoredCompletion],
) -> Sequence[Experience]:
"""Converts a sequence of ScoredCompletion to a sequence of Experiences."""
return [from_scored_completion(sc) for sc in scored_completions]
22 changes: 22 additions & 0 deletions src/forge/data_models/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass

import torch
from forge.data_models.distributed_metric import Fraction
from forge.data_models.minibatch import Minibatch


@dataclass
class LossInput:
minibatch: Minibatch
trainer_logits: torch.Tensor


@dataclass
class LossOutput:
loss: Fraction
Loading
Loading