diff --git a/apps/sft/main.py b/apps/sft/main.py index 9781dad5c..d15c5b086 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -92,12 +92,22 @@ def setup_data(self): ), ) + # dataset = sft_iterable_dataset( + # model_transform=tokenizer, + # message_transform=AlpacaToMessages(), + # path="yahma/alpaca-cleaned", + # split="train", + # ) + dataset = sft_iterable_dataset( + path="arrow", + data_dir="/mnt/mffuse/forge/alpaca-cleaned/train", model_transform=tokenizer, message_transform=AlpacaToMessages(), - path="yahma/alpaca-cleaned", + data_files={"train": "data-00000-of-00001.arrow"}, split="train", ) + packer = TextPacker(padding_idx=0) dataset = PackedDataset( dataset=dataset,