# 3D Low pass filtering OF-driven

In a tomogram, compute the OF between adjacent slices and apply a low pass filtering in each cartesian direction, conducted by the motion vectors. The process is repeated until convergence (energy of the modification of the tomogram smaller than a threshold).

In each direction, the i-th slice is warped to an OF-compensated i-th sliced that is averaged with the (i+1)-th slice, initialized with the pixels of the (i+1)-th slice. The resulting slice replaces the (i+1)-th one, resulting in that all the slices except the first one will be filtered. Then, the process is repeated in the oposite direction.

In [None]:
!pip install mrcfile
import mrcfile
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
#%matplotlib notebook
from ipywidgets import *
import cv2
from google.colab.patches import cv2_imshow
import math

In [None]:
# Mount Google Drive.
from google.colab import drive
drive.mount('/content/drive')

## Configuration

In [None]:
tomogram_name = "epfl1_subset1"
max_iterations = 5 # Maximum number of iterations
w = 3              # Window size used in Farneback
l = 2              # Number of levels used in Farneback

In [None]:
!cp drive/Shareddrives/MissingWedge/tomograms/{tomogram_name}.mrc .

In [None]:
ofca_extension_mode = cv2.BORDER_REPLICATE

def make_prediction(reference: np.ndarray, MVs: np.ndarray) -> np.ndarray:
    height, width = MVs.shape[:2]
    map_x = np.tile(np.arange(width), (height, 1))
    map_y = np.swapaxes(np.tile(np.arange(height), (width, 1)), 0, 1)
    map_xy = (MVs + np.dstack((map_x, map_y))).astype('float32')
    return cv2.remap(reference, map_xy, None, interpolation=cv2.INTER_LINEAR, borderMode=ofca_extension_mode)

def x_x1(tomogram, w=5, l=3):
  print("x -> x+1", end=' ')
  x_x1_tomogram = np.zeros_like(tomogram).astype(np.float64)
  for x in range(tomogram.shape[2] - 1):
    #print(f"x_x1 {x}/{tomogram.shape[2] - 1}", end='', flush=True)
    print('.', end='')
    slice_x = tomogram[:,:,x].astype(np.float64)
    slice_x1 = tomogram[:,:,x + 1].astype(np.float64)
    flow = cv2.calcOpticalFlowFarneback(prev=slice_x1, next=slice_x, flow=None, pyr_scale=0.5, levels=l, winsize=w, iterations=3, poly_n=5, poly_sigma=1.2, flags=0)
    prediction_x1 = make_prediction(slice_x, flow)
    average_x1 = (prediction_x1 + slice_x1)/2
    x_x1_tomogram[:,:,x + 1] = average_x1[:,:]
  print()
  return x_x1_tomogram

def x1_x(tomogram, w=5, l=3):
  print("x+1 -> x", end=' ')
  x1_x_tomogram = np.zeros_like(tomogram).astype(np.float64)
  for x in range(tomogram.shape[2] - 1):
    #print(f"\rx1_x {x}/{tomogram.shape[2] - 1}", flush=True)
    print('.', end='')
    slice_x = tomogram[:,:,x].astype(np.float64)
    slice_x1 = tomogram[:,:,x + 1].astype(np.float64)
    flow = cv2.calcOpticalFlowFarneback(prev=slice_x, next=slice_x1, flow=None, pyr_scale=0.5, levels=l, winsize=w, iterations=3, poly_n=5, poly_sigma=1.2, flags=0)
    prediction_x = make_prediction(slice_x1, flow)
    average_x = (prediction_x + slice_x)/2
    x1_x_tomogram[..., x] = average_x[:,:]
  print()
  return x1_x_tomogram

def y_y1(tomogram, w=5, l=3):
  print("y -> y+1", end=' ')
  y_y1_tomogram = np.zeros_like(tomogram).astype(np.float64)
  for y in range(tomogram.shape[1] - 1):
    #print(f"\ry_y1 {x}/{tomogram.shape[1] - 1}", flush=True)
    #print(y, end= ' ', flush=True)
    print('.', end='')
    slice_y = tomogram[:,y,:].astype(np.float64)
    slice_y1 = tomogram[:,y+1,:].astype(np.float64)
    flow = cv2.calcOpticalFlowFarneback(prev=slice_y1, next=slice_y, flow=None, pyr_scale=0.5, levels=l, winsize=w, iterations=3, poly_n=5, poly_sigma=1.2, flags=0)
    prediction_y1 = make_prediction(slice_y, flow)
    average_y1 = (prediction_y1 + slice_y1)/2
    y_y1_tomogram[:,y + 1,:] = average_y1[:,:]
  print()
  return y_y1_tomogram

