In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import math

#### 读取光场图像

In [None]:
image_file = 'chessboard.png'
h,w = 400,700
img = cv2.imread(image_file)
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)

# Extract sub-aperture images with different view point from light field image
sub_aperture = np.zeros((16,16,400,700,3),dtype=np.uint8)
# TODO 
height, width, _ = img.shape
N = 16
for i in range(N):
    for j in range(N):
        sub_aperture[i, j, :, :, :] = img[i:height:N, j:width:N, :]


# Recombine image sequence into [16*16,400,700,3] and display
img_concate = np.zeros_like(img)
# TODO
for i in range(N):
    for j in range(N):
        img_concate[i*h:(i+1)*h, j*w:(j+1)*w, :] = sub_aperture[i, j, :, :, :]
cv2.imwrite('./img_concate.png', img_concate)


plt.axis("off")
plt.title("Light Field Images")
plt.imshow(img_concate)

#### 重对焦

$$I_{\text{refoucs}} = \frac{1}{uv}\int_u\int_v L(u,v,s+d\cdot u,t+d\cdot v,c)\text{d}v\text{d}u$$

In [None]:
N = 16
def refocus(d):
    # TODO
    LF_pic = np.zeros((h, w, 3))
    for i in range(N):
        for j in range(N):
            each_img = np.squeeze(sub_aperture[i, j, :, :, :])
            for x in range(h):
                for y in range(w):
                    LF_pic[x, y, 0] = LF_pic[x, y, 0] + each_img[min(399, math.floor(x+d*i)), min(699, math.floor(y-d*j)), 0]
                    LF_pic[x, y, 1] = LF_pic[x, y, 1] + each_img[min(399, math.floor(x+d*i)), min(699, math.floor(y-d*j)), 1]
                    LF_pic[x, y, 2] = LF_pic[x, y, 2] + each_img[min(399, math.floor(x+d*i)), min(699, math.floor(y-d*j)), 2]
    LF_pic = LF_pic // (N*N)
    
    return LF_pic.astype(np.uint8)

In [None]:
# Refocus at the top of image
disp = 0
image_top = refocus(disp)

plt.axis("off")
plt.title("Refocus at the top")
plt.imshow(image_top)
cv2.imwrite('./top.png', image_top)

In [None]:
# Refocus at the middle of image
# TODO 
# change the displacement
disp = 0.6

image_mid = refocus(disp)
plt.axis("off")
plt.title("Refocus at the middle")
plt.imshow(image_mid)
cv2.imwrite('./mid.png', image_mid)

In [None]:
# Refocus at the bottom of image 
# TODO
# change the displacement
disp = 1.3

image_bottom = refocus(disp)
plt.axis("off")
plt.title("Refocus at the bottom")
plt.imshow(image_bottom)
cv2.imwrite('./down.png', image_bottom)

#### 焦点堆栈

In [None]:
# synthesize focal stack from light field image

focal_stack = []
# # TODO
focal_stack.append(image_top)
focal_stack.append(image_mid)
focal_stack.append(image_bottom)
disp_list = [0.2, 0.7, 1.2, 0.3, 0.9, 1.4, 0.4, 1.1, 1.5]
for disp in disp_list:
    focal_stack.append(refocus(disp))

# display the focal stack
# TODO 
stack_img1 = image_top
stack_img2 = image_mid
stack_img3 = image_bottom
for i in range(1, 4):
    stack_img1 = np.concatenate((stack_img1, focal_stack[i]), axis=0)
for i in range(5, 8):
    stack_img2 = np.concatenate((stack_img2, focal_stack[i]), axis=0)
for i in range(9, 12):
    stack_img3 = np.concatenate((stack_img3, focal_stack[i]), axis=0)
final_stack_img = np.concatenate((stack_img1, stack_img2), axis=1)
final_stack_img = np.concatenate((final_stack_img, stack_img3), axis=1)
plt.axis('off')
plt.imshow(final_stack_img)
cv2.imwrite('./focal_stack.png', final_stack_img)

#### 全对焦图像生成

- 对于每一张图像，首先并从RGB图像中抽取亮度通道：$$I_{\text{Iumincance}}(s,t,d) = \text{get\_luminance}(I(s,t,c,d))$$ 
- 利用一个标准差为$\sigma_1$高斯卷积核从图像中提取低频信息：$$I_{\text{low-freq}}(s,t,d) = G_{\sigma_1}(s,t)*I_{\text{Iumincance}}(s,t,d)$$
- 将原始图像减去低频信息得到图像的高频信息：$$I_{\text{high-freq}}(s,t,d) = I_{\text{Iumincance}}(s,t,d) - I_{\text{low-freq}}(s,t,d)$$
- 利用一个标准差$\sigma_2$高斯卷积核从图像的高频信息中估计出图像的锐度权重：$$\omega_{\text{sharpness}}(s,t,d) = G_{\sigma_2}(s,t) * (I_{\text{high-freq}}(s,t,d))^2$$
- 基于图像的锐度所计算的权重，可以生成一张全对焦图像：$$I_{\text{all-in-focus}} = \frac{\sum_d \omega_{\text{sharpness}}(s,t,d)I(s,t,c,d)}{\sum_d \omega_{\text{sharpness}}(s,t,d)}$$


In [None]:
# All-in-focus image

all_in_focus = np.zeros((h,w,3),dtype = np.uint8)
# TODO

weights = []
for image in focal_stack:
    image_lum = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY).astype(np.float)
    kernel = cv2.getGaussianKernel(13, 0)
    image_low_freq = cv2.filter2D(image_lum, -1, kernel)
    image_high_freq = image_lum - image_low_freq
    kernel = cv2.getGaussianKernel(15, 0)
    weight = cv2.filter2D(image_high_freq**2, -1, kernel)
    weights.append(weight)
numerator, denominator = 0, 0
for i in range(len(weights)):
    weight = np.repeat(weights[i], 3).reshape(h, w, 3)
    numerator += weight * focal_stack[i]
    denominator += weight
all_in_focus = (numerator / denominator).astype(np.uint8)
plt.axis("off")
plt.title("All-in-focus imaging")
plt.imshow(all_in_focus)
cv2.imwrite('./refocus.png', all_in_focus)