-
Notifications
You must be signed in to change notification settings - Fork 34
/
prepare_train_data.py
143 lines (125 loc) · 6.12 KB
/
prepare_train_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from __future__ import division
import argparse
import scipy.misc
import numpy as np
from glob import glob
from joblib import Parallel, delayed
import os
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_dir", type=str, required=True, help="where the dataset is stored")
parser.add_argument("--sparse_data_dir", type=str, default=None, help="sparse data directory")
parser.add_argument("--dataset_name", type=str, required=True, choices=["kitti_raw_eigen", "kitti_raw_stereo", "kitti_odom", "cityscapes"])
parser.add_argument("--dump_root", type=str, required=True, help="Where to dump the data")
parser.add_argument("--seq_length", type=int, required=True, help="Length of each training sequence")
parser.add_argument("--img_height", type=int, default=128, help="image height")
parser.add_argument("--img_width", type=int, default=416, help="image width")
parser.add_argument("--num_threads", type=int, default=4, help="number of threads to use")
parser.add_argument("--match_num", type=int, default=0, help="number of sampled match pairs")
parser.add_argument("--skip_image", type=bool, default=False, help="do not generate images")
parser.add_argument("--generate_test", type=bool, default=False, help="generate test images")
args = parser.parse_args()
def concat_image_seq(seq):
for i, im in enumerate(seq):
if i == 0:
res = im
else:
res = np.hstack((res, im))
return res
def dump_example(n, is_training):
if is_training:
frame_num = data_loader.num_train
example = data_loader.get_train_example_with_idx(n)
else:
frame_num = data_loader.num_test
example = data_loader.get_test_example_with_idx(n)
if example == False:
return
dump_dir = os.path.join(args.dump_root, example['folder_name'])
try:
os.makedirs(dump_dir)
except OSError:
if not os.path.isdir(dump_dir):
raise
if n % 2000 == 0:
print('Progress %d/%d....' % (n, frame_num))
# save image file
if not args.skip_image:
image_seq = concat_image_seq(example['image_seq'])
dump_img_file = dump_dir + '/%s.jpg' % example['file_name']
scipy.misc.imsave(dump_img_file, image_seq.astype(np.uint8))
# save camera info
if is_training:
intrinsics = example['intrinsics']
fx = intrinsics[0, 0]
fy = intrinsics[1, 1]
cx = intrinsics[0, 2]
cy = intrinsics[1, 2]
dump_cam_file = dump_dir + '/%s_cam.txt' % example['file_name']
with open(dump_cam_file, 'w') as f:
f.write('%f,0.,%f,0.,%f,%f,0.,0.,1.\n' % (fx, cx, fy, cy))
if 'pose' in example:
poses = example['pose']
for each_pose in poses:
f.write(','.join([str(num) for num in each_pose])+'\n')
if 'match' in example:
matches = example['match']
for match in matches:
for i in range(match.shape[0]):
f.write(','.join([str(match[i,j]) for j in range(4)])+'\n')
def main():
if not os.path.exists(args.dump_root):
os.makedirs(args.dump_root)
global data_loader
if args.dataset_name == 'kitti_odom':
from kitti.kitti_odom_loader import kitti_odom_loader
data_loader = kitti_odom_loader(args.dataset_dir,
args.sparse_data_dir,
args.match_num,
img_height=args.img_height,
img_width=args.img_width,
seq_length=args.seq_length)
if args.dataset_name == 'kitti_raw_eigen':
from kitti.kitti_raw_loader import kitti_raw_loader
data_loader = kitti_raw_loader(args.dataset_dir,
split='eigen',
match_num=args.match_num,
img_height=args.img_height,
img_width=args.img_width,
seq_length=args.seq_length)
if args.dataset_name == 'kitti_raw_stereo':
from kitti.kitti_raw_loader import kitti_raw_loader
data_loader = kitti_raw_loader(args.dataset_dir,
split='stereo',
match_num=args.match_num,
img_height=args.img_height,
img_width=args.img_width,
seq_length=args.seq_length)
if args.dataset_name == 'cityscapes':
from cityscapes.cityscapes_loader import cityscapes_loader
data_loader = cityscapes_loader(args.dataset_dir,
split='train',
match_num=args.match_num,
img_height=args.img_height,
img_width=args.img_width,
seq_length=args.seq_length)
if args.generate_test:
Parallel(n_jobs=args.num_threads)(delayed(dump_example)(n, is_training=False) for n in range(data_loader.num_test))
else:
Parallel(n_jobs=args.num_threads)(delayed(dump_example)(n, is_training=True) for n in range(data_loader.num_train))
# Split into train/val
np.random.seed(8964)
subfolders = os.listdir(args.dump_root)
with open(args.dump_root + 'train.txt', 'w') as tf:
with open(args.dump_root + 'val.txt', 'w') as vf:
for s in subfolders:
if not os.path.isdir(args.dump_root + '/%s' % s):
continue
imfiles = glob(os.path.join(args.dump_root, s, '*.jpg'))
frame_ids = [os.path.basename(fi).split('.')[0] for fi in imfiles]
for frame in frame_ids:
if np.random.random() < 0.1:
vf.write('%s %s\n' % (s, frame))
else:
tf.write('%s %s\n' % (s, frame))
if __name__ == '__main__':
main()