Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
smcdonagh committed Jun 1, 2020
1 parent 2c4462e commit e5b8e89
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
24 changes: 24 additions & 0 deletions loss/angular_error.py
@@ -0,0 +1,24 @@
import torch
import torch.nn as nn
from torch.autograd import Variable
from core.utils import *

class AngularError(nn.Module):
def __init__(self, conf, compute_acos, illuminant_key = 'illuminant',
gt_key = 'illuminant'):
super(AngularError, self).__init__()
self._conf = conf
self._illuminant_key = illuminant_key
self._gt_key = gt_key
self._compute_acos = compute_acos

def forward(self, outputs, data, model):
labels = Variable(data[self._gt_key])
pred = outputs[self._illuminant_key]

# angular_error_gradsafe computes differentiable angular error,
# arccos(x) is not differentiable at -1 and +1. We handle that,
# as well as 0 vector.
err = angular_error_gradsafe(pred, labels, compute_acos=self._compute_acos)

return err.mean()
57 changes: 57 additions & 0 deletions loss/ffcc.py
@@ -0,0 +1,57 @@
import torch
import math
import torch.nn as nn
from torch.autograd import Variable
from core.utils import *
from numpy.linalg import norm
import torch.nn.functional as F

# google FFCC loss
class Ffcc(nn.Module):
def __init__(self, conf, logistic_loss_epochs,
logistic_loss_mult=2.5, bvm_mult=2.5,
regularization_mult=0.5):
logistic_loss_mult = 2**logistic_loss_mult
bvm_mult = 2**bvm_mult

super(Ffcc, self).__init__()
self._conf = conf
self._bin_size = self._conf['log_uv_warp_histogram']['bin_size']

self._logistic_loss_epochs = logistic_loss_epochs
self._logistic_loss_mult = logistic_loss_mult
self._bvm_mult = bvm_mult
self._regularization_mult = regularization_mult

def forward(self, outputs, data, model):
labels = Variable(data['illuminant_log_uv'], requires_grad=False)
mu = outputs['mu']
sigma = outputs['sigma']

regularization_term = 0
for name, param in model.named_parameters():
if 'conv' not in name:
regularization_term += (param*param).sum()

# they actually use 2 losses, logistic regression for some epochs,
# then, BVM
if data['epoch'] < self._logistic_loss_epochs:
# logistic loss
gt_pdf = data['gt_pdf']
bin_probability_logits = outputs['bin_probability_logits'].squeeze(1)
logsoft = F.log_softmax(bin_probability_logits.view(bin_probability_logits.shape[0], -1), 1).view_as(bin_probability_logits)
logistic_loss_positive = (gt_pdf*logsoft).view(bin_probability_logits.shape[0], -1).sum(1)
data_term = -self._logistic_loss_mult*logistic_loss_positive.mean()
else:
# bivariate von mises
dif = (labels - mu).unsqueeze(-1)

sigma_inv = torch.inverse(sigma)
fitting_loss = torch.sum(torch.mul(torch.matmul(sigma_inv, dif), dif).squeeze(-1), 1)
logdet = batch_logdet2x2(sigma)
loss_bvm = 0.5*(fitting_loss + logdet + 2*math.log(2*math.pi))
loss_bvm_min = math.log(2*math.pi*outputs['bivariate_von_mises_epsilon']*self._bin_size*self._bin_size)
l = loss_bvm - loss_bvm_min
data_term = self._bvm_mult*l.mean()

return data_term + self._regularization_mult*regularization_term

0 comments on commit e5b8e89

Please sign in to comment.