Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stuck on training: Created a PretokDataset with rng seed 42 #311

Open
madroidmaq opened this issue Aug 17, 2023 · 22 comments
Open

Stuck on training: Created a PretokDataset with rng seed 42 #311

madroidmaq opened this issue Aug 17, 2023 · 22 comments

Comments

@madroidmaq
Copy link
Contributor

When I try to train the model, I run into some problems, don't know if anyone has the same problem or how should I solve this problem.

When I execute the training code (below), the log will always be stuck on the output of Created a PretokDataset with rng seed 42, and there will be no change for several hours.

Below are some key steps I performed along with my device information.

python train.py

The corresponding output is roughly as follows:

(base) jupyter@instance-20230817-103839:~/llama2/llama2.c$ python train.py
tokens per iteration will be: 131,072
breaks down as: 4 grad accum steps * 1 processes * 128 batch size * 256 max seq len
Initializing a new model from scratch
num decayed parameter tensors: 43, with 15,187,968 parameters
num non-decayed parameter tensors: 13, with 3,744 parameters
using fused AdamW: True
Created a PretokDataset with rng seed 42

The GPU information of my machine is roughly as follows:

Thu Aug 17 05:21:51 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-SXM...  On   | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P0    52W / 400W |   1045MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A     40510      C   python                           1042MiB |
+-----------------------------------------------------------------------------+

The CPU information of my machine is roughly as follows:12 vCPUs, 85GB RAM

@CatTimson
Copy link

I have exactly the same issue.

"Created a PretokDataset with rng seed 42" and then nothing

@RahulSChand
Copy link
Contributor

@CatTimson @madroidmaq can you try setting pin_memory=False in this line ?

@madroidmaq madroidmaq changed the title Stuck on training dataset Stuck on training dataset: Created a PretokDataset with rng seed 42 Aug 17, 2023
@madroidmaq
Copy link
Contributor Author

#296 Should be the same problem

@madroidmaq
Copy link
Contributor Author

madroidmaq commented Aug 17, 2023

@CatTimson @madroidmaq can you try setting pin_memory=False in this line ?

@RahulSChand According to your method, my problem disappeared, I can train and see the detailed data of each step of training, thank you very much. like this:

...
10405 | loss 1.2478 | lr 4.889483e-04 | 345.92ms | mfu 11.71%
10406 | loss 1.2216 | lr 4.889459e-04 | 346.45ms | mfu 11.71%
10407 | loss 1.2321 | lr 4.889436e-04 | 346.29ms | mfu 11.71%
10408 | loss 1.2726 | lr 4.889413e-04 | 345.98ms | mfu 11.71%
10409 | loss 1.2325 | lr 4.889389e-04 | 346.04ms | mfu 11.71%
...

I'd like to know why this tweak works, or what sources I should be looking at for this information. Looking forward to your reply.

@CatTimson
Copy link

CatTimson commented Aug 18, 2023

@CatTimson @madroidmaq can you try setting pin_memory=False in this line ?

No, it did not resolve my issue. I was still observing ""Created a PretokDataset with rng seed 42"" for about one hour before I cancelled the script.

What if the issue is software package dependent? Whoever can run the training successfully would you mind to share you software configs?.

Mine is :
Ubuntu 22.04
Nvidia driver 530, Nvidia A5000 card. CPU i5 10600KF or something like that
Python 10.10

@madroidmaq
Copy link
Contributor Author

@CatTimson my device

image

@RahulSChand
Copy link
Contributor

RahulSChand commented Aug 18, 2023

@CatTimson @madroidmaq can you try setting pin_memory=False in this line ?

@RahulSChand According to your method, my problem disappeared, I can train and see the detailed data of each step of training, thank you very much. like this:

