Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/forge/data_models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
31 changes: 31 additions & 0 deletions src/forge/data_models/completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Optional

import torch
from forge.data_models.prompt import Prompt


@dataclass
class Completion:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently our Episode is a combination of Completion and ScoreCompletion and Experience. I think keeping them flat makes customization and logging a bit easier but either way can work.

"""A model-generated completion for a given prompt."""

# The original prompt.
prompt: Prompt

# the decoded text returned by the model
text: str

# the encoded text (token ids) that were fed into the model
prompt_ids: torch.Tensor

# the encoded text (token ids) that were generated by the model
token_ids: torch.Tensor

# the log probabilities of the target tokens
log_probs: Optional[torch.Tensor] = None
64 changes: 64 additions & 0 deletions src/forge/data_models/distributed_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from dataclasses import dataclass

import torch
import torch.distributed as dist


@dataclass
class DistributedMetric(ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this concept and it could be combined with our MetricLogger. But if we're going to handle this in a generic way, I don't think we want to deal with process groups, otherwise you have to set them up with a bunch of services that don't need them. Either we should use these abstractions and return metrics to a MetricService/controller to aggregate or handle it without formal abstractions, in a service specific way, as we do now.

"""Metrics that are calculated in distributed fashion.

Metrics computed in each rank are going to be wrapped in DistributedMetric
according to how they are going to be aggregated. For example, average log prob
can be wrapped as `Fraction(Sum((logp * mask).sum()), Sum(mask.sum()))` where
`mask` indicates which token is valid.
"""

# We need to pass a context argument for distribution setup in the future.
@abstractmethod
def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor:
pass

@abstractmethod
def local(self) -> torch.Tensor:
pass


@dataclass
class SumDistributedMetric(DistributedMetric):
def __init__(self, tensor: torch.Tensor) -> None:
self.tensor = tensor

def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor:
return _try_clone_and_reduce(self.tensor, op=dist.ReduceOp.SUM, group=group)

def local(self) -> torch.Tensor:
return self.tensor


@dataclass
class Fraction:
numerator: DistributedMetric
denominator: DistributedMetric

def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor:
return self.numerator.reduce(group) / self.denominator.reduce(group)

def local(self) -> torch.Tensor:
return self.numerator.local() / self.denominator.local()


def _try_clone_and_reduce(
tensor: torch.Tensor, op: dist.ReduceOp, group: dist.ProcessGroup | None
) -> torch.Tensor:
cloned = tensor.detach().clone()
if dist.is_initialized():
dist.all_reduce(cloned, op=op, group=group)
return cloned
68 changes: 68 additions & 0 deletions src/forge/data_models/experience.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Optional, Sequence

import torch
from forge.data_models.scored_completion import ScoredCompletion


@dataclass
class Experience:
"""
The Experience data class to be used by the trainer.

Experiences are usually generated from a scored completion and running various post processing steps.
"""

# Concatenated prompt and sample token ids.
ids: torch.Tensor

# The mask for the target ids, 0 for prompt tokens, 1 for sample tokens.
mask: torch.Tensor

# The weight to apply to the loss of each target token. It's normally computed
# from the advantage and the reward.
weights: torch.Tensor

# The log probabilities of the target tokens, for prompt part it's set to 0,
# for generation part it's computed from the Generator/Sampler.
log_probs: Optional[torch.Tensor] = None

# TODO: add more fields as required
state: str = ""


def from_scored_completion(scored_completion: ScoredCompletion) -> Experience:
"""Converts a ScoredCompletion to an Experience."""
prompt_ids = scored_completion.completion.prompt_ids
token_ids = scored_completion.completion.token_ids
log_probs = scored_completion.completion.log_probs
ids = torch.cat([prompt_ids, token_ids])
mask = torch.cat(
[
torch.zeros(prompt_ids.shape, dtype=torch.float32),
torch.ones_like(token_ids, dtype=torch.float32),
]
)
advantage = scored_completion.score
weights = mask * advantage
log_probs = torch.cat(
[
torch.zeros(prompt_ids.shape, dtype=torch.float32),
# TODO: this only works if sample.log_probs is 1
log_probs,
]
)
return Experience(ids=ids, mask=mask, weights=weights, log_probs=log_probs)


def from_scored_completions(
scored_completions: Sequence[ScoredCompletion],
) -> Sequence[Experience]:
"""Converts a sequence of ScoredCompletion to a sequence of Experiences."""
return [from_scored_completion(sc) for sc in scored_completions]
22 changes: 22 additions & 0 deletions src/forge/data_models/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass

import torch
from forge.data_models.distributed_metric import Fraction
from forge.data_models.minibatch import Minibatch


@dataclass
class LossInput:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is roughly what we're doing now, except it's trainer_logits + target_minibatch which is a subset of the minibatch you routed for the loss.

minibatch: Minibatch
trainer_logits: torch.Tensor


@dataclass
class LossOutput:
loss: Fraction
84 changes: 84 additions & 0 deletions src/forge/data_models/minibatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Sequence

import torch
from forge.data_models.experience import Experience


@dataclass
class Minibatch:
"""The minibatch that trainer will recieve."""

# The input sequence token ids for the trainer forward pass.
input_ids: torch.Tensor

# The segment ids for the input sequence token ids. Same segment
# ids respresent the same sequence.
segment_ids: torch.Tensor

# The targets required for loss computation, usually concatenated prompt and
# sample token ids.
target_ids: torch.Tensor

# The mask for the target ids, 0 for prompt tokens, 1 for sample tokens.
target_mask: torch.Tensor

# The weight to apply to the loss of each target token. It's normally computed
# from the advantage and the reward.
target_weights: torch.Tensor

# The log probabilities of the target tokens, for prompt part it's set to 0,
# for generation part it's computed from the sampler.
target_log_probs: torch.Tensor


def from_experiences(
exps: Sequence[Experience], max_seq_len: int, pad_val: int = 0
) -> Minibatch:
"""
Convert a list of experiences to a minibatch.
"""

def pack_sequence(
tensors: Sequence[torch.Tensor],
pad_val: Any,
dtype: torch.dtype,
max_len: int,
) -> torch.Tensor:
"""Packs multiple tensors along the seq dim."""
seq = torch.cat(tensors)
pad_len = max_len - seq.size(0)
if pad_len < 0:
raise ValueError(
f"Sequence lenth {seq.size(0)} exceeds the maximum length {max_len}"
)
return torch.nn.functional.pad(seq, (0, pad_len), value=pad_val)[None, ...].to(
dtype
)

mini_batch = {}
exp_list = defaultdict(list)
for i, exp in enumerate(exps):
input_ids = exp.ids[:-1]
exp_list["input_ids"].append(input_ids)
exp_list["target_ids"].append(exp.ids[1:])
exp_list["segment_ids"].append(torch.ones_like(input_ids) * i)
exp_list["target_mask"].append(exp.mask[1:])
exp_list["target_weights"].append(exp.weights[1:])
exp_list["target_log_probs"].append(exp.log_probs[1:])

for k, v in exp_list.items():
_dtype = torch.int64
if k == "target_mask" or k == "target_weights" or k == "target_log_probs":
_dtype = torch.float32

mini_batch[k] = pack_sequence(v, pad_val, _dtype, max_seq_len)

return Minibatch(**mini_batch)
66 changes: 66 additions & 0 deletions src/forge/data_models/prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum
from typing import Any


class Role(Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
NONE = "none"


@dataclass
class Message:
"""A single message in a conversation."""

chunks: Sequence[str]
role: Role


@dataclass
class Prompt:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general we should match OpenAI requests here as the type. This way we can use the same formatting for local and API judge calls

"""A multi-turn prompt (conversation history)."""

# Multi-turn messages, each turn is a message.
messages: Sequence[Message]
metadata: Any | None = None

@classmethod
def from_prompt(
cls, prompt: str, system_instruction: str | None = None
) -> "Prompt":
messages = prompt_to_messages(prompt, system_instruction)
return Prompt(
messages=messages,
)


def prompt_to_messages(
prompt: str, system_instruction: str | None = None
) -> Sequence[Message]:
"""Convert a prompt to a sequence of messages."""
messages = []
if system_instruction is not None:
messages.append(Message(chunks=[system_instruction], role=Role.SYSTEM))
messages.append(
Message(chunks=[prompt], role=Role.USER),
)
return messages


def to_prompt(
prompt: str, metadata: Any | None = None, system_instruction: str | None = None
) -> Prompt:
"""Converts a prompt to a sequence of messages."""
return Prompt(
messages=prompt_to_messages(prompt, system_instruction),
metadata=metadata,
)
19 changes: 19 additions & 0 deletions src/forge/data_models/scored_completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass

from forge.data_models.completion import Completion


@dataclass
class ScoredCompletion:
"""A completion with an associated score (from a reward model or human)."""

completion: Completion
score: float # akin to reward

# TODO: add more fields as needed.
Loading