Skip to content

Commit

Permalink
Autoformat utils
Browse files Browse the repository at this point in the history
  • Loading branch information
JorisVincent committed Jun 14, 2022
1 parent da4f0d2 commit b739247
Showing 1 changed file with 86 additions and 52 deletions.
138 changes: 86 additions & 52 deletions stimuli/utils/utils.py
Expand Up @@ -2,9 +2,10 @@
Provides some functionality for creating and manipulating visual stimuli
represented as numpy arrays.
"""
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt


def write_array_to_image(filename, arr):
"""
Expand All @@ -20,7 +21,7 @@ def write_array_to_image(filename, arr):
"""
if Image:
imsize = arr.shape
im = Image.new('L', (imsize[1], imsize[0]))
im = Image.new("L", (imsize[1], imsize[0]))
im.putdata(arr.flatten())
im.save(filename)

Expand All @@ -47,9 +48,9 @@ def luminance2munsell(lum_values, reference_white):
"""

x = lum_values / float(reference_white)
idx = x <= (6. / 29) ** 3
y1 = 841. / 108 * x[idx] + 4. / 29
y2 = x[~idx] ** (1. / 3)
idx = x <= (6.0 / 29) ** 3
y1 = 841.0 / 108 * x[idx] + 4.0 / 29
y2 = x[~idx] ** (1.0 / 3)
y = np.empty(x.shape)
y[idx] = y1
y[~idx] = y2
Expand All @@ -76,8 +77,8 @@ def munsell2luminance(munsell_values, reference_white):
terms'," J. Opt. Soc. Am. 66, 866-867 (1976)
"""
lum_values = (munsell_values + 1.6) / 11.6
idx = lum_values <= 6. / 29
lum_values[idx] = (lum_values[idx] - 4. / 29) / 841 * 108
idx = lum_values <= 6.0 / 29
lum_values[idx] = (lum_values[idx] - 4.0 / 29) / 841 * 108
lum_values[~idx] **= 3
return lum_values * reference_white

Expand All @@ -102,7 +103,7 @@ def degrees_to_pixels(degrees, ppd):
return (np.round(degrees * ppd)).astype(int)

# This is the 'super correct' conversion, but it makes very little difference in practice
#return (np.tan(np.radians(degrees / 2.)) / np.tan(np.radians(.5)) * ppd).astype(int)
# return (np.tan(np.radians(degrees / 2.)) / np.tan(np.radians(.5)) * ppd).astype(int)


def pixels_to_degrees(pixels, ppd):
Expand Down Expand Up @@ -150,7 +151,7 @@ def compute_ppd(screen_size, resolution, distance):
"""

ppmm = resolution / screen_size
mmpd = 2 * np.tan(np.radians(.5)) * distance
mmpd = 2 * np.tan(np.radians(0.5)) * distance
return ppmm * mmpd


Expand Down Expand Up @@ -183,14 +184,17 @@ def pad_array(arr, amount, pad_value=0):
assert amount.amin() >= 0
if len(arr.shape) != 2:
raise NotImplementedError(
"pad_array currently only works for 2D arrays")
"pad_array currently only works for 2D arrays"
)
if amount.sum() == 0:
return arr

output_shape = [x + y.sum() for x, y in zip(arr.shape, amount)]
output = np.ones(output_shape, dtype=arr.dtype) * pad_value
output[amount[0][0]:output_shape[0] - amount[0][1],
amount[1][0]:output_shape[1] - amount[1][1]] = arr
output[
amount[0][0] : output_shape[0] - amount[0][1],
amount[1][0] : output_shape[1] - amount[1][1],
] = arr
return output


Expand Down Expand Up @@ -220,7 +224,7 @@ def center_array(arr, shape, pad_value=0):
assert (y_pad % 2 == 0) and (x_pad % 2 == 0)
assert x_pad > 0 and y_pad > 0
out = np.ones(shape, dtype=arr.dtype) * pad_value
out[y_pad / 2: -y_pad / 2, x_pad / 2: -x_pad / 2] = arr
out[y_pad / 2 : -y_pad / 2, x_pad / 2 : -x_pad / 2] = arr
return out


Expand Down Expand Up @@ -276,9 +280,15 @@ def smooth_window(shape, plateau, min_val, max_val, width):
y = np.arange(shape[0])[:, np.newaxis]
distance = np.ones(shape) * width
if len(plateau) == 2:
plateau_points = (plateau[0], (plateau[0][0], plateau[1][1]), plateau[1],
(plateau[1][0], plateau[0][1]))
distance[plateau[0][0]: plateau[1][0], plateau[0][1]: plateau[1][1]] = 0
plateau_points = (
plateau[0],
(plateau[0][0], plateau[1][1]),
plateau[1],
(plateau[1][0], plateau[0][1]),
)
distance[
plateau[0][0] : plateau[1][0], plateau[0][1] : plateau[1][1]
] = 0
else:
plateau_points = plateau
for i in range(len(plateau_points)):
Expand Down Expand Up @@ -307,28 +317,31 @@ def dist_to_segment(y, x, p1, p2): # x3,y3 is the point
if sl == 0:
return np.sqrt(dist_squared(y, x, p1))
t = ((y - p1[0]) * (p2[0] - p1[0]) + (x - p1[1]) * (p2[1] - p1[1])) / sl
dist = dist_squared(y, x, (p1[0] + t * (p2[0] - p1[0]), p1[1] + t * (p2[1] - p1[1])))
dist = dist_squared(
y, x, (p1[0] + t * (p2[0] - p1[0]), p1[1] + t * (p2[1] - p1[1]))
)
dist[t > 1] = dist_squared(y, x, p2)[t > 1]
dist[t < 0] = dist_squared(y, x, p1)[t < 0]
return np.sqrt(dist)


