/
convert_gt.py
105 lines (97 loc) · 3.49 KB
/
convert_gt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import argparse
import csv
import json
from pathlib import Path
from tqdm import tqdm
from data_preparation.loader import iam, vnondb
def main():
parser = argparse.ArgumentParser(
description="Convert ground truth to be fully contained in JSON files"
)
parser.add_argument(
"-d",
"--data",
dest="data",
required=True,
type=Path,
help="Path to the handwriting data containing the strokes",
)
parser.add_argument(
"-s",
"--segmentations",
dest="segmentations",
required=True,
type=Path,
nargs="+",
help="Path to the JSON file containing the segmentations",
)
parser.add_argument(
"-t",
"--type",
dest="data_type",
type=str,
choices=["iam", "vnondb"],
default="iam",
help="Which type of data is given",
)
parser.add_argument(
"-o",
"--out-dir",
dest="out_dir",
required=True,
type=Path,
help="Output directory to save the JSON file with their ground truths",
)
options = parser.parse_args()
if options.data_type == "iam":
drawings = iam.get_all_drawings(options.data)
elif options.data_type == "vnondb":
drawings = vnondb.get_all_drawings(options.data)
else:
raise ValueError(f"Data type {options.data_type} is not supported")
drawings_dict = {d.key: d for d in drawings}
options.out_dir.mkdir(parents=True, exist_ok=True)
pbar = tqdm(total=len(options.segmentations), leave=False, dynamic_ncols=True)
for seg_path in options.segmentations:
pbar.set_description(desc=f"Converting {seg_path}")
with open(seg_path, "r", encoding="utf-8") as fd:
segmentation = json.load(fd)
# GT file that serves as an index given as <dataset-name>.tsv
gt_fd = open(options.out_dir / f"{seg_path.stem}.tsv", "w", encoding="utf-8")
writer = csv.writer(
gt_fd, delimiter="\t", quoting=csv.QUOTE_NONE, quotechar=None
)
seg_dir = options.out_dir / seg_path.stem
seg_dir.mkdir(parents=True, exist_ok=True)
for key, seg in segmentation.items():
drawing = drawings_dict.get(key)
if drawing is None:
print(f"No drawing found for segmentation: {key} - SKIPPING")
continue
out_segmentation = dict(
key=key,
**seg,
# For convenience (it's already present, but text is clearer)
text=seg["ctc_spike_symbols"],
# x_start, y_start, x_end, y_end
bbox=[
drawing.bbox.left,
drawing.bbox.top,
drawing.bbox.right,
drawing.bbox.bottom,
],
# Points have information about x, y, time, index, stroke
points=[vars(p) for p in drawing.all_points()],
)
# Individual segmentation files go into a directory with the same name,
# i.e. <dataset-name>/<file-name>.json
single_seg_path = seg_dir / f"{drawing.key}.json"
with open(single_seg_path, "w", encoding="utf-8") as fd:
json.dump(out_segmentation, fd, indent=2)
writer.writerow([single_seg_path.relative_to(options.out_dir)])
gt_fd.close()
pbar.write(f"✔ {seg_path} -> {seg_dir}")
pbar.update()
pbar.close()
if __name__ == "__main__":
main()