-
Notifications
You must be signed in to change notification settings - Fork 24
[3/N] Core generator abstraction #159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Any, Dict, Tuple | ||
|
||
from forge.data_models.completion import Completion | ||
|
||
from forge.data_models.loss import LossOutput | ||
from forge.data_models.minibatch import Minibatch | ||
from forge.data_models.prompt import Prompt | ||
|
||
|
||
# TODO: This file needs should NOT be in the data_models folder/package | ||
|
||
|
||
class Store(ABC): | ||
""" | ||
Abstract base class for a generic key-value store. | ||
|
||
This class defines the interface for a storage backend that can save and retrieve | ||
values using string keys. Subclasses should implement the actual storage logic, | ||
which could be in-memory, on disk, remote (e.g., RDMA, Redis), or any other backend. | ||
|
||
Example use cases include storing model weights, configuration objects, or any | ||
other data that needs to be accessed by key. | ||
|
||
Methods: | ||
put(key: str, value: Any) -> None | ||
Store a value under the specified key. | ||
|
||
get(key: str) -> Any | ||
Retrieve the value associated with the specified key. | ||
|
||
Subclasses must implement both methods. | ||
""" | ||
|
||
@abstractmethod | ||
def put(self, key: str, value: Any) -> None: | ||
"""Store a value under a key.""" | ||
pass | ||
|
||
@abstractmethod | ||
def get(self, key: str) -> Any: | ||
"""Retrieve a value by key.""" | ||
pass | ||
|
||
|
||
class WeightsBuffer: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is a reasonable abstraction in forge, although I do wonder if this is something torchstore should support to begin with? These are all things torchstore should support OOB. |
||
""" | ||
Concrete class for managing model weights using a generic key-value Store backend. | ||
This class provides a simple interface to store and retrieve model weights | ||
(or references to them) by delegating the actual storage logic to a Store instance. | ||
The Store abstraction allows for flexible backends (e.g., in-memory, RDMA, file system, torchstore etc.) | ||
without changing the WeightBuffer interface. | ||
Example usage: | ||
store = MyCustomStoreBackend() | ||
buffer = WeightBuffer(store) | ||
buffer.put("model_weights", weights) | ||
latest_weights = buffer.get("model_weights") | ||
Args: | ||
store (Store): An instance of a Store backend to use for storage. | ||
""" | ||
|
||
def __init__(self, store): | ||
""" | ||
Initialize the WeightBuffer with a given Store backend. | ||
Args: | ||
store (Store): The storage backend to use. | ||
""" | ||
self.store = store | ||
|
||
def put(self, key: str, weights): | ||
""" | ||
Store the given weights under the specified key. | ||
Args: | ||
key (str): The key under which to store the weights. | ||
weights: The weights object or reference to store. | ||
""" | ||
self.store.put(key, weights) | ||
|
||
def get(self, key: str): | ||
""" | ||
Retrieve the weights stored under the specified key. | ||
Args: | ||
key (str): The key for which to retrieve the weights. | ||
Returns: | ||
The weights object or reference associated with the key. | ||
""" | ||
return self.store.get(key) | ||
|
||
|
||
class Trainer(ABC): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 |
||
""" | ||
Abstract base class for a reinforcement learning (RL) trainer. | ||
This class defines the interface for any RL trainer implementation. | ||
It standardizes the methods required for gradient accumulation, applying updates, | ||
and snapshotting model weights. Subclasses should implement the actual logic | ||
for these operations, which may vary depending on the underlying model, | ||
framework, or distributed setup. | ||
""" | ||
|
||
@abstractmethod | ||
def accummulate_gradients(self, minibatch: Minibatch) -> LossOutput: | ||
""" | ||
Accumulate gradients for the given minibatch. | ||
This method is called once per minibatch during training. It should compute | ||
the gradients for the minibatch and accumulate them (without applying them yet). | ||
|
||
Args: | ||
minibatch (Minibatch): The minibatch of data to use for gradient computation. | ||
Returns: | ||
LossOutput: The computed loss and any additional outputs needed for logging or analysis. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def apply_gradients(self) -> None: | ||
""" | ||
Apply accumulated gradients to the model parameters. | ||
This method should update the model's parameters using the gradients that have | ||
been accumulated so far (e.g., by calling an optimizer step). After this call, | ||
the accumulated gradients should be cleared/reset. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def snapshot_weights(self) -> WeightsBuffer: | ||
""" | ||
Save the current model weights and return a buffer handle. | ||
This method should capture the current state of the model's weights and store | ||
them in a WeightBuffer (which may be local or remote, depending on the implementation). | ||
The returned buffer can be used to transfer weights between components or for checkpointing. | ||
Returns: | ||
WeightsBuffer: A handle or reference to the stored weights buffer. | ||
""" | ||
pass | ||
|
||
|
||
class Generator(ABC): | ||
""" | ||
Abstract base class for a model generator in RL or sequence modeling workflows. | ||
This class defines the interface for any generator implementation, which is responsible | ||
for producing completions (e.g., text, actions) given a prompt, and for updating its | ||
internal model weights. Subclasses should implement the actual logic for generation | ||
and weight updates, which may vary depending on the underlying model or framework. | ||
""" | ||
|
||
@abstractmethod | ||
def generate(self, prompt: Prompt) -> list[Completion]: | ||
""" | ||
Generate completions given a prompt. | ||
This method should use the current model to produce one or more completions | ||
(e.g., text outputs, actions) based on the provided prompt. | ||
Args: | ||
prompt (Prompt): The input prompt or context for generation. | ||
Returns: | ||
list[Completion]: A list of generated completions corresponding to the prompt. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def update_weights(self, weights_handle: WeightsBuffer) -> None: | ||
""" | ||
Update the weights of the model using the provided weights buffer. | ||
This method should update the generator's internal model parameters using | ||
the weights stored in the given WeightsBuffer (which may be local or remote). | ||
Args: | ||
weights_handle (WeightsBuffer): A handle or reference to the weights buffer | ||
containing the new model weights. | ||
""" | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
import torch | ||
from forge.data_models.prompt import Prompt | ||
|
||
|
||
@dataclass | ||
class Completion: | ||
"""A model-generated completion for a given prompt.""" | ||
Comment on lines
+15
to
+16
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think we already have a dataclass for this. Not sure if its "Episode". But it makes sense to have a completion class. |
||
|
||
# The original prompt. | ||
prompt: Prompt | ||
|
||
# the decoded text returned by the model | ||
text: str | ||
|
||
# the encoded text (token ids) that were fed into the model | ||
prompt_ids: torch.Tensor | ||
|
||
# the encoded text (token ids) that were generated by the model | ||
token_ids: torch.Tensor | ||
|
||
# the log probabilities of the target tokens | ||
log_probs: Optional[torch.Tensor] = None |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass | ||
|
||
import torch | ||
import torch.distributed as dist | ||
|
||
|
||
@dataclass | ||
class DistributedMetric(ABC): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should use https://github.com/meta-pytorch/forge/tree/main/src/forge/data/dataset_metrics But lets connect and see if you think it makes sense! |
||
"""Metrics that are calculated in distributed fashion. | ||
|
||
Metrics computed in each rank are going to be wrapped in DistributedMetric | ||
according to how they are going to be aggregated. For example, average log prob | ||
can be wrapped as `Fraction(Sum((logp * mask).sum()), Sum(mask.sum()))` where | ||
`mask` indicates which token is valid. | ||
""" | ||
|
||
# We need to pass a context argument for distribution setup in the future. | ||
@abstractmethod | ||
def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor: | ||
pass | ||
|
||
@abstractmethod | ||
def local(self) -> torch.Tensor: | ||
pass | ||
|
||
|
||
@dataclass | ||
class SumDistributedMetric(DistributedMetric): | ||
def __init__(self, tensor: torch.Tensor) -> None: | ||
self.tensor = tensor | ||
|
||
def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor: | ||
return _try_clone_and_reduce(self.tensor, op=dist.ReduceOp.SUM, group=group) | ||
|
||
def local(self) -> torch.Tensor: | ||
return self.tensor | ||
|
||
|
||
@dataclass | ||
class Fraction: | ||
numerator: DistributedMetric | ||
denominator: DistributedMetric | ||
|
||
def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor: | ||
return self.numerator.reduce(group) / self.denominator.reduce(group) | ||
|
||
def local(self) -> torch.Tensor: | ||
return self.numerator.local() / self.denominator.local() | ||
|
||
|
||
def _try_clone_and_reduce( | ||
tensor: torch.Tensor, op: dist.ReduceOp, group: dist.ProcessGroup | None | ||
) -> torch.Tensor: | ||
cloned = tensor.detach().clone() | ||
if dist.is_initialized(): | ||
dist.all_reduce(cloned, op=op, group=group) | ||
return cloned |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from dataclasses import dataclass | ||
from typing import Optional, Sequence | ||
|
||
import torch | ||
from forge.data_models.scored_completion import ScoredCompletion | ||
|
||
|
||
@dataclass | ||
class Experience: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is very similar to Completion. I wonder if we should just keep one. But its also a bit similar to "Episode". There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to see how this plays with data packing. It could be that the Packer takes list[Experience]. But i dont think it makes sense for this class to have concat logic if we are having packing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is modeled in a way that, Completion is what we receive from the generator and the experience is what we feed into trainer (in-between it goes through scoring, post-processing, etc.). |
||
""" | ||
The Experience data class to be used by the trainer. | ||
|
||
Experiences are usually generated from a scored completion and running various post processing steps. | ||
""" | ||
|
||
# Concatenated prompt and sample token ids. | ||
ids: torch.Tensor | ||
|
||
# The mask for the target ids, 0 for prompt tokens, 1 for sample tokens. | ||
mask: torch.Tensor | ||
|
||
# The weight to apply to the loss of each target token. It's normally computed | ||
# from the advantage and the reward. | ||
weights: torch.Tensor | ||
|
||
# The log probabilities of the target tokens, for prompt part it's set to 0, | ||
# for generation part it's computed from the Generator/Sampler. | ||
log_probs: Optional[torch.Tensor] = None | ||
|
||
# TODO: add more fields as required | ||
state: str = "" | ||
|
||
|
||
def from_scored_completion(scored_completion: ScoredCompletion) -> Experience: | ||
"""Converts a ScoredCompletion to an Experience.""" | ||
prompt_ids = scored_completion.completion.prompt_ids | ||
token_ids = scored_completion.completion.token_ids | ||
log_probs = scored_completion.completion.log_probs | ||
ids = torch.cat([prompt_ids, token_ids]) | ||
mask = torch.cat( | ||
[ | ||
torch.zeros(prompt_ids.shape, dtype=torch.float32), | ||
torch.ones_like(token_ids, dtype=torch.float32), | ||
] | ||
) | ||
advantage = scored_completion.score | ||
weights = mask * advantage | ||
log_probs = torch.cat( | ||
[ | ||
torch.zeros(prompt_ids.shape, dtype=torch.float32), | ||
# TODO: this only works if sample.log_probs is 1 | ||
log_probs, | ||
] | ||
) | ||
return Experience(ids=ids, mask=mask, weights=weights, log_probs=log_probs) | ||
|
||
|
||
def from_scored_completions( | ||
scored_completions: Sequence[ScoredCompletion], | ||
) -> Sequence[Experience]: | ||
"""Converts a sequence of ScoredCompletion to a sequence of Experiences.""" | ||
return [from_scored_completion(sc) for sc in scored_completions] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from dataclasses import dataclass | ||
|
||
import torch | ||
from forge.data_models.distributed_metric import Fraction | ||
from forge.data_models.minibatch import Minibatch | ||
|
||
|
||
@dataclass | ||
class LossInput: | ||
minibatch: Minibatch | ||
trainer_logits: torch.Tensor | ||
|
||
|
||
@dataclass | ||
class LossOutput: | ||
loss: Fraction | ||
Comment on lines
+14
to
+22
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. its interesting to have a LossInput and LossOutput class. I am just afraid that these abstractions would add too much hierarchy. i.e. instead of But the LossOutput makes more sense to me because we may be outputting multiple numbers. We should double check, because if its only a couple, then having a dataclass might be an overkill |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks really close to the KVStore we're building in
https://github.com/meta-pytorch/forge/pull/147
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ack. Will rebase and remove this once the other PR gets merged.