diff --git a/src/forge/data_models/__init__.py b/src/forge/data_models/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/src/forge/data_models/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/src/forge/data_models/api.py b/src/forge/data_models/api.py new file mode 100644 index 00000000..68c02f89 --- /dev/null +++ b/src/forge/data_models/api.py @@ -0,0 +1,137 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from typing import Any, Dict, Tuple + +from forge.data_models.loss import LossOutput +from forge.data_models.minibatch import Minibatch + + +# TODO: This file needs should NOT be in the data_models folder/package + + +class Store(ABC): + """ + Abstract base class for a generic key-value store. + + This class defines the interface for a storage backend that can save and retrieve + values using string keys. Subclasses should implement the actual storage logic, + which could be in-memory, on disk, remote (e.g., RDMA, Redis), or any other backend. + + Example use cases include storing model weights, configuration objects, or any + other data that needs to be accessed by key. + + Methods: + put(key: str, value: Any) -> None + Store a value under the specified key. + + get(key: str) -> Any + Retrieve the value associated with the specified key. + + Subclasses must implement both methods. + """ + + @abstractmethod + def put(self, key: str, value: Any) -> None: + """Store a value under a key.""" + pass + + @abstractmethod + def get(self, key: str) -> Any: + """Retrieve a value by key.""" + pass + + +class WeightsBuffer: + """ + Concrete class for managing model weights using a generic key-value Store backend. + This class provides a simple interface to store and retrieve model weights + (or references to them) by delegating the actual storage logic to a Store instance. + The Store abstraction allows for flexible backends (e.g., in-memory, RDMA, file system, torchstore etc.) + without changing the WeightBuffer interface. + Example usage: + store = MyCustomStoreBackend() + buffer = WeightBuffer(store) + buffer.put("model_weights", weights) + latest_weights = buffer.get("model_weights") + Args: + store (Store): An instance of a Store backend to use for storage. + """ + + def __init__(self, store): + """ + Initialize the WeightBuffer with a given Store backend. + Args: + store (Store): The storage backend to use. + """ + self.store = store + + def put(self, key: str, weights): + """ + Store the given weights under the specified key. + Args: + key (str): The key under which to store the weights. + weights: The weights object or reference to store. + """ + self.store.put(key, weights) + + def get(self, key: str): + """ + Retrieve the weights stored under the specified key. + Args: + key (str): The key for which to retrieve the weights. + Returns: + The weights object or reference associated with the key. + """ + return self.store.get(key) + + +class Trainer(ABC): + """ + Abstract base class for a reinforcement learning (RL) trainer. + This class defines the interface for any RL trainer implementation. + It standardizes the methods required for gradient accumulation, applying updates, + and snapshotting model weights. Subclasses should implement the actual logic + for these operations, which may vary depending on the underlying model, + framework, or distributed setup. + """ + + @abstractmethod + def accummulate_gradients(self, minibatch: Minibatch) -> LossOutput: + """ + Accumulate gradients for the given minibatch. + This method is called once per minibatch during training. It should compute + the gradients for the minibatch and accumulate them (without applying them yet). + + Args: + minibatch (Minibatch): The minibatch of data to use for gradient computation. + Returns: + LossOutput: The computed loss and any additional outputs needed for logging or analysis. + """ + pass + + @abstractmethod + def apply_gradients(self) -> None: + """ + Apply accumulated gradients to the model parameters. + This method should update the model's parameters using the gradients that have + been accumulated so far (e.g., by calling an optimizer step). After this call, + the accumulated gradients should be cleared/reset. + """ + pass + + @abstractmethod + def snapshot_weights(self) -> WeightsBuffer: + """ + Save the current model weights and return a buffer handle. + This method should capture the current state of the model's weights and store + them in a WeightBuffer (which may be local or remote, depending on the implementation). + The returned buffer can be used to transfer weights between components or for checkpointing. + Returns: + WeightsBuffer: A handle or reference to the stored weights buffer. + """ + pass diff --git a/src/forge/data_models/completion.py b/src/forge/data_models/completion.py new file mode 100644 index 00000000..eca4f62f --- /dev/null +++ b/src/forge/data_models/completion.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Optional + +import torch +from forge.data_models.prompt import Prompt + + +@dataclass +class Completion: + """A model-generated completion for a given prompt.""" + + # The original prompt. + prompt: Prompt + + # the decoded text returned by the model + text: str + + # the encoded text (token ids) that were fed into the model + prompt_ids: torch.Tensor + + # the encoded text (token ids) that were generated by the model + token_ids: torch.Tensor + + # the log probabilities of the target tokens + log_probs: Optional[torch.Tensor] = None diff --git a/src/forge/data_models/distributed_metric.py b/src/forge/data_models/distributed_metric.py new file mode 100644 index 00000000..5fe6f0fb --- /dev/null +++ b/src/forge/data_models/distributed_metric.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from dataclasses import dataclass + +import torch +import torch.distributed as dist + + +@dataclass +class DistributedMetric(ABC): + """Metrics that are calculated in distributed fashion. + + Metrics computed in each rank are going to be wrapped in DistributedMetric + according to how they are going to be aggregated. For example, average log prob + can be wrapped as `Fraction(Sum((logp * mask).sum()), Sum(mask.sum()))` where + `mask` indicates which token is valid. + """ + + # We need to pass a context argument for distribution setup in the future. + @abstractmethod + def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor: + pass + + @abstractmethod + def local(self) -> torch.Tensor: + pass + + +@dataclass +class SumDistributedMetric(DistributedMetric): + def __init__(self, tensor: torch.Tensor) -> None: + self.tensor = tensor + + def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor: + return _try_clone_and_reduce(self.tensor, op=dist.ReduceOp.SUM, group=group) + + def local(self) -> torch.Tensor: + return self.tensor + + +@dataclass +class Fraction: + numerator: DistributedMetric + denominator: DistributedMetric + + def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor: + return self.numerator.reduce(group) / self.denominator.reduce(group) + + def local(self) -> torch.Tensor: + return self.numerator.local() / self.denominator.local() + + +def _try_clone_and_reduce( + tensor: torch.Tensor, op: dist.ReduceOp, group: dist.ProcessGroup | None +) -> torch.Tensor: + cloned = tensor.detach().clone() + if dist.is_initialized(): + dist.all_reduce(cloned, op=op, group=group) + return cloned diff --git a/src/forge/data_models/experience.py b/src/forge/data_models/experience.py new file mode 100644 index 00000000..34a183eb --- /dev/null +++ b/src/forge/data_models/experience.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Optional, Sequence + +import torch +from forge.data_models.scored_completion import ScoredCompletion + + +@dataclass +class Experience: + """ + The Experience data class to be used by the trainer. + + Experiences are usually generated from a scored completion and running various post processing steps. + """ + + # Concatenated prompt and sample token ids. + ids: torch.Tensor + + # The mask for the target ids, 0 for prompt tokens, 1 for sample tokens. + mask: torch.Tensor + + # The weight to apply to the loss of each target token. It's normally computed + # from the advantage and the reward. + weights: torch.Tensor + + # The log probabilities of the target tokens, for prompt part it's set to 0, + # for generation part it's computed from the Generator/Sampler. + log_probs: Optional[torch.Tensor] = None + + # TODO: add more fields as required + state: str = "" + + +def from_scored_completion(scored_completion: ScoredCompletion) -> Experience: + """Converts a ScoredCompletion to an Experience.""" + prompt_ids = scored_completion.completion.prompt_ids + token_ids = scored_completion.completion.token_ids + log_probs = scored_completion.completion.log_probs + ids = torch.cat([prompt_ids, token_ids]) + mask = torch.cat( + [ + torch.zeros(prompt_ids.shape, dtype=torch.float32), + torch.ones_like(token_ids, dtype=torch.float32), + ] + ) + advantage = scored_completion.score + weights = mask * advantage + log_probs = torch.cat( + [ + torch.zeros(prompt_ids.shape, dtype=torch.float32), + # TODO: this only works if sample.log_probs is 1 + log_probs, + ] + ) + return Experience(ids=ids, mask=mask, weights=weights, log_probs=log_probs) + + +def from_scored_completions( + scored_completions: Sequence[ScoredCompletion], +) -> Sequence[Experience]: + """Converts a sequence of ScoredCompletion to a sequence of Experiences.""" + return [from_scored_completion(sc) for sc in scored_completions] diff --git a/src/forge/data_models/loss.py b/src/forge/data_models/loss.py new file mode 100644 index 00000000..9806938e --- /dev/null +++ b/src/forge/data_models/loss.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + +import torch +from forge.data_models.distributed_metric import Fraction +from forge.data_models.minibatch import Minibatch + + +@dataclass +class LossInput: + minibatch: Minibatch + trainer_logits: torch.Tensor + + +@dataclass +class LossOutput: + loss: Fraction diff --git a/src/forge/data_models/minibatch.py b/src/forge/data_models/minibatch.py new file mode 100644 index 00000000..1bcbc4ef --- /dev/null +++ b/src/forge/data_models/minibatch.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Sequence + +import torch +from forge.data_models.experience import Experience + + +@dataclass +class Minibatch: + """The minibatch that trainer will recieve.""" + + # The input sequence token ids for the trainer forward pass. + input_ids: torch.Tensor + + # The segment ids for the input sequence token ids. Same segment + # ids respresent the same sequence. + segment_ids: torch.Tensor + + # The targets required for loss computation, usually concatenated prompt and + # sample token ids. + target_ids: torch.Tensor + + # The mask for the target ids, 0 for prompt tokens, 1 for sample tokens. + target_mask: torch.Tensor + + # The weight to apply to the loss of each target token. It's normally computed + # from the advantage and the reward. + target_weights: torch.Tensor + + # The log probabilities of the target tokens, for prompt part it's set to 0, + # for generation part it's computed from the sampler. + target_log_probs: torch.Tensor + + +def from_experiences( + exps: Sequence[Experience], max_seq_len: int, pad_val: int = 0 +) -> Minibatch: + """ + Convert a list of experiences to a minibatch. + """ + + def pack_sequence( + tensors: Sequence[torch.Tensor], + pad_val: Any, + dtype: torch.dtype, + max_len: int, + ) -> torch.Tensor: + """Packs multiple tensors along the seq dim.""" + seq = torch.cat(tensors) + pad_len = max_len - seq.size(0) + if pad_len < 0: + raise ValueError( + f"Sequence lenth {seq.size(0)} exceeds the maximum length {max_len}" + ) + return torch.nn.functional.pad(seq, (0, pad_len), value=pad_val)[None, ...].to( + dtype + ) + + mini_batch = {} + exp_list = defaultdict(list) + for i, exp in enumerate(exps): + input_ids = exp.ids[:-1] + exp_list["input_ids"].append(input_ids) + exp_list["target_ids"].append(exp.ids[1:]) + exp_list["segment_ids"].append(torch.ones_like(input_ids) * i) + exp_list["target_mask"].append(exp.mask[1:]) + exp_list["target_weights"].append(exp.weights[1:]) + exp_list["target_log_probs"].append(exp.log_probs[1:]) + + for k, v in exp_list.items(): + _dtype = torch.int64 + if k == "target_mask" or k == "target_weights" or k == "target_log_probs": + _dtype = torch.float32 + + mini_batch[k] = pack_sequence(v, pad_val, _dtype, max_seq_len) + + return Minibatch(**mini_batch) diff --git a/src/forge/data_models/prompt.py b/src/forge/data_models/prompt.py new file mode 100644 index 00000000..741b097a --- /dev/null +++ b/src/forge/data_models/prompt.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Sequence +from dataclasses import dataclass +from enum import Enum +from typing import Any + + +class Role(Enum): + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + NONE = "none" + + +@dataclass +class Message: + """A single message in a conversation.""" + + chunks: Sequence[str] + role: Role + + +@dataclass +class Prompt: + """A multi-turn prompt (conversation history).""" + + # Multi-turn messages, each turn is a message. + messages: Sequence[Message] + metadata: Any | None = None + + @classmethod + def from_prompt( + cls, prompt: str, system_instruction: str | None = None + ) -> "Prompt": + messages = prompt_to_messages(prompt, system_instruction) + return Prompt( + messages=messages, + ) + + +def prompt_to_messages( + prompt: str, system_instruction: str | None = None +) -> Sequence[Message]: + """Convert a prompt to a sequence of messages.""" + messages = [] + if system_instruction is not None: + messages.append(Message(chunks=[system_instruction], role=Role.SYSTEM)) + messages.append( + Message(chunks=[prompt], role=Role.USER), + ) + return messages + + +def to_prompt( + prompt: str, metadata: Any | None = None, system_instruction: str | None = None +) -> Prompt: + """Converts a prompt to a sequence of messages.""" + return Prompt( + messages=prompt_to_messages(prompt, system_instruction), + metadata=metadata, + ) diff --git a/src/forge/data_models/scored_completion.py b/src/forge/data_models/scored_completion.py new file mode 100644 index 00000000..c1b41b8a --- /dev/null +++ b/src/forge/data_models/scored_completion.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass + +from forge.data_models.completion import Completion + + +@dataclass +class ScoredCompletion: + """A completion with an associated score (from a reward model or human).""" + + completion: Completion + score: float # akin to reward + + # TODO: add more fields as needed. diff --git a/src/forge/stores/__init__.py b/src/forge/stores/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/forge/stores/in_memory_store.py b/src/forge/stores/in_memory_store.py new file mode 100644 index 00000000..cf5be8b7 --- /dev/null +++ b/src/forge/stores/in_memory_store.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from forge.data_models.api import Store + + +class InMemoryStore(Store): + """ + Simple in-memory key-value store implementation. + Stores values in a Python dictionary, keyed by strings. + Suitable for testing, prototyping, or single-process use cases. + """ + + def __init__(self): + self._store = {} + + def put(self, key: str, value): + """ + Store a value under the specified key. + Args: + key (str): The key under which to store the value. + value: The value to store. + """ + self._store[key] = value + + def get(self, key: str): + """ + Retrieve the value associated with the specified key. + Args: + key (str): The key for which to retrieve the value. + Returns: + The value associated with the key, or None if not found. + """ + return self._store.get(key, None) diff --git a/src/forge/trainers/__init__.py b/src/forge/trainers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/forge/trainers/huggingface_trainer.py b/src/forge/trainers/huggingface_trainer.py new file mode 100644 index 00000000..56df4c6a --- /dev/null +++ b/src/forge/trainers/huggingface_trainer.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from forge.data_models.api import Trainer, WeightsBuffer + +from forge.data_models.distributed_metric import Fraction, SumDistributedMetric + +from forge.data_models.loss import LossOutput +from forge.data_models.minibatch import Minibatch +from forge.stores.in_memory_store import InMemoryStore + + +class HuggingFaceTrainer(Trainer): + def __init__(self, model_path: str): + # TODO: model_path and other trainer related configs should be passed in as a config object + super().__init__() + self.model_name = model_path + + # TODO: Harded coded implementation for RFC. this needs to be injected via config + self._store = InMemoryStore() + self._weights_buffer = WeightsBuffer(self._store) + + def accummulate_gradients(self, minibatch: Minibatch) -> LossOutput: + """ + Accumulate gradients for the given minibatch. + """ + return LossOutput( + loss=Fraction( + SumDistributedMetric(torch.Tensor(1)), SumDistributedMetric(1.0) + ) + ) + + def apply_gradients(self) -> None: + """ + Apply accumulated gradients to the model parameters. + """ + pass + + def snapshot_weights( + self, + ) -> WeightsBuffer: + """ + Save the current model weights using the provided WeightBuffer. + Args: + buffer (WeightBuffer): The buffer abstraction to use for storing weights. + Returns: + WeightsBuffer: A remote handle to the buffered weights. + """ + return self._weights_buffer