def y1_y(tomogram, w=5, l=3):
  print("y+1 -> y", end=' ')
  y1_y_tomogram = np.zeros_like(tomogram).astype(np.float64)
  for y in range(tomogram.shape[1] - 1):
    #print(f"\ny1_y {x}/{tomogram.shape[1] - 1}", flush=True)
    #print(y, end= ' ', flush=True)
    print('.', end='')
    slice_y = tomogram[:,y,:].astype(np.float64)
    slice_y1 = tomogram[:,y+1,:].astype(np.float64)
    flow = cv2.calcOpticalFlowFarneback(prev=slice_y, next=slice_y1, flow=None, pyr_scale=0.5, levels=l, winsize=w, iterations=3, poly_n=5, poly_sigma=1.2, flags=0)
    prediction_y = make_prediction(slice_y1, flow)
    average_y = (prediction_y + slice_y)/2
    y1_y_tomogram[:,y,:] = average_y[:,:]
  print()
  return y1_y_tomogram

def z_z1(tomogram, w=5, l=3):
  print("z -> z+1", end=' ')
  z_z1_tomogram = np.zeros_like(tomogram).astype(np.float64)
  for z in range(tomogram.shape[0] - 1):
    #print(f"\rz_z1 {x}/{tomogram.shape[0] - 1}", flush=True)
    #print(z, end= ' ', flush=True)
    print('.', end='')
    slice_z = tomogram[z,:,:].astype(np.float64)
    slice_z1 = tomogram[z+1,:,:].astype(np.float64)
    flow = cv2.calcOpticalFlowFarneback(prev=slice_z1, next=slice_z, flow=None, pyr_scale=0.5, levels=l, winsize=w, iterations=3, poly_n=5, poly_sigma=1.2, flags=0)
    prediction_z1 = make_prediction(slice_z, flow)
    average_z1 = (prediction_z1 + slice_z1)/2
    z_z1_tomogram[z + 1,:,:] = average_z1[:,:]
  print()
  return z_z1_tomogram

def z1_z(tomogram, w=5, l=3):
  print("z+1 -> z", end=' ')
  z1_z_tomogram = np.zeros_like(tomogram.data).astype(np.float64)
  for z in range(tomogram.shape[0] - 1):
    #print(f"\rz1_z {x}/{tomogram.shape[0] - 1}", flush=True)
    #print(z, end= ' ', flush=True)
    print('.', end='')
    slice_z = tomogram[z,:,:].astype(np.float64)
    slice_z1 = tomogram[z+1,:,:].astype(np.float64)
    flow = cv2.calcOpticalFlowFarneback(prev=slice_z, next=slice_z1, flow=None, pyr_scale=0.5, levels=l, winsize=w, iterations=3, poly_n=5, poly_sigma=1.2, flags=0)
    prediction_z = make_prediction(slice_z1, flow)
    average_z = (prediction_z + slice_z)/2
    z1_z_tomogram[z:,:,:] = average_z[:,:]
  print()
  return z1_z_tomogram

def XYZ_iteration(tomogram, w=5, l=3):
  filtered_tomogram = (x_x1(tomogram, w, l).astype(np.float64) +
                       x1_x(tomogram, w, l) +
                       y_y1(tomogram, w, l) +
                       y1_y(tomogram, w, l) +
                       z_z1(tomogram, w, l) +
                       z1_z(tomogram, w, l))/6
  return filtered_tomogram

def average_energy(tomogram):
  #energy = np.sum(tomogram.astype(np.float64) * tomogram)
  energy = np.sum(np.abs(tomogram.astype(np.float64))).astype(np.float64)
  average_energy = energy / tomogram.size
  return average_energy

def R_ener(tomogram):
  return math.sqrt(energy(tomogram))

def write(image, fn):
  cv2.imwrite(fn, image)

