# Wavelet Transform interactive notebook

## TODO list

Here are the few things to improve in this notebook:

- [ ] **improve the colormap and/or use the radius to distinguish background to signal**
- [ ] add a colorbar
- [ ] add the possibility to get images id to plot from a (file automatically made from others notebooks)
- [ ] highlight pixels option

## Import required modules and set some variables

This notebook requires PyWI and PyWI-CTA for the I/O and for the signal processing. It also requires Bokeh to display images (as a much faster alternative to Matplotlib).

In [None]:
import numpy as np

#from bokeh.plotting import figure, output_notebook, show    # For fix images

import bokeh

from bokeh.io import push_notebook, output_notebook, show    # For animations
from bokeh.plotting import figure, ColumnDataSource          # For animations

from bokeh.models import LogColorMapper, LogTicker, ContinuousColorMapper, ContinuousTicker, ColorBar
from bokeh.models import HoverTool

from bokeh.models.annotations import Title

import matplotlib as mpl
import math

import os

import ipywidgets
from ipywidgets import interact

import astropy.units as u

import pywi
import pywicta
from pywicta.io import geometry_converter
from pywicta.io.images import image_generator
from pywicta.io.images import plot_ctapipe_image
from pywicta.io.images import plot_hillas_parameters_on_axes
from pywicta.io.images import get_mars_like_default_integrator_config
from pywicta.image.hillas_parameters import get_hillas_parameters
from pywicta.denoising.rejection_criteria import CTAMarsCriteria

from pywicta.denoising import wavelets_mrtransform
from pywicta.denoising.wavelets_mrtransform import WaveletTransform
from pywicta.denoising import inverse_transform_sampling
from pywicta.denoising.inverse_transform_sampling import EmpiricalDistribution

In [None]:
import warnings
warnings.filterwarnings('ignore')

### Check whether or not a ramdisk can be used to speedup wavelet transform processing.

