Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
37de2be
only compute lengths in the token dataset when it's not already prese…
aldo-pareja Oct 6, 2024
bf7e86f
Refactor padding function to support position_ids for FlashAttention
aldo-pareja Oct 6, 2024
77a3965
logging the global gradnorm now
aldo-pareja Oct 6, 2024
be0cc71
fixing deepspeed because it's not working with the scheduler we want
aldo-pareja Oct 7, 2024
014e5e4
fixing accelerate lr_scheduler
aldo-pareja Oct 7, 2024
bf5b25e
fixing accelerate lr_scheduler
aldo-pareja Oct 7, 2024
4ac74f4
samples seen was broken because now the samples are a single line
aldo-pareja Oct 7, 2024
18182e1
find packing is wrong because when flash attention is supported paddi…
aldo-pareja Oct 7, 2024
70abd41
black formatting
aldo-pareja Oct 7, 2024
538e506
it should not fail on granite 8b models anymore
aldo-pareja Oct 7, 2024
208f396
linting
aldo-pareja Oct 7, 2024
5ed04dc
linting
aldo-pareja Oct 7, 2024
eda2641
bug on padding when creating the multipack sampler
aldo-pareja Oct 7, 2024
d8c3ac1
linter
aldo-pareja Oct 8, 2024
377d9a2
linter
aldo-pareja Oct 8, 2024
5c05b0a
Change old padding-free and granite flags to use_dolomite
Maxusmusti Oct 8, 2024
9c73f27
Add safeguards and checks for flash attention when enabled/disabled
Maxusmusti Oct 8, 2024
6a21d8d
Rework flash attention checks for better modularity
Maxusmusti Oct 8, 2024
4c431a2
Fix arg name
Maxusmusti Oct 8, 2024
8fff855
Update transformers to a version with Granite model class
Maxusmusti Oct 9, 2024
4288b28
Adding stateguards for dolomite and granite and model path check
Maxusmusti Oct 10, 2024
7a6f567
Missing update
Maxusmusti Oct 10, 2024
8e7c86a
Clean up early validation checks and move to utils
Maxusmusti Oct 10, 2024
3cd8597
Fix spelling mistake
Maxusmusti Oct 10, 2024
710ae92
Include AMD in flash attn check
Maxusmusti Oct 14, 2024
27959c7
Red-add is_padding_free with deprecation warning
Maxusmusti Oct 16, 2024
ef19e26
Make use_dolomite default false
Maxusmusti Oct 16, 2024
f777d45
this is needed because the tag <MASK> is too common and some datasets…
aldo-pareja Oct 16, 2024
041a856
added a warning in case the special tokens used for data processing a…
aldo-pareja Oct 16, 2024
f03427b
added a warning in case the special tokens used for data processing a…
aldo-pareja Oct 16, 2024
4d095c3
Update valid data filter
Maxusmusti Oct 16, 2024
a45e82b
Fix ruff formatting
Maxusmusti Oct 16, 2024
d5fe4d8
Apply review feedback
Maxusmusti Oct 22, 2024
0b46e83
Added comments
Maxusmusti Oct 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ py-cpuinfo
# we set this to be above 0a0 so that it doesn't
# replace custom pytorch images with the 2.3.0
torch>=2.3.0a0
transformers>=4.41.2
transformers>=4.45.2
accelerate>=0.34.2
datasets>=2.15.0
numba
Expand Down
3 changes: 2 additions & 1 deletion src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,9 @@ class TrainingArgs(BaseModel):
save_samples: int
learning_rate: float
warmup_steps: int
is_padding_free: bool
random_seed: int = 42
use_dolomite: bool = False
is_padding_free: bool = False # TODO: deprecate
checkpoint_at_epoch: bool = True
accelerate_full_state_at_epoch: bool = True

Expand Down
25 changes: 21 additions & 4 deletions src/instructlab/training/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def print_masked_samples(data, tokenizer, is_pretrain, num_proc):
def get_masked_and_orig_text(sample):
labels = sample["labels"]
input_ids = sample["input_ids"]
mask_id = get_sp_token(tokenizer, "<MASK>")[0]
mask_id = get_sp_token(tokenizer, "<|MASK|>")[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this affect existing models? Or is this purely for training-time?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah only relevant during training

label = [mask_id if tk == -100 else tk for tk in labels]
text = tokenizer.decode(label)
orig_text = tokenizer.decode(input_ids)
Expand Down Expand Up @@ -239,7 +239,7 @@ def main(args: DataProcessArgs):

# Adding after tokenizer setup as these are temp tokens, not to be saved
tokenizer.add_special_tokens(
{"additional_special_tokens": ["<|pretrain|>", "<|/pretrain|>", "<MASK>"]}
{"additional_special_tokens": ["<|pretrain|>", "<|/pretrain|>", "<|MASK|>"]}
)

try:
Expand Down Expand Up @@ -347,9 +347,26 @@ def main(args: DataProcessArgs):
)