# Esta versión nunca para (la energía no decrece asintóticamente hacia 0 ni a ningún otro valor)
def filter2(tomogram):
  z = 100
  normalized_image = cv2.normalize(tomogram[z, :, :], None, 0, 255, cv2.NORM_MINMAX)
  write(normalized_image, f"drive/Shareddrives/MissingWedge/ejemplo_{0:03d}.png")
  cv2_imshow(normalized_image)
  threshold = 10000
  RMSE_decrease = threshold + 1
  i = 1
  while(energy_decrease > threshold):
    filtered_tomogram = XYZ_iteration(tomogram)
    if __debug__:
      normalized_image = cv2.normalize(filtered_tomogram[z, :, :], None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
      cv2_imshow(normalized_image)
      write(normalized_image, f"drive/Shareddrives/MissingWedge/ejemplo_{i:03d}.png")
      avg_energy_decrease = average_energy(tomogram - filtered_tomogram)
    if __debug__:
      tomogram_avg_energy = average_energy(tomogram)
      filtered_tomogram_avg_energy = average_energy(filtered_tomogram)
      print("tomogram average energy =", tomogram_avg_energy)
      print("filtered tomogram energy =", filtered_tomogram_energy)
      print("energy decrease =", energy_decrease)
    tomogram = filtered_tomogram
    i = i + 1
  return filtered_tomogram

# Esta versión se detiene cuando se comienza a aumentar la energía de la diferencia entre el tomograma actual y su versión filtrada
def filter(tomogram, epsilon=0.1, max_iterations=20, w=5, l=3):
  if __debug__:
    print(f"epsilon={epsilon}, max_iterations={max_iterations}, w={w}, l={l}")
    z = 100
  if __debug__:
    img = tomogram[z, :, :]
    print(f"min={np.min(img)} max={np.max(img)}")
    normalized_img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX)
    #write(normalized_image, f"drive/Shareddrives/MissingWedge/ejemplo_{0:03d}.png")
    cv2_imshow(normalized_img)
  #energy_difference = energy(tomogram - filtered_tomogram)
  avg_energy_difference = 0
  min_avg_energy_difference = avg_energy_difference + 1
  #decreasing_energy = True
  i = 1
  prev_avg_energy_difference = 10E10
  #while(decreasing_energy):
  while True:
    if __debug__:
      print("iteration =", i)
    filtered_tomogram = XYZ_iteration(tomogram, w, l)
    if __debug__:
      img = filtered_tomogram[z, :, :]
      print(f"min={np.min(img)} max={np.max(img)}")
      normalized_img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
      cv2_imshow(normalized_img)
      #write(normalized_image, f"drive/Shareddrives/MissingWedge/ejemplo_{i:03d}.png")
    avg_energy_difference = average_energy(tomogram - filtered_tomogram)
    if __debug__:
      tomogram_avg_energy = average_energy(tomogram)
      print("tomogram average energy =", tomogram_avg_energy)
      filtered_tomogram_avg_energy = average_energy(filtered_tomogram)
      print("filtered_tomogram average energy =", filtered_tomogram_avg_energy)
      print("average energy (tomogram - filtered_tomogram) =", avg_energy_difference)
    if i > max_iterations:
      if __debug__:
        print("Maximum number of iterations reached")
      return filtered_tomogram
    if avg_energy_difference < epsilon:
      if __debug__:
        print("Minimum epsilon reached")
      return tomogram
    if avg_energy_difference >= prev_avg_energy_difference:
      if __debug__:
        #img = filtered_tomogram[z, :, :]
        #print(f"min={np.min(img)} max={np.max(img)}")
        #normalized_img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
        #cv2_imshow(normalized_img)
        print("Differences rising ... finishing")
      return tomogram
    #if min_iterations == 0:
    #  if avg_energy_difference >= prev_avg_energy_difference:
    #    #decreasing_energy = False
    #    return tomogram
    #else:
    #  if i >= min_iterations:
    #    #decreasing_energy = False
    #    return filtered_tomogram
    prev_avg_energy_difference = avg_energy_difference
    tomogram = filtered_tomogram
    i = i + 1
  #return tomogram


In [None]:
tomogram_MRC = mrcfile.open(f'{tomogram_name}.mrc')
print(tomogram_MRC.data.dtype, tomogram_MRC.data.shape)

In [None]:
tomogram = np.copy(tomogram_MRC.data.astype(np.uint8))
tomogram[100, 10:60, 10:60] = 255

In [None]:
filtered_tomogram = filter(tomogram.astype(np.uint8), max_iterations=max_iterations, w=w, l=l)

In [None]:
with mrcfile.new(f'drive/Shareddrives/MissingWedge/{tomogram_name}_filtered__max_iterations={max_iterations}__w={w}__l={l}.mrc', overwrite=True) as mrc:
  mrc.set_data(filtered_tomogram.astype(np.float32))
  mrc.data

In [None]:
filtered_tomogram_MRC = mrcfile.open(f"drive/Shareddrives/MissingWedge/{tomogram_name}_filtered__max_iterations={max_iterations}__w={w}__l={l}.mrc")
tomogram_MRC = mrcfile.open('epfl1_subset1.mrc')

In [None]:
def g(z=0):
  cv2_imshow(cv2.normalize(tomogram_MRC.data[z, :, :].astype(np.uint8), None, 0, 255, cv2.NORM_MINMAX))
  cv2_imshow(cv2.normalize(filtered_tomogram_MRC.data[z, :, :], None, 0, 255, cv2.NORM_MINMAX))

interactive_plot = interactive(g, z=100)
interactive_plot

In [None]:
while True:pass