A ramdisk is a file storage space located in RAM.
PyWI use intermediate files to make the wavelet transform, putting those files in a ramdisk makes the process much faster.
Use [the following script](https://github.com/jeremiedecock/pywi-cta/blob/master/utils/ramdisk_macosx.sh) to make a ramdisk (MacOSX only):

    ramdisk_macosx.sh create 32

This creates a ramdisk reachable from `/Volumes/ramdisk`.

On most Linux distributions, a ramdisk is already mounted by default in `/dev/shm` thus Linux users should directly set `RAMDISK_PATH = "/dev/shm"` in the following cell.

In [None]:
RAMDISK_PATH = "/Volumes/ramdisk"

if os.path.isdir(RAMDISK_PATH):
    print("Use ramdisk")
    TMP_DIR = RAMDISK_PATH
else:
    print(RAMDISK_PATH + " IS NOT MOUNTED; RAMDISK WON'T BE USED FOR TEMPORARY FILES.")
    TMP_DIR = "."

### Set the number of scales for the Wavelet transform algorithm

In [None]:
WT_NUM_SCALES = 3

## Make the dataset

The next cells define the list of images to use in this notebook.

Images can be fetched from Fits files or from Simtel files.
Fits files are much lighter and much faster to process than Simtel files but they are specific to PyWI-CTA and thus you first have to generate them from Simtel files using [the following script](). Also, contrary to Simtel files, a Fits files contains only one "image" (i.e. an unique event viewed from one unique telescope).

The others things to configure are:

1. `CAM_ID` to define the camera to use:  ASTRICam, CHEC, DigiCam, FlashCam, NectarCam or LSTCam.
2. `IMG_ID_LIST` to fetch specific images e.g. set `IMG_ID_LIST = [ "run104_2081200_1", "run105_1244901_4"]` if you only want the image of the event `2081200` from telescope `1` in run `104` and the image of the event `1244901` from telescope `4` in run `105`.

Alternatively, one can set `TEL_FILTER_LIST` and/or `EVENT_FILTER_LIST` to set the list of desired telescopes (e.g. `TEL_FILTER_LIST = [1, 3, 4]`) and the list of desired events (e.g. `EVENT_FILTER_LIST = [2081200, 1244901]`).

In [None]:
#CAM_ID = "ASTRICam"
#CAM_ID = "CHEC"
#CAM_ID = "DigiCam"
#CAM_ID = "FlashCam"
#CAM_ID = "NectarCam"
CAM_ID = "LSTCam"

IMG_ID_LIST = [
    "run104_2081200_1",
    "run105_1244901_4",
]
IMG_ID_LIST = []

TEL_FILTER_LIST = []
EVENT_FILTER_LIST = []

In [None]:
if len(IMG_ID_LIST) > 0:
    
    # Use only the images defined in IMG_ID_LIST
    # Take them from FITS_DIR if SIMTEL_DIR is True else take them from SIMTEL_DIR
    
    USE_FITS = True
    FITS_DIR = "dataset"
    SIMTEL_DIR = "dataset"
    #FITS_DIR = "~/data/grid_prod3b_north/fits/lst/gamma"
    #SIMTEL_DIR = "~/data/grid_prod3b_north/simtel/gamma"

    PATHS = []
    for img_id in IMG_ID_LIST:
        run_id, event_id, tel_id = img_id.split("_")
        
        run_id = int(run_id[3:])
        event_id = int(event_id)
        tel_id = int(tel_id)
        
        if USE_FITS:
            PATHS.append("{}/gamma_20deg_0deg_run{}___cta-prod3-lapalma3-2147m-LaPalma.simtel.gz_TEL{:03d}_EV{}.fits".format(FITS_DIR, run_id, tel_id, event_id))
        else:
            PATHS.append("{}/gamma_20deg_0deg_run{}___cta-prod3-lapalma3-2147m-LaPalma.simtel.gz".format(SIMTEL_DIR, run_id))
            TEL_FILTER_LIST.append(tel_id)
            EVENT_FILTER_LIST.append(event_id)
            
    NUM_IMAGES = None     # The maximum number of images to load
                
else:
    
    # Use the N first images in the following files
    
    #SIMTEL_FILE = "~/data/astri_mini_array_konrad/simtel/astri_v2/gamma/gamma_20deg_180deg_run2203___cta-prod3-sst-astri_desert-2150m-Paranal-sst-astri2.simtel.gz"
    #SIMTEL_FILE = "~/data/gct_mini_array_konrad/simtel/gct/gamma/gamma_20deg_180deg_run2203___cta-prod3-sst-gct_desert-2150m-Paranal-sst-gct.simtel.gz"
    #SIMTEL_FILE = "~/data/sst1m_mini_array_konrad/simtel/sst1m/gamma/gamma_20deg_180deg_run2203___cta-prod3-sst-dc_desert-2150m-Paranal-sst-dc.simtel.gz"
    SIMTEL_FILE = "~/data/grid_prod3b_north/simtel/gamma/gamma_20deg_0deg_run104___cta-prod3-lapalma3-2147m-LaPalma.simtel.gz"
    PATHS = [SIMTEL_FILE]
    
    NUM_IMAGES = 30     # The maximum number of images to load
    
print(PATHS)

## Define preselection cuts

`rejection_criteria` defines a function to apply a preselection cut.

In [None]:
#rejection_criteria = lambda ref_image: not 200 < np.nansum(ref_image) < 250
rejection_criteria = CTAMarsCriteria(cam_id=CAM_ID)
#rejection_criteria = None

## Get images

In [None]:
integrator_config_dict = get_mars_like_default_integrator_config(CAM_ID)

image_dict = {"run{}_{}_{}".format(int(image.meta['run_id']),
                                   int(image.meta['event_id']),
                                   int(image.meta['tel_id'])): image
              for image
              in image_generator(PATHS,
                                 max_num_images=NUM_IMAGES,
                                 cam_filter_list=[CAM_ID],
                                 tel_filter_list=TEL_FILTER_LIST,
                                 ev_filter_list=EVENT_FILTER_LIST,
                                 ctapipe_format=False,
                                 mc_rejection_criteria=rejection_criteria,
                                 **integrator_config_dict)}

## Make the image list for the widget

In [None]:
if len(IMG_ID_LIST) > 0:
    image_dict = {k: v for k, v in image_dict.items() if k in IMG_ID_LIST}

## Plot

In [None]:
palette_1 = bokeh.palettes.Viridis256
palette_2 = bokeh.palettes.Viridis256

palette_2[0] = "#eeeeee"

In [None]:
color_mapper = bokeh.models.mappers.LinearColorMapper(palette=palette_1,
                                                      low=0,
                                                      high=0)

In [None]:
def plot_hillas_parameters(geom,
                           image=None,
                           hillas_params=None,
                           hillas_implementation=2):
    """Plot the shower ellipse and direction on bokeh."""
    
    centroid = (0, 0)
    angle = 0.
    length = 0.
    width = 0.
    lines_xs = [[0, 0], [0, 0]]
    lines_ys = [[0, 0], [0, 0]]

    try:
        if hillas_params is None and image is not None:
            hillas_params = get_hillas_parameters(geom, image, implementation=hillas_implementation)

            centroid = (hillas_params.cen_x.value, hillas_params.cen_y.value)
            length = hillas_params.length.value
            width = hillas_params.width.value
            angle = hillas_params.psi.to(u.rad).value

            #print("centroid:", centroid)
            #print("length:",   length)
            #print("width:",    width)
            #print("angle:",    angle)

            p0_x = centroid[0]
            p0_y = centroid[1]

            p1_x = p0_x + math.cos(angle)
            p1_y = p0_y + math.sin(angle)

            p2_x = p0_x + math.cos(angle + math.pi)
            p2_y = p0_y + math.sin(angle + math.pi)

            lines_xs = [[p1_x, p2_x], [0, p0_x]]
            lines_ys = [[p1_y, p2_y], [0, p0_y]]

    except HillasParameterizationError as err:
        print(err)
    
    return centroid, angle, length, width, lines_xs, lines_ys

In [None]:
# Make the cleaning class
wavelet = WaveletTransform()

# Get empirical noise distribution
noise_cdf_file = inverse_transform_sampling.get_cdf_file_path(CAM_ID)  # pywicta.denoising.cdf.LSTCAM_CDF_FILE
print(noise_cdf_file)
empirical_noise_distribution = EmpiricalDistribution(noise_cdf_file)

# Prepare plots
geom1d = geometry_converter.get_geom1d(CAM_ID)

# Prepare Bokeh plot
if geom1d.pix_type == 'hexagonal':
    radius = math.sqrt(geom1d.pix_area.value[0]/(2. * math.sqrt(3.))) # assuming an hexagon (see:https://fr.wikipedia.org/wiki/Hexagone#Calcul_de_l'aire)
elif geom1d.pix_type == 'rectangular':
    radius = math.sqrt(geom1d.pix_area.value[0]) / 2.
else:
    raise NotImplementedError("Unknown camera type {}".format(geom1d.pix_type))
    

hover = HoverTool(
            tooltips=[
                ("PE", "@pe"),
            ]
        )

In [None]:
# Initialize Bokeh (the display library) #################

TOOLS = "crosshair,pan,wheel_zoom,box_zoom,reset,tap,previewsave,box_select,poly_select,lasso_select"

CAM_SIZE = max(geom1d.pix_x.value.max(), geom1d.pix_y.value.max()) * 1.05

FIGURE_SIZE = 600
#fig = figure(plot_width=FIGURE_SIZE, plot_height=FIGURE_SIZE, tools=TOOLS)
fig = figure(plot_width=FIGURE_SIZE,
             plot_height=FIGURE_SIZE,
             x_range=(-CAM_SIZE, CAM_SIZE),
             y_range=(-CAM_SIZE, CAM_SIZE))

output_notebook()

#colors = data_to_colors(img)

title = Title()
title.text = "-"
fig.title = title

## add a circle renderer with a size, color, and alpha
#circles = fig.circle(geom1d.pix_x.value,
#                     geom1d.pix_y.value,
#                     #size=5,              # The size (diameter) values for the markers **in screen space units** (i.e. aspect changes with figure size or zoom).
#                     radius=radius,        # The radius values for circle markers (**in "data space" units**, by default).
#                     fill_color=["#ffffff" for pix in geom1d.pix_x],
#                     line_color=["#000000" for pix in geom1d.pix_x],
#                     alpha=1.)

source = ColumnDataSource(
             data=dict(
                 x=geom1d.pix_x.value,
                 y=geom1d.pix_y.value,
                 fill_color=["#ffffff" for pix in geom1d.pix_x],
                 line_color=["#000000" for pix in geom1d.pix_x],
                 radius=[radius for pix in geom1d.pix_x],
                 pe=[0. for pix in geom1d.pix_x],
             )
         )

circles = fig.circle("x",
                     "y",
                     radius="radius",        # The radius values for circle markers (**in "data space" units**, by default).
                     fill_color=bokeh.transform.transform('pe', color_mapper), # "fill_color",
                     line_color=bokeh.transform.transform('pe', color_mapper), # "line_color",
                     source=source)

shower_lines = fig.multi_line(xs=[[0, 0],[0, 0]],
                              ys=[[0, 0],[0, 0]],
                              color=["red", "green"],
                              line_width=2,
                              alpha=0.75)

shower_ellipse = fig.ellipse(x=0.,
                             y=0.,
                             width=0.,
                             height=0.,
                             angle=0.,
                             color="red",
                             alpha=0.5)

fig.add_tools(hover)

# show the results
handle = show(fig, notebook_handle=True)

# Interactive widget ################

@interact(plot_type=["raw", "ref.", "cleaned", "diff", "1st plane", "2nd plane", "3rd plane"],
          image_key=image_dict.keys(),
          type_of_filtering = list(pywi.processing.filtering.hard_filter.AVAILABLE_TYPE_OF_FILTERING),
          wt_threshold_1=(0., 10.),
          wt_threshold_2=(0., 10.),
          clusters_threshold=0.,
          last_scale_treatment = list(pywi.processing.transform.mrtransform_wrapper.AVAILABLE_LAST_SCALE_OPTIONS),
          detect_only_positive_structures = False,
          use_noise_distribution=True,
          kill_isolated_pixels=True)
def compute_hillas_and_display(plot_type,
                               image_key,
                               type_of_filtering = 'hard_filtering',
                               wt_threshold_1=3.,
                               wt_threshold_2=0.2,
                               clusters_threshold=0.2,
                               last_scale_treatment = 'drop',
                               detect_only_positive_structures = False,
                               use_noise_distribution=True,
                               kill_isolated_pixels=False):
    
    # GET IMAGES ###########################
    
    image = image_dict[image_key]
    image.meta['npe'] = np.nansum(image.reference_image)
    cam_id = image.meta['cam_id']

    filter_thresholds = [wt_threshold_1, wt_threshold_2]
    number_of_scales = len(filter_thresholds) + 1
    
    noise_distribution = empirical_noise_distribution if use_noise_distribution else None
    
    plot_hillas_params = True
    show_background = True
    
    if plot_type == "raw":

        img_1d = geometry_converter.image_2d_to_1d(image.input_image, cam_id)
        highlight_pixel_mask_1d = np.ones_like(img_1d, dtype="bool")
        plot_hillas_params = False
        show_background = True

    elif plot_type == "ref.":

        img_1d = geometry_converter.image_2d_to_1d(image.reference_image, cam_id)
        highlight_pixel_mask_1d = np.ones_like(img_1d, dtype="bool")
        
        #reference_image_mask = np.full(reference_image_1d.shape, False)
        #reference_image_mask[reference_image_1d > 0] = True
        
        shower_centroid, shower_angle, shower_length, shower_width, shower_lines_xs, shower_lines_ys = plot_hillas_parameters(geom1d, image=img_1d)
        show_background = False

    elif plot_type == "cleaned":

        cleaned_img = wavelet.clean_image(image.input_image,
                                          type_of_filtering = type_of_filtering,
                                          filter_thresholds = filter_thresholds,
                                          clusters_threshold = clusters_threshold,
                                          last_scale_treatment = last_scale_treatment,
                                          detect_only_positive_structures = detect_only_positive_structures,
                                          kill_isolated_pixels = kill_isolated_pixels,
                                          noise_distribution = noise_distribution,
                                          tmp_files_directory = TMP_DIR)
    
        if np.nanmax(cleaned_img) == 0:

            img_1d = None
            plot_hillas_params = False
            
        else:
            
            img_1d = geometry_converter.image_2d_to_1d(cleaned_img, cam_id)
            highlight_pixel_mask_1d = np.ones_like(img_1d, dtype="bool")
            shower_centroid, shower_angle, shower_length, shower_width, shower_lines_xs, shower_lines_ys = plot_hillas_parameters(geom1d, image=img_1d)
        
        show_background = False

    elif plot_type == "diff":

        cleaned_img = wavelet.clean_image(image.input_image,
                                          type_of_filtering = type_of_filtering,
                                          filter_thresholds = filter_thresholds,
                                          clusters_threshold = clusters_threshold,
                                          last_scale_treatment = last_scale_treatment,
                                          detect_only_positive_structures = detect_only_positive_structures,
                                          kill_isolated_pixels = kill_isolated_pixels,
                                          noise_distribution = noise_distribution,
                                          tmp_files_directory = TMP_DIR)
        
        img_1d = geometry_converter.image_2d_to_1d(image.reference_image - cleaned_img, cam_id)
        img_1d = np.abs(img_1d)
        highlight_pixel_mask_1d = np.ones_like(img_1d, dtype="bool")
        plot_hillas_params = False
        show_background = False
        
    elif plot_type == "1st plane":

        in_planes = wavelets_mrtransform.wavelet_transform(image.input_image,
                                                           number_of_scales=number_of_scales,
                                                           tmp_files_directory=TMP_DIR,
                                                           noise_distribution=noise_distribution)

        highlight_pixel_mask = wavelets_mrtransform.filter_planes(in_planes,
                                                                  method=type_of_filtering,
                                                                  thresholds=filter_thresholds,
                                                                  detect_only_positive_structures=detect_only_positive_structures)
        highlight_pixel_mask_1d = geometry_converter.image_2d_to_1d(highlight_pixel_mask[0], cam_id)
        
        img_1d = geometry_converter.image_2d_to_1d(in_planes[0], cam_id)
        plot_hillas_params = False
        show_background = True
        
    elif plot_type == "2nd plane":

        in_planes = wavelets_mrtransform.wavelet_transform(image.input_image,
                                                           number_of_scales=number_of_scales,
                                                           tmp_files_directory=TMP_DIR,
                                                           noise_distribution=noise_distribution)

        highlight_pixel_mask = wavelets_mrtransform.filter_planes(in_planes,
                                                                  method=type_of_filtering,
                                                                  thresholds=filter_thresholds,
                                                                  detect_only_positive_structures=detect_only_positive_structures)
        highlight_pixel_mask_1d = geometry_converter.image_2d_to_1d(highlight_pixel_mask[1], cam_id)
        
        img_1d = geometry_converter.image_2d_to_1d(in_planes[1], cam_id)
        plot_hillas_params = False
        show_background = True
        
    elif plot_type == "3rd plane":

        in_planes = wavelets_mrtransform.wavelet_transform(image.input_image,
                                                           number_of_scales=number_of_scales,
                                                           tmp_files_directory=TMP_DIR,
                                                           noise_distribution=noise_distribution)

        highlight_pixel_mask = wavelets_mrtransform.filter_planes(in_planes,
                                                                  method=type_of_filtering,
                                                                  thresholds=filter_thresholds,
                                                                  detect_only_positive_structures=detect_only_positive_structures)
        highlight_pixel_mask_1d = geometry_converter.image_2d_to_1d(highlight_pixel_mask[2], cam_id)
        
        img_1d = geometry_converter.image_2d_to_1d(in_planes[2], cam_id)
        plot_hillas_params = False
        show_background = True
    
    # Update the plot
    if img_1d is None:
        circles.data_source.data['fill_color'] = ["#ffffff" for color in geom1d.pix_x]
        circles.data_source.data['line_color'] = ["#aaaaaa" for color in geom1d.pix_x]
        circles.data_source.data['radius'] = [radius for pixel in geom1d.pix_x]
    else:
        circles.data_source.data['pe'] = img_1d
        
        pix_radius = np.full_like(geom1d.pix_x.value, radius)
        pix_radius[highlight_pixel_mask_1d == 0] = radius/2.
        circles.data_source.data['radius'] = pix_radius
            
        color_mapper.low = img_1d.min()
        color_mapper.high = img_1d.max()
        
        if show_background:
            color_mapper.palette = palette_1
        else:
            color_mapper.palette = palette_2
        
    if not plot_hillas_params:
        shower_centroid = (0, 0)
        shower_angle = 0.
        shower_length = 0.
        shower_width = 0.
        shower_lines_xs = [[0, 0], [0, 0]]
        shower_lines_ys = [[0, 0], [0, 0]]
    
    shower_lines.data_source.data['xs'] = shower_lines_xs
    shower_lines.data_source.data['ys'] = shower_lines_ys

    shower_ellipse.glyph.x = shower_centroid[0]
    shower_ellipse.glyph.y = shower_centroid[1]
    shower_ellipse.glyph.angle = shower_angle
    shower_ellipse.glyph.width = shower_length
    shower_ellipse.glyph.height =  shower_width
        
    #print(image.meta)
    
    if plot_type == "raw":
        plot_desc = "calibrated image"
    elif plot_type == "ref.":
        plot_desc = "MC image"
    elif plot_type == "cleaned":
        plot_desc = "Wavelet clean"
    elif plot_type == "diff":
        plot_desc = "MC image - Wavelet clean"
    else:
        plot_desc = plot_type
        
    title.text = "Run{} Ev{} Tel{} {:0.3f}TeV {}NPE {}".format(image.meta['run_id'],
                                                               image.meta['event_id'],
                                                               image.meta['tel_id'],
                                                               image.meta['mc_energy'][0],
                                                               image.meta['npe'],
                                                               plot_desc)
    
    fig.title = title
    
    push_notebook(handle=handle)