-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
info_utils.py
130 lines (93 loc) 路 4.93 KB
/
info_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import enum
import os
from typing import Optional
from huggingface_hub.utils import insecure_hashlib
from .. import config
from .logging import get_logger
logger = get_logger(__name__)
class VerificationMode(enum.Enum):
"""`Enum` that specifies which verification checks to run.
The default mode is `BASIC_CHECKS`, which will perform only rudimentary checks to avoid slowdowns
when generating/downloading a dataset for the first time.
The verification modes:
| | Verification checks |
|---------------------------|------------------------------------------------------------------------------ |
| `ALL_CHECKS` | Split checks, uniqueness of the keys yielded in case of the GeneratorBuilder |
| | and the validity (number of files, checksums, etc.) of downloaded files |
| `BASIC_CHECKS` (default) | Same as `ALL_CHECKS` but without checking downloaded files |
| `NO_CHECKS` | None |
"""
ALL_CHECKS = "all_checks"
BASIC_CHECKS = "basic_checks"
NO_CHECKS = "no_checks"
class ChecksumVerificationException(Exception):
"""Exceptions during checksums verifications of downloaded files."""
class UnexpectedDownloadedFile(ChecksumVerificationException):
"""Some downloaded files were not expected."""
class ExpectedMoreDownloadedFiles(ChecksumVerificationException):
"""Some files were supposed to be downloaded but were not."""
class NonMatchingChecksumError(ChecksumVerificationException):
"""The downloaded file checksum don't match the expected checksum."""
def verify_checksums(expected_checksums: Optional[dict], recorded_checksums: dict, verification_name=None):
if expected_checksums is None:
logger.info("Unable to verify checksums.")
return
if len(set(expected_checksums) - set(recorded_checksums)) > 0:
raise ExpectedMoreDownloadedFiles(str(set(expected_checksums) - set(recorded_checksums)))
if len(set(recorded_checksums) - set(expected_checksums)) > 0:
raise UnexpectedDownloadedFile(str(set(recorded_checksums) - set(expected_checksums)))
bad_urls = [url for url in expected_checksums if expected_checksums[url] != recorded_checksums[url]]
for_verification_name = " for " + verification_name if verification_name is not None else ""
if len(bad_urls) > 0:
raise NonMatchingChecksumError(
f"Checksums didn't match{for_verification_name}:\n"
f"{bad_urls}\n"
"Set `verification_mode='no_checks'` to skip checksums verification and ignore this error"
)
logger.info("All the checksums matched successfully" + for_verification_name)
class SplitsVerificationException(Exception):
"""Exceptions during splis verifications"""
class UnexpectedSplits(SplitsVerificationException):
"""The expected splits of the downloaded file is missing."""
class ExpectedMoreSplits(SplitsVerificationException):
"""Some recorded splits are missing."""
class NonMatchingSplitsSizesError(SplitsVerificationException):
"""The splits sizes don't match the expected splits sizes."""
def verify_splits(expected_splits: Optional[dict], recorded_splits: dict):
if expected_splits is None:
logger.info("Unable to verify splits sizes.")
return
if len(set(expected_splits) - set(recorded_splits)) > 0:
raise ExpectedMoreSplits(str(set(expected_splits) - set(recorded_splits)))
if len(set(recorded_splits) - set(expected_splits)) > 0:
raise UnexpectedSplits(str(set(recorded_splits) - set(expected_splits)))
bad_splits = [
{"expected": expected_splits[name], "recorded": recorded_splits[name]}
for name in expected_splits
if expected_splits[name].num_examples != recorded_splits[name].num_examples
]
if len(bad_splits) > 0:
raise NonMatchingSplitsSizesError(str(bad_splits))
logger.info("All the splits matched successfully.")
def get_size_checksum_dict(path: str, record_checksum: bool = True) -> dict:
"""Compute the file size and the sha256 checksum of a file"""
if record_checksum:
m = insecure_hashlib.sha256()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(1 << 20), b""):
m.update(chunk)
checksum = m.hexdigest()
else:
checksum = None
return {"num_bytes": os.path.getsize(path), "checksum": checksum}
def is_small_dataset(dataset_size):
"""Check if `dataset_size` is smaller than `config.IN_MEMORY_MAX_SIZE`.
Args:
dataset_size (int): Dataset size in bytes.
Returns:
bool: Whether `dataset_size` is smaller than `config.IN_MEMORY_MAX_SIZE`.
"""
if dataset_size and config.IN_MEMORY_MAX_SIZE:
return dataset_size < config.IN_MEMORY_MAX_SIZE
else:
return False