Skip to content

Commit

Permalink
Merge branch 'drwalton-master'
Browse files Browse the repository at this point in the history
  • Loading branch information
kaanaksit committed Jul 30, 2022
2 parents e0153b9 + 8913c4a commit 7de1766
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 50 deletions.
2 changes: 1 addition & 1 deletion odak/learn/perception/blur_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class BlurLoss():


def __init__(self, device=torch.device("cpu"),
alpha=0.08, real_image_width=0.2, real_viewing_distance=0.7, mode="quadratic", blur_source=False):
alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode="quadratic", blur_source=False):
"""
Parameters
----------
Expand Down
2 changes: 1 addition & 1 deletion odak/learn/perception/foveation.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def make_pooling_size_map_pixels(gaze_location, image_pixel_size, alpha=0.3, rea
major_axis = (torch.tan(angle_max) - torch.tan(angle_min)) / \
real_viewing_distance
minor_axis = 2 * distance_to_pixel * torch.tan(pooling_rad*0.5)
area = math.pi * major_axis * minor_axis
area = math.pi * major_axis * minor_axis * 0.25
# Should be +ve anyway, but check to ensure we don't take sqrt of negative number
area = torch.abs(area)
pooling_real = torch.sqrt(area)
Expand Down
21 changes: 9 additions & 12 deletions odak/learn/perception/metamer_mse_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from .metameric_loss import MetamericLoss
from .color_conversion import ycrcb_2_rgb, rgb_2_ycrcb
from .spatial_steerable_pyramid import pad_image_for_pyramid


class MetamerMSELoss():
Expand All @@ -16,7 +17,7 @@ class MetamerMSELoss():


def __init__(self, device=torch.device("cpu"),
alpha=0.08, real_image_width=0.2, real_viewing_distance=0.7, mode="quadratic",
alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode="quadratic",
n_pyramid_levels=5, n_orientations=2):
"""
Parameters
Expand Down Expand Up @@ -66,6 +67,9 @@ def gen_metamer(self, image, gaze):
The generated metamer image
"""
image = rgb_2_ycrcb(image)
image_size = image.size()
image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)

target_stats = self.metameric_loss.calc_statsmaps(
image, gaze=gaze, alpha=self.metameric_loss.alpha)
target_means = target_stats[::2]
Expand Down Expand Up @@ -101,6 +105,8 @@ def match_level(input_level, target_mean, target_std):
metamer = self.metameric_loss.pyramid_maker.reconstruct_from_pyramid(
noise_pyramid)
metamer = ycrcb_2_rgb(metamer)
# Crop to remove any padding
metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]
return metamer

def __call__(self, image, target, gaze=[0.5, 0.5]):
Expand All @@ -123,17 +129,8 @@ def __call__(self, image, target, gaze=[0.5, 0.5]):
The computed loss.
"""
# Pad image and target if necessary
min_divisor = 2 ** self.metameric_loss.n_pyramid_levels
height = image.size(2)
width = image.size(3)
required_height = math.ceil(height / min_divisor) * min_divisor
required_width = math.ceil(width / min_divisor) * min_divisor
if required_height > height or required_width > width:
# We need to pad!
pad = torch.nn.ReflectionPad2d(
(0, 0, required_height-height, required_width-width))
image = pad(image)
target = pad(target)
image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)
target = pad_image_for_pyramid(target, self.metameric_loss.n_pyramid_levels)

if target is not self.target or self.target is None:
self.target_metamer = self.gen_metamer(target, gaze)
Expand Down
17 changes: 4 additions & 13 deletions odak/learn/perception/metameric_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import math

from .color_conversion import ycrcb_2_rgb, rgb_2_ycrcb
from .spatial_steerable_pyramid import SpatialSteerablePyramid
from .spatial_steerable_pyramid import SpatialSteerablePyramid, pad_image_for_pyramid
from .radially_varying_blur import RadiallyVaryingBlur
from .foveation import make_radial_map