# extract only labels and messages formatted into a new dataset
data_with_labels = data_with_labels.select_columns(["labels", "input_ids"])
data_with_labels = data_with_labels.map(
lambda x: {
"len": len(x["input_ids"]),
},
num_proc=NUM_PROC,
)
data_with_labels = data_with_labels.select_columns(["labels", "input_ids", "len"])
# MASK and both pretrain tokens should not be in the final tokens, those are special tokens added only for data processing purposes.
max_id = len(tokenizer) - 3
final_valid_data = data_with_labels.filter(
lambda x: all(tk < max_id for tk in x["labels"]), num_proc=NUM_PROC
)
# Dropping samples that could break training due to oob ids
if len(final_valid_data) < len(data_with_labels):
dropped_samples = len(data_with_labels) - len(final_valid_data)
print(
f"\033[93mWarning: {dropped_samples} samples were dropped because they contained token IDs greater than or equal to {max_id}.\033[0m"
)
# use path to get the stem of the file
data_with_labels.to_json(Path(args.data_output_path) / f"data.jsonl")
final_valid_data.to_json(Path(args.data_output_path) / "data.jsonl")


if __name__ == "__main__":
Expand Down
64 changes: 31 additions & 33 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@
StreamablePopen,
add_noisy_embeddings,
apply_gradient_checkpointing,
check_flash_attn_enabled,
check_valid_train_args,
convert_loss_to_reduce_sum,
ensure_loadable_granite_checkpoint,
ensure_loadable_dolomite_checkpoint,
get_projection_layer_names,
load_latest_full_state,
prepare_peft_model,
Expand Down Expand Up @@ -84,7 +86,7 @@ def setup_optimizer(args, model):
return optimizer


def setup_model(args, tokenizer, train_loader, grad_accum):
def setup_model(args, tokenizer, train_loader, grad_accum, flash_enabled):
bnb_config = None
if args.lora_r > 0 and args.lora_quant_bits == 4:
# Third Party
Expand All @@ -102,15 +104,11 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
"torch_dtype": torch.bfloat16,
"quantization_config": bnb_config,
}
if not args.disable_flash_attn:
if flash_enabled:
base_model_args["attn_implementation"] = "flash_attention_2"
elif args.is_granite:
raise RuntimeError(
"ERROR: Trying to use padding-free transformer without flash attention is not supported"
)

if args.is_granite:
with ensure_loadable_granite_checkpoint(
if args.use_dolomite:
with ensure_loadable_dolomite_checkpoint(
args.model_name_or_path, args.output_dir
) as path:
base_model_args["pretrained_model_name_or_path"] = path
Expand Down Expand Up @@ -165,9 +163,10 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
"Starcoder2ForCausalLM",
"GemmaForCausalLM",
"MixtralForCausalLM",
"GraniteForCausalLM",
], f"Model class name: {model.__class__.__name__} is not supported."

model = convert_loss_to_reduce_sum(model, is_granite=args.is_granite)
model = convert_loss_to_reduce_sum(model, use_dolomite=args.use_dolomite)
model = add_noisy_embeddings(model, noise_alpha=args.NEFTune_alpha)

# handling of gradient checkpointing
Expand Down Expand Up @@ -212,15 +211,15 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
target_modules=args.lora_target_modules,
)
model = prepare_peft_model(
model, peft_config, gradient_checkpointing=not args.is_granite
model, peft_config, gradient_checkpointing=not args.use_dolomite
)

elif not args.is_granite:
elif not args.use_dolomite:
model.gradient_checkpointing_enable()

# granite gradient checkpointing is handled uniformly
# for both lora and full here
if args.is_granite:
if args.use_dolomite:
block_name = model._no_split_modules[0]
apply_gradient_checkpointing(
model,
Expand Down Expand Up @@ -252,6 +251,9 @@ def make_inputs_require_grad(module, input, output):
deepcopy(train_loader),
lr_scheduler,
)
# Necessary so that Accelerate does not step once per GPU
# see https://github.com/huggingface/accelerate/blob/127818fc27ebe5cb236357fff59ff1748326d643/src/accelerate/scheduler.py#L69
lr_scheduler.split_batches = True
return model, lr_scheduler, optimizer, accelerator


