/
dataset.py
1341 lines (1188 loc) · 56 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Build and load tensorFlow dataset Record wrapper"""
import copy
import math
import json
import os
from collections import defaultdict
import shutil
import sys
from time import time
from typing_extensions import TypeAlias
from typing import Dict, List, Optional, Union, Set, Tuple
from pathlib import Path
import pprint
from tabulate import tabulate
from termcolor import cprint
from tqdm.auto import tqdm
import numpy as np
import semver
import tensorflow as tf
import scaaml
from scaaml.utils import bytelist_to_hex
from scaaml.io.spell_check import find_misspellings
import scaaml.io.utils as siutils
from .shard import Shard
from .errors import DatasetExistsError
# Prevent importing Literal with older versions.
if sys.version_info >= (3, 8):
from typing import Literal
class Dataset():
"""Dataset class."""
# Valid split values (used also as directory names).
# Define split type, keep compatibility (Literal was introduced in
# Python3.8).
if sys.version_info >= (3, 8):
SPLIT_T: TypeAlias = Literal["train", "test", "holdout"]
else:
SPLIT_T: TypeAlias = str
TRAIN_SPLIT: SPLIT_T = "train"
TEST_SPLIT: SPLIT_T = "test"
HOLDOUT_SPLIT: SPLIT_T = "holdout"
SPLITS: Tuple[SPLIT_T, SPLIT_T,
SPLIT_T] = (TRAIN_SPLIT, TEST_SPLIT, HOLDOUT_SPLIT)
# Largest possible part number.
MAX_PART_NUMBER = 10
def __init__(
self,
root_path: str,
shortname: str,
architecture: str,
implementation: str,
algorithm: str,
version: int,
firmware_sha256: str,
description: str,
examples_per_shard: int,
measurements_info: Dict,
attack_points_info: Dict,
url: str,
firmware_url: str = "",
paper_url: str = "",
licence: str = "https://creativecommons.org/licenses/by/4.0/",
compression: str = "GZIP",
shards_list: Optional[Dict[str, List]] = None,
keys_per_group: Optional[Dict[str, Dict[int, int]]] = None,
keys_per_split: Optional[Dict[str, int]] = None,
examples_per_group: Optional[Dict[str, Dict[int, int]]] = None,
examples_per_split: Optional[Dict[str, int]] = None,
capture_info: Optional[dict] = None,
min_values: Optional[Dict[str, float]] = None,
max_values: Optional[Dict[str, float]] = None,
from_config: bool = False,
verbose: bool = True,
) -> None:
"""Class for saving and loading a database.
Args:
url: Where to download this dataset.
firmware_url: Where to dowload the firmware used while capture.
paper_url: Where to find the published paper.
licence: URL or the whole licence the dataset is published under.
from_config: This Dataset object has been created from a saved
config, root_path thus points to what should be self.path. When
True set self.path = root_path, self.root_path to be the parent of
self.path. In this case it does not necessarily hold that
self.path.name == self.slug (the directory could have been renamed).
verbose: If True, print the dataset path.
Raises:
ValueError: If firmware_sha256 or firmware_url evaluates to False.
DatasetExistsError: If creating this object would overwrite the
corresponding config file.
"""
# Do not allow mutable default parameters.
if capture_info is None:
capture_info = {}
if min_values is None:
min_values = {}
if max_values is None:
max_values = {}
self.shortname = shortname
self.architecture = architecture
self.implementation = implementation
self.algorithm = algorithm
self.version = version
self.compression = compression
self.firmware_sha256 = firmware_sha256
self.description = description
self.url = url
self.firmware_url = firmware_url
self.paper_url = paper_url
self.licence = licence
self.capture_info = capture_info or {}
self.measurements_info = measurements_info
self.attack_points_info = attack_points_info
if not self.firmware_sha256:
raise ValueError("Firmware hash is required")
if not self.firmware_url:
raise ValueError("Firmware URL is required")
self.slug = (f"{shortname}_{algorithm}_{architecture}_"
f"v{implementation}_{version}")
if from_config:
self.path = Path(root_path)
self.root_path = str(self.path.parent)
else:
self.root_path = root_path
self.path = Path(self.root_path) / self.slug
# create directory -- check if its empty
if Dataset._get_config_path(self.path).exists():
raise DatasetExistsError(dataset_path=self.path)
else:
# create path if needed
self.path.mkdir(parents=True)
Path(self.path / Dataset.TRAIN_SPLIT).mkdir()
Path(self.path / Dataset.TEST_SPLIT).mkdir()
Path(self.path / Dataset.HOLDOUT_SPLIT).mkdir()
if verbose:
cprint(f"Dataset path: {self.path}", "green")
# current shard tracking
self.shard_key: Optional[str] = None
self.prev_shard_key = None # track key change for counting
self.shard_path: Optional[str] = None
self.shard_split: Optional[str] = None
self.shard_part: Optional[int] = None
self.shard_relative_path: Optional[str] = None # for the shardlist
self.curr_shard: Optional[Shard] = None # current_ shard object
# [counters] - must be passed as param to allow reload.
# shards_list[split] is a list of shard info dictionaries (where split
# in Dataset.SPLITS).
self.shards_list = siutils.ddict(value=shards_list,
levels=1,
type_var=list)
# keys counting
# keys_per_group[split][group_id] contains the number (int) of keys
# belonging to the group (group_id is int)
self.keys_per_group = siutils.ddict(value=keys_per_group,
levels=2,
type_var=int)
self.keys_per_split = siutils.ddict(value=keys_per_split,
levels=1,
type_var=int)
# examples counting
# keys_per_group[split][gid] = cnt
self.examples_per_group = siutils.ddict(value=examples_per_group,
levels=2,
type_var=int)
self.examples_per_split = siutils.ddict(value=examples_per_split,
levels=1,
type_var=int)
self.examples_per_shard = examples_per_shard
# traces extreme values
self.min_values = min_values or {}
self.max_values = max_values or {}
for k in measurements_info.keys():
# init only if not existing
if k not in self.min_values:
self.min_values[k] = math.inf
self.max_values[k] = 0.
# write config if needed
if not from_config:
self._write_config()
@staticmethod
def get_dataset(*args, **kwargs):
"""Convenience method for getting a Dataset either by creating a new
dataset using the Dataset constructor or by calling Dataset.from_config.
Args: Same as scaaml.io.Dataset.__init__
Returns: A scaaml.io.Dataset object.
Raises: ValueError if the dataset version is higher than the scaaml
module used (via Dataset.from_config).
"""
try:
return Dataset(*args, **kwargs)
except DatasetExistsError as err:
return Dataset.from_config(dataset_path=err.dataset_path)
@staticmethod
def _shard_name(shard_group: int, shard_key: str, shard_part: int) -> str:
"""Return filename of the shard. When updating this method also update
Dataset._shard_info_from_name.
Args:
shard_group: The group this shard belongs to.
shard_key: The key contained in this shard (hex encoded).
shard_part: The part this shard belongs to.
Returns: Lowercase filename of the shard (including .tfrec filetype).
"""
sname = f"{shard_group}_{shard_key}_{shard_part}.tfrec"
return sname.lower()
@staticmethod
def _shard_info_from_name(shard_name: str) -> Dict[str, Union[int, str]]:
"""Inverse of Dataset._shard_name. This method is used by
Dataset.cleanup_shards to count how many shards per group or part are
left.
Args:
shard_name: The filename of the shard as returned by
Dataset._shard_name or with a parent directory.
Returns: A dictionary representation of Dataset._shard_name kwargs.
"""
for dir_separator in ["\\", "/"]:
if dir_separator in shard_name:
shard_name = shard_name.split(dir_separator)[-1]
parts = shard_name.split("_")
kwargs: Dict[str, Union[int, str]] = {}
kwargs["shard_group"] = int(parts[0])
kwargs["shard_key"] = parts[1]
kwargs["shard_part"] = int(parts[2].split(".")[0])
return kwargs
def new_shard(self, key: list, part: int, group: int, split: str,
chip_id: int):
"""Initiate a new key
Args:
key: the key that was used to create the measurements.
part: Indicate which part of a given key set of catpure this
shard represent. Capture are splitted into parts to easily
allow to restrict the number of traces used per key.
group: logical group the shard belong to. For example,
on AES a group represent a collection of shard that have distinct
byte values. It allows to balance the diversity of keys when using
a subset of the dataset.
split: the split the shard belongs to {train, test, holdout}
chip_id: indicate which chip was used for collecting the traces.
"""
# finalize previous shard if need
if self.curr_shard:
self.close_shard()
if split not in Dataset.SPLITS:
raise ValueError(f"Invalid split, must be in {Dataset.SPLITS}")
if part < 0 or part > self.MAX_PART_NUMBER:
raise ValueError(f"Invalid part value -- must be in "
f"[0, Dataset.MAX_PART_NUMBER] "
f"(that is [0, {self.MAX_PART_NUMBER}]).")
self.shard_split = split
self.shard_part = part
self.shard_group = group
self.shard_key = bytelist_to_hex(key, spacer="")
self.shard_chip_id = chip_id
# shard name
fname = Dataset._shard_name(self.shard_group, self.shard_key,
self.shard_part)
self.shard_relative_path = f"{split}/{fname}"
self.shard_path = str(self.path / self.shard_relative_path)
# new shard
self.curr_shard = Shard(self.shard_path,
attack_points_info=self.attack_points_info,
measurements_info=self.measurements_info,
compression=self.compression)
def write_example(self, attack_points: Dict, measurement: Dict):
assert self.curr_shard is not None
self.curr_shard.write(attack_points, measurement)
def close_shard(self):
# close the shard
stats = self.curr_shard.close()
if stats["examples"] != self.examples_per_shard:
cprint(
f"This shard contains {stats['examples']}, expected "
f"{self.examples_per_shard}", "red")
# update min/max values
for k, v in stats["min_values"].items():
self.min_values[k] = min(self.min_values[k], v)
for k, v in stats["max_values"].items():
self.max_values[k] = max(self.max_values[k], v)
# update key stats only if key changed
if self.shard_key != self.prev_shard_key:
self.keys_per_split[self.shard_split] += 1
self.keys_per_group[self.shard_split][self.shard_group] += 1
self.prev_shard_key = self.shard_key
self.examples_per_split[self.shard_split] += stats["examples"]
self.examples_per_group[self.shard_split][
self.shard_group] += stats["examples"]
# record in shardlist
self.shards_list[self.shard_split].append({
"path": str(self.shard_relative_path),
"examples": stats["examples"],
"size": os.stat(self.shard_path).st_size,
"sha256": siutils.sha256sum(self.shard_path).lower(),
"group": self.shard_group,
"key": self.shard_key,
"part": self.shard_part,
"chip_id": self.shard_chip_id
})
# update config
self._write_config()
self.curr_shard = None
@staticmethod
def download(url: str):
"Download dataset from a given url"
raise NotImplementedError("implement me using keras dl mechanism")
@staticmethod
def as_tfdataset(dataset_path: str,
split: str,
attack_points: List[Dict[str, Union[str, int]]],
traces: Union[List[str], str],
shards: Optional[int] = None,
parts: Optional[Union[List[int], int]] = None,
trace_start: int = 0,
trace_len: Optional[int] = None,
batch_size: int = 32,
prefetch: int = 4,
file_parallelism: Optional[int] = os.cpu_count(),
parallelism: Optional[int] = os.cpu_count(),
shuffle: int = 1000) -> Union[tf.data.Dataset, Dict, Dict]:
""""Dataset as tfdataset
Args:
dataset_path (str): The root path of the dataset.
split (str): Split, see Dataset.SPLITS.
attack_points (List[Dict[str, Union[str, int]]]): Attack point
information. Contains the attack point name, index, and type. For
example:
```python
[
{ "name": "key", "index": 1, "type": "byte" },
{ "name": "sub_bytes_out", "index": 0, "type": "byte" },
{ "name": "sub_bytes_out", "index": 1, "type": "byte" },
{ "name": "sub_bytes_out", "index": 2, "type": "byte" },
]
```
traces (Union[List[str], str]): Either a single trace name or a list
of trace names.
shards (Optional[int]): If specified limits the dataset to the first
`shards` shards.
parts (Optional[Union[List[int], int]]): Not implemented.
trace_start (int): Skip this many first points of each trace.
trace_len (Optional[int]): Return trace of this length (more
formally of length min(trace_len, original length - trace_start)).
If None then trace[trace_start:] is used.
batch_size (int): Number of examples in a single batch.
prefetch (int): Prefetch this many batches.
file_parallelism (Optional[int]): IO parallelism.
parallelism (Optional[int]): Parallelism of trace decoding and
processing.
shuffle (int): How many examples should be shuffled across shards
(note that shards are shuffled by default).
FIXME: restrict shards to specific part if they exists.
"""
if parts:
raise NotImplementedError("Implement part filtering")
# boxing
if isinstance(traces, str):
traces = [traces]
# loading info
dpath = Path(dataset_path)
dataset = Dataset.from_config(dataset_path)
if split not in dataset.keys_per_split:
raise ValueError("Unknown split -- see Dataset.summary() for list")
# TF_FEATURES construction: must contains all features and be global
tf_features = {} # what is decoded
for name, ipt in dataset.measurements_info.items():
tf_features[name] = tf.io.FixedLenFeature([ipt["len"]], tf.float32)
for name, ap in dataset.attack_points_info.items():
tf_features[name] = tf.io.FixedLenFeature([ap["len"]], tf.int64)
# decoding funtion
def from_tfrecord(tfrecord):
rec = tf.io.parse_single_example(tfrecord, tf_features)
return rec
# inputs construction
inputs = {} # model inputs
for name in traces:
ipt = dataset.measurements_info[name]
inputs[name] = ipt
inputs[name]["min"] = tf.constant(dataset.min_values[name])
inputs[name]["max"] = tf.constant(dataset.max_values[name])
delta = tf.constant(inputs[name]["max"] - inputs[name]["min"])
inputs[name]["delta"] = delta
# output construction
outputs = {} # model outputs
for attack_point in attack_points:
ap_name = attack_point["name"]
ap_index = attack_point["index"]
ap_type = attack_point["type"]
full_name = f"{ap_name}_{ap_index}"
# Add attack point info (len, max_val).
outputs[full_name] = dataset.attack_points_info[ap_name]
# Set the attack point name (keep backwards compatibility).
outputs[full_name]["ap"] = ap_name
# Set the byte/bit index (keep backwards compatibility).
outputs[full_name]["byte"] = ap_index
# Set the type of the attack point.
outputs[full_name]["type"] = ap_type
# processing function
# @tf.function
def process_record(rec):
"process the tf record to get it ready for learning"
x = {}
# normalize the traces
for name, data in inputs.items():
trace = rec[name]
# truncate if needed
if trace_start:
trace = trace[trace_start:]
if trace_len:
trace = trace[:trace_len]
# rescale
# trace = 2 * ((trace - data["min"]) / (data["delta"])) - 1
# reshape
# trace = tf.reshape(trace, (reshaped_trace_len, step_size))
# assign
x[name] = trace
inputs[name]["shape"] = trace.shape # (trace_len - trace_start)
# Encoding the outptut for each ap/byte
y = {}
for name, data in outputs.items():
max_val = data["max_val"]
if max_val == 2:
# Binary classification.
v = rec[data["ap"]][data["byte"]]
else:
# Multiple classes classification.
v = tf.one_hot(rec[data["ap"]][data["byte"]], max_val)
y[name] = v
return (x, y)
# collect and truncate shard list of a given split
# this is done prior to anything to allow to only download the nth
# first shards
shards_list = dataset.shards_list[split]
if shards:
shards_list = shards_list[:shards]
shards_paths = [str(dpath / s["path"]) for s in shards_list]
num_shards = len(shards_paths)
# print(shards_paths)
# dataset creation
# with tf.device("/cpu:0"):
# shuffle the shard order
ds = tf.data.Dataset.from_tensor_slices(shards_paths)
ds = ds.repeat()
# shuffle shard order
ds = ds.shuffle(num_shards)
# This is the tricky part, we are using the interleave function to
# do the sampling as requested by the user. This is not the
# standard use of the function or an obvious way to do it but
# its by far the faster and more compatible way to do so
# we are favoring for once those factors over readability
# deterministic=False is not an error, it is what allows us to
# create random batch
ds = ds.interleave(
lambda x: tf.data.TFRecordDataset(
x, compression_type=dataset.compression), # noqa
cycle_length=file_parallelism,
block_length=1,
num_parallel_calls=file_parallelism,
deterministic=False)
# decode to records
ds = ds.map(from_tfrecord, num_parallel_calls=parallelism)
# process them
ds = ds.map(process_record, num_parallel_calls=parallelism)
# # randomize only if > 0 -- no shuffle in test/validation
if shuffle:
ds = ds.shuffle(shuffle)
# # batching with repeat
ds = ds.repeat()
ds = ds.batch(batch_size)
ds = ds.prefetch(prefetch)
return ds, inputs, outputs
@staticmethod
def summary(dataset_path):
"""Print a summary of the dataset"""
lst = [
"shortname", "description", "url", "architecture", "implementation",
"algorithm", "version", "compression"
]
conf_path = Dataset._get_config_path(dataset_path)
config = Dataset._load_config(conf_path)
cprint("[Dataset Summary]", "cyan")
cprint("Info", "yellow")
print(tabulate([[k, config.get(k, "")] for k in lst]))
cprint("\nAttack Points", "yellow")
d = [[k, v["len"], v["max_val"]]
for k, v in config["attack_points_info"].items()]
print(tabulate(d, headers=["ap", "len", "max_val"]))
cprint("\nMeasurements", "magenta")
d = [[k, v["type"], v["len"]]
for k, v in config["measurements_info"].items()]
print(tabulate(d, headers=["name", "type", "len"]))
cprint("\nContent", "green")
d = []
for split in config["keys_per_split"].keys():
d.append([
split,
len(config["shards_list"][split]),
config["keys_per_split"][split],
config["examples_per_split"][split],
])
print(tabulate(d, ["split", "num_shards", "num_keys", "num_examples"]))
@staticmethod
def inspect(dataset_path,
split: SPLIT_T,
shard_id: int,
num_example: int,
verbose: bool = True):
"""Display the content of a given shard.
Args:
dataset_path: Root path to the dataset.
split: The split to inspect.
shard_id: Index into the shards_list.
num_example: How many examples to return. If -1 or larger than
examples_per_shard, all examples are taken.
verbose: Print debugging output to stdout.
Returns: tf TakeDataset object.
"""
conf_path = Dataset._get_config_path(dataset_path)
config = Dataset._load_config(conf_path)
shard_path = Path(
dataset_path) / config["shards_list"][split][shard_id]["path"]
if verbose:
cprint(f"Reading shard {shard_path}", "cyan")
s = Shard(str(shard_path),
attack_points_info=config["attack_points_info"],
measurements_info=config["measurements_info"],
compression=config["compression"])
data = s.read(num=num_example)
if verbose:
print(data)
return data
def check(self,
deep_check: bool = True,
show_progressbar: bool = True,
key_ap: str = "key"):
"""Check the dataset integrity. Check integrity of metadata in config
and also that no key from the train is in the test.
Args:
deep_check: When checking that keys in test and train splits are
disjoint inspect train shards (set to True if a single train shard
may contain multiple different keys).
show_progressbar: Use tqdm to show a progressbar for different checks.
key_ap: The attack point that is checked for when checking
disjointness of splits.
Raises: ValueError if the dataset is inconsistent.
"""
if key_ap not in self.attack_points_info:
raise ValueError(f"{key_ap} is not an attack point.")
if show_progressbar:
pbar = tqdm
else:
# Redefine tqdm to the identity function returning the first unnamed
# parameter.
pbar = lambda *args, **kwargs: args[0] # pylint: disable=C3001
Dataset._check_metadata(config=self.get_config_dictionary())
Dataset._check_sha256sums(shards_list=self.shards_list,
dpath=Path(self.path),
pbar=pbar)
# Check shard metadata
for slist in self.shards_list.values():
for shard_info in slist:
Dataset._check_shard_metadata(shard_info=shard_info,
dataset_path=self.path)
# Ensure that no keys in the train split are present in the test split.
has_test: bool = Dataset.TEST_SPLIT in self.examples_per_split
has_train: bool = Dataset.TRAIN_SPLIT in self.examples_per_split
if has_test and has_train:
self._check_disjoint_keys(pbar=pbar,
key_ap=key_ap,
deep_check=deep_check)
def _check_disjoint_keys(self, pbar, key_ap: str, deep_check: bool = True):
"""Check that no key in the train split is present in the test split.
Args:
pbar: Either tqdm.tqdm or an identity function (in order not to
print).
key_ap: The attack point that is checked for when checking
disjointness of splits.
deep_check: When checking that keys in test and train splits are
disjoint inspect train shards (set to True if a single train shard
may contain multiple different keys).
Raises: ValueError if some key from train is present in test.
"""
seen_keys = set()
for i in range(len(self.shards_list[Dataset.TEST_SPLIT])):
for example in Dataset.inspect(dataset_path=self.path,
split=Dataset.TEST_SPLIT,
shard_id=i,
num_example=self.examples_per_shard,
verbose=False).as_numpy_iterator():
seen_keys.add(example[key_ap].astype(np.uint8).tobytes())
if deep_check:
Dataset._deep_check(
seen_keys=seen_keys,
dpath=self.path,
train_shards=self.shards_list[Dataset.TRAIN_SPLIT],
pbar=pbar,
examples_per_shard=self.examples_per_shard,
key_ap=key_ap)
else:
Dataset._shallow_check(
seen_keys=seen_keys,
train_shards=self.shards_list[Dataset.TRAIN_SPLIT],
pbar=pbar)
@staticmethod
def _check_sha256sums(shards_list, dpath: Path, pbar):
"""Check the metadata of this dataset.
Args:
shards_list: Dictionary with information about each shard.
Use get_config_dictionary()["shards_list"]
dpath: Root path of the dataset.
pbar: Either tqdm.tqdm or an identity function (in order not to
print).
Raises: ValueError if some hash does not match.
"""
for split, slist in shards_list.items():
for sinfo in pbar(slist, desc=f"Checking sha for {split}"):
shard_path = dpath / sinfo["path"]
sha_hash = siutils.sha256sum(shard_path)
if sha_hash != sinfo["sha256"]:
raise ValueError(sinfo["path"], "SHA256 miss-match")
@staticmethod
def _check_shard_metadata(shard_info: Dict, dataset_path: Path) -> None:
"""Checks shard metadata.
Args:
shard_info: Dictionary of the shard metadata.
dataset_path: Dataset path, so that we can check size of the shard
file.
Raises: ValueError if the metadata is inconsistent.
"""
# Check that only expected keys are present:
si_keys = {
"examples", # Checked by Dataset._check_metadata
"sha256", # Checked by Dataset._check_sha256sums
"path", # Checked by Dataset._check_sha256sums
"group", # Checked against path
"key", # Checked against path
"part", # Checked against path
"size", # Checked here
"chip_id", # Checked that it is a non-negative integer
}
if set(shard_info.keys()) != si_keys:
raise ValueError(f"Shard info keys are: {shard_info.keys()} "
f"expected: {si_keys}, in shard: {shard_info}")
# Check that the info corresponds to the filename:
file_info = Dataset._shard_info_from_name(shard_info["path"])
for key in ["group", "part"]:
# either shard_group or shard_part
shard_membership: str = f"shard_{key}"
if file_info[shard_membership] != shard_info[key]:
raise ValueError(f"{key} does not match filename, expected: "
f"{file_info[shard_membership]}, got: "
f"{shard_info[key]}, in shard: {shard_info}")
# Check key (in filename it is lower case, in info it is upper case)
if str(file_info["shard_key"]).lower() != shard_info["key"].lower():
raise ValueError(f"key does not match filename, expected: "
f"{file_info['shard_key']}, got: "
f"{shard_info['key']} (not case sensitive), in "
f"shard: {shard_info}")
# Check size of the file
size = os.stat(dataset_path / shard_info["path"]).st_size
if size != shard_info["size"]:
raise ValueError(f"Wrong size, got: {size}, expected: "
f"{shard_info['size']}, in shard: {shard_info}")
# Check chip_id is non-negative integer
chip_id = shard_info["chip_id"]
if not isinstance(chip_id, int) or chip_id < 0:
raise ValueError(f"Wrong chip_id, got: {chip_id}, of type: "
f"{type(chip_id)}, in shard: {shard_info}")
@staticmethod
def _check_metadata(config,
n_examples_in_each_shard_is_constant: bool = False):
"""Check the metadata of this dataset.
Args:
config: A dictionary representing the metadata.
n_examples_in_each_shard_is_constant: Check that each shard contains
exactly examples_per_shard examples.
Raises: ValueError if some metadata do not match.
"""
for split, expected_examples in config["examples_per_split"].items():
slist = config["shards_list"][split]
# checking we have the rigt number of shards
if len(slist) != expected_examples // config["examples_per_shard"]:
raise ValueError("Num shards in shard_list != "
"examples_per_split // examples_per_shard")
# Check that expected_examples is a multiple of examples_per_shard.
if expected_examples % config["examples_per_shard"]:
raise ValueError("expected_examples is not divisible by "
"examples_per_shard")
if expected_examples != sum(s["examples"] for s in slist):
raise ValueError(f"Mismatch in expected_examples, shards "
f"metadata do not agree in {split}.")
if n_examples_in_each_shard_is_constant:
# All shards have the same number of examples.
if any(s["examples"] != config["examples_per_shard"]
for s in slist):
raise ValueError(f"Not all shards in {split} contain the "
f"same number of examples.")
# Check examples_per_group sums to the right thing.
sum_examples_per_group = sum(
config["examples_per_group"][split].values())
if sum_examples_per_group != expected_examples:
raise ValueError(f"Wrong sum of examples_per_group in {split}")
# Check examples_per_group in individual groups.
# Dataset.check can be called either after creating a dataset (when
# all measurements are done) or after loading from a config. The
# JSON file-format only allows keys to be strings. When the dataset
# is created the group ids are integers, but when dataset is loaded
# they are strings. We check the case where all keys are strings.
examples_per_group: Dict[str, int] = defaultdict(int)
for shard in slist:
examples_per_group[str(shard["group"])] += shard["examples"]
examples_per_group_config = {
str(k): v
for k, v in config["examples_per_group"][split].items()
}
if examples_per_group != examples_per_group_config:
raise ValueError(f"Wrong examples_per_group in {split}")
actual_examples = 0
for sinfo in slist:
actual_examples += sinfo["examples"]
if sinfo["examples"] != config["examples_per_shard"]:
raise ValueError(f"Wrong number of examples, expected: "
f"{config['examples_per_shard']}, got: "
f"{sinfo['examples']}, in shard: {sinfo}")
if actual_examples != expected_examples:
raise ValueError("sum example don't match top_examples")
@staticmethod
def _shallow_check(seen_keys, train_shards, pbar):
"""Check just what is in self.shards_list info (do not parse all
shards).
Args:
seen_keys: Set of all keys that are present in the test split.
train_shards: Description of train shards
(self.shards_list[Dataset.TRAIN_SPLIT]).
pbar: Either tqdm.tqdm or an identity function (in order not to
print).
"""
for shard in pbar(train_shards, desc="Checking test key uniqueness"):
k = shard["key"].lower()
list_k = [int(k[2 * i:2 * i + 2], 16) for i in range(len(k) // 2)]
cur_key = np.array(list_k, dtype=np.uint8).tobytes()
if cur_key in seen_keys:
raise ValueError(
f"Duplicate key: {k} in test split, in {shard}")
@staticmethod
def _deep_check(seen_keys, dpath, train_shards, pbar,
examples_per_shard: int, key_ap: str):
"""Check all keys from all shards (parse all shards in the train split).
Args:
seen_keys: Set of all keys that are present in the test split.
dpath: Root path of this dataset.
train_shards: Description of train shards
(self.shards_list[Dataset.TRAIN_SPLIT]).
pbar: Either tqdm.tqdm or an identity function (in order not to
print).
examples_per_shard: Number of examples in each shard.
key_ap: The attack point that is checked for when checking
disjointness of splits.
"""
for i in pbar(range(len(train_shards)),
desc="Checking test key uniqueness"):
for example in Dataset.inspect(dataset_path=dpath,
split=Dataset.TRAIN_SPLIT,
shard_id=i,
num_example=examples_per_shard,
verbose=False).as_numpy_iterator():
cur_key = example[key_ap].astype(np.uint8).tobytes()
if cur_key in seen_keys:
raise ValueError(
f"Duplicate key: {cur_key} in test split, in "
f"{train_shards[i]}")
def get_config_dictionary(self):
"""Return dictionary of information about this dataset.
Raises: ValueError if saving this dictionary using json would cause
data loss. This can be caused by having different keys with the same
string representation:
d = {0: 1, "0": 2} # JSON key collision
l = json.loads(json.dumps(d))
assert l != d
Note that it is ok to have keys of other type than string, since the
check is performed using Dataset._from_loaded_json.
"""
representation = {
"shortname": self.shortname,
"architecture": self.architecture,
"implementation": self.implementation,
"algorithm": self.algorithm,
"version": self.version,
"firmware_sha256": self.firmware_sha256,
"url": self.url,
"firmware_url": self.firmware_url,
"paper_url": self.paper_url,
"licence": self.licence,
"description": self.description,
"compression": self.compression,
"shards_list": self.shards_list,
"keys_per_group": self.keys_per_group,
"keys_per_split": self.keys_per_split,
"examples_per_group": self.examples_per_group,
"examples_per_shard": self.examples_per_shard,
"examples_per_split": self.examples_per_split,
"capture_info": self.capture_info,
"measurements_info": self.measurements_info,
"attack_points_info": self.attack_points_info,
"min_values": self.min_values,
"max_values": self.max_values,
# See scaaml.__version__ docstring for more information.
"scaaml_version": scaaml.__version__,
}
loaded = Dataset._from_loaded_json(
json.loads(json.dumps(representation)))
if loaded != representation:
pprint_file = self.path / f"info.{time()}.pprint"
pprint_file.write_text(pprint.pformat(representation))
raise ValueError(f"JSON representation causes data loss, saving "
f"into {pprint_file}")
return representation
@staticmethod
def _load_config(conf_path: Path) -> Dict:
"""Get config dictionary from a file. Use this function instead of an
json.loads, as this function returns correct types for group ids.
Args:
conf_path: Path object representing the dataset information (e.g.,
the return value of Dataset._get_config_path).
Returns: Dictionary representation of the Dataset.
"""
return Dataset._from_loaded_json(json.loads(conf_path.read_text()))
@staticmethod
def _from_loaded_json(loaded_dict: Dict) -> Dict:
"""Fix types in the datastructure loaded from JSON. Necessary as JSON
allows only string keys, but for instance group keys are integers in
Dataset.
Args:
loaded_dict: The datastructure returned by json.load on the info.json
file.
Returns: The same information with fixed types.
"""
fixed_dict = copy.deepcopy(loaded_dict)
find_misspellings(fixed_dict.keys()) # Check for misspellings of keys.
# Fix type of keys_per_group
fixed_dict["keys_per_group"] = {
split:
{int(group): n_examples for group, n_examples in keys_info.items()}
for split, keys_info in loaded_dict["keys_per_group"].items()
}
# Fix type of examples_per_group
fixed_dict["examples_per_group"] = {
split:
{int(group): n_examples for group, n_examples in ex_info.items()}
for split, ex_info in loaded_dict["examples_per_group"].items()
}
# Fix missing keys
if "licence" not in fixed_dict:
# Do not relicence
fixed_dict["licence"] = ""
for k in ["firmware_url", "paper_url"]:
if k not in fixed_dict:
fixed_dict[k] = ""
return fixed_dict
def _write_config(self):
"""Save configuration as json."""
with open(self._get_config_path(self.path), "w+",
encoding="utf-8") as f:
json.dump(self.get_config_dictionary(), f)
@staticmethod
def from_config(dataset_path: str, verbose: bool = True):
"""Load a dataset from a config file.
Args:
dataset_path: The path to the dataset.
verbose: Print config path and dataset path.