# Context

Anatomical structures have an impact on the detected optical flow during slow wave events. One may try to remove this impact by removing the average activation for total activation value. This way one is left with what is different for the given slow wave event. 

To achieve this one requires the expected value (median) for all pixels given the mean of the frame:

$$\mathbf{E}(\omega_{ij}\mid \mathbf{E}(\omega)=k)$$


The median is more robust with respect to small peaks that affect the overall predicted image and potentially introduces artifacts.


# Imports

In [None]:
from skimage import io
import skimage
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter, uniform_filter
import pickle

In [None]:
import imageio
from pathlib import Path
from matplotlib.pyplot import show
from argparse import ArgumentParser

from pyoptflow import HornSchunck, getimgfiles
from pyoptflow.plots import compareGraphs

In [None]:
from PIL import Image
import os
from scipy.signal import argrelextrema
from skimage import exposure

In [None]:
import matplotlib
import matplotlib.animation
from IPython.display import HTML
matplotlib.rcParams['animation.embed_limit'] = 2**128

In [None]:
np.array(np.clip([300],0,255), dtype=np.uint8)

### Import our custom utility methods

In [None]:
import sys
%reload_ext autoreload
%autoreload 2
sys.path.append('..')

from utils.visualization_tools import *
import utils.visualization_tools
from utils.data_transformations import *
import utils.data_transformations
from utils.diverse import *
import utils.diverse

The following modules are available

In [None]:
print_module_methods(utils.diverse)

In [None]:
print_module_methods(utils.visualization_tools)

In [None]:
print_module_methods(utils.data_transformations)

# Methods for pipeline

In [None]:
import numpy as np
from scipy.interpolate import interp1d

def value_range_of_frame_means(filepaths, mean="average"):
    """ Retrieve value-range of several files. This method is slow but it opens one file at a time only such that the memory requirements are limited.
    Args:
        filepaths: List of filepaths to tif files
    Returns:
        min: Minimum value
        max: Maximal value
    """
    mean_function = np.mean
    if mean == "median":
        mean_function = np.median
        
    prelim_min = float("inf")
    prelim_max = -float("inf")
    for filepath in filepaths:
        print(".", end="")
        frames = np.array(skimage.io.imread(os.path.join(filepath)), dtype=np.double)
        print(".", end= "")
        mean = np.mean(frames,axis=0)#pixelwise mean
        print(".", end = "")
        frames = framewise_difference(frames, mean, bigdata=True)
        mean = None

        min_val = np.min(mean_function(frames,axis=(1,2)))#maximal value of framewise mean
        max_val = np.max(mean_function(frames,axis=(1,2)))

        if min_val < prelim_min:
            prelim_min = min_val
        if max_val > prelim_max:
            prelim_max = max_val     
    return prelim_min, prelim_max

def expected_images(filepaths, min_val, max_val,bins=100, mean="average"):
    """ Retrieve expected images for a given median brighness value.
    Args:
        filepaths: List of filepaths
        min_val: Minimum value of frame means
        max_val: Maximum value of frame means
        bins: Number of bins between min_val and max_val for which the expected image is calculated
        mean: Either average or median
    """
    mean_function = np.mean
    if mean == "median":
        mean_function = np.median
    n_per_bin = np.zeros(shape = [bins])
    bin_upper_boundaries = np.linspace(0, bins,bins+1)
    output_tensor = None
    for filepath in filepaths:
        print(".", end ="")
        frames = np.array(skimage.io.imread(os.path.join(filepath)), dtype= np.double)
        print(".", end="")
        mean = np.mean(frames,axis=0)#pixelwise mean
        print(".", end="")
        frames = framewise_difference(frames, mean, bigdata=True)
        mean = None
        
        if type(output_tensor) == type(None):
            output_tensor = np.zeros(shape = [bins,frames.shape[1],frames.shape[2]], dtype=np.double)
        for i, frame in enumerate(frames):
            if (i % 500) == 0:
                print("*",end="")
            frame_mean = mean_function(frame)
            try:
                assert frame_mean <= max_val
                assert frame_mean >= min_val
            except:
                print(frame_mean)

            frame_mean -= min_val
            frame_mean /= (max_val-min_val)
            frame_mean *= bins
            frame_mean = int(frame_mean)
            if frame_mean == bins:
                continue

            n_per_bin[frame_mean] += 1
            output_tensor[frame_mean] += frame
    
    output_tensor = output_tensor/n_per_bin[:, np.newaxis, np.newaxis] #TODO
            
    return output_tensor, bin_upper_boundaries, n_per_bin

