Skip to content

Commit

Permalink
show more info about datasets (InternLM#464)
Browse files Browse the repository at this point in the history
* show more info about datasets

* update log_dataset.py

* update log_dataset.py
  • Loading branch information
amulil committed Mar 14, 2024
1 parent 9f976c6 commit 30bdd1f
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions xtuner/tools/log_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
def parse_args():
parser = argparse.ArgumentParser(description='Log processed dataset.')
parser.add_argument('config', help='config file name or path.')
# chose which kind of dataset style to show
parser.add_argument(
'--show',
default='text',
choices=['text', 'masked_text', 'input_ids', 'labels', 'all'],
help='which kind of dataset style to show')
args = parser.parse_args()
return args

Expand All @@ -24,12 +30,22 @@ def main():
else:
train_dataset = BUILDER.build(cfg.train_dataloader.dataset)

print('#' * 20 + ' text ' + '#' * 20)
print(tokenizer.decode(train_dataset[0]['input_ids']))
print('#' * 20 + ' input_ids ' + '#' * 20)
print(train_dataset[0]['input_ids'])
print('#' * 20 + ' labels ' + '#' * 20)
print(train_dataset[0]['labels'])
if args.show == 'text' or args.show == 'all':
print('#' * 20 + ' text ' + '#' * 20)
print(tokenizer.decode(train_dataset[0]['input_ids']))
if args.show == 'masked_text' or args.show == 'all':
print('#' * 20 + ' text(masked) ' + '#' * 20)
masked_text = ' '.join(
['[-100]' for i in train_dataset[0]['labels'] if i == -100])
unmasked_text = tokenizer.decode(
[i for i in train_dataset[0]['labels'] if i != -100])
print(masked_text + ' ' + unmasked_text)
if args.show == 'input_ids' or args.show == 'all':
print('#' * 20 + ' input_ids ' + '#' * 20)
print(train_dataset[0]['input_ids'])
if args.show == 'labels' or args.show == 'all':
print('#' * 20 + ' labels ' + '#' * 20)
print(train_dataset[0]['labels'])


if __name__ == '__main__':
Expand Down

0 comments on commit 30bdd1f

Please sign in to comment.