def shift_pixels(img, shift):
"""
Shift image by specified number of pixels. The pixels pushed on the edge will reappear on the other side (wrap around)
Shift image by specified number of pixels. The pixels pushed on the edge will reappear on the other side (wrap around)
Parameters
----------
img : 2D array representing the image to be shifted
shift: (x,y) tuple specifying the number of pixels to shift. Positive x specifies shift in the right direction
and positive y shift downwards
Parameters
----------
img : 2D array representing the image to be shifted
shift: (x,y) tuple specifying the number of pixels to shift. Positive x specifies shift in the right direction
and positive y shift downwards
Returns
-------
img : shifted image
Returns
-------
img : shifted image
"""
return np.roll(img, shift, (1, 0))


def get_circle_indices(n_numbers, grid_shape):

height, width = grid_shape
Expand All @@ -339,7 +352,7 @@ def get_circle_indices(n_numbers, grid_shape):
xx_min = np.abs(xx.min())
xx += xx_min
xx_max = xx.max()
xx = xx / xx_max * (width-1)
xx = xx / xx_max * (width - 1)

yy = np.sin(x)
yy_min = np.abs(yy.min())
Expand Down Expand Up @@ -370,76 +383,97 @@ def get_circle_mask(shape, center, radius):
xx, yy = np.mgrid[:height, :width]
grid_radii = (xx - x_c) ** 2 + (yy - y_c) ** 2

circle_mask = grid_radii < (radius ** 2)
circle_mask = grid_radii < (radius**2)

return circle_mask


def get_annulus_mask(shape, center, inner_radius, outer_radius):
"""
Get an annulus shaped mask
Get an annulus shaped mask
Parameters
-------
shape: (height, width) of the mask in pixels
radius: radius of the circle in pixels
center: width of the annulus in pixels
Parameters
-------
shape: (height, width) of the mask in pixels
radius: radius of the circle in pixels
center: width of the annulus in pixels
Returns
-------
mask: 2D boolean numpy array
"""
Returns
-------
mask: 2D boolean numpy array
"""

mask1 = get_circle_mask(shape, center, inner_radius)
mask2 = get_circle_mask(shape, center, outer_radius)
mask = np.logical_xor(mask1, mask2)

return mask


def pad_img(img, padding, ppd, val):
"""
padding: degrees visual angle (top, bottom, left, right)
"""
padding_px = np.array(degrees_to_pixels(padding, ppd), dtype=np.int32)
padding_top, padding_bottom, padding_left, padding_right = padding_px
return np.pad(img, ((int(padding_top), int(padding_bottom)), (int(padding_left), int(padding_right))), 'constant', constant_values=((val,val),(val,val)))
return np.pad(
img,
(
(int(padding_top), int(padding_bottom)),
(int(padding_left), int(padding_right)),
),
"constant",
constant_values=((val, val), (val, val)),
)


def pad_img_to_shape(img, shape, val=0):
"""
shape: shape of the resulting image in pixels (height, width)
"""
height_px, width_px = shape
height_img_px, width_img_px = img.shape
if height_img_px >= height_px or width_img_px >= width_px:
if height_img_px > height_px or width_img_px > width_px:
raise ValueError("the image is bigger than the size after padding")

padding_vertical_top = int((height_px - height_img_px) // 2)
padding_vertical_bottom = int(height_px - height_img_px - padding_vertical_top)
padding_vertical_bottom = int(
height_px - height_img_px - padding_vertical_top
)

padding_horizontal_left = int((width_px - width_img_px) // 2)
padding_horizontal_right = int(width_px - width_img_px - padding_horizontal_left)
padding_horizontal_right = int(
width_px - width_img_px - padding_horizontal_left
)

return np.pad(img, ((padding_vertical_top, padding_vertical_bottom), (padding_horizontal_left, padding_horizontal_right)), 'constant', constant_values=val)


return np.pad(
img,
(
(padding_vertical_top, padding_vertical_bottom),
(padding_horizontal_left, padding_horizontal_right),
),
"constant",
constant_values=val,
)


def compare_plots(plots):
M = len(plots)
for i, (plot_name, plot) in enumerate(plots.items()):
plt.subplot(1,M,i+1)
plt.subplot(1, M, i + 1)
plt.title(plot_name)
plt.imshow(plot, cmap='gray')
plt.imshow(plot, cmap="gray")
plt.show()


def plot_stim(stim, mask=False):
if not mask:
plt.imshow(stim['img'], cmap='gray')
plt.imshow(stim["img"], cmap="gray")
else:
plt.subplot(1,2,1)
plt.imshow(stim['img'], cmap='gray')
plt.subplot(1,2,2)
plt.imshow(stim['mask'], cmap='gray')
plt.subplot(1, 2, 1)
plt.imshow(stim["img"], cmap="gray")
plt.subplot(1, 2, 2)
plt.imshow(stim["mask"], cmap="gray")

plt.tight_layout()
plt.show()

0 comments on commit b739247

Please sign in to comment.