-
Notifications
You must be signed in to change notification settings - Fork 250
/
example4.py
102 lines (81 loc) · 3.44 KB
/
example4.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
Example 4. Finding camera parameters.
"""
import os
import argparse
import glob
import torch
import torch.nn as nn
import numpy as np
from skimage.io import imread, imsave
import tqdm
import imageio
import neural_renderer as nr
current_dir = os.path.dirname(os.path.realpath(__file__))
data_dir = os.path.join(current_dir, 'data')
class Model(nn.Module):
def __init__(self, filename_obj, filename_ref=None):
super(Model, self).__init__()
# load .obj
vertices, faces = nr.load_obj(filename_obj)
self.register_buffer('vertices', vertices[None, :, :])
self.register_buffer('faces', faces[None, :, :])
# create textures
texture_size = 2
textures = torch.ones(1, self.faces.shape[1], texture_size, texture_size, texture_size, 3, dtype=torch.float32)
self.register_buffer('textures', textures)
# load reference image
image_ref = torch.from_numpy((imread(filename_ref).max(-1) != 0).astype(np.float32))
self.register_buffer('image_ref', image_ref)
# camera parameters
self.camera_position = nn.Parameter(torch.from_numpy(np.array([6, 10, -14], dtype=np.float32)))
# setup renderer
renderer = nr.Renderer(camera_mode='look_at')
renderer.eye = self.camera_position
self.renderer = renderer
def forward(self):
image = self.renderer(self.vertices, self.faces, mode='silhouettes')
loss = torch.sum((image - self.image_ref[None, :, :]) ** 2)
return loss
def make_gif(filename):
with imageio.get_writer(filename, mode='I') as writer:
for filename in sorted(glob.glob('/tmp/_tmp_*.png')):
writer.append_data(imread(filename))
os.remove(filename)
writer.close()
def make_reference_image(filename_ref, filename_obj):
model = Model(filename_obj)
model.cuda()
model.renderer.eye = nr.get_points_from_angles(2.732, 30, -15)
images, _, _ = model.renderer.render(model.vertices, model.faces, torch.tanh(model.textures))
image = images.detach().cpu().numpy()[0]
imsave(filename_ref, image)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-io', '--filename_obj', type=str, default=os.path.join(data_dir, 'teapot.obj'))
parser.add_argument('-ir', '--filename_ref', type=str, default=os.path.join(data_dir, 'example4_ref.png'))
parser.add_argument('-or', '--filename_output', type=str, default=os.path.join(data_dir, 'example4_result.gif'))
parser.add_argument('-mr', '--make_reference_image', type=int, default=0)
parser.add_argument('-g', '--gpu', type=int, default=0)
args = parser.parse_args()
if args.make_reference_image:
make_reference_image(args.filename_ref, args.filename_obj)
model = Model(args.filename_obj, args.filename_ref)
model.cuda()
# optimizer = chainer.optimizers.Adam(alpha=0.1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
loop = tqdm.tqdm(range(1000))
for i in loop:
optimizer.zero_grad()
loss = model()
loss.backward()
optimizer.step()
images, _, _ = model.renderer(model.vertices, model.faces, torch.tanh(model.textures))
image = images.detach().cpu().numpy()[0].transpose(1,2,0)
imsave('/tmp/_tmp_%04d.png' % i, image)
loop.set_description('Optimizing (loss %.4f)' % loss.data)
if loss.item() < 70:
break
make_gif(args.filename_output)
if __name__ == '__main__':
main()