...
10405 | loss 1.2478 | lr 4.889483e-04 | 345.92ms | mfu 11.71%
10406 | loss 1.2216 | lr 4.889459e-04 | 346.45ms | mfu 11.71%
10407 | loss 1.2321 | lr 4.889436e-04 | 346.29ms | mfu 11.71%
10408 | loss 1.2726 | lr 4.889413e-04 | 345.98ms | mfu 11.71%
10409 | loss 1.2325 | lr 4.889389e-04 | 346.04ms | mfu 11.71%
...

I'd like to know why this tweak works, or what sources I should be looking at for this information. Looking forward to your reply.

@madroidmaq since training gets stuck at dataloader, I looked up PyTorch issues for it & found a similar issue other people reported when pin_memory=True (link). You can read that thread to know more about why it happens.

@RahulSChand
Copy link
Contributor

@CatTimson what is your PyTorch version? Use print(torch.__version__)
Pytorch 2.0.1+cu117 works for me. You can get same version in new environment by pip install torch==2.0.0+cu117 --index-url https://download.pytorch.org/whl/cu117

@madroidmaq
Copy link
Contributor Author

I also encountered the same problem on another device, but modifying pin_memory=False did not solve the problem, I think this may not be a general solution.

@madroidmaq madroidmaq reopened this Aug 18, 2023
@RahulSChand
Copy link
Contributor

@madroidmaq can you remove this time.time() from following lines?

t0 = time.time()
t1 = time.time()

Just put t0=0 and t1=0 & check?

Also you can change below to print(...., flush=True) to see if the issue is occurring after first loss.backward

print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

I can't reproduce the error so there is no way for me to check if the suggestion is actually correct but worth a try

@madroidmaq
Copy link
Contributor Author

@RahulSChand Thank you for your reply. I modified the file according to your suggestion, but the result is still not working. It will still get stuck after outputting Created a PretokDataset with rng seed 42.

(base) jupyter@umn-20230612-000220:~/llama2/llama2.c$ git diff --cached
diff --git a/tinystories.py b/tinystories.py
index 690cb02..7e46ee8 100644
--- a/tinystories.py
+++ b/tinystories.py
@@ -235,7 +235,7 @@ class Task:
     def iter_batches(batch_size, device, num_workers=0, **dataset_kwargs):
         ds = PretokDataset(**dataset_kwargs)
         dl = torch.utils.data.DataLoader(
-            ds, batch_size=batch_size, pin_memory=True, num_workers=num_workers
+            ds, batch_size=batch_size, pin_memory=False, num_workers=num_workers
         )
         for x, y in dl:
             x = x.to(device, non_blocking=True)
diff --git a/train.py b/train.py
index b1972dc..ace7e02 100644
--- a/train.py
+++ b/train.py
@@ -246,7 +246,7 @@ if wandb_log and master_process:
 # training loop
 train_batch_iter = iter_batches(split="train")
 X, Y = next(train_batch_iter)  # fetch the very first batch
-t0 = time.time()
+t0 = 0
 local_iter_num = 0  # number of iterations in the lifetime of this process
 raw_model = model.module if ddp else model  # unwrap DDP container if needed
 running_mfu = -1.0
@@ -259,7 +259,7 @@ while True:
     # evaluate the loss on train/val sets and write checkpoints
     if iter_num % eval_interval == 0 and master_process:
         losses = estimate_loss()
-        print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
+        print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}",flush=True)
         if wandb_log:
             try:
                 wandb.log(
@@ -319,7 +319,7 @@ while True:
     optimizer.zero_grad(set_to_none=True)
 
     # timing and logging
-    t1 = time.time()
+    t1 = 0
     dt = t1 - t0
     t0 = t1
     if iter_num % log_interval == 0 and master_process:

The full output is:

(base) jupyter@umn-20230612-000220:~/llama2/llama2.c$ python train.py --vocab_source=custom --vocab_size=4096
Overriding: vocab_source = custom
Overriding: vocab_size = 4096
tokens per iteration will be: 131,072
breaks down as: 4 grad accum steps * 1 processes * 128 batch size * 256 max seq len
Initializing a new model from scratch
num decayed parameter tensors: 43, with 7,151,616 parameters
num non-decayed parameter tensors: 13, with 3,744 parameters
using fused AdamW: True
compiling the model... (takes a ~minute)
Created a PretokDataset with rng seed 42

@madroidmaq madroidmaq changed the title Stuck on training dataset: Created a PretokDataset with rng seed 42 Stuck on training: Created a PretokDataset with rng seed 42 Aug 18, 2023
@CatTimson
Copy link

My original configuration was:
nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Jun__8_16:49:14_PDT_2022
Cuda compilation tools, release 11.7, V11.7.99
Build cuda_11.7.r11.7/compiler.31442593_0

2.0.0+cu117

Then I did complete cleanup/reinstall of everything

Now :
nvidia-smi
Fri Aug 18 20:23:40 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10 Driver Version: 535.86.10 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA RTX A5000 On | 00000000:01:00.0 On | Off |
| 30% 35C P8 23W / 230W | 647MiB / 24564MiB | 5% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 2057 G /usr/lib/xorg/Xorg 212MiB |
| 0 N/A N/A 2268 G /usr/bin/gnome-shell 85MiB |
| 0 N/A N/A 5498 C python 264MiB |
| 0 N/A N/A 5669 G ...ures=TFLiteLanguageDetectionEnabled 69MiB |
+---------------------------------------------------------------------------------------+
(py31010) tim@linux-ws01:~$ nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0

Result exactly the same :
"Created a PretokDataset with rng seed 42"
and nothingness ....

@kunwar-vikrant
Copy link

getting similar issue with same cuda config, totally lost what could be wrong here!

@RahulSChand
Copy link
Contributor

RahulSChand commented Aug 19, 2023

@CatTimson @kunwar-vikrant @madroidmaq I was able to reproduce the error when using custom dataset. It happens because the data_dir path ./data/tok{vocab_size}/ doesn't have any .bin files.

You can confirm if this is the case for you by adding a print(shard_filenames) before this line

rng.shuffle(shard_filenames)

If running train.py now prints [] then you either need to rerun your train_vocab/pretokenize steps for TInyStories or check if you have changed this hardcoded "story" to fit your custom json data

text = example["story"]

For example, my json data is

{
        "query": "My order hasn't arrived yet.",
        "response": "We apologize for the inconvenience. Can you please provide your order number so we can investigate?"
}

So I change "story" to "response" & then run the vocab/tokenization steps again.

@madroidmaq
Copy link
Contributor Author

@CatTimson @kunwar-vikrant @madroidmaq I was able to reproduce the error when using custom dataset (you shouldn't get any errors if you are using TinyStories though). It happens because the data_dir path ./data/tok{vocab_size}/ doesn't have any .bin files.

You can confirm if this is the case for you by adding a print(shard_filenames) before this line

rng.shuffle(shard_filenames)

If running train.py now prints [] then you need to change this hardcoded "story" to whatever is in your json data format

text = example["story"]

For example, my json data is

{
        "query": "My order hasn't arrived yet.",
        "response": "We apologize for the inconvenience. Can you please provide your order number so we can investigate?"
}

So I change "story" to "response" & then run the vocab/tokenization steps again.

@RahulSChand I re-tested according to your method, and found that it is indeed what you said, and [] will be printed all the time.

However, I am not training my own data set, but using a custom train_vocab method, the specific execution sequence is as follows:

python tinystories.py download
python tinystories.py train_vocab --vocab_size=4096
python tinystories.py pretokenize --vocab_size=4096
python train.py --vocab_source=custom --vocab_size=4096

@RahulSChand
Copy link
Contributor

@madroidmaq what is inside your data/tok4096/ folder? It should have a bunch of .bin files. If it's empty then train.py will be stuck. If its empty then try rerunning train_vocab/pretokenize steps & see if they complete successfully.

@madroidmaq
Copy link
Contributor Author

@RahulSChand You are right, I encountered spm_train: command not found when training a custom vocabulary, so there is no corresponding .bin file generated, I did not notice this problem at first. The detailed log output is as follows:

(base) jupyter@umn-20230612-000220:~/llama2/llama2.c$ python tinystories.py train_vocab --vocab_size=4096
Writing temporary file data/tiny.txt with 10 shards...
100%|█████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:14<00:00,  1.48s/it]
Size is: 739.57 MB
Will now train the vocab with:
bash train_vocab.sh data/tiny.txt data/tok4096 4096
OK? [y/N] 
y
Input: data/tiny.txt
Model Prefix: data/tok4096
Vocabulary Size: 4096
train_vocab.sh: line 114: spm_train: command not found
Delete the temporary file data/tiny.txt? [y/N] y
Deleted data/tiny.txt
Trained tokenizer is in data/tok4096.model
Done.

To solve the spm_train: command not found problem, follow the instructions of sentencepiece#installation, to use the spm_train command line command, you need Download the corresponding source code to compile, this process is relatively troublesome. In fact, you can use the version of python to call, so I modified the script, roughly as follows:

diff --git a/train_vocab.sh b/train_vocab.sh
index 7803af8..454cb12 100755
--- a/train_vocab.sh
+++ b/train_vocab.sh
@@ -111,16 +111,29 @@ echo "Vocabulary Size: $vocab_size"
 # --byte_fallback is true, default in spm is false
 # --normalization_rule_name is identity, default in spm is nmt_nfkc

-spm_train --input="$input" \
-          --model_prefix="$model_prefix" \
-          --model_type=bpe \
-          --vocab_size="$vocab_size" \
-          --self_test_sample_size=0 \
-          --input_format="text" \
-          --character_coverage=1.0 \
-          --num_threads="$(nproc)" \
-          --split_digits=true \
-          --allow_whitespace_only_pieces=true \
-          --byte_fallback=true \
-          --unk_surface=" \342\201\207 " \
-          --normalization_rule_name=identity \
+python3 << END
+import sentencepiece as spm
+import os
+
+input_path = "$input"
+model_prefix = "$model_prefix"
+vocab_size = "$vocab_size"
+num_threads = os.cpu_count()
+
+spm.SentencePieceTrainer.train(
+    f'--input={input_path} '
+    f'--model_prefix={model_prefix} '
+    '--model_type=bpe '
+    f'--vocab_size={vocab_size} '
+    '--self_test_sample_size=0 '
+    '--input_format=text '
+    '--character_coverage=1.0 '
+    f'--num_threads={num_threads} '
+    '--split_digits=true '
+    '--allow_whitespace_only_pieces=true '
+    '--byte_fallback=true '
+    '--unk_surface= \342\201\207 '
+    '--normalization_rule_name=identity'
+)
+END
+ \

@karpathy I'm not sure if the above code changes can be accepted, if so I will make a PR to submit.

@RahulSChand
Copy link
Contributor

@madroidmaq You can do apt install sentencepiece and the spm_train command should work. I don't think this change in script is necessary.

@madroidmaq
Copy link
Contributor Author

madroidmaq commented Aug 19, 2023

@RahulSChand You are correct, the spm_train command line tool can be installed using apt install sentencepiece/sudo apt install sentencepiece.

(base) jupyter@umn-20230612-000220:~/llama2/llama2.c$ python tinystories.py train_vocab --vocab_size=4096
Writing temporary file data/tiny.txt with 10 shards...
100%|█████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:14<00:00,  1.49s/it]
Size is: 739.57 MB
Will now train the vocab with:
bash train_vocab.sh data/tiny.txt data/tok4096 4096
OK? [y/N] 
y
Input: data/tiny.txt
Model Prefix: data/tok4096
Vocabulary Size: 4096
Unknown/Invalid flag allow_whitespace_only_pieces

