-
Notifications
You must be signed in to change notification settings - Fork 9
/
data_loader.py
120 lines (97 loc) · 4.34 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
from torch.utils.data import Dataset, DataLoader
import os
import sys
import cv2
import numpy as np
import warnings
warnings.filterwarnings("ignore")
sys.path.append("../renderer/")
import nmr_test as nmr
import neural_renderer
class MyDataset(Dataset):
def __init__(self, data_dir, img_size, texture_size, faces, vertices, distence=None, mask_dir='', ret_mask=False):
self.data_dir = data_dir
self.files = []
files = os.listdir(data_dir)
for file in files:
if distence is None:
self.files.append(file)
else:
data = np.load(os.path.join(self.data_dir, file))
veh_trans = data['veh_trans']
cam_trans = data['cam_trans']
cam_trans[0][0] = cam_trans[0][0] + veh_trans[0][0]
cam_trans[0][1] = cam_trans[0][1] + veh_trans[0][1]
cam_trans[0][2] = cam_trans[0][2] + veh_trans[0][2]
veh_trans[0][2] = veh_trans[0][2] + 0.2
dis = (cam_trans - veh_trans)[0, :]
dis = np.sum(dis ** 2)
# print(dis)
if dis <= distence:
self.files.append(file)
print(len(self.files))
self.img_size = img_size
textures = np.ones((1, faces.shape[0], texture_size, texture_size, texture_size, 3), 'float32')
self.textures = torch.from_numpy(textures).cuda(device=0)
self.faces_var = torch.from_numpy(faces[None, :, :]).cuda(device=0)
self.vertices_var = torch.from_numpy(vertices[None, :, :]).cuda(device=0)
self.mask_renderer = nmr.NeuralRenderer(img_size=self.img_size).cuda()
self.mask_dir = mask_dir
self.ret_mask = ret_mask
# print(self.files)
def set_textures(self, textures):
self.textures = textures
def __getitem__(self, index):
# index = 5
# print(index)
file = os.path.join(self.data_dir, self.files[index])
data = np.load(file)
img = data['img']
veh_trans = data['veh_trans']
cam_trans = data['cam_trans']
cam_trans[0][0] = cam_trans[0][0] + veh_trans[0][0]
cam_trans[0][1] = cam_trans[0][1] + veh_trans[0][1]
cam_trans[0][2] = cam_trans[0][2] + veh_trans[0][2]
veh_trans[0][2] = veh_trans[0][2] + 0.2
eye, camera_direction, camera_up = nmr.get_params(cam_trans, veh_trans)
self.mask_renderer.renderer.renderer.eye = eye
self.mask_renderer.renderer.renderer.camera_direction = camera_direction
self.mask_renderer.renderer.renderer.camera_up = camera_up
imgs_pred = self.mask_renderer.forward(self.vertices_var, self.faces_var, self.textures)
# masks = imgs_pred[:, 0, :, :] | imgs_pred[:, 1, :, :] | imgs_pred[:, 2, :, :]
# print(masks.size())
img = img[:, :, ::-1]
img = cv2.resize(img, (self.img_size, self.img_size))
img = np.transpose(img, (2, 0, 1))
img = np.resize(img, (1, img.shape[0], img.shape[1], img.shape[2]))
img = torch.from_numpy(img).cuda(device=0)
# print(img.size())
# print(imgs_pred.size())
imgs_pred = imgs_pred / torch.max(imgs_pred)
# if self.ret_mask:
mask_file = os.path.join(self.mask_dir, self.files[index][:-4] + '.png')
mask = cv2.imread(mask_file)
mask = cv2.resize(mask, (self.img_size, self.img_size))
mask = np.logical_or(mask[:, :, 0], mask[:, :, 1], mask[:, :, 2])
mask = torch.from_numpy(mask.astype('float32')).cuda()
# print(mask.size())
# print(torch.max(mask))
total_img = img * (1-mask) + 255 * imgs_pred * mask
return index, total_img.squeeze(0) , imgs_pred.squeeze(0), mask
# return index, total_img.squeeze(0) , imgs_pred.squeeze(0)
def __len__(self):
return len(self.files)
if __name__ == '__main__':
obj_file = 'audi_et_te.obj'
vertices, faces, textures = neural_renderer.load_obj(filename_obj=obj_file, load_texture=True)
dataset = MyDataset('../data/phy_attack/train/', 608, 4, faces, vertices)
loader = DataLoader(
dataset=dataset,
batch_size=3,
shuffle=True,
#num_workers=2,
)
for img, car_box in loader:
print(img.size(), car_box.size())
ß