-
Notifications
You must be signed in to change notification settings - Fork 5
/
attenuations.py
52 lines (42 loc) · 2.12 KB
/
attenuations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
class JND(nn.Module):
""" https://ieeexplore.ieee.org/document/7885108 """
def __init__(self, preprocess = lambda x: x):
super(JND, self).__init__()
kernel_x = [[-1., 0., 1.], [-2., 0., 2.], [-1., 0., 1.]]
kernel_y = [[1., 2., 1.], [0., 0., 0.], [-1., -2., -1.]]
kernel_lum = [[1, 1, 1, 1, 1], [1, 2, 2, 2, 1], [1, 2, 0, 2, 1], [1, 2, 2, 2, 1], [1, 1, 1, 1, 1]]
kernel_x = torch.FloatTensor(kernel_x).unsqueeze(0).unsqueeze(0)
kernel_y = torch.FloatTensor(kernel_y).unsqueeze(0).unsqueeze(0)
kernel_lum = torch.FloatTensor(kernel_lum).unsqueeze(0).unsqueeze(0)
self.weight_x = nn.Parameter(data=kernel_x, requires_grad=False)
self.weight_y = nn.Parameter(data=kernel_y, requires_grad=False)
self.weight_lum = nn.Parameter(data=kernel_lum, requires_grad=False)
self.preprocess = preprocess
def jnd_la(self, x, alpha=1.0, eps=1e-3):
""" Luminance masking: x must be in [0,255] """
la = F.conv2d(x, self.weight_lum, padding=2) / 32
mask_lum = la <= 127
la[mask_lum] = 17 * (1 - torch.sqrt(la[mask_lum]/127 + eps)) + 3
la[~mask_lum] = 3/128 * (la[~mask_lum] - 127) + 3
return alpha * la
def jnd_cm(self, x, beta=0.117):
""" Contrast masking: x must be in [0,255] """
grad_x = F.conv2d(x, self.weight_x, padding=1)
grad_y = F.conv2d(x, self.weight_y, padding=1)
cm = torch.sqrt(grad_x**2 + grad_y**2)
cm = 16 * cm**2.4 / (cm**2 + 26**2)
return beta * cm
def heatmaps(self, x, clc=0.3):
""" x must be in [0,1] """
x = 255 * self.preprocess(x)
x = 0.299 * x[...,0:1,:,:] + 0.587 * x[...,1:2,:,:] + 0.114 * x[...,2:3,:,:]
la = self.jnd_la(x)
cm = self.jnd_cm(x)
return torch.clamp_min(la + cm - clc * torch.minimum(la, cm), 5)/255 # b 1 h w