In [None]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import tqdm

In [None]:
# cоздаёv матрицу Y, которая представляют соседние значения каждого пикселя
def to_Y(image):
  windows = np.lib.stride_tricks.sliding_window_view(image, (3,3))
  Y = []
  for i in range(windows.shape[0]):
      for j in range(windows.shape[1]):
          window = windows[i, j].flatten()
          window_without_center = np.delete(window, 4)
          Y.append(window_without_center)

  return np.array(Y)

In [None]:
# двумерный биномиальный фильтр для фильтрации ошибок на E-шаге
def generate_binomial_filter(n):
    binomial_coeffs = [comb(n - 1, k) for k in range(n)]
    kernel = np.outer(binomial_coeffs, binomial_coeffs)
    return kernel / np.sum(kernel)

In [None]:
def em_algorithm(image, N=1, max_iterations=100):
    image = image / 255.0
    height, width = image.shape
    b_image = cv2.copyMakeBorder(image, 1, 1, 1, 1, cv2.BORDER_REFLECT_101)

    alpha = np.random.rand(2 * N + 1, 2 * N + 1) 
    alpha[N, N] = 0 #Матрица весов
    sigm = 0.005 # степень разброса вероятностцй на Е шаге
    p0 = 1 / (np.max(image) - np.min(image))
    Y = to_Y(b_image)
    y_s = image.flatten()

    h = generate_binomial_filter(3)

    w = np.zeros(image.shape)
    R = np.zeros(image.shape)
    P = np.zeros(image.shape)

    for iteration in tqdm.tqdm(range(max_iterations)):
        # E
        sum_wR = 0
        R = np.abs(cv2.filter2D(image, -1, alpha) - image)
        R = cv2.filter2D(R, -1, h)

        t = 1 / (sigm * np.sqrt(2 * np.pi))
        for y in range(0, height):
            for x in range(0, width):
                P[y, x] = t * (np.exp(-(R[y, x] ** 2) / (2 * sigm**2)))
                w[y,x] = P[y, x] / (P[y, x] + p0)
                sum_wR += w[y,x] * R[y, x]**2

        w_flat  = w.flatten()

        # M
        sigm = np.sqrt(sum_wR / w.sum())
        alpha_new = np.linalg.inv((Y.T * w_flat.T) @ Y) @ (Y.T * w_flat.T) @ y_s
        alpha_matrix = np.reshape(np.insert(alpha_new, alpha_new.size // 2, 0), (2 * N + 1, 2 * N + 1))

        if np.linalg.norm(alpha_matrix - alpha) < 1e-6:
            break

        alpha = alpha_matrix

    return P

In [None]:
input_image = cv2.imread('image2.png', cv2.IMREAD_GRAYSCALE)
P = em_algorithm(input_image)

In [None]:
plt.imshow(P, cmap="gray")

In [None]:
spectrum = np.fft.fft2(input_image)
spectrum_shift = np.fft.fftshift(spectrum)
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(20 * np.log(np.abs(spectrum_shift)), cmap='gray')
plt.axis('off')

spectrum = np.fft.fft2(P)
spectrum_shift = np.fft.fftshift(spectrum)
plt.subplot(1, 2, 2)
plt.imshow(20 * np.log(np.abs(spectrum_shift)), cmap='gray')
plt.axis('off')

In [None]:
input_image = cv2.imread('image1.jpg', cv2.IMREAD_GRAYSCALE)
P = em_algorithm(input_image)

In [None]:
spectrum = np.fft.fft2(input_image, norm = 'ortho')
spectrum_shift = np.fft.fftshift(spectrum)
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(20 * np.log(np.abs(spectrum_shift)), cmap='gray')
plt.axis('off')

spectrum = np.fft.fft2(P, norm = 'ortho')
spectrum_shift = np.fft.fftshift(spectrum)
plt.subplot(1, 2, 2)
plt.imshow(20 * np.log(np.abs(spectrum_shift)), cmap='gray')
plt.axis('off')