def interpolate_tensor(tensor, size, axis=0, smoothing=None):
    """ Resizes and intepolates along axis
    Args:
        tensor: 3d tensor
        size: Desired output size along axis
        axis: Axis along which the tensor is resized
        smoothing: Sigma of the gaussian used for smoothing before resizing
    """
    if smoothing:
        if axis == 0:
            tensor = gaussian_filter(tensor, smoothing, [1,0,0], mode="mirror")
        elif axis == 1:
            tensor = gaussian_filter(tensor, smoothing, [0,1,0], mode="mirror")
        elif axis == 2:
            tensor = gaussian_filter(tensor, smoothing, [0,0,1], mode="mirror")        
    x = np.linspace(0, tensor.shape[axis], tensor.shape[axis])
    x_new = np.linspace(0, tensor.shape[axis], size)
    out = interp1d(x, tensor, axis=axis)(x_new)
    return out

# Load filepaths and mask

In [None]:
from pathlib import Path
source_folder = os.path.join(Path(os.getcwd()).parent, "datasets/source_data")

In [None]:
files = []
files.append(os.path.join(source_folder,"runstart16_X1.tif"))
files.append(os.path.join(source_folder,"runstart16_X2.tif"))

In [None]:
mask = None
try:
    mask = np.array(Image.open(os.path.join(source_folder,"mask_runstart16_X.png")))==0
    mask = mask [:,:,0]
    mask = ~mask
except:
    print("Mask not found")

In [None]:
plt.imshow(mask)

# Do the numerics

In [None]:
min_val, max_val = value_range_of_frame_means(files, mean="median")
output_tensor, upper_bin_boundaries, n_per_bin = expected_images(files, min_val,max_val,bins=25, mean = "median")
large = interpolate_tensor(output_tensor, 100)

In [None]:
fig, ax = plt.subplots(1,2, figsize=(20,4))
ax[0].bar(np.arange(len(n_per_bin)),n_per_bin, .5)

ax[1].set_xlabel("Index of expected image (bin)")
ax[1].plot(np.mean(large,axis=(1,2)))

ax[1].set_title("Mapping to indices")
ax[1].set_ylabel("Median of frame")
ax[0].set_xlabel("Index of expected image (bin)")
ax[0].set_ylabel("Frames per bin")
ax[0].set_title("Frames per bin")

In [None]:
fig, ax = plt.subplots(1,3, figsize=(20,4))
_ = ax[0].plot(output_tensor[:,50:60,120])#Left hemisphere frontal
ax[0].set_title("Frontal ROI in left hemisphere")
_ = ax[1].plot(output_tensor[:,100:150,100])#Left hemisphere center
ax[1].set_title("Frontoparietal ROI in left hemisphere")
_ = ax[2].plot(output_tensor[:,180:220,130])#Left hemisphere bottom right
ax[2].set_title("Occipital ROI in left hemisphere")

for a in ax:
    a.set_xlabel("Index of expected image (bin)")
    a.set_ylabel("Pixel value")

In [None]:
polynomial_fitted = pixelwise_polynomial(output_tensor, 2, 6)
polynomial_fitted = interpolate_tensor(polynomial_fitted, 100)
np.save("10_bins_expected_images_polynomial_median.npy",polynomial_fitted)

In [None]:
%%capture
ani = show_video(normalize(polynomial_fitted),n_frames=99)

In [None]:
HTML(ani)

In [None]:
np.save("25_bins_expected_images_polynomial_median.npy",polynomial_fitted)

# Variation for better optical flow

In [None]:
idx_nan = np.where(np.any(np.isnan(large),axis=(1,2)))
print(idx_nan)
large[idx_nan] = 1

In [None]:
masked = apply_mask(normalize(large.copy()), mask)

In [None]:
upp_dec = normalize(upper_decentile_pixels(masked, .9, .95))

In [None]:
%%capture
ani = show_video(upp_dec[250:], n_frames = 100)

## Upper decentile

# Adaptive histogram equalization and clipping

In [None]:
poi = upp_dec[250:350]
poi = normalize(gaussian_filter(poi, 2))

In [None]:
%%capture
ani = show_video(poi, n_frames = len(poi), vmin=.0, vmax=.99)

In [None]:
x_comp, y_comp = horn_schunck(poi,99)

In [None]:
adaptive = clipped_adaptive(upp_dec[250:])

In [None]:
%%capture
ani = show_video(adaptive, n_frames= 100)

In [None]:
%%capture
fig, ax = display_combined(x_comp[0],y_comp[0], large[1])
start = 10

def animate(i):
    i += start
    print(".", end ="")    
    display_combined(x_comp[i],y_comp[i], poi[i+1], fig=fig, ax=ax)

In [None]:
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=85).to_jshtml()

In [None]:
poi = normalize(large[150:250])
poi = gaussian_filter(poi, 2)

In [None]:
x_comp, y_comp = horn_schunck(poi,99)

In [None]:
%%capture
fig, ax = display_combined(x_comp[0],y_comp[0], large[1])
start = 10

def animate(i):
    i += start
    print(".", end ="")    
    display_combined(x_comp[i]*10,y_comp[i]*10, poi[i+1], fig=fig, ax=ax, head_width=1)

In [None]:
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=85).to_jshtml()

In [None]:
HTML(ani)

# Conclusion
