Skip to content
Merged
14 changes: 13 additions & 1 deletion apps/sft/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,19 @@ training:
max_norm: 1.0
steps: 1000
compile: false
dataset: "c4"

validation:
local_batch_size: 1
freq: -1 # Change to a positive number to enable validation
steps: 200 # Max steps to run validation. Validation disabled if negative.

dataset:
path: yahma/alpaca-cleaned
split: train[:95%]

dataset_val:
path: yahma/alpaca-cleaned
split: train[95%:]

parallelism:
data_parallel_replicate_degree: 1
Expand Down
86 changes: 75 additions & 11 deletions apps/sft/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
from forge.data.datasets.packed import PackedDataset, TextPacker
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
from forge.data.tokenizer import HuggingFaceModelTokenizer
from forge.data.utils import batch_to_device, CROSS_ENTROPY_IGNORE_IDX

from omegaconf import DictConfig, OmegaConf
from torch import nn

from torchdata.stateful_dataloader import StatefulDataLoader
from torchtitan.components.loss import LossFunction
from torchtitan.components.lr_scheduler import LRSchedulersContainer
Expand All @@ -30,6 +32,7 @@
from torchtitan.experiments.forge.job_config import ForgeJobConfig
from tqdm import tqdm


# stubs for now
Checkpointer = Any
Dataloader = Any
Expand Down Expand Up @@ -63,7 +66,16 @@ def __init__(self, job_config: ForgeJobConfig):
self.metric_logger = None # TODO: fix this

def setup(self):
self.train_dataloader = self.setup_data()
self.train_dataloader = self.setup_data(
self.job_config.dataset,
batch_size=self.job_config.training.local_batch_size,
)

self.val_dataloader = self.setup_data(
self.job_config.dataset_val,
batch_size=self.job_config.validation.local_batch_size,
)

# self.train_dataloader = self.setup_data(
# self.train_config.train_dataset_config,
# self.train_config.train_dataloader_config,
Expand All @@ -79,7 +91,7 @@ def setup(self):
# self.profiler = self.setup_profiler(self.train_config.profiler_config)
# self.logger = self.setup_logger(self.train_config.logger_config)

