-
Notifications
You must be signed in to change notification settings - Fork 5
/
generate-train-val-splits
53 lines (43 loc) · 1.84 KB
/
generate-train-val-splits
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#!/usr/bin/env python
import argparse
import os
import pathlib
def main(args):
# In order to make sure that we do not add images originating from the same
# background into both splits, create a dictionary that maps from background
# to annotations.
filename_to_lines = dict()
with args.csv_datafile.open() as file_handle:
for line in file_handle:
filename, x0, y0, x1, y1, category = line.rstrip().split(',')
_, background_name = os.path.basename(filename).split('_')
filename_to_lines.setdefault(background_name, []).append(line)
# Generate splits. Here we silently assume that all background images
# end up the same number of times in the final dataset.
filenames = list(filename_to_lines.keys())
split = int(round(len(filenames) * (100.0 - args.s) / 100.0))
# Write training data file.
dataset_name, _ = os.path.splitext(args.csv_datafile.name)
train_filepath = args.csv_datafile.parent.joinpath(
"{}_train.csv".format(dataset_name))
with train_filepath.open(mode="w") as file_handle:
for filename in filenames[:split]:
file_handle.writelines(filename_to_lines[filename])
# Write validation data to file.
valid_filepath = args.csv_datafile.parent.joinpath(
"{}_valid.csv".format(dataset_name))
with valid_filepath.open(mode="w") as file_handle:
for filename in filenames[split:]:
file_handle.writelines(filename_to_lines[filename])
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("csv_datafile",
type=pathlib.Path,
help="Input dataset 'multiclass.csv' file.")
parser.add_argument("-s",
type=float,
default=20.0,
help="Percentage of data to split off for validation")
return parser.parse_args()
if __name__ == "__main__":
main(parse_args())