Skip to content

Commit

Permalink
Merge pull request #77 from jianzhnie/dev
Browse files Browse the repository at this point in the history
update datasets
  • Loading branch information
jianzhnie committed Jul 28, 2023
2 parents c46cd52 + cb39ead commit 97e34d1
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
12 changes: 11 additions & 1 deletion chatllms/configs/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ class DatasetAttr(object):
multi_turn: Optional[bool] = False

def __repr__(self) -> str:
rep = f'dataset_name: {self.dataset_name}, hf_hub_url: {self.hf_hub_url}, local_path: {self.local_path}, data_formate:{self.dataset_format} load_from_local: {self.load_from_local}, multi_turn: {self.multi_turn}'
rep = (f'dataset_name: {self.dataset_name} || '
f'hf_hub_url: {self.hf_hub_url} || '
f'local_path: {self.local_path} \n'
f'data_formate: {self.dataset_format} || '
f'load_from_local: {self.load_from_local} || '
f'multi_turn: {self.multi_turn}')
return rep

def __post_init__(self):
Expand Down Expand Up @@ -104,6 +109,11 @@ def init_for_training(self): # support mixing multiple datasets
if datasets_info[name]['local_path'] and os.path.exists(
datasets_info[name]['local_path']):
dataset_attr.load_from_local = True
else:
dataset_attr.load_from_local = False
raise Warning(
'You have set local_path for {} but it does not exist! Will load the data from {}'
.format(name, dataset_attr.hf_hub_url))

if 'columns' in datasets_info[name]:
dataset_attr.prompt_column = datasets_info[name][
Expand Down
8 changes: 5 additions & 3 deletions chatllms/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,12 +471,14 @@ def make_data_module(args):

for dataset_attr in args.datasets_list:
print('=' * 80)
print('DatasetAttr: {}...'.format(dataset_attr))
print('DatasetAttr: {}'.format(dataset_attr))

if dataset_attr.load_from_local:
dataset_path = dataset_attr.local_path
elif dataset_attr.hf_hub_url:
dataset_path = dataset_attr.hf_hub_url
else:
raise ValueError('Please set the dataset path or hf_hub_url.')

dataset = load_data(dataset_path,
eval_dataset_size=args.eval_dataset_size)
Expand All @@ -498,11 +500,11 @@ def make_data_module(args):
max_train_samples=args.max_train_samples,
)
if train_dataset:
print('loaded dataset:', dataset_attr.dataset_name,
print('loaded dataset:', dataset_attr.dataset_name, ' ',
'#train data size:', len(train_dataset))
train_datasets.append(train_dataset)
if eval_dataset:
print('loaded dataset:', dataset_attr.dataset_name,
print('loaded dataset:', dataset_attr.dataset_name, ' '
'#eval data size:', len(eval_dataset))
eval_datasets.append(eval_dataset)

Expand Down
2 changes: 0 additions & 2 deletions train_qlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def main():
args = argparse.Namespace(**vars(model_args), **vars(data_args),
**vars(training_args), **vars(lora_args),
**vars(quant_args))

print(args.datasets_list)
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
if not os.path.exists(args.output_dir):
Expand Down

0 comments on commit 97e34d1

Please sign in to comment.