In [None]:
%matplotlib inline

import numpy as np

import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['figure.figsize'] = (6, 6)

import os

from IPython.display import display

import ipywidgets
from ipywidgets import interact

from ctapipe.image.hillas import HillasParameterizationError

import pywi
import pywicta
from pywicta.io import geometry_converter
from pywicta.io.images import image_generator
from pywicta.io.images import plot_ctapipe_image
from pywicta.io.images import plot_hillas_parameters_on_axes
from pywicta.io.images import print_hillas_parameters
from pywicta.io.images import hillas_parameters_to_df
from pywicta.io.images import get_mars_like_default_integrator_config
from pywicta.image.hillas_parameters import get_hillas_parameters
from pywicta.denoising import wavelets_mrtransform
from pywicta.denoising.wavelets_mrtransform import WaveletTransform
from pywicta.denoising import inverse_transform_sampling
from pywicta.denoising.inverse_transform_sampling import EmpiricalDistribution
from pywicta.denoising.rejection_criteria import CTAMarsCriteria

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
%%javascript
IPython.OutputArea.auto_scroll_threshold = 9999;
// see https://github.com/ipython/ipython/issues/2172

In [None]:
# Use ramdisk to store temporary files (MacOSX only)

RAMDISK_PATH = "/Volumes/ramdisk"

if os.path.isdir(RAMDISK_PATH):
    print("Use ramdisk")
    TMP_DIR = RAMDISK_PATH
else:
    print(RAMDISK_PATH + " IS NOT MOUNTED; RAMDISK WON'T BE USED FOR TEMPORARY FILES.")
    TMP_DIR = "."

## Make the dataset

In [None]:
#CAM_ID = "ASTRICam"
#CAM_ID = "CHEC"
#CAM_ID = "DigiCam"
#CAM_ID = "FlashCam"
#CAM_ID = "NectarCam"
CAM_ID = "LSTCam"

IMG_ID_LIST = [
    "run104_2081200_1",
    "run105_1244901_4",
]
#IMG_ID_LIST = []

TEL_FILTER_LIST = []
EVENT_FILTER_LIST = []

In [None]:
if len(IMG_ID_LIST) > 0:
    
    # Use only the images defined in IMG_ID_LIST
    
    USE_FITS = False
    FITS_DIR = "~/data/grid_prod3b_north/fits/lst/gamma"
    SIMTEL_DIR = "~/data/grid_prod3b_north/simtel/gamma"

    PATHS = []
    for img_id in IMG_ID_LIST:
        run_id, event_id, tel_id = img_id.split("_")
        
        run_id = int(run_id[3:])
        event_id = int(event_id)
        tel_id = int(tel_id)
        
        if USE_FITS:
            PATHS.append("{}/gamma_20deg_0deg_run{}___cta-prod3-lapalma3-2147m-LaPalma.simtel.gz_TEL{:03d}_EV{}.fits".format(FITS_DIR, run_id, tel_id, event_id))
        else:
            PATHS.append("{}/gamma_20deg_0deg_run{}___cta-prod3-lapalma3-2147m-LaPalma.simtel.gz".format(SIMTEL_DIR, run_id))
            TEL_FILTER_LIST.append(tel_id)
            EVENT_FILTER_LIST.append(event_id)
            
    NUM_IMAGES = None
                
else:
    
    # Use the N first images in the following files
    
    #SIMTEL_FILE = "~/data/astri_mini_array_konrad/simtel/astri_v2/gamma/gamma_20deg_180deg_run2203___cta-prod3-sst-astri_desert-2150m-Paranal-sst-astri2.simtel.gz"
    #SIMTEL_FILE = "~/data/gct_mini_array_konrad/simtel/gct/gamma/gamma_20deg_180deg_run2203___cta-prod3-sst-gct_desert-2150m-Paranal-sst-gct.simtel.gz"
    #SIMTEL_FILE = "~/data/sst1m_mini_array_konrad/simtel/sst1m/gamma/gamma_20deg_180deg_run2203___cta-prod3-sst-dc_desert-2150m-Paranal-sst-dc.simtel.gz"
    SIMTEL_FILE = "~/data/grid_prod3b_north/simtel/gamma/gamma_20deg_0deg_run104___cta-prod3-lapalma3-2147m-LaPalma.simtel.gz"
    PATHS = [SIMTEL_FILE]
    
    NUM_IMAGES = 30
    
print(PATHS)

## Define preselection cuts

In [None]:
#rejection_criteria = lambda ref_image: not 200 < np.nansum(ref_image) < 250
#rejection_criteria = CTAMarsCriteria(cam_id=CAM_ID)
rejection_criteria = None

## Get images

In [None]:
integrator_config_dict = get_mars_like_default_integrator_config(CAM_ID)

image_dict = {"run{}_{}_{}".format(int(image.meta['run_id']),
                                   int(image.meta['event_id']),
                                   int(image.meta['tel_id'])): image
              for image
              in image_generator(PATHS,
                                 max_num_images=NUM_IMAGES,
                                 cam_filter_list=[CAM_ID],
                                 tel_filter_list=TEL_FILTER_LIST,
                                 ev_filter_list=EVENT_FILTER_LIST,
                                 ctapipe_format=False,
                                 mc_rejection_criteria=rejection_criteria,
                                 **integrator_config_dict)}

In [None]:
if len(IMG_ID_LIST) > 0:
    image_dict = {k: v for k, v in image_dict.items() if k in IMG_ID_LIST}

In [None]:
@interact(image_key=image_dict.keys(),
          type_of_filtering = list(pywi.processing.filtering.hard_filter.AVAILABLE_TYPE_OF_FILTERING),
          filter_thresholds="3.,0.2",
          clusters_threshold=0.,
          last_scale_treatment = list(pywi.processing.transform.mrtransform_wrapper.AVAILABLE_LAST_SCALE_OPTIONS),
          detect_only_positive_structures = False,
          use_noise_distribution=True,
          kill_isolated_pixels=True)
