/
script_preprocessing.py
60 lines (48 loc) · 1.86 KB
/
script_preprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import numpy as np
import os
import torch
from torch.nn.utils.rnn import pad_sequence
from module_utils import draw_strokes, save_svg2png
def main():
data_file = 'sheep_market.npz'
data_dict = np.load(data_file, allow_pickle=True, encoding='bytes')
# each is a list of [n_strokes x 3] dimension drawings
train_data = data_dict['train']
train_end = len(train_data)
validation_data = data_dict['valid']
validation_end = train_end + len(validation_data)
test_data = data_dict['test']
all_data = np.concatenate((train_data, validation_data, test_data))
ex_sketch_idxs = np.random.randint(0, len(all_data), size=4)
write_dir = "sketches"
if not os.path.isdir(write_dir):
os.makedirs(write_dir)
for i in range(len(ex_sketch_idxs)):
idx = ex_sketch_idxs[i]
cur_sketch = all_data[idx]
svg_filename = f"real_ex_{str(i).zfill(4)}.svg"
svg_filepath = os.path.join(write_dir, svg_filename)
draw_strokes(cur_sketch, svg_filename=svg_filepath)
png_filepath = svg_filepath[:-3] + "png"
save_svg2png(svg_filepath, png_filepath)
os.remove(svg_filepath)
all_tensors = []
all_lens = []
for sketch in all_data:
all_tensors.append(torch.from_numpy(sketch))
all_lens.append(len(sketch))
padded_tensors = pad_sequence(
all_tensors,
batch_first=True
)
train_tensors = padded_tensors[:train_end]
validation_tensors = padded_tensors[train_end:validation_end]
test_tensors = padded_tensors[validation_end:]
tensor_dict = {
'train': train_tensors,
'validation': validation_tensors,
'test': test_tensors
}
torch.save(tensor_dict, 'sheep_market_preprocessed.pt')
if __name__ == '__main__':
main()