# Example preprocessing of an IFS cube

> Author: *Valentin Christiaens*  
> Last update: *2022/04/04*

**Table of contents**

* [1. Loading and visualizing the data](#1.-Loading-and-visualizing-the-data)

* [2. Preprocessing](#2.-Preprocessing)
    - [2.1. Bad pixel correction](#2.1.-Bad-pixel-correction)
    - [2.2. Bad frame removal](#2.2.-Bad-frame-removal)

The purpose of this tutorial is to show how to further improve the quality of IFS data cubes provided in the phase 2 of the Exoplanet Data Challenge (as is from each instrument's official pipeline), before applying your favourite post-processing algorithms to detect the injected planets. 

Specifically, we show how to use the relevant functions of the `VIP` package in order to correct for remaining bad pixels and remove bad frames in the second SPHERE/IFS datacube used in the data challenge.

-----------

Let's first import the packages needed in this tutorial:

In [None]:
from hciplot import plot_frames, plot_cubes
from matplotlib.pyplot import *
from matplotlib import pyplot as plt
import numpy as np
from os.path import isfile
import pandas as pd
from vip_hci.fits import open_fits, write_fits

## 1. Loading and visualizing the data

In the 'dataset' folder of the `phase2` repository you can find a toy SPHERE/IFS coronagraphic cube acquired in pupil-stabilized mode on the source HIP39826 (a star with no reported directly imaged companion). The folder also contains the associated non-coronagraphic point spread function (PSF), wavelength vector and parallactic angles (the latter also including information on the airmass of the source).

Let's now load the data. Note that more info on opening and visualizing fits files with VIP in general is available [in the first VIP tutorial](https://vip.readthedocs.io/en/latest/tutorials/01_quickstart.html).

**Replace next box when zenodo link available**

In [None]:
path = '/Users/Valentin/Documents/Postdoc/EIDC/Cubes/data_test/'
datpath = path+'datasets/'

cubename = datpath+'image_cube_sphere2.fits'
angname = datpath+'parallactic_angles_sphere2.fits'

cube = open_fits(cubename)
derot_angles = open_fits(angname)[0]

nch, nz, ny, nx = cube.shape

Each IFS spectral cube consists of 39 monochromatic images spread in wavelengths between the Y and J bands ('YJ' mode) or Y and H bands ('YJH' mode). Here the IFS+ADI cube contains 65 such spectral cubes combined into a single master cube. The first column of the parallactic angle file actually contains the parallactic (derotation) angles, while the second columns contains the airmass at which each spectral cube was obtained.

The **master spectral cube** is already centered. Let's inspect the first and last wavelengths, using `hciplot.plot_cubes` (feel free to set the backend to 'bokeh' to read pixel values interactively):

In [None]:
plot_cubes(cube[0])#, backend='bokeh')

In [None]:
plot_cubes(cube[-5])

Inspection of these two cubes shows the presence of residual bad pixels (e.g. in frames 7 and 30 of cube 0), and the presence of frames very different from the others (e.g. frame 61 in the last cube).

In the next section, we will show how to better prepare the master cube by correcting bad pixels and trimming bad frames, in order to increase the performance of typical post-processing algorithms.

[Go to the top](#Table-of-contents)

## 2. Preprocessing

### 2.1. Bad pixel correction

In [None]:
from vip_hci.preproc import cube_fix_badpix_isolated, cube_fix_badpix_clump, cube_fix_badpix_with_kernel

Let's first identify static bad pixels - this is done with a sigma filtering algorithm applied to the mean frame (temporal dimension) at each wavelength:

In [None]:
cube_corr = cube.copy()
bpm_mask_static = np.zeros([nch, ny, nx])
for i in range(nch):
    cube_corr[i], bpm_mask_static[i] = cube_fix_badpix_isolated(cube[i], sigma_clip=4, num_neig=9, size=9, 
                                                                frame_by_frame=False, mad=False, 
                                                                verbose=True, full_output=True)

Let's now identify individual bad pixels in each frame (e.g. due to cosmic rays) with an iterative algorithm (suited to correct bad pixel clumps), and add them to the static bad pixel map (**warning: this may take some time depending on your machine** hence is by default deactivated - set `overwrite=True` to run the next cell anyway):

In [None]:
overwrite = False

if not isfile(datpath+"master_bad_pixel_map.fits") or overwrite:
    bpm_mask = np.zeros([nch, nz, ny, nx])
    for i in range(nch):
        cube_corr[i], bpm_mask_indiv = cube_fix_badpix_clump(cube_corr[i], fwhm=6, sig=4,
                                                             mad=True, min_thr=(-0.5,2),
                                                             verbose=True, full_output=True) 
        for z in range(nz):
            bpm_mask[i,z] = bpm_mask_indiv[z]+bpm_mask_static[i]

    bpm_mask[np.where(bpm_mask>1)]=1       
    write_fits(datpath+"master_bad_pixel_map.fits", bpm_mask)

Let's now correct all identified bad pixels with a Gaussian kernel:

In [None]:
bpm_mask = open_fits(datpath+"master_bad_pixel_map.fits")
for i in range(nch):
    cube_corr[i] = cube_fix_badpix_with_kernel(cube[i], bpm_mask=bpm_mask[i], mode='gauss', fwhm=2)

Let's compare 2 frames before and after bad pixel correction:

In [None]:
%matplotlib inline
idx = 7
plot_frames((cube[0,idx], cube_corr[0,idx], bpm_mask[0,idx]*cube[0,idx]), 
            vmin=0, vmax=float(np.amax(cube[0,idx])))

In [None]:
idx = -2
plot_frames((cube[-1,idx], cube_corr[-1,idx], bpm_mask[-1,idx]*cube[-1,idx]), 
            vmin=0, vmax=float(np.amax(cube[-1,idx])))

Note that it may be possible to obtain better bad pixel corrections by tweaking the parameters of each routine used to identify them (such as `sigma_clip`, `size`, `sig` or `min_thr`)  

[Go to the top](#Table-of-contents)

### 2.2. Bad frames trim

Let's now remove bad frames, based on the cross-correlation between each frame and the median of the ADI sequence (at each wavelength). We use the Structural Similarity (SSIM) index (REF), computed in an annular region beyond the coronagraphic mask. It is worth running a first pass of the algorithm using a percentile threshold, plot the values for each cube, and then  define an absolute threshold based on the mean SSIM values (over all wavelengths):

In [None]:
from vip_hci.preproc import cube_detect_badfr_correlation

In [None]:
cube_detect_badfr_correlation?

In [None]:
ssim = np.zeros([nch, nz])
for i in range(nch):
    good_idx, bad_idx, ssim[i] = cube_detect_badfr_correlation(cube_corr[i], 
                                                                frame_ref=np.median(cube_corr[i], axis=0), 
                                                                crop_size=61, dist='ssim', percentile=10, 
                                                                mode='annulus', inradius=8, width=18, 
                                                                plot=True, verbose=True, full_output=True)

Note that other distances can be used, such as the Pearson correlation coefficient. 

Although a plot was made for the ADI cube at each wavelength, let's visualize this better by plotting the measured SSIM values for a few different wavelengths, and the average over all channels:

In [None]:
plt.plot(range(1,nz+1), ssim[0], 'bo', label='ch. 0', alpha=0.7)
plt.plot(range(1,nz+1), ssim[20], 'yo', label='ch. 20', alpha=0.7)
plt.plot(range(1,nz+1), ssim[20], 'ro', label='ch. 39', alpha=0.7)
plt.plot(range(1,nz+1), np.mean(ssim,axis=0), 'k', label='mean over all ch.')
plt.plot(range(1,nz+1), [np.percentile(np.mean(ssim,axis=0),10)]*nz, 'k:', alpha=0.5, label='10% threshold')
plt.plot(range(1,nz+1), [np.percentile(np.mean(ssim,axis=0),15)]*nz, 'k--', alpha=0.6, label='15% threshold')
plt.plot(range(1,nz+1), [np.percentile(np.mean(ssim,axis=0),20)]*nz, 'k-.', alpha=0.7, label='20% threshold')
plt.xlabel(r"Cube index")
plt.ylabel(r"SSIM")
plt.legend()

Overall the trends are very similar at all wavelengths: the end of the sequence suffered from much worse conditions, as the dropping SSIM values testify. 

You can finally set a threshold in terms of percentile to remove bad frames depending on your post-processing algorithm (how sensitive it is to very different PSFs) and the regime in which the candidate you are looking after is located (speckle-dominated vs photon noise dominated).

Here, let's consider for example removing the 15% worst frames. THe mean appears to follow closely the last channel, so we'll just get the good indices from the last spectral channel:

In [None]:
perc_thr = 15
good_idx, bad_idx = cube_detect_badfr_correlation(cube_corr[-1], frame_ref=np.median(cube_corr[-1], axis=0), 
                                                  crop_size=61, dist='ssim', percentile=perc_thr, 
                                                  mode='annulus', inradius=8, width=18, 
                                                  plot=True, verbose=True, full_output=False)

Finally, let's save the master cube and associated parallactic angles after bad frame trim:

In [None]:
write_fits(datpath+'image_cube_sphere2_ready.fits', cube_corr[:,good_idx])
write_fits(datpath+'parallactic_angles_sphere2_ready.fits', pa[:,good_idx])

[Go to the top](#Table-of-contents)