-
Notifications
You must be signed in to change notification settings - Fork 140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Created HDF5 based dataset #227
Conversation
Hi @peastman , Thanks very much for posting this! I am away from code development right now and unfortunately will not be able to take a closer look for a little while, but at a glance this looks like a promising start.
Did you find this to be true in practice? I've had trouble getting anything but slowdowns from adding in the multiprocessing. |
Thanks! There's no hurry. Whenever you get around to looking at it. I did find some benefit to using more workers. My GPU utilization increased from 45% with one worker to 55% with three workers. By increasing the batch size I was able to increase it further, up to about 85%. |
Can I bump this issue? The inability to work with large datasets is really a showstopper for me. |
self.file_name = file_name | ||
self.r_max = extra_fixed_fields["r_max"] | ||
self.index = None | ||
self.num_molecules = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-> num_frames
if mode == "rms": | ||
total += np.sum(values * values) | ||
count += len(values.flatten()) | ||
else: | ||
length = len(values.flatten()) | ||
if mode == "per_atom_mean_std": | ||
values /= len(data[0][i]) | ||
sample_mean = np.mean(values) | ||
new_mean = (total[0] * count + sample_mean * length) / ( | ||
count + length | ||
) | ||
total[1] += ( | ||
length * (sample_mean - total[0]) * (sample_mean - new_mean) | ||
) | ||
total[0] = new_mean | ||
count += length |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See https://github.com/mir-group/pytorch_runstats / https://pytorch-runstats.readthedocs.io/en/latest/index.html
which can do this for rms
and mean
already with unit tests and correct numerics. We already depend on this library.
Easy to do an OK std with two-pass using this (https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Two-pass_algorithm) or to implement Welford's online algorithm into torch_runstats
maybe (https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm)?
Or can just return NaN
for the std if you only ever use the mean for now.
Hi @peastman , I've just brought this branch up-to-date with
For the simple subset of statistics implemented here, I think this can be done without too much work. The only problem is that it would call I think the best solution for that is to go to the root of the problem: if you want to implement an alternative neighborlist backend (torch-nl or matscipy?) that could solve it. But as long as this implementation correctly errors out on cases it does not support, we can also just merge it for now and come back to it later. |
Thanks! I'll try to get the changes in shortly. |
Great, thanks @peastman ! |
After the updates, I now get an exception when I try to train a model with a HDF5 dataset. Here's the output.
and then after many repetitions of similar lines,
Any idea what would cause that? |
Yeah we've seen this before when the there are atom indexes in the |
Where would the invalid indices come from? Nothing in this class sets them. The AtomicData objects get created like this: args = {
"pos": data[1][i],
"r_max": self.r_max,
AtomicDataDict.ATOM_TYPE_KEY: data[0][i],
AtomicDataDict.TOTAL_ENERGY_KEY: data[2][i],
}
if self.has_forces:
args[AtomicDataDict.FORCE_KEY] = data[3][i]
return AtomicData.from_points(**args) |
I made most of the changes, and also added unit tests. The one thing I haven't done is switch |
The tests are failing because h5py isn't installed in the test environment. I can fix it in either of two ways. 1) Add h5py as a dependency in setup.py so it will always get installed, or 2) make it skip those tests if it isn't installed. Which do you prefer? I suggest 1, but I don't want to add a dependency without checking first. |
Online one-pass stddev might be tricky, but I can look later at how bad it would be to implement Welford's algorithm in the simple case.
I don't want to add it as a mandatory dependency, so why don't we |
It already uses Welford's algorithm. It's mainly a matter of upstreaming the code to pytorch_runstats, though that will also add some complexity since pytorch_runstats supports features that aren't needed here.
Sounds good. |
Ah I see, great--- you mean |
This should be ready to merge, unless you need other changes first. |
Looks good to me, no need to wait on statistics stuff to do a first merge--- I'll get this merged later today. Thanks @peastman ! |
Merged! If I manage to get around to implementing Welford's into runstats, we can revisit migrating to that and making this a "first class" dataset in terms of feature support in another PR later. Thanks for your efforts on this @peastman ! |
Thank you! |
This is a rough draft of the HDF5 based dataset discussed in #214. It can handle arbitrarily large datasets, even ones that are too large to fit in memory. Are you interested in developing this into a supported feature?
Description
I added a HDF5Dataset class that extends AtomicDataset. It loads data from one or more HDF5 files with the format described at https://github.com/torchmd/torchmd-net/blob/main/torchmdnet/datasets/hdf.py. The only data fields it currently supports are positions, atom types, energies, and forces. Others could easily be added.
There's a lot that could be cleaned up. I've only implemented the features that were needed for my own work. That's especially true of
statistics()
, which only computes a subset of the available statistics. It could be extended to compute the rest of them, but in the long term it might be better to rework the main statistics routine to support streaming calculations so it can work with any dataset class.The class constructs AtomicData objects as needed in
get()
, which adds overhead. You can partly compensate for that by increasingdataloader_num_workers
to use multiple processes.Computing statistics on large datasets can be slow. I use
dataset_statistics_stride
to reduce the cost.Motivation and Context
All of the existing dataset classes require all the data to be held in memory at once. This limits how much data can be used for training.
Resolves: #214.
How Has This Been Tested?
I used it to train a model and it generally seems to work, but I haven't done extensive testing or written unit tests yet.
Types of changes
Checklist:
black
.docs/options
) has been updated with new or changed options.CHANGELOG.md
.