In [None]:
%matplotlib inline

import matplotlib
matplotlib.rcParams['figure.figsize'] = (6, 6)

import math
import cmath          # math functions for complex numbers
import numpy as np
import matplotlib.pyplot as plt

from matplotlib.colors import LogNorm

import scipy
import scipy.stats
import pandas as pd

from astropy.io import fits
import os.path

import time

import ipywidgets
from ipywidgets import interact

import sys
sys.path.append("/Users/jdecock/git/pub/jdhp-sap/sap-cta-data-pipeline")
import datapipe
import datapipe.denoising.wavelets_mrfilter as mrfilter

sys.path.append("/Users/jdecock/git/pub/jdhp-sap/sap-cta-data-pipeline/utils")
import common_functions as common

from datapipe.denoising import wavelets_mrfilter as wavelets_mod
from datapipe.benchmark import assess as assess_mod

In [None]:
#dataset = "astri_gamma"
#dataset = "astri_gamma_cropped"
#dataset = "astri_proton"
#dataset = "astri_proton_cropped"
dataset = "gct"

In [None]:
ramdisk = True

if ramdisk:
    ROOT_DIR="/Volumes/ramdisk/data"
else:
    ROOT_DIR="/Users/jdecock/data"

if dataset == "astri_gamma":
    FITS_FILE_PATH = os.path.join(ROOT_DIR, "astri_mini_array/fits/gamma")
elif dataset == "astri_gamma_cropped":
    FITS_FILE_PATH = os.path.join(ROOT_DIR, "astri_mini_array/fits_cropped/gamma")
elif dataset == "astri_proton":
    FITS_FILE_PATH = os.path.join(ROOT_DIR, "astri_mini_array/fits/proton")
elif dataset == "astri_proton_cropped":
    FITS_FILE_PATH = os.path.join(ROOT_DIR, "astri_mini_array/fits_cropped/proton")
elif dataset == "gct":
    FITS_FILE_PATH = os.path.join(ROOT_DIR, "gct/fits/proton")

### Copy input files in ramdisk to speedup processes

In [None]:
#!mkdir -p /Volumes/ramdisk/data/gct/fits/proton

In [None]:
#!ls /Users/jdecock/data/gct/fits/proton/group1run1000.simtel.gz_TEL001_EV003*.fits

In [None]:
#!cp /Users/jdecock/data/gct/fits/proton/group1run1000.simtel.gz_TEL001_EV003*.fits /Volumes/ramdisk/data/gct/fits/proton/

In [None]:
!tree /Volumes/ramdisk/data/

### GUI

In [None]:
if dataset in ("astri_gamma", "astri_gamma_cropped"):
    file_path_list = [
        os.path.join(FITS_FILE_PATH, "run1001.simtel.gz_TEL001_EV00507.fits"),
        os.path.join(FITS_FILE_PATH, "run1001.simtel.gz_TEL001_EV01909.fits")
        ]
elif dataset in ("astri_proton", "astri_proton_cropped"):
    file_path_list = [
        os.path.join(FITS_FILE_PATH, "run10000.simtel.gz_TEL001_EV03118.fits"),
        os.path.join(FITS_FILE_PATH, "run10000.simtel.gz_TEL001_EV08001.fits")
        ]
elif dataset == "gct":
    file_path_list = [
        os.path.join(FITS_FILE_PATH, "group1run1000.simtel.gz_TEL001_EV00304.fits"),
        os.path.join(FITS_FILE_PATH, "group1run1000.simtel.gz_TEL001_EV00305.fits"),
        os.path.join(FITS_FILE_PATH, "group1run1000.simtel.gz_TEL001_EV00307.fits"),
        os.path.join(FITS_FILE_PATH, "group1run1000.simtel.gz_TEL001_EV00316.fits")
        ]

file_path_list

In [None]:
# mr_filter options

