-
Notifications
You must be signed in to change notification settings - Fork 124
/
custom_dataset.py
116 lines (93 loc) · 4.57 KB
/
custom_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from typing import Dict, List, Callable, Union, Optional
import numpy as np
import logging
import torch
from nequip.data import AtomicData
from nequip.utils.savenload import atomic_write
from nequip.data.transforms import TypeMapper
from nequip.data import AtomicDataset
class ExampleCustomDataset(AtomicDataset):
"""
See https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-larger-datasets.
If you don't need downloading or pre-processing, just don't define any of the relevant methods/properties.
"""
def __init__(
self,
root: str,
custom_option1,
custom_option2="default",
type_mapper: Optional[TypeMapper] = None,
):
# Initialize the AtomicDataset, which runs .download() (if present) and .process()
# See https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-larger-datasets
# This will only run download and preprocessing if cached dataset files aren't found
super().__init__(root=root, type_mapper=type_mapper)
# if the processed paths don't exist, `self.process()` has been called at this point
# (if it is defined)
# but otherwise you need to load the data from the cached pre-processed dir:
if self.mydata is None:
self.mydata = torch.load(self.processed_paths[0])
# if you didn't define `process()`, this is where you would unconditionally load your data.
def len(self) -> int:
"""Return the number of frames in the dataset."""
return 42
@property
def raw_file_names(self) -> List[str]:
"""Return a list of filenames for the raw data.
Need to be simple filenames to be looked for in `self.raw_dir`
"""
return ["data.dat"]
@property
def raw_dir(self) -> str:
return "/path/to/dataset-folder/"
@property
def processed_file_names(self) -> List[str]:
"""Like `self.raw_file_names`, but for the files generated by `self.process()`.
Should not be paths, just file names. These will be stored in `self.processed_dir`,
which is set by NequIP in `AtomicDataset` based on `self.root` and a hash of the
dataset options provided to `__init__`.
"""
return ["processed-data.pth"]
# def download(self):
# """Optional method to download raw data before preprocessing if the `raw_paths` do not exist."""
# pass
def process(self):
# load things from the raw data:
# whatever is appropriate for your format
data = np.load(self.raw_dir + "/" + self.raw_file_names[0])
# if any pre-processing is necessary, do it and cache the results to
# `self.processed_paths` as you defined above:
with atomic_write(self.processed_paths[0], binary=True) as f:
# e.g., anything that takes a file `f` will work
torch.save(data, f)
# ^ use atomic writes to avoid race conditions between
# different trainings that use the same dataset
# since those separate trainings should all produce the same results,
# it doesn't matter if they overwrite each others cached'
# datasets. It only matters that they don't simultaneously try
# to write the _same_ file, corrupting it.
logging.info("Cached processed data to disk")
# optionally, save the processed data on the Dataset object
# to avoid a roundtrip from disk in `__init__` (see above)
self.mydata = data
def get(self, idx: int) -> AtomicData:
"""Return the data frame with a given index as an `AtomicData` object."""
build_an_AtomicData_here = None
return build_an_AtomicData_here
def statistics(
self,
fields: List[Union[str, Callable]],
modes: List[str],
stride: int = 1,
unbiased: bool = True,
kwargs: Optional[Dict[str, dict]] = {},
) -> List[tuple]:
"""Optional method to compute statistics over an entire dataset.
This must correctly handle `self._indices` for subsets!!!
If not provided, options like `avg_num_neighbors: auto`, `per_species_rescale_scales: dataset_*`,
and others that compute dataset statistics will not work. This only needs to support the statistics
modes that are necessary for what you need to run (i.e. if you do not use `dataset_per_species_*`
statistics, you do not need to implement them).
See `AtomicInMemoryDataset` for full documentation and example implementation.
"""
raise NotImplementedError