def compute_hillas_and_display(image_key,
                               type_of_filtering = 'hard_filtering',
                               filter_thresholds="7.,4.",
                               clusters_threshold=0.2,
                               last_scale_treatment = 'mask',
                               detect_only_positive_structures = False,
                               use_noise_distribution=True,
                               kill_isolated_pixels=False):
    
    # GET IMAGES ###########################
    
    image = image_dict[image_key]
    calibrated_image = image.input_image
    reference_image = image.reference_image
    cam_id = image.meta['cam_id']
    
    # GET PARAMS ###########################
    
    if use_noise_distribution:
        noise_cdf_file = inverse_transform_sampling.get_cdf_file_path(cam_id)  # pywicta.denoising.cdf.LSTCAM_CDF_FILE
        print(noise_cdf_file)
        noise_distribution = EmpiricalDistribution(noise_cdf_file)
    else:
        noise_distribution = None
    
    # CLEAN THE IMAGE ######################
    
    filter_thresholds = [float(threshold_str) for threshold_str in filter_thresholds.split(",")]
    
    wavelet = WaveletTransform()
    cleaned_image = wavelet.clean_image(calibrated_image,
                                        type_of_filtering = type_of_filtering,
                                        filter_thresholds = filter_thresholds,
                                        clusters_threshold = clusters_threshold,
                                        last_scale_treatment = last_scale_treatment,
                                        detect_only_positive_structures = detect_only_positive_structures,
                                        kill_isolated_pixels = kill_isolated_pixels,
                                        noise_distribution = noise_distribution,
                                        tmp_files_directory = TMP_DIR)
    
    # PLOT REF., INPUT AND CLEANED IMAGES ##

    image_list = [geometry_converter.image_2d_to_1d(calibrated_image, cam_id),
                  geometry_converter.image_2d_to_1d(reference_image, cam_id),
                  geometry_converter.image_2d_to_1d(cleaned_image, cam_id)]
    title_list = ["Input image", "Reference image", "Cleaned image"]
    
    geom1d = geometry_converter.get_geom1d(cam_id)
    geom_list = [geom1d, geom1d, geom1d]
    
    hillas_list = [False, True, True]

    pywicta.io.images.plot_list(image_list,
                                geom_list=geom_list,
                                title_list=title_list,
                                hillas_list=hillas_list,
                                metadata_dict=image.meta)
    
    # PLOT THE CLEANED IMAGE ###############
    
    disp = plot_ctapipe_image(image_list[-1],
                              geom1d,
                              title='Wavelet cleaned ({} PE)'.format(np.sum(image_list[-1])),
                              norm='lin',
                              plot_axis=False)
    plot_hillas_parameters_on_axes(disp.axes,
                                   image_list[-1],
                                   geom1d,
                                   hillas_implementation=2)
    
    disp = plot_ctapipe_image(image_list[0],
                              geom1d,
                              title='Signal + NSB',
                              norm='lin',
                              plot_axis=False)
    
    disp = plot_ctapipe_image(image_list[1],
                              geom1d,
                              title='Signal only ({} PE)'.format(np.sum(image_list[1])),
                              norm='log',
                              plot_axis=False)
    
    disp = plot_ctapipe_image(image_list[-1] - image_list[1],
                              geom1d,
                              title='Wavelet cleaned - Signal only ({} PE)'.format(np.sum(image_list[-1] - image_list[1])),
                              norm='lin',
                              plot_axis=False)
    
    # PLOT PLANES ##########################
    
    number_of_scales = len(filter_thresholds) + 1
    
    in_planes = wavelets_mrtransform.wavelet_transform(calibrated_image,
                                                       number_of_scales=number_of_scales,
                                                       tmp_files_directory=TMP_DIR,
                                                       noise_distribution=noise_distribution)
    
    filtered_in_planes = wavelets_mrtransform.filter_planes(in_planes,
                                                            method=type_of_filtering,
                                                            thresholds=filter_thresholds,
                                                            detect_only_positive_structures=detect_only_positive_structures)
    
    #ref_planes = wavelet_transform(reference_image)
    #noise_planes = wavelet_transform(noise_img)
    
    for plane_index, (plane, filtered_plane) in enumerate(zip(in_planes, filtered_in_planes)):
        if plane_index < len(in_planes) - 1:
            significant_pixels_mask = (geometry_converter.image_2d_to_1d(filtered_plane, cam_id) != 0)
        else:
            significant_pixels_mask = None
            
        disp = plot_ctapipe_image(geometry_converter.image_2d_to_1d(plane, cam_id),
                                  geom1d,
                                  title='Plane {}'.format(plane_index),
                                  norm='lin',
                                  highlight_mask=significant_pixels_mask,
                                  plot_axis=False)
        
        #disp = plot_ctapipe_image(significant_pixels_mask,
        #                          geom1d,
        #                          title='Filtered plane {} (mask)'.format(plane_index),
        #                          plot_axis=False)

        #disp = plot_ctapipe_image(geometry_converter.image_2d_to_1d(filtered_plane, cam_id),
        #                          geom1d,
        #                          title='Filtered plane {}'.format(plane_index),
        #                          norm='lin',
        #                          plot_axis=False)
    
    # PRINT THE CLEANED IMAGE ##############
    
    for image, title, print_hillas in zip(image_list, title_list, hillas_list):
        if print_hillas:
            print("\n{}:".format(title))
            #print_hillas_parameters(image,
            #                        cam_id=cam_id,
            #                        implementation=2)
            
            df = hillas_parameters_to_df(image,
                                         cam_id=cam_id,
                                         implementation=2)
            display(df)