sentencepiece

Usage: spm_train [options] files

   --help (show help)  type: bool default: false
   --version (show version)  type: bool default: false
   --minloglevel (Messages logged at a lower level than this don't actually get logged anywhere)  type: int default: 0
   --input (comma separated list of input sentences)  type: std::string default: ""
   --input_format (Input format. Supported format is `text` or `tsv`.)  type: std::string default: ""
   --model_prefix (output model prefix)  type: std::string default: ""
   --model_type (model algorithm: unigram, bpe, word or char)  type: std::string default: "unigram"
   --vocab_size (vocabulary size)  type: int32 default: 8000
   --accept_language (comma-separated list of languages this model can accept)  type: std::string default: ""
   --self_test_sample_size (the size of self test samples)  type: int32 default: 0
   --character_coverage (character coverage to determine the minimum symbols)  type: double default: 0.9995
   --input_sentence_size (maximum size of sentences the trainer loads)  type: std::uint64_t default: 0
   --shuffle_input_sentence (Randomly sample input sentences in advance. Valid when --input_sentence_size > 0)  type: bool default: true
   --seed_sentencepiece_size (the size of seed sentencepieces)  type: int32 default: 1000000
   --shrinking_factor (Keeps top shrinking_factor pieces with respect to the loss)  type: double default: 0.75
   --num_threads (number of threads for training)  type: int32 default: 16
   --num_sub_iterations (number of EM sub-iterations)  type: int32 default: 2
   --max_sentencepiece_length (maximum length of sentence piece)  type: int32 default: 16
   --max_sentence_length (maximum length of sentence in byte)  type: int32 default: 4192
   --split_by_unicode_script (use Unicode script to split sentence pieces)  type: bool default: true
   --split_by_number (split tokens by numbers (0-9))  type: bool default: true
   --split_by_whitespace (use a white space to split sentence pieces)  type: bool default: true
   --split_digits (split all digits (0-9) into separate pieces)  type: bool default: false
   --treat_whitespace_as_suffix (treat whitespace marker as suffix instead of prefix.)  type: bool default: false
   --control_symbols (comma separated list of control symbols)  type: std::string default: ""
   --control_symbols_file (load control_symbols from file.)  type: std::string default: ""
   --user_defined_symbols (comma separated list of user defined symbols)  type: std::string default: ""
   --user_defined_symbols_file (load user_defined_symbols from file.)  type: std::string default: ""
   --required_chars (UTF8 characters in this flag are always used in the character set regardless of --character_coverage)  type: std::string default: ""
   --required_chars_file (load required_chars from file.)  type: std::string default: ""
   --byte_fallback (decompose unknown pieces into UTF-8 byte pieces)  type: bool default: false
   --vocabulary_output_piece_score (Define score in vocab file)  type: bool default: true
   --normalization_rule_name (Normalization rule name. Choose from nfkc or identity)  type: std::string default: "nmt_nfkc"
   --normalization_rule_tsv (Normalization rule TSV file. )  type: std::string default: ""
   --denormalization_rule_tsv (Denormalization rule TSV file.)  type: std::string default: ""
   --add_dummy_prefix (Add dummy whitespace at the beginning of text)  type: bool default: true
   --remove_extra_whitespaces (Removes leading, trailing, and duplicate internal whitespace)  type: bool default: true
   --hard_vocab_limit (If set to false, --vocab_size is considered as a soft limit.)  type: bool default: true
   --use_all_vocab (If set to true, use all tokens as vocab. Valid for word/char models.)  type: bool default: false
   --unk_id (Override UNK (<unk>) id.)  type: int32 default: 0
   --bos_id (Override BOS (<s>) id. Set -1 to disable BOS.)  type: int32 default: 1
   --eos_id (Override EOS (</s>) id. Set -1 to disable EOS.)  type: int32 default: 2
   --pad_id (Override PAD (<pad>) id. Set -1 to disable PAD.)  type: int32 default: -1
   --unk_piece (Override UNK (<unk>) piece.)  type: std::string default: "<unk>"
   --bos_piece (Override BOS (<s>) piece.)  type: std::string default: "<s>"
   --eos_piece (Override EOS (</s>) piece.)  type: std::string default: "</s>"
   --pad_piece (Override PAD (<pad>) piece.)  type: std::string default: "<pad>"
   --unk_surface (Dummy surface string for <unk>. In decoding <unk> is decoded to `unk_surface`.)  type: std::string default: " ⁇ "
   --train_extremely_large_corpus (Increase bit depth for unigram tokenization.)  type: bool default: false
   --random_seed (Seed value for random generator.)  type: int32 default: -1

The version number in the project is: 0.1.95-1, the latest version is actually 0.1.99, the prompt when I execute sudo apt-get install --only-upgrade sentencepiece is as follows:

(base) jupyter@umn-20230612-000220:~/llama2/llama2.c$ sudo apt-get install --only-upgrade sentencepiece
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
sentencepiece is already the newest version (0.1.95-1).

The sentencepiece package in my python environment is the latest version:

(base) jupyter@umn-20230612-000220:~/llama2/llama2.c$ python -c "import sentencepiece as spm; print(spm.__version__)"
0.1.99

So, I should still need to manually compile the sentencepiece project. All in all, thanks a lot for your guidance.

@CatTimson
Copy link

I have discovered another bug:
If I only use a single json file (data00.json) the train fail, if i use at least two json files (data00.json + data01.json), it works perfectly fine.

Example below ( single data00.json)
tim@linux-ws01:/AI/llama-2$ python3 tinystories.py pretokenize
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 100000/100000 [00:23<00:00, 4254.65it/s]
Saved data/TinyStories_all_data/data00.bin, average seqlen: 206.59
Done.
tim@linux-ws01:
/AI/llama-2$ python3 train.py
tokens per iteration will be: 131,072
breaks down as: 4 grad accum steps * 1 processes * 128 batch size * 256 max seq len
Initializing a new model from scratch
num decayed parameter tensors: 43, with 15,187,968 parameters
num non-decayed parameter tensors: 13, with 3,744 parameters
using fused AdamW: True
compiling the model... (takes a ~minute)
Created a PretokDataset with rng seed 42
Traceback (most recent call last):
File "/home/tim/AI/llama-2/train.py", line 248, in
X, Y = next(train_batch_iter) # fetch the very first batch
File "/home/tim/AI/llama-2/tinystories.py", line 241, in iter_batches
for x, y in dl:
File "/home/tim/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 633, in next
data = self._next_data()
File "/home/tim/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 677, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/home/tim/.local/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 32, in fetch
data.append(next(self.dataset_iter))
File "/home/tim/AI/llama-2/tinystories.py", line 199, in iter
assert len(shard_filenames)>0, f"No bin files found in {bin_dir}"
AssertionError: No bin files found in data/TinyStories_all_data

@mvuthegoat
Copy link

I have a side question, when will the while loop in the def __iter__(self) function of PretokDataset break?

llama2.c/tinystories.py

Lines 206 to 223 in c7a2626

while True:
rng.shuffle(shard_filenames)
for shard in shard_filenames:
# open the dataset for reading but keep it on disk with memmap
m = np.memmap(shard, dtype=np.uint16, mode="r")
num_batches = len(m) // self.max_seq_len
num_batches -= 1 # drop the last partial batch
assert num_batches > 0, "this shard is way too small? investigate."
ixs = list(range(num_batches))
rng.shuffle(ixs)
for ix in ixs:
start = ix * self.max_seq_len
end = start + self.max_seq_len + 1
# calling .astype will copy the data into a new numpy array, now in RAM
chunk = torch.from_numpy((m[start:end]).astype(np.int64))
x = chunk[:-1]
y = chunk[1:]
yield x, y

@RahulSChand
Copy link
Contributor

@mvuthegoat It doesn't need to break. It will yield (run) until we keep calling next on the iterator (like this line in train.py).

X, Y = next(train_batch_iter)

This isn't a normal while loop, it has yield statement so it runs until either an internal condition is met (which isn't there in the code you mentioned) or until we stop calling yield & the whole program terminates (when train.py finishes)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants