Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extend preprocess_data_dist to handle jsonl files #60

Open
wants to merge 76 commits into
base: main
Choose a base branch
from

Conversation

adammoody
Copy link
Contributor

@adammoody adammoody commented Aug 11, 2021

This extends tools/preprocess_dataset_dist.py to handle JSONL files as an input dataset.

It defines a new IndexedJSON class in tools/indexed_json.py that creates and uses an index for the JSONL file. The index enables one to read records in random access from the file. If the index file does not exist, it is constructed on the fly and then stored along side the json file. For a file named dataset.jsonl, the index file is created as dataset.jsonl.idx.

To build the index, processes collectively read consecutive sections of the file and split records on newline characters. Scan collectives are used to compute the byte offsets. The offsets are written to a temporary file called dataset.jsonl.idxtmp which is then processed further to compute the record lengths and the final index file.

An example SLURM script (not tested) to run this script with multiple nodes: bigscience-workshop/bigscience#4

TODO:

  • support torch.distributed
  • handle exceptions to avoid deadlocks
  • for small datasets, read full index by rank 0, bcast to all ranks, and store in memory
  • define a file header with version number, store bytes in network order
  • research for existing formats that could be used instead

Scaling tests:
With this PR so that I can read in JSON files, I am able to test encoding with the oscar dataset. This also required the PR for scatter #63, since otherwise the index list is too large to bcast to all ranks (insufficient memory to hold the full list on all ranks).

I ran some scaling tests using 40 procs/node. Here is the time to process all samples, which measures the cost to read, encode, and write each sample to its per-rank file. In particular, the startup and merge costs are not included here.

nodes     8   16   32   64
secs:  4643 2393 1445  769
MB/s    264  512  848 1594

So scaling is pretty solid so far.

The results above are not using --shuffle. If I add that, the cost jumps dramatically. For a test with 16 nodes, I get:

      no shuffle  --shuffle
secs:       2393      21388
MB/s         512         57

With --shuffle, it's 10x slower on 16 nodes. With additional timers, I can see the read cost is where that increased time comes from. I'm sure that's due to the random access through the json file. To fix that, one will have to pre-shuffle the json file to get good scaling.

Update: Regarding performance on a shuffled dataset, it may help to disable buffering in the IndexedJSON class.

        self.fh_idx = open(filename_idx, "rb", buffering=0)
        self.fh_json = open(filename, "rb", buffering=0)

I added a column for those numbers below.

      no shuffle  --shuffle  --shuffle+buffering=0
secs:       2393      21388    3309
MB/s         512         57     370

It's still slower than a non-shuffled dataset with buffering, but it is more tolerable now. To really use buffering=0, the code may need to be updated to handle short reads properly.

Running test with higher node counts with --shuffle and buffering=0, I get things level off at 32 nodes.

nodes     8   16   32   64

no shuffle
secs:  4643 2393 1445  769
MB/s    264  512  848 1594

shuffle + buffering=0
secs:     - 3051 1663 1509
MB/s      -  402  737  812

I still suspect this is due to random access into the jsonl file. This time due to buffering on the underlying file system client, rather than buffering within the python runtime. So we're back to needing to pre-shuffle the jsonl file to do better.

@thomasw21 thomasw21 changed the title Extend preprocess_dataset_mpi to handle jsonl files WIP: Extend preprocess_dataset_mpi to handle jsonl files Aug 12, 2021
adammoody and others added 25 commits August 12, 2021 13:58
Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com>
Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com>
Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com>
Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com>
Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com>
Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com>
Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com>
Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com>
Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com>
@adammoody adammoody changed the title WIP: Extend preprocess_dataset_mpi to handle jsonl files WIP: Extend preprocess_data_dist to handle jsonl files Aug 19, 2021
@thomasw21
Copy link
Member

Btw you can open a PR against the other branch #55 instead, so I'd be able to see only the diffs compared to that on GH. This would make the reviewing process easier.

Awesome scaling benchmarks! shuffling can be done manually (other preprocessing script don't do it) using terashuf using a single node. https://github.com/bigscience-workshop/bigscience/tree/master/data/oscar#:~:text=fast-shuffle.slurm-,terashuf,-is%20in%20%24six_ALL_CCFRWORK Essentially we probably want a copy of our shuffled dataset. This is useful since we might want to debug some spikes we're observing, and check related samples. Either way it's okay to keep the feature to live shuffle, and leave it up to the user to set up the correct configs. However I would be in favor of logging a warning, that adding shuffling will reduce performance, and that we might consider offline operations.

Sorry I got a bit much on my plate currently, I'll try looking at your PRs this week or over the week-end.

@adammoody
Copy link
Contributor Author

Thanks, @thomasw21 . Thanks for the time and effort in reviewing these. I know that takes time.

Yes, if one has a pre-shuffled json file, you get the best performance from this.

GitHub doesn't seem to let me base this PR on the parallel merge PR. I think I can only pick branches that are in the bigscience repo. However, in case it helps, I created a PR in my own repo of this branch that is based on my pmerge branch to shows the diff.

adammoody#1

Anyway, this PR can obviously wait on the review until the other PR is merged. I just updated it to enable torch.distributed in case someone found it useful before being merged.

I still have some ideas on a parallel shuffle, but that would be a future PR. In fact, I think something like the terashuf algorithm should parallelize pretty well. That would be the last serial part of the whole process.

@adammoody
Copy link
Contributor Author

This has been refreshed after merging #55, so the actual differences are easier to see now.

@adammoody adammoody changed the title WIP: Extend preprocess_data_dist to handle jsonl files extend preprocess_data_dist to handle jsonl files Sep 1, 2021
Copy link
Member

@thomasw21 thomasw21 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! I have a few questions though:

  • Seeing as we used jsonl because original Megatron used jsonl, but we can now handle datasets, perhaps we never actually want to use jsonl anymore? This is to be discussed I guess @stas00
  • also another approach would be to load a json using datasets directly, and then run the script. Typically you might need to expose a method convert_to_megatron_dst(dataset) or something. You wouldn't need another IndexJson as it would be a dataset.

@@ -163,6 +164,8 @@ def open(self, filename, truncate=None):

except Exception as e:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we instead scope everything? It would allow exception handling to be specific? Like if truncate fails then we need to close the file.

@@ -176,12 +179,20 @@ def open(self, filename, truncate=None):
err = e

# Verify that all ranks successfully opened the file
if not self.alltrue(err is None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I overall curious why you need to close the file? Raise should destroy everything no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good point. I think the raise should do the trick, since the file handle will not be returned and go out of scope in that case. I'll simplify that code.


return f

def openread(self, filename):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't that just self.open(filename, truncate=None)? I don't see why there's such a function. Will look further in the PR to understand the need for duplication. Is it because you have "rb" here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the open-for-write function, I have that written as a two phase process, where rank 0 creates and truncates the file, then other ranks open the file afterwards. In the open-for-read, all procs open the file simultaneously. I think it's useful to keep the two-phase step for creating the file, because create/truncate can be expensive on some file systems. However, I suspect this can be refactored to have a single open call so that openread can be merged into open for a single call.

self.allraise_if(err)

# Get value from rank 0
exists = self.bcast(exists, root=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

technically you can have all ranks to run os.path.exists(filename) and remove one bcast.

Also sidenote, we're having a lot of

err = None

        # Rank 0 executes the existence check
        exists = False
        if self.rank == 0:
            try:
                do_something
            except Exception as e:
                err = e

        # Verify that the check succeeded
        self.allraise_if(err)

can we make a helper and factorise that code somewhere? You could pass a method as an argument. (there might not even be a need for allraise_if anymore.

Copy link
Contributor Author

@adammoody adammoody Sep 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll see if I can create a helper routine to account for that pattern. I tend to use this pattern for a couple of reasons.

One is that having rank 0 do the check and bcast the result tends to be more scalable than having all ranks do the check directly. For example, a stat call invokes communication between the client and file system server, which is a remote process on networked file systems like Lustre/GPFS. With P procs, the direct method can induce O(P) messages and time at the server. Having rank 0 do the check and bcast requires one query to the server and then O(log P) time to spread the result, assuming a tree-based bcast.

A second benefit is that it guarantees that all procs see a consistent result. As an example, imagine that someone deletes the file while the exist check is running. When doing direct queries, some procs might get a result that says the file exists while others see it as not existing. That leads the procs to take different code branches later on. With a single query, everyone gets the same state. Though it might still be wrong, at least all procs work from a consistent state.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay this makes sense to me. Thanks for the great explanation!

exists = self.bcast(exists, root=0)
return exists

def stat(self, filename, field):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why we need this helper, can't all ranks run os.start(filename)[field]. Granted you're running that code a lot, but it removes all te communication, but we can remove a lot of code by doing this. Same comment for exists

entry = json.loads(buf)
return entry
except:
# TODO: throw exception instead?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably.

# list of (offset, length) pairs of uint64

# Seek to the right spot in the index file for the given sample id.
header_size = 16
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should computa that from variables like "INDEX_DTYPE_SIZE" and the header name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks. Good suggestion to define something like INDEX_DTYPE_SIZE.


# Seek to the right spot in the index file for the given sample id.
header_size = 16
offset_idx = header_size + idx * 16
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
offset_idx = header_size + idx * 16
offset_idx = header_size + idx * ( 2 * INDEX_DTYPE_SIZE )

def __get__(self, idx):
return self.getitem(idx)

def index(self, idx):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this doesn't support slicing for example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, not yet anyway. It could be added if needed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be easy right? because you only need to read n times the byte size of a (offset, length) tuple. Essentially your index file is contiguous, so only one read. And then you can add lengths and keep the first offset to know how much you read from the jsonl right? and json lib should be able to handle jsonl formats (if not I'm pretty sure we can find one easily).


return offset, size

def pread(self, offset, size):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

read_raw maybe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, we could pick another name here. I had picked pread since it's reading a size from an offset, so it's like the libc pread call: https://man7.org/linux/man-pages/man2/pread.2.html

I guess it's similar to the python os.pread() call, though maybe I should flip the order of the parameters.

@stas00
Copy link
Member

stas00 commented Sep 7, 2021

Seeing as we used jsonl because original Megatron used jsonl, but we can now handle datasets, perhaps we never actually want to use jsonl anymore? This is to be discussed I guess @stas00

You probably want to take this out to the group discussion, but my personal opinion is that we don't need the intermediary jsonl dump stage.

It has been useful:

  • for huge dataprocessing pipes when it makes each stage shorter and easier to fit into 20h slurm job limitations
  • for re-running json->megatron as we were tweaking things

These however may be important for users with less resources than what we currently have. Or perhaps if the user already has a jsonl file to start with. Therefore, perhaps, if it's not too complicated by default the program could skip the intermediary stage and go directly to the final format, but for those who need it could still do it in stages?

update: I forgot it was much faster to shuffle jsonl than doing the same in datasets. so that is another plus to using jsonl

@adammoody
Copy link
Contributor Author

adammoody commented Sep 7, 2021

Thanks ! I have a few questions though:

  • Seeing as we used jsonl because original Megatron used jsonl, but we can now handle datasets, perhaps we never actually want to use jsonl anymore? This is to be discussed I guess @stas00

  • also another approach would be to load a json using datasets directly, and then run the script. Typically you might need to expose a method convert_to_megatron_dst(dataset) or something. You wouldn't need another IndexJson as it would be a dataset.

Thanks again for your review, @thomasw21 . I'll work though everything over time and respond back as I go.

I ended up creating this because I had some jsonl files of data for which I have no HF dataset ready to use. I submitted it as a PR in case that might also help others if they are in the same position.

It may have worked for me to load the jsonl file into a HF dataset, but I got a bit impatient and never let it finish. It was taking a while to process, and I'm not sure how long it would take. I could look at that in more detail.

@adammoody
Copy link
Contributor Author

@thomasw21 , on exploring the idea of using a regex finditer. The finditer slowed things down by about 25%. It does clean up the code, so the tradeoff could still be worth it. Anyway, that me going down the path to take a closer look at the performance. I have a few ideas, but it will require some reworking of the algorithm that I'm using. I'll come back to this after I give that a shot.

@thomasw21
Copy link
Member

thomasw21 commented Sep 15, 2021

Hey! Sorry for the wait, I've been busy with other projects.

Okay let's remove regex finder, as the alternative code isn't that complicated to understand, ad 25% is huge IMO (I'm very surprised though ... would you mind sharing a code snippet?)

Also, seeing as I haven't gotten an answer from the group. I suggest we keep both, we'll remove them if we don't need them anymore.

@adammoody
Copy link
Contributor Author

After reworking the implementation to scan for newlines, I've got something that is now close to saturating the read bandwidth of the file system. On 8 nodes and 320 procs, it scans the source JSON file at a rate of 120 GB/s, which is basically the peak read speed on my system. For comparison, the previous algorithm topped out around 7 GB/s when using 16 MB read buffers and reached as high as 38 GB/s with 128 MB read buffers.

The new implementation requires one to keep the newline offsets in memory until the scan is complete, whereas the previous algorithm wrote those offsets to a temporary file as it scanned. The new method should be fine until one needs to process really large JSON files. Each newline requires 8 bytes, so the oscar file takes 300 million * 8 = 2.4 GB to hold all offsets. The new method could also be easily be extended to write those offsets to a temporary per-rank file if needed.

adammoody pushed a commit to adammoody/Megatron-DeepSpeed that referenced this pull request Oct 27, 2022
* Added GPT pretraining distillation and quantization examples

* updated the compressor initialization API to the latest one

* fixed API calls 

* fixed several compatibility issues of Kd/quantization with respect to the standard gpt training in both checkpointing and tensorboard visualization. 

* dir name typo

* Incorporated Conglong's suggestions.

Co-authored-by: yaozhewei <zheweiy@berkeley.edu>
Co-authored-by: Conglong Li <conglong.li@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants