# Study on (tailcut) cleaning optimisation

Notes:
- This benchmark might not be optimal
- It could use the prepared file from `Preparation/integration_prep.ipynb`

The idea here is to define a benchmark to optimise cleaning independently of any reconstruction that would come **after**.    
This to avoid optimising the cleaning as a function of the whole reconstruction as:   
- it can be tedious (you have to loop over the whole reconstruction)    
- optimising cleaning before optimising the later part of the reconstruction might end up in reaching a cleaning well adapted to the reconstruction method chosen a priori but not good in absolute. (then a different/better reconstruction might end-up showing worst results)


This benchmark uses the the ground thruth image in photo-electron from MC simulations by computing the distance between the cleaned image and the ground truth as a function of cleaning method/parameters and finding the minimum of this distance (average on many events).

This also allow to study the cleaning as a function of event info (such as energy, signal amplitude... )

Of course, this supposes that the calibration has been previously optimised.

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))


In [None]:
from ctapipe.io import event_source
from ctapipe.utils import datasets
from ctapipe.calib import CameraCalibrator
from ctapipe.image import tailcuts_clean, dilate
from ctapipe.visualization import CameraDisplay
from ctapipe.instrument import CameraGeometry
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from scipy.stats import poisson
import os
from ctapipe.utils import get_dataset_path
import copy
import astropy.units as u
%matplotlib inline

In [None]:
def get_cleaned_image(camera_geometry, image, cleaning, **cleaning_params):
    """
    Apply a cleaning method to a calibrated event
    """

    # Apply image cleaning
    cleanmask = cleaning(camera_geometry, image,
                               **cleaning_params
                               )

    cleaned_image = copy.deepcopy(image)
    cleaned_image[~cleanmask] = 0

    return cleaned_image

In [None]:
def cleaning_accuracy(cleaned_image, true_image):
    """
    compute a single number giving the cleaning accuracy (lower is better)
    """
    
    return np.linalg.norm(cleaned_image - true_image)

In [None]:
def tailcut_cleaning_accuracy(camera_geometry, image, true_image, picture_threshold, boundary_threshold):
    """
    Apply tailcut cleaning and compute cleaning accuracy
    """
    
    tailcut_params = {'picture_thresh':picture_threshold,
                  'boundary_thresh':boundary_threshold,
                  'keep_isolated_pixels':False,
                  'min_number_picture_neighbors':1,
                 }

    cleaned_image = get_cleaned_image(camera_geometry, image, tailcuts_clean, **tailcut_params)

    diff = cleaning_accuracy(cleaned_image, true_image)
    
    return diff

In [None]:
camera_name = 'FlashCam'
tel_ids_with_camera_name = [tel_id for tel_id in source.subarray.tel.keys() if source.subarray.tel[tel_id].camera.camera_name == camera_name]
tel_ids_with_camera_name
event.r0.tels_with_data.intersection(tel_ids_with_camera_name)