Expand All @@ -18,7 +18,7 @@ class MetamericLoss():
"""


def __init__(self, device=torch.device('cpu'), alpha=0.08, real_image_width=0.2,
def __init__(self, device=torch.device('cpu'), alpha=0.2, real_image_width=0.2,
real_viewing_distance=0.7, n_pyramid_levels=5, mode="quadratic",
n_orientations=2, use_l2_foveal_loss=True, fovea_weight=20.0, use_radial_weight=False,
use_fullres_l0=False):
Expand Down Expand Up @@ -213,17 +213,8 @@ def __call__(self, image, target, gaze=[0.5, 0.5], image_colorspace="RGB", visua
raise Exception(
"MetamericLoss ERROR: Input and target must have same number of channels.")
# Pad image and target if necessary
min_divisor = 2**self.n_pyramid_levels
height = image.size(2)
width = image.size(3)
required_height = math.ceil(height/min_divisor)*min_divisor
required_width = math.ceil(width/min_divisor)*min_divisor
if required_height > height or required_width > width:
# We need to pad!
pad = torch.nn.ReflectionPad2d(
(0, 0, required_height-height, required_width-width))
image = pad(image)
target = pad(target)
image = pad_image_for_pyramid(image, self.n_pyramid_levels)
target = pad_image_for_pyramid(target, self.n_pyramid_levels)
if image.size(1) == 3 and image_colorspace == "RGB":
image = rgb_2_ycrcb(image)
target = rgb_2_ycrcb(target)
Expand Down
15 changes: 3 additions & 12 deletions odak/learn/perception/metameric_loss_uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import math

from .color_conversion import ycrcb_2_rgb, rgb_2_ycrcb
from .spatial_steerable_pyramid import SpatialSteerablePyramid
from .spatial_steerable_pyramid import SpatialSteerablePyramid, pad_image_for_pyramid
from .radially_varying_blur import RadiallyVaryingBlur
from .foveation import make_radial_map

Expand Down Expand Up @@ -132,17 +132,8 @@ def __call__(self, image, target, image_colorspace="RGB", visualise_loss=False):
raise Exception(
"MetamericLoss ERROR: Input and target must have same number of channels.")
# Pad image and target if necessary
min_divisor = 2**self.n_pyramid_levels
height = image.size(2)
width = image.size(3)
required_height = math.ceil(height/min_divisor)*min_divisor
required_width = math.ceil(width/min_divisor)*min_divisor
if required_height > height or required_width > width:
# We need to pad!
pad = torch.nn.ReflectionPad2d(
(0, 0, required_height-height, required_width-width))
image = pad(image)
target = pad(target)
image = pad_image_for_pyramid(image, self.n_pyramid_levels)
target = pad_image_for_pyramid(target, self.n_pyramid_levels)
if image.size(1) == 3 and image_colorspace == "RGB":
image = rgb_2_ycrcb(image)
target = rgb_2_ycrcb(target)
Expand Down
2 changes: 1 addition & 1 deletion odak/learn/perception/radially_varying_blur.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class RadiallyVaryingBlur():
def __init__(self):
self.lod_map = None

def blur(self, image, alpha=0.08, real_image_width=0.2, real_viewing_distance=0.7, centre=None, mode="quadratic"):
def blur(self, image, alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, centre=None, mode="quadratic"):
"""
Apply the radially varying blur to an image.
Expand Down
32 changes: 26 additions & 6 deletions odak/learn/perception/spatial_steerable_pyramid.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,33 @@
from .steerable_pyramid_filters import get_steerable_pyramid_filters
import torch
import numpy as np
import math

def pad_image_for_pyramid(image, n_pyramid_levels):
"""
Pads an image to the extent necessary to compute a steerable pyramid of the input image.
This involves padding so both height and width are divisible by 2**n_pyramid_levels.
Uses reflection padding.
Parameters
----------
image: torch.tensor
Image to pad, in NCHW format
n_pyramid_levels: int
Number of levels in the pyramid you plan to construct.
"""
min_divisor = 2 ** n_pyramid_levels
height = image.size(2)
width = image.size(3)
required_height = math.ceil(height / min_divisor) * min_divisor
required_width = math.ceil(width / min_divisor) * min_divisor
if required_height > height or required_width > width:
# We need to pad!
pad = torch.nn.ReflectionPad2d(
(0, 0, required_height-height, required_width-width))
return pad(image)
return image


class SpatialSteerablePyramid():
Expand Down Expand Up @@ -108,12 +134,8 @@ def construct_pyramid(self, image, n_levels, multiple_highpass=False):
level0 = {}
level0['h'] = torch.nn.functional.conv2d(
self.pad_h0(image), self.filt_h0)
#plt.imshow(level0['h'][0,0,...], cmap="gray", vmin=0, vmax=1)
# plt.show()
lowpass = torch.nn.functional.conv2d(self.pad_l0(image), self.filt_l0)
level0['l'] = lowpass.clone()
#np.save("lowpass_filtered.npy", level0['l'][0,...].permute(1,2,0).numpy())
# quit()
bands = []
for filt_b in self.band_filters:
bands.append(torch.nn.functional.conv2d(
Expand All @@ -138,8 +160,6 @@ def construct_pyramid(self, image, n_levels, multiple_highpass=False):
self.pad_b(lowpass), filt_b))
level['b'] = bands
if multiple_highpass:
#downsampled = torch.nn.functional.interpolate(image, scale_factor=0.5, mode="area")
#level['h'] = torch.nn.functional.conv2d(self.pad_h0(downsampled), self.filt_h0)
level['h'] = torch.nn.functional.conv2d(
self.pad_h0(lowpass), self.filt_h0)
pyramid.append(level)
Expand Down
4 changes: 0 additions & 4 deletions odak/learn/perception/steerable_pyramid_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@ def crop_filter(filter, r, normalise=True):
sum_l0 = torch.sum(filters["l0"])
filters["l0"] = crop_filter(filters["l0"], 2, normalise=False)
filters["l0"] *= sum_l0 / torch.sum(filters["l0"])
# l0_sum = torch.sum(filters["l0"])
# filters["l0"] = crop_filter(filters["l0"], r)
# filters["l0"] /= torch.sum(filters["l0"])
# filters["l0"] *= l0_sum
for b in range(len(filters["b"])):
filters["b"][b] = crop_filter(filters["b"][b], r, normalise=True)
return filters
Expand Down

0 comments on commit 7de1766

Please sign in to comment.