In [1]:
import pyarrow.parquet as pq

import transformers
import numpy as np
import json


DATASET = "../../dataset_processed_v3/sharegpt_v3.2.train.parquet"
TOKENIZER_NAME = "imone/LLaMA2_13B_with_EOT_token"

In [2]:
# Load dataset
dataset = pq.read_table(DATASET)

# Load tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_NAME)

In [3]:
# Length statistics
lengths = np.array(dataset.column("total_length"))

print (np.min(lengths), np.mean(lengths), np.max(lengths))

print (np.argsort(lengths)[:100])

18 1911.0328626542228 4096
[21436 27808 12549 74425 34001 22231 12524  5758 53903 36293 32039 34018
 12534 61012 36128 67877 59158 19842 39765 25296 32036  3156 54101 69927
 10372 41171 77743 28537 24479 39592 50125 12681  8676 76826 64141 71009
 70235 59193  1666 22873 45708 45607 56562 22487 27296 22017 48507 28541
   909 20975 20955 35385 55246 20921 48195  1823  7913 50270 77569 26721
 37832 12089 55575 32153 63153 54966 21406 77440 19843 63211 73818 21762
 45065 63244  4330 11884   256 39383 10885  2080 59535 60511 63385  7756
 56739 23108 75592  5754 40388  8506 67092 13778 47203 66085 61831 33750
 52926  9651 52932 25011]


In [4]:
# Print the supervised and unsupervised text
sample_index = 21204

sample = dataset.take([sample_index]).to_pydict()

token_ids = np.array(sample["0_tokens"])
masks     = np.array(sample["0_masks"])

print(sample)

print(tokenizer.decode(token_ids[~masks].tolist(), spaces_between_special_tokens=False).replace("<|end_of_turn|>", "<|end_of_turn|>\n\n"))
print("=================")
print(tokenizer.decode(token_ids[masks].tolist(), spaces_between_special_tokens=False).replace("<|end_of_turn|>", "<|end_of_turn|>\n\n"))

{'total_length': [961], 'num_seqs': [1], '0_tokens': [[1, 402, 7982, 29941, 4911, 29901, 7251, 32000, 402, 7982, 29941, 4007, 22137, 29901, 15043, 29991, 1128, 508, 306, 6985, 366, 9826, 29973, 32000, 402, 7982, 29941, 4911, 29901, 437, 366, 1073, 10377, 5847, 32000, 402, 7982, 29941, 4007, 22137, 29901, 3869, 29892, 306, 437, 29991, 5641, 9998, 15028, 363, 4124, 20334, 300, 653, 3497, 2184, 29889, 739, 338, 263, 9608, 322, 3564, 8688, 304, 1653, 263, 23533, 29899, 517, 29899, 412, 261, 1158, 310, 15446, 322, 19383, 11266, 9799, 297, 263, 13235, 934, 1788, 29889, 512, 916, 3838, 29892, 5641, 9998, 6511, 366, 304, 3787, 322, 6232, 2066, 297, 263, 27189, 1705, 1891, 8214, 29892, 1728, 337, 5890, 373, 263, 6555, 1891, 1923, 29889, 910, 508, 367, 5407, 363, 8324, 988, 848, 4225, 304, 367, 3625, 1584, 565, 278, 2441, 1923, 5771, 1623, 470, 338, 12522, 1623, 29892, 408, 1532, 408, 363, 8324, 988, 4160, 864, 304, 6232, 2919, 2066, 1728, 5528, 1038, 292, 278, 21544, 310, 6555, 1891, 8635, 322,