In [3]:
import numpy as np
import cv2

In [4]:
class Joint_bilateral_filter(object):

	def __init__(self, sigma_s, sigma_r, border_type='reflect'):

		self.border_type = border_type
		self.sigma_r = sigma_r
		self.sigma_s = sigma_s

	def spatial_kernel(self): 

		r = 3*self.sigma_s

		a = np.arange(-r, r+1)
		a = np.multiply(a, a)
		b = np.zeros((len(a), len(a)))

		for i in range(len(a)):
			for j in range(len(a)):
				b[i][j] = a[i] + a[j]



		b = b/(-2*(self.sigma_s**2))


		return np.exp(b)

	def range_kernel(self, guidance, x, y):

		r = 3*self.sigma_s
		a = np.arange(-r, r+1)
		b = np.zeros((len(a), len(a)))

		if len(guidance.shape) is 2:

			for i in range(len(a)):
				for j in range(len(a)):

					if abs(x-a[i]) < guidance.shape[0]:
						index_x = abs(x-a[i])
					else:
						index_x = guidance.shape[0] - abs(x-a[i]) - 2

					if abs(y-a[j]) < guidance.shape[0]:
						index_y = abs(y-a[j])
					else:
						index_y = guidance.shape[1] - abs(y-a[j]) - 2

					b[i][j] = (int(guidance[index_x][index_y] - int(guidance[x][y])))/255
					b[i][j] = np.dot(b[i][j], b[i][j])

		else:
			
			guidance_diff = np.zeros((len(a), len(a), guidance.shape[2]))
			
			for i in range(len(a)):
				for j in range(len(a)):

					if abs(x-a[i]) < guidance.shape[0]:
						index_x = abs(x-a[i])
					else:
						index_x = guidance.shape[0] - abs(x-a[i]) - 2

					if abs(y-a[j]) < guidance.shape[0]:
						index_y = abs(y-a[j])
					else:
						index_y = guidance.shape[1] - abs(y-a[j]) - 2

					guidance_diff[i][j] = (guidance[index_x][index_y] - guidance[x][y])/255
					b[i][j] = np.dot(guidance_diff[i][j], guidance_diff[i][j])

		b = b/(-2*(self.sigma_r**2))

		return np.exp(b)

	def ori_img(self, input, x, y):


		r = 3*self.sigma_s
		a = np.arange(-r, r+1)

		if len(input.shape) is 2:

			input_diff = np.zeros((len(a), len(a), 1))
			for i in range(len(a)):
				for j in range(len(a)):

					if abs(x-a[i]) < input.shape[0]:
						index_x = abs(x-a[i])
					else:
						index_x = input.shape[0] - abs(x-a[i]) - 2

					if abs(y-a[j]) < input.shape[0]:
						index_y = abs(y-a[j])
					else:
						index_y = input.shape[1] - abs(y-a[j]) - 2

					input_diff[i][j] = input[index_x][index_y]

		else:

			input_diff = np.zeros((len(a), len(a), input.shape[2]))

			for i in range(len(a)):
				for j in range(len(a)):

					if abs(x-a[i]) < input.shape[0]:
						index_x = abs(x-a[i])
					else:
						index_x = input.shape[0] - abs(x-a[i]) - 2

					if abs(y-a[j]) < input.shape[0]:
						index_y = abs(y-a[j])
					else:
						index_y = input.shape[1] - abs(y-a[j]) - 2


					input_diff[i][j] = input[index_x][index_y]


		return input_diff

	def joint_bilateral_filter(self, input, guidance):
		
		h_s = self.spatial_kernel()
		h_s_1 = h_s[:, :, np.newaxis]
		output = np.zeros(input.shape)

		r = 3*self.sigma_s
		a = np.arange(-r, r+1)

		

		print("bf_out", output[0][0])

		for x in range(0, input.shape[0]):
			for y in range(0, input.shape[1]):

				h_r = self.range_kernel(guidance, x, y)
				h_r_1 = h_r[:, :, np.newaxis]
				h = np.multiply(h_s_1, h_r_1)
				kernel = np.sum(h)
				

				d = np.sum(np.multiply(h, self.ori_img(input, x, y)), axis = 0)
				output[x][y] = np.sum(d, axis = 0)/kernel


		return output

In [5]:
img = cv2.imread('ex.png')
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
guidance = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img_1 = cv2.imread('ex_gt_bf.png')
img_2 = cv2.imread('ex_gt_jbf.png')

In [6]:
bf_gt = cv2.cvtColor(img_1, cv2.COLOR_BGR2RGB)	#somewhat blue
jbf_gt = cv2.cvtColor(img_2, cv2.COLOR_BGR2RGB)	

In [14]:
JBF = Joint_bilateral_filter(3, 0.1, border_type='reflect')

In [15]:
bf_out = JBF.joint_bilateral_filter(img_rgb, img_rgb).astype(np.uint8)

bf_out [0. 0. 0.]


In [26]:
for i in range(10):
    for j in range(10):
        print(bf_gt[i][j])

[225 135 121]
[225 135 121]
[225 134 121]
[225 134 119]
[225 134 119]
[225 134 118]
[225 134 117]
[224 133 115]
[224 133 114]
[224 132 113]
[225 134 121]
[225 135 121]
[225 134 120]
[225 134 119]
[225 134 119]
[225 134 118]
[225 133 116]
[225 133 115]
[224 133 114]
[224 132 112]
[225 134 120]
[225 134 120]
[225 134 119]
[225 133 118]
[225 133 117]
[225 133 117]
[225 133 115]
[225 132 114]
[224 132 113]
[224 132 112]
[225 134 118]
[225 133 117]
[225 133 117]
[225 133 116]
[225 133 115]
[225 133 115]
[225 132 114]
[225 132 112]
[224 132 112]
[224 131 110]
[225 133 116]
[225 132 115]
[225 132 115]
[225 132 114]
[225 132 113]
[225 132 113]
[225 132 113]
[225 131 111]
[225 131 110]
[224 131 110]
[225 132 112]
[225 132 113]
[225 132 113]
[225 132 113]
[225 131 112]
[225 131 112]
[225 131 111]
[225 131 110]
[225 131 109]
[225 130 109]
[225 131 111]
[225 131 111]
[225 131 111]
[225 131 111]
[225 131 110]
[225 131 110]
[225 131 110]
[225 131 109]
[225 130 108]
[225 130 108]
[225 131 111]
[225 1

In [25]:
for i in range(10):
    for j in range(10):
        print(bf_out[i][j])

[231 147 114]
[222 125 123]
[220 173 156]