option_list = [
    "-K -k -C1 -m3 -n4 -s2,2.5,3,3",
    "-K -k -C1 -m3 -n4 -s2,2,3,3",
    "-K -k -C1 -m3 -n5 -s2,2,3,3,3",
    "-K -k -C1 -m3 -n4 -s2,2,3,3 -I../astri_pixels_mask.fits -v",
    "-K    -C1 -m3 -n4 -s2,2,3,3",
    "-K -k -C1 -m3 -n4 -s3",
    "-K    -C1 -m3 -n4 -s3",      # Hard K-Sigma Thresholding, Poisson + Gaussian
    #
    "-K -C1 -m3  -s4 -n4 -t24 -f3",           # Suggested by Jean-Luc (TODO: try to adapt -s)
#    "-K -C1 -m3  -s5 -n4 -t28 -f3 -i10 -e0",  # Suggested by Jean-Luc (TODO: try to adapt -s)
    "-K -C1 -m3  -s3 -n4 -f3",
#    "-K -C1 -m3  -s2,2,3,3 -n4 -f3",
#    "-K -C1 -m3  -s3 -n4 -f2",  # Soft K-Sigma Thresholding, Poisson + Gaussian (deconseille par JL)
#    "-K -C2 -m1      -n4 -f2",  # False Discovery Rate
#    "-K     -m1      -n4 -f6",  # Wiener Filtering (min MSE between random process and desired process)
#    "-K -C1 -m10 -s3 -n4",      # Poisson with few events
#    "-K -C1 -m2  -s3 -n4",      # Poisson
    ]

In [None]:
#!mr_filter -h

In [None]:
@interact(kill_pix=True,
          hist=False,
          log_scale=True,
          ellipses=True,
          lateral_hst=["None", "Wavelet", "Tailcut"],
          ref_angle=False,
          file_path=file_path_list,
          mrfilter=option_list)
def gui(kill_pix, hist, log_scale, ellipses, lateral_hst, ref_angle, file_path, mrfilter):
    fits_images_dict, fits_metadata_dict = datapipe.io.images.load_benchmark_images(file_path)
    input_img = fits_images_dict["input_image"]
    reference_img = fits_images_dict["reference_image"]
    pixels_position = fits_images_dict["pixels_position"]

    fig = plt.figure(figsize=(14, 14))

    common.plot_gui(fig,
                    input_img,
                    reference_img,
                    pixels_position,
                    fits_metadata_dict,
                    wavelets_cmd=mrfilter,
                    kill_isolated_pixels=kill_pix,
                    plot_histogram=hist,
                    plot_log_scale=log_scale,
                    plot_ellipse_shower=ellipses,
                    _plot_perpendicular_hit_distribution=lateral_hst,  # None, "Tailcut", "Wavelet"
                    use_ref_angle_for_perpendicular_hit_distribution=ref_angle,
                    notebook=True
                   )

## Tweak thresholds

In [None]:
file_path = "/Volumes/ramdisk/data/gct/fits/proton/group1run1000.simtel.gz_TEL001_EV00307.fits"
fits_images_dict, fits_metadata_dict = datapipe.io.images.load_benchmark_images(file_path)
input_img = fits_images_dict["input_image"]
reference_img = fits_images_dict["reference_image"]
pixels_position = fits_images_dict["pixels_position"]

wavelets = wavelets_mod.WaveletTransform()

#mrfilter = "-K -k -C1 -m3 -n5 -s{:0.2f},{:0.2f},{:0.2f},{:0.2f},{:0.2f}"
mrfilter_option_string = "-K -k -C1 -m3 -n4 -s{:0.2f},{:0.2f},{:0.2f},{:0.2f}"

In [None]:
#@interact(s1=(0., 8., 0.1), s2=(0., 8., 0.1), s3=(0., 8., 0.1), s4=(0., 8., 0.1), s5=(0., 8., 0.1))
#def gui(s1, s2, s3, s4, s5):

@interact(s1=(0., 6.1, 0.1), s2=(0., 6.1, 0.1), s3=(0., 6.1, 0.1), s4=(0., 6.1, 0.1))
def gui(s1, s2, s3, s4):
    fig, ax = plt.subplots(figsize=(8, 8))
    #raw_option_string = mrfilter.format(s1, s2, s3, s4, s5)
    raw_option_string = mrfilter_option_string.format(s1, s2, s3, s4)
    wavelets_cleaned_img = wavelets.clean_image(input_img,
                                                kill_isolated_pixels=False,
                                                verbose=True,
                                                raw_option_string=raw_option_string)

    psi_ref = datapipe.image.hillas_parameters.get_hillas_parameters(reference_img).psi
    psi_cln = datapipe.image.hillas_parameters.get_hillas_parameters(wavelets_cleaned_img).psi
    print("ref:", psi_ref)
    print("cln:", psi_cln)
    print("delta:", abs(psi_ref - psi_cln))

    #common.plot_image_meter(axis=ax, image_array=wavelets_cleaned_img, pixels_position=pixels_position, title="")   
    plt.imshow(wavelets_cleaned_img, interpolation='nearest', cmap='gnuplot2')