Skip to content

Commit

Permalink
Made further fixes to model.py.
Browse files Browse the repository at this point in the history
  • Loading branch information
emilianavt committed Sep 24, 2020
1 parent 2cca2cb commit cfe2e2f
Showing 1 changed file with 18 additions and 14 deletions.
32 changes: 18 additions & 14 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ def _forward_impl(self, x):
def forward(self, x):
return self._forward_impl(x)

def logit_arr(p, factor=16.0):
p = p.clamp(0.0000001, 0.9999999)
return torch.log(p / (1 - p)) / factor

# Landmark detection model
# Models:
# 0: "small", 0.5
Expand Down Expand Up @@ -162,16 +166,16 @@ def _forward_impl(self, x):

if self.inference:
t_main = x[:, 0:66].reshape((-1, 66, 28*28))
t_m = t_main.argmax(dim=1)
indices = t_m.unsqueeze(1)
t_conf = t_main.gather(1, indices).squeeze(1)
t_off_x = x[:, 66:132].reshape((-1, 66, 28*28)).gather(1, indices).squeeze(1)
t_off_y = x[:, 132:198].reshape((-1, 66, 28*28)).gather(1, indices).squeeze(1)
t_m = t_main.argmax(dim=2)
indices = t_m.unsqueeze(2)
t_conf = t_main.gather(2, indices).squeeze(2)
t_off_x = x[:, 66:132].reshape((-1, 66, 28*28)).gather(2, indices).squeeze(2)
t_off_y = x[:, 132:198].reshape((-1, 66, 28*28)).gather(2, indices).squeeze(2)
t_off_x = (223. * logit_arr(t_off_x) + 0.5).floor()
t_off_y = (223. * logit_arr(t_off_y) + 0.5).floor()
t_x = 223. * (t_m / 28.).floor() / 27. + t_off_x
t_y = 223. * t_m.remainder(28.).float() / 27. + t_off_y
x = (t_conf.mean(), torch.stack([t_x, t_y, t_conf], 1))
x = (t_conf.mean(1), torch.stack([t_x, t_y, t_conf], 2))

return x
def forward(self, x):
Expand Down Expand Up @@ -204,16 +208,16 @@ def _forward_impl(self, x):

if self.inference:
t_main = x[:, 0:30].reshape((-1, 30, 7*7))
t_m = t_main.argmax(dim=1)
indices = t_m.unsqueeze(1)
t_conf = t_main.gather(1, indices).squeeze(1)
t_off_x = x[:, 30:60].reshape((-1, 30, 7*7)).gather(1, indices).squeeze(1)
t_off_y = x[:, 60:90].reshape((-1, 30, 7*7)).gather(1, indices).squeeze(1)
t_off_x = 55. * logit_arr(t_off_x)
t_off_y = 55. * logit_arr(t_off_y)
t_m = t_main.argmax(dim=2)
indices = t_m.unsqueeze(2)
t_conf = t_main.gather(2, indices).squeeze(2)
t_off_x = x[:, 30:60].reshape((-1, 30, 7*7)).gather(2, indices).squeeze(2)
t_off_y = x[:, 60:90].reshape((-1, 30, 7*7)).gather(2, indices).squeeze(2)
t_off_x = 55. * logit_arr(t_off_x, factor=8.0)
t_off_y = 55. * logit_arr(t_off_y, factor=8.0)
t_x = 55. * (t_m / 7.).floor() / 6. + t_off_x
t_y = 55. * t_m.remainder(7.).float() / 6. + t_off_y
x = (t_conf.mean(), torch.stack([t_x, t_y, t_conf], 1))
x = (t_conf.mean(1), torch.stack([t_x, t_y, t_conf], 2))

return x
def forward(self, x):
Expand Down

0 comments on commit cfe2e2f

Please sign in to comment.