# Study on (tailcut) cleaning optimisation

Notes:
- This benchmark might not be optimal
- DL1 file prepared with ctapipe-stage1

The idea here is to define a benchmark to optimise cleaning independently of any reconstruction that would come **after**.    
This to avoid optimising the cleaning as a function of the whole reconstruction as:   
- it can be tedious (you have to loop over the whole reconstruction)    
- optimising cleaning before optimising the later part of the reconstruction might end up in reaching a cleaning well adapted to the reconstruction method chosen a priori but not good in absolute. (then a different/better reconstruction might end-up showing worst results)


This benchmark uses the the ground thruth image in photo-electron from MC simulations by computing the distance between the cleaned image and the ground truth as a function of cleaning method/parameters and finding the minimum of this distance (average on many events).

This also allow to study the cleaning as a function of event info (such as energy, signal amplitude... )

Of course, this supposes that the calibration has been previously optimised.

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))


In [None]:
import ctapipe
print(ctapipe.__version__)

In [None]:
from ctapipe.io import EventSource
from ctapipe.utils import datasets
from ctapipe.calib import CameraCalibrator
from ctapipe.image import tailcuts_clean, dilate
from ctapipe.visualization import CameraDisplay
from ctapipe.instrument import CameraGeometry
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from scipy.stats import poisson
import os
from ctapipe.utils import get_dataset_path
from ctapipe.io import read_table
from ctapipe.instrument import SubarrayDescription  # for working with CTA instruments
from astropy.table import join

from ctapipe.utils.download import download_file_cached

import copy
import astropy.units as u
import tables
from astropy.table import Table, vstack
%matplotlib inline

In [None]:
ls ../../prepared_data/

In [None]:
remote_url = "http://cccta-dataserver.in2p3.fr/data/Prod5_Paranal_North_20deg_ctapipe_v0.10.5_DL1/"
filename = "gamma_20deg_0deg_run107___cta-prod5-paranal_desert-2147m-Paranal-dark_cone10_merged.DL1.h5"

In [None]:
filename = download_file_cached(filename, default_url=remote_url)

In [None]:
filename

In [None]:
subarray = SubarrayDescription.from_hdf(filename)
subarray.info()
subarray.peek()

In [None]:
subarray.tel_ids

In [None]:
telescope_types = subarray.telescope_types

In [None]:
def read_images_from_telescope_type(filename, telescope_type):
    images_tables = []
    for tel_id in subarray.get_tel_ids_for_type(telescope_type):
        images = read_table(filename, f"/dl1/event/telescope/images/tel_{tel_id:03d}")
        sim_images = read_table(filename, f"/simulation/event/telescope/images/tel_{tel_id:03d}")
        images_tables.append(join(images, sim_images, keys=['event_id', 'tel_id', 'obs_id']))
        
    return vstack(images_tables)

In [None]:
subarray.to_table()

In [None]:
telescope_types

In [None]:
tel_type = telescope_types[0]

geometry = subarray.tels[subarray.get_tel_ids_for_type(tel_type)[0]].camera.geometry
image_table = read_images_from_telescope_type(filename, tel_type)

In [None]:
def residuals_after_cleaning(cleaned_image, true_image):
    return (cleaned_image-true_image)

In [None]:
import copy

def add_residuals_to_table(image_table):
    cleaned_images = copy.deepcopy(image_table['image'])
    cleaned_images[~image_table['image_mask']]=0
    image_table['residuals'] = residuals_after_cleaning(cleaned_images, image_table['true_image'])
    image_table['accuracy'] = np.linalg.norm(image_table['residuals'], axis=1)

In [None]:
add_residuals_to_table(image_table)
image_table[:3]

In [None]:
plt.hist(image_table['residuals'].ravel(), log=True, bins=100, range=(-20, 20));
print("residuals mean: ", np.mean(np.abs(image_table['residuals'])))

In [None]:
plt.hist(image_table['accuracy'], bins=100, range=(0, 100));

In [None]:
from matplotlib.colors import Normalize

def display_row(geometry, image_table, row_index=0):
    fig, axes = plt.subplots(1, 3, figsize=(20,5))
    row = image_table[row_index]

    display = CameraDisplay(geometry, row['image'], ax=axes[0])
    display.add_colorbar()
    display.highlight_pixels(row['image_mask'], color='red', alpha=0.3)
    display.axes.set_title('image')
    
    display = CameraDisplay(geometry, row['true_image'], ax=axes[1])
    display.add_colorbar()
    display.axes.set_title('true_image')
    
    if 'residuals' in row.colnames:
        display = CameraDisplay(geometry, row['residuals'], ax=axes[2], cmap='RdBu')
        max_pe = np.max(np.abs(row['residuals']))
        display.add_colorbar()
        display.set_limits_minmax(-max_pe, max_pe)
        display.axes.set_title('residuals')
        
    return axes

In [None]:
image_table[4]

In [None]:
display_row(geometry, image_table, 4);

## Find tailcut parameters that minimise residuals

In [None]:
def thresholds_grid(image_table, pt_array=np.linspace(3, 12, 10)):
    acc = []
    picture_threshold = []
    boundary_threshold = []
    for pt in pt_array:
        for bt in np.linspace(0, pt, len(pt_array)):
            picture_threshold.append(pt)
            boundary_threshold.append(bt)
            tailcut_opt = dict(picture_thresh=pt, boundary_thresh=bt)
            image_mask = [tailcuts_clean(geometry, image, **tailcut_opt) for image in image_table['image']]
            image_table['image_mask'] = image_mask
            add_residuals_to_table(image_table)
#             acc.append(np.mean(image_table['accuracy']))
            acc.append((np.linalg.norm(image_table['residuals'].ravel(), ord=2))/image_table['residuals'].ravel().shape[0])
            
    return np.array(picture_threshold), np.array(boundary_threshold), np.array(acc)


def best_thresholds(picture_threshold, boundary_threshold, accuracy):
    """
    return picture_threshold, boundary_threshold
    """
    return picture_threshold[np.argmin(acc)], boundary_threshold[np.argmin(acc)]

In [None]:
def plot_threshold_heatmap(picture_threshold, boundary_threshold, accuracy):
    fig, ax = plt.subplots(figsize=(10, 6))
    im = ax.tricontourf(picture_threshold,boundary_threshold,acc)
    cbar = plt.colorbar(im)
    ax.set_xlabel('picture threshold')
    ax.set_ylabel('boundary threhsold')
    cbar.set_label('accuracy')
    ax.axis('equal')
    return ax

In [None]:
for tel_type in telescope_types:
    print(f"---- {tel_type} ----")
    geometry = subarray.tels[subarray.get_tel_ids_for_type(tel_type)[0]].camera.geometry
    image_table = read_images_from_telescope_type(filename, tel_type)[:1000]
    add_residuals_to_table(image_table)
    print("Example:")
    display_row(geometry, image_table, 0)
    plt.show()
    
    pt, bt, acc = thresholds_grid(image_table, np.linspace(4, 20, 10))
    print(f"best threshold for {tel_type}: {best_thresholds(pt, bt, acc)}")
    plot_threshold_heatmap(pt, bt, acc)
    plt.show()