Expand Down Expand Up @@ -381,8 +383,8 @@ def train(
num_loss_counted_tokens = float(
torch.tensor([batch.pop("num_loss_counted_tokens")])
)
micro_batch_size = float(len(batch["input_ids"]))
if not args.is_granite:
micro_batch_size = float(torch.tensor([batch.pop("num_samples")]))
if not args.use_dolomite:
for k in batch:
batch[k] = batch[k].to(local_rank)
output = model(
Expand Down Expand Up @@ -453,7 +455,7 @@ def train(
"batch_size": int(micro_batch_size),
"total_loss": float(log_loss / num_loss_counted_tokens),
"samples_seen": samples_seen,
# "gradnorm": global_grad_norm,
"gradnorm": global_grad_norm,
# "weight_norm": weight_norm,
}
)
Expand Down Expand Up @@ -535,6 +537,8 @@ def main(args):
torch.distributed.all_reduce(tensor)
torch.distributed.barrier()

flash_enabled = check_flash_attn_enabled(args.disable_flash_attn, args.use_dolomite)

dataset = setup_dataset(
args.data_path,
mock=args.mock_data,
Expand All @@ -547,7 +551,7 @@ def main(args):
avg_sample_len=dataset.get_lengths().mean(),
effective_batch_size=args.effective_batch_size,
max_batch_len_per_gpu=args.max_batch_len,
is_padding=not args.is_granite,
is_padding=not (args.use_dolomite or flash_enabled),
dataset=dataset,
seed=args.seed,
)
Expand All @@ -570,7 +574,8 @@ def main(args):
dataset,
tokenizer.pad_token_id,
num_workers=8,
is_granite=args.is_granite,
use_dolomite=args.use_dolomite,
flash_enabled=flash_enabled,
max_batch_len=args.max_batch_len,
packing_max_batch_len=packing_max_batch_len,
samples_per_gpu=args.samples_per_gpu,
Expand All @@ -589,7 +594,8 @@ def main(args):
dataset,
tokenizer.pad_token_id,
num_workers=8,
is_granite=args.is_granite,
use_dolomite=args.use_dolomite,
flash_enabled=flash_enabled,
max_batch_len=args.max_batch_len,
packing_max_batch_len=packing_max_batch_len,
samples_per_gpu=args.samples_per_gpu,
Expand All @@ -613,7 +619,7 @@ def main(args):
)

model, lr_scheduler, optimizer, accelerator = setup_model(
args, tokenizer, train_loader, grad_accum
args, tokenizer, train_loader, grad_accum, flash_enabled
)

load_latest_full_state(args=args, accelerator=accelerator)
Expand All @@ -639,11 +645,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
"""
Wrapper around the main training job that calls torchrun.
"""
# early validation logic here
if train_args.max_batch_len < train_args.max_seq_len:
raise ValueError(
f"the `max_batch_len` cannot be less than `max_seq_len`: {train_args.max_batch_len=} < {train_args.max_seq_len=}"
)
check_valid_train_args(train_args)

if train_args.process_data:
dp.main(
Expand Down Expand Up @@ -697,14 +699,10 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
if train_args.mock_len:
command.append(f"--mock_len={train_args.mock_len}")

if train_args.is_padding_free:
command.append("--is_granite")
if train_args.use_dolomite:
command.append("--use_dolomite")

if train_args.disable_flash_attn:
if train_args.is_padding_free:
raise RuntimeError(
"ERROR: Trying to use padding-free transformer without flash attention is not supported"
)
command.append("--disable_flash_attn")

if train_args.lora:
Expand Down Expand Up @@ -888,7 +886,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
default="SHARD_GRAD_OP",
help="Sharding strategy to be used for FSDP distributed training.",
)
parser.add_argument("--is_granite", action="store_true")
parser.add_argument("--use_dolomite", action="store_true")
parser.add_argument("--lora_r", type=int, default=0) # set to > 0 to activate lora
parser.add_argument("--lora_alpha", type=int, default=32)
parser.add_argument("--lora_dropout", type=float, default=0.1)
Expand Down Expand Up @@ -977,7 +975,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
--save_samples=250000 \
--log_level="INFO" \
--fsdp_sharding_strategy="SHARD_GRAD_OP" \
--is_granite \
--use_dolomite \
--max_batch_len 70000 \
--seed=42
"""
25 changes: 16 additions & 9 deletions src/instructlab/training/token_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
class TokenDataset(Dataset):
def __init__(self, data_path):
self.data = load_dataset("json", data_files=data_path, split="train")
self.lengths = np.array(
self.data.map(
lambda x: {"len": len(x["input_ids"])},
num_proc=8,
)["len"]
)
if "len" not in self.data.column_names:
self.lengths = np.array(
self.data.map(
lambda x: {"len": len(x["input_ids"])},
num_proc=8,
)["len"]
)
else:
self.lengths = np.array(self.data["len"])

def __len__(self):
return len(self.data)
Expand Down Expand Up @@ -87,15 +90,19 @@ def setup_dataloader(
dataset: Dataset,
pad_token_id: int,
num_workers: int = 8,
is_granite=False,
use_dolomite=False,
flash_enabled=True,
max_batch_len=60000,
packing_max_batch_len=60000,
samples_per_gpu=None,
sampler="multipack",
seed=47,
) -> DataLoader:
collate_fn = make_collate_fn(
pad_token_id, is_granite=is_granite, max_batch_len=max_batch_len
pad_token_id,
use_dolomite=use_dolomite,
flash_enabled=flash_enabled,
max_batch_len=max_batch_len,
)
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
Expand All @@ -108,7 +115,7 @@ def setup_dataloader(
num_replicas=world_size,
rank=rank,
seed=seed,
padding=not is_granite,
padding=not flash_enabled,
)
sampler = {"batch_sampler": sampler}
elif sampler == "distributed":
Expand Down
Loading