def setup_data(self):
def setup_data(self, dataset_config, batch_size):
tokenizer = HuggingFaceModelTokenizer(
tokenizer_json_path=os.path.join(
self.job_config.model.hf_assets_path, "tokenizer.json"
Expand All @@ -95,8 +107,8 @@ def setup_data(self):
dataset = sft_iterable_dataset(
model_transform=tokenizer,
message_transform=AlpacaToMessages(),
path="yahma/alpaca-cleaned",
split="train",
path=dataset_config.path,
split=dataset_config.split,
Comment on lines +110 to +111
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing we should think about is how to support additional args beyond those we've already hardcoded. E.g. in #50 we also need to pass data_files. (This is more of a config system question so it's OK to punt on it for now, but one path is to use something like instantiate for this, you can see this section in the torchtune docs for an example)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can support passing file paths. Which one (data_files or path) should it prioritize? For example, if user pass both data_files and path

)
packer = TextPacker(padding_idx=0)
dataset = PackedDataset(
Expand All @@ -106,7 +118,7 @@ def setup_data(self):
)
dataloader = StatefulDataLoader(
dataset=dataset,
batch_size=self.job_config.training.local_batch_size,
batch_size=batch_size,
collate_fn=partial(
collate_packed, mask_fn=packer.create_block_mask, device=self.device
),
Expand All @@ -119,7 +131,10 @@ def setup_data(self):
return dataloader

def forward_backward(
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
self,
input_dict: dict[str, torch.Tensor],
labels: torch.Tensor,
do_backward: bool = True,
) -> torch.Tensor:
model_parts = self.model_parts
parallel_dims = self.parallel_dims
Expand All @@ -145,14 +160,16 @@ def forward_backward(
targets, losses = (
(labels, []) if self.pp_has_last_stage else (None, None)
)
if do_backward:
pp_schedule_fn = self.pp_schedule.step
else:
pp_schedule_fn = self.pp_schedule.eval
if self.pp_has_first_stage:
self.pp_schedule.step(
pp_schedule_fn(
inputs, target=targets, losses=losses, input_batch=inputs
)
else:
self.pp_schedule.step(
target=targets, losses=losses, input_batch=inputs
)
pp_schedule_fn(target=targets, losses=losses, input_batch=inputs)

# accumulate losses across pipeline microbatches
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
Expand All @@ -170,7 +187,8 @@ def forward_backward(
loss = self.loss_fn(pred, labels)
# need to free to before bwd to avoid peaking memory
del pred
loss.backward()
if do_backward:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work with pipeline parallel, right? Since there the backward happens inside of step(), I think we will need to handle that case differently

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the code to use eval when not doing backward.

loss.backward()

return loss

Expand Down Expand Up @@ -214,6 +232,52 @@ def train(self) -> None:
last_step=self.current_step == self.num_training_steps,
)

if (
self.job_config.validation.freq > 0
and self.job_config.validation.steps > 0
and self.current_step % self.job_config.validation.freq == 0
):
self.validate(self.job_config.validation.steps)

def validate(self, max_steps: int) -> None:
for m in self.model_parts:
m.eval()
total_val_loss = torch.tensor(0.0, device=self.device)
total_val_tokens = torch.tensor(0.0, device=self.device)
with torch.no_grad():
val_pbar = tqdm(self.val_dataloader, desc="Validation", leave=False)
for batch_idx, batch in enumerate(val_pbar):
if batch_idx >= max_steps:
break
batch_to_device(batch, self.device)
current_num_tokens = (batch["labels"] != CROSS_ENTROPY_IGNORE_IDX).sum()
# Compute loss
labels = batch.pop("labels")
loss = self.forward_backward(batch, labels, do_backward=False)
val_loss = loss * current_num_tokens
total_val_loss += val_loss
total_val_tokens += current_num_tokens
# Update progress bar description with current average loss
avg_loss_so_far = (
(total_val_loss / total_val_tokens).item()
if total_val_tokens > 0
else float("inf")
)
val_pbar.set_description(
f"Running validation Loss: {avg_loss_so_far:.4f}"
)
# Aggregate validation metrics across all ranks
torch.distributed.all_reduce(total_val_loss)
torch.distributed.all_reduce(total_val_tokens)
avg_val_loss = (
(total_val_loss / total_val_tokens).item()
if total_val_tokens > 0
else float("inf")
)
for m in self.model_parts:
m.train()
print(f"\nValidation loss: {avg_val_loss}")

def cleanup(self) -> None:
if self.checkpointer:
self.checkpointer.close()
Expand Down
31 changes: 31 additions & 0 deletions src/forge/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from enum import Enum
from typing import Any, Literal, Optional, Union

import torch

from torch.nn.attention.flex_attention import BlockMask

CROSS_ENTROPY_IGNORE_IDX = -100

Role = Literal[
Expand Down Expand Up @@ -182,3 +186,30 @@ def mask_messages(
message.masked = True
elif masking_strategy == MaskingStrategy.TRAIN_ON_ASSISTANT:
message.masked = message.role != "assistant"


def batch_to_device(batch: dict, device: torch.device) -> None:
"""Function that takes a dictionary (or nested dictionary) of tensors and sets them
all to the same device. This utility is intended to be used for batches of data to be
moved to device, the update is inplace.

Args:
batch (dict): dict of Tensors or more nested dicts of tensors.
device (torch.device): torch device to move the tensors to.

Raises:
ValueError: if batch dict contains anything other than ``torch.Tensor``.

"""
for k, v in batch.items():
if isinstance(v, dict):
batch_to_device(v, device)
elif isinstance(v, torch.Tensor):
batch[k] = v.to(device)
elif isinstance(v, BlockMask):
batch[k] = v.to(device)
else:
raise ValueError(
f"""To use batch_to_device, all elements in the batch must be a dict, Tensor, or BlockMask with flexattention enabled.
Got key "{k}" with value of type {type(v)}"""
)
Loading