From cb39ead13142bb3e67092600f67e55199886f75c Mon Sep 17 00:00:00 2001 From: jianzhnie Date: Fri, 28 Jul 2023 09:32:08 +0800 Subject: [PATCH] update datasets --- chatllms/configs/data_args.py | 12 +++++++++++- chatllms/data/data_utils.py | 8 +++++--- train_qlora.py | 2 -- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/chatllms/configs/data_args.py b/chatllms/configs/data_args.py index 9682923..bc60d7a 100644 --- a/chatllms/configs/data_args.py +++ b/chatllms/configs/data_args.py @@ -17,7 +17,12 @@ class DatasetAttr(object): multi_turn: Optional[bool] = False def __repr__(self) -> str: - rep = f'dataset_name: {self.dataset_name}, hf_hub_url: {self.hf_hub_url}, local_path: {self.local_path}, data_formate:{self.dataset_format} load_from_local: {self.load_from_local}, multi_turn: {self.multi_turn}' + rep = (f'dataset_name: {self.dataset_name} || ' + f'hf_hub_url: {self.hf_hub_url} || ' + f'local_path: {self.local_path} \n' + f'data_formate: {self.dataset_format} || ' + f'load_from_local: {self.load_from_local} || ' + f'multi_turn: {self.multi_turn}') return rep def __post_init__(self): @@ -104,6 +109,11 @@ def init_for_training(self): # support mixing multiple datasets if datasets_info[name]['local_path'] and os.path.exists( datasets_info[name]['local_path']): dataset_attr.load_from_local = True + else: + dataset_attr.load_from_local = False + raise Warning( + 'You have set local_path for {} but it does not exist! Will load the data from {}' + .format(name, dataset_attr.hf_hub_url)) if 'columns' in datasets_info[name]: dataset_attr.prompt_column = datasets_info[name][ diff --git a/chatllms/data/data_utils.py b/chatllms/data/data_utils.py index 7651e9b..f0987a2 100644 --- a/chatllms/data/data_utils.py +++ b/chatllms/data/data_utils.py @@ -471,12 +471,14 @@ def make_data_module(args): for dataset_attr in args.datasets_list: print('=' * 80) - print('DatasetAttr: {}...'.format(dataset_attr)) + print('DatasetAttr: {}'.format(dataset_attr)) if dataset_attr.load_from_local: dataset_path = dataset_attr.local_path elif dataset_attr.hf_hub_url: dataset_path = dataset_attr.hf_hub_url + else: + raise ValueError('Please set the dataset path or hf_hub_url.') dataset = load_data(dataset_path, eval_dataset_size=args.eval_dataset_size) @@ -498,11 +500,11 @@ def make_data_module(args): max_train_samples=args.max_train_samples, ) if train_dataset: - print('loaded dataset:', dataset_attr.dataset_name, + print('loaded dataset:', dataset_attr.dataset_name, ' ', '#train data size:', len(train_dataset)) train_datasets.append(train_dataset) if eval_dataset: - print('loaded dataset:', dataset_attr.dataset_name, + print('loaded dataset:', dataset_attr.dataset_name, ' ' '#eval data size:', len(eval_dataset)) eval_datasets.append(eval_dataset) diff --git a/train_qlora.py b/train_qlora.py index 2d02e26..d7ad402 100644 --- a/train_qlora.py +++ b/train_qlora.py @@ -34,8 +34,6 @@ def main(): args = argparse.Namespace(**vars(model_args), **vars(data_args), **vars(training_args), **vars(lora_args), **vars(quant_args)) - - print(args.datasets_list) # init the logger before other steps timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) if not os.path.exists(args.output_dir):