In [None]:
def tailcut_cleaning_analyse(source, camera_name, picture_threshold=[2, 10, 5], min_boundary_threshold=0):
    """
    source: event_source object
    picture_threshold: list of [min, max, number of steps]
    """
    ptmin, ptmax, ptnumber = picture_threshold
    pt = np.linspace(ptmin, ptmax, ptnumber)
    
    picture_threshold = np.empty((ptnumber,ptnumber))
    boundary_threshold = np.empty((ptnumber,ptnumber))
    for i in range(ptnumber):
        picture_threshold[i] = pt
        boundary_threshold.T[i] = np.linspace(min_boundary_threshold, pt[i], ptnumber)
        
    all_diff = []
    event_energy = []
    event_dl1_amplitude = []
    event_multiplicity = []
    
    tel_ids_with_camera_name = [tel_id for tel_id in source.subarray.tel.keys() if source.subarray.tel[tel_id].camera.camera_name == camera_name]
    geom = source.subarray.tel[tel_ids_with_camera_name[0]].camera.geometry
    
    for event in source:
        cal(event)
        diff = np.zeros((ptnumber,ptnumber))
        event_energy.append(event.mc.energy.to(u.TeV).value)
        event_multiplicity.append(len(event.r0.tels_with_data))
        for i in range(ptnumber):
            for j in range(ptnumber):
                pt = picture_threshold[i, j]
                bt = boundary_threshold[i, j]
                for tel_id in event.r0.tels_with_data.intersection(tel_ids_with_camera_name):
                    image = event.dl1.tel[tel_id].image
                    true_image = event.mc.tel[tel_id].true_image
                    d = tailcut_cleaning_accuracy(geom, image, true_image, pt, bt)
                    diff[i,j]+=d
        all_diff.append(diff)
        
        amp = 0
        for tel_id in event.r0.tels_with_data:
            amp += event.dl1.tel[tel_id].image[0].sum()
        event_dl1_amplitude.append(amp)
        
    all_diff = np.array(all_diff)
    
    event_info = {'event_energy': np.array(event_energy) * u.TeV, 
                  'event_multiplicity': np.array(event_multiplicity),
                  'event_dl1_amplitude': np.array(event_dl1_amplitude)}
    
    return np.array([picture_threshold, boundary_threshold]), all_diff, event_info

In [None]:
def find_best_threshold(thresholds, all_diff):
    index = np.unravel_index(all_diff.sum(axis=0).argmin(), all_diff.sum(axis=0).shape)
    return thresholds[0][index], thresholds[1][index]

In [None]:
def plot_cleaning_analysis(thresholds, all_diff, ax=None, **kwargs):
    
    ax = plt.gca() if ax is None else ax
    
    x = thresholds[0].ravel()
    y = thresholds[1].ravel()
    z = all_diff.sum(axis=0).ravel()
    im = ax.tricontourf(x,y,z, 20)
    ax.set_xlabel('picture threshold')
    ax.set_ylabel('boundary threshold')
    plt.colorbar(im)
    # ax.axis('equal')
    print("Best thresholds = ", find_best_threshold(thresholds, all_diff))
    return ax

In [None]:
def analyse_all_cameras(filename, 
                        cal,
                       max_events = None,
                       **kwargs_tailcut_analysis):
    
    source = event_source(filename, back_seekable=True)
    
    source.max_events = max_events
    
    cam_dict = set([source.subarray.tel[tel_id].camera.camera_name for tel_id in source.subarray.tel.keys()])

    for cam_name in cam_dict:     
        thresholds, all_diff, event_info = tailcut_cleaning_analyse(source, 
                                                                    cam_name,
                                                                    picture_threshold=[2, 20, 10],
                                                                    min_boundary_threshold=-10 )
        cam_dict[cam_id] = find_best_threshold(thresholds, all_diff)
        print("Best thresholds for camera {0} are: {1}".format(cam_id, find_best_threshold(thresholds, all_diff)))
            
    return cam_dict

In [None]:
input_dir = '/Users/thomasvuillaume/Work/CTA/Data/DL0/Simtel/'
gamma_diffuse = 'gamma_20deg_0deg_run100___cta-prod3-lapalma3-2147m-LaPalma_cone10.simtel.gz'

In [None]:
filename = os.path.join(input_dir, gamma_diffuse)
source = event_source(input_url = filename, back_seekable=True)
source.max_events = 60

In [None]:
from ctapipe.calib.camera.calibrator import NeighborPeakWindowSum

In [None]:
cal = CameraCalibrator(subarray=source.subarray) #image_extractor = NeighborPeakWindowSum())

In [None]:
event=next(iter(source))

## Visualise simple cleaning example

In [None]:
event.r0.tels_with_data

In [None]:
cal(event)

In [None]:
tailcut_params = {'picture_thresh':4,
                  'boundary_thresh':1,
                  'keep_isolated_pixels':False,
                  'min_number_picture_neighbors':1,
                 }
cal(event)

for tel_id in event.r0.tels_with_data:
    image = event.dl1.tel[tel_id].image
    camera_geometry = source.subarray.tels[tel_id].camera.geometry
    cleaned_image = get_cleaned_image(camera_geometry, image, tailcuts_clean, **tailcut_params)
    fig, axes = plt.subplots(1, 3, figsize=(35,10))
    print(tel_id)
    CameraDisplay(source.subarray.tel[tel_id].camera.geometry, event.dl1.tel[tel_id].image, ax=axes[0])
    axes[0].set_title("Calibrated image")
    CameraDisplay(source.subarray.tel[tel_id].camera.geometry, cleaned_image, ax=axes[1])
    axes[1].set_title("Cleaned image")
    CameraDisplay(source.subarray.tel[tel_id].camera.geometry, event.mc.tel[tel_id].true_image, ax=axes[2])
    axes[2].set_title("True image")
    plt.show()

In [None]:
tailcut_params = {'picture_thresh':2,
                  'boundary_thresh':1,
                  'keep_isolated_pixels':False,
                  'min_number_picture_neighbors':1,
                 }
cal(event)

for tel_id in event.r0.tels_with_data:
    geom = source.subarray.tel[tel_id].camera.geometry
    image = event.dl1.tel[tel_id].image.copy()
    pulse_time = event.dl1.tel[tel_id].peak_time.copy()
    cleaned_image = get_cleaned_image(geom, image, tailcuts_clean, **tailcut_params)
#     pulse_time -= np.median(pulse_time)
#     print(pulse_time.mean())
#     image[pulse_time>20] = 0
    m = image.max()
    weighted_image = image/(1+pulse_time)
    weighted_image *= m/weighted_image.max()
    pixels = tailcuts_clean(geom, weighted_image, **tailcut_params)
    image[~pixels] = 0
    fig, axes = plt.subplots(1, 3, figsize=(35,10))
    print(tel_id)
    display = CameraDisplay(geom, image, ax=axes[0])
    axes[0].set_title("Calibrated image")
    display.add_colorbar(ax=axes[0])
    display = CameraDisplay(geom, cleaned_image, ax=axes[1])
    axes[1].set_title("Cleaned image")
    display.add_colorbar(ax=axes[1])
    display = CameraDisplay(geom, event.mc.tel[tel_id].true_image, ax=axes[2])
    display.add_colorbar()
    axes[2].set_title("True image")
    plt.show()

In [None]:
image = event.dl1.tel[tel_id].image
true_image = event.mc.tel[tel_id].true_image
tailcut_cleaning_accuracy(geom, image, true_image, 8, 3)

# Run the analysis

In [None]:
thresholds, all_diff, event_info = \
tailcut_cleaning_analyse(source,
                         'LSTCam',
                         picture_threshold=[2, 20, 10],
                         min_boundary_threshold=-10 )

In [None]:
print(thresholds.shape, all_diff.shape)
print(event_info.keys())

## Visualise the result and finding the best tailcut thresholds

In [None]:
fig, ax = plt.subplots(figsize=(18,10))
ax = plot_cleaning_analysis(thresholds, all_diff, ax=ax)

## One can also make analysis as a function of event info

### By total event amplitude (in p.e.)

In [None]:
plt.hist(event_info['event_dl1_amplitude'], log=True, bins=30);

In [None]:
mask = event_info['event_dl1_amplitude'] > 2000

In [None]:
ax = plot_cleaning_analysis(thresholds, all_diff[mask])

#### Or by event energy

In [None]:
mask = event_info['event_energy'].value < 1

In [None]:
ax = plot_cleaning_analysis(thresholds, all_diff[mask])

# Finally one can make an analysis to find the best threshold for each camera type

In [None]:
best_thresh = analyse_all_cameras(filename, 
                                  max_events=30, 
                                  **{'picture_threshold': [2, 14, 10], 
                                     'min_boundary_threshold':-4})