In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import warnings
warnings.filterwarnings('ignore')

In [None]:
plot_all(379464439,
#          large=True,
#          abs_xlim='3p',
#          abs_offset=9,
#          abs_ylim='3%',
#          abs_ylim=(0.98, 1.01),
        )

### Run at startup

In [None]:
from astropy.io import fits
from astronet.preprocess import generate_input_records
from astronet.preprocess import preprocess
from light_curve_util import keplersplinev2
from light_curve_util import tess_io
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd


def plot_all(tic,
             large=False,
             dpi=100,
             use_actual_preprocess=True,
             include_old_detrending=False,
             abs_xlim=None,
             abs_offset=0,
             abs_ylim=None):
    tess_data_dir = f'/mnt/tess/lc'
    reports_dir = f'/mnt/tess/rpt/png'
    tces_file = '/mnt/tess/astronet/tces-v3.csv'

    b = plt.get_cmap('tab20')(0)
    b2 = plt.get_cmap('tab20')(1)
    o = plt.get_cmap('tab20')(2)
    o2 = plt.get_cmap('tab20')(3)
    g = plt.get_cmap('tab20')(4)
    r = plt.get_cmap('tab20')(6)
    n = plt.get_cmap('tab20')(8)
    k = plt.get_cmap('tab20')(14)

    fsize = (16, 9)
    
    
    tce_table = pd.read_csv(tces_file, header=0)

    tce = tce_table[tce_table.tic_id == tic]
    period = tce.Period.values.item()
    epoc = tce.Epoc.values.item()
    duration = tce.Duration.values.item()

    print(f'Epoc: {epoc}\nPeriod: {period}\nDuration: {duration}')
    
    def config_abs_plot(title):
        plt.legend()
        plt.title(title)
        if abs_xlim:
            if abs_xlim == '3p':
                minx = min(td) + abs_offset * period
                maxx = minx + 3.5 * period
                plt.xlim(minx, maxx)
            else:
                plt.xlim(*abs_xlim)
        if abs_ylim:
            if abs_ylim == '3%':
                miny = np.percentile(fs[~np.isnan(fs)], 3)
                maxy = np.percentile(fs[~np.isnan(fs)], 97)
                plt.ylim(miny, maxy)
            else:
                plt.ylim(*abs_ylim)
    
    
    file_names = tess_io.tess_filenames(tic, tess_data_dir)
    f = fits.open(file_names)

    td = f[1].data["TIME"]
    fd = f[1].data["KSPSAP_FLUX"]
    fs = f[1].data["SAP_FLUX"]

    if large:
        plt.figure(figsize=fsize, dpi=dpi)
    else:
        plt.figure(figsize=fsize, dpi=dpi)
        plt.subplot(2, 3, 1)

    if include_old_detrending:
        plt.plot(td, fd, '-', c=b, label='KSPSAP')
    plt.plot(td, fs, '-', alpha=0.6, color=g, label='SAP')
    config_abs_plot('fits data')
    
    
    if large:
        fig = plt.figure(figsize=fsize, dpi=dpi)
    else:
        plt.subplot(2, 3, 2)

    if include_old_detrending:
        t, f = preprocess.read_and_process_light_curve(tic, tess_data_dir, 'KSPSAP_FLUX')
    ut, uf = preprocess.read_and_process_light_curve(tic, tess_data_dir, 'SAP_FLUX')

    input_mask = preprocess.get_spline_mask(ut, period, epoc, duration)
    sf, mdata = keplersplinev2.choosekeplersplinev2(ut, uf, input_mask=input_mask, verbose=True, return_metadata=True)

    if include_old_detrending:
        plt.plot(t, f, '-', alpha=0.6, color=b, label='KSPSAP')
    plt.plot(ut, uf, '-', alpha=0.6, color=g, label='SAP')
    plt.plot(ut[input_mask], sf[input_mask], 'x', markersize=3, color=k, label='spline (OOT)')
    plt.plot(ut[~input_mask], sf[~input_mask], 'o', markersize=3, color=o, label='spline')
    config_abs_plot(f'raw | bkspace: {mdata.bkspace}')
    
    
    if large:
        fig = plt.figure(figsize=fsize, dpi=dpi)
    else:
        plt.subplot(2, 3, 3)

    if include_old_detrending:
        ft, ff = preprocess.phase_fold_and_sort_light_curve(t, f, period, epoc)

    ut, nf, fm = preprocess.detrend_and_filter(tic, ut, uf, period, epoc, duration)
    sft, sff = preprocess.phase_fold_and_sort_light_curve(ut, nf, period, epoc)

    if include_old_detrending:
        plt.plot(ft, ff, '-', markersize=3, c=b, label='original')
    plt.plot(sft, sff, 'o', markersize=3, alpha=0.6, c=o, label='spline')
    sff_filtered = np.where((sff > 1.5) | (sff < -0.5), 0, sff)
    if len(sff_filtered):
        fmt, fm = preprocess.phase_fold_and_sort_light_curve(ut, fm, period, epoc)
        mask = np.where(fm, 1, min(sff_filtered))
        plt.plot(fmt, mask, '-', markersize=1, alpha=0.6, c=r, label='OOT')
        title = 'phase folded / filtered'
    else:
        title = 'phase folded | WARNING: filtering removed all data'
    plt.legend()
    plt.title(title)


    if not use_actual_preprocess:

        if large:
            fig = plt.figure(figsize=fsize, dpi=dpi)
        else:
            plt.subplot(2, 3, 4)

        if include_old_detrending:
            gv = preprocess.global_view(tic, ft, ff, period)
        sgv = preprocess.global_view(tic, sft, sff, period)

        if include_old_detrending:
            plt.plot(gv, '-', markersize=3, color=b, label='original')
        plt.plot(sgv, 'o-', markersize=3, color=o, label='spline')
        plt.legend()
        plt.title('global view')


        if large:
            fig = plt.figure(figsize=fsize, dpi=dpi)
        else:
            plt.subplot(2, 3, 5)

        if include_old_detrending:
            lv = preprocess.local_view(tic, ft, ff, period, duration, new_binning=False)
            plt.plot(lv, '-', markersize=3, c=b2, label='original median')
            lvs = preprocess.local_view(tic, sft, sff, period, duration, new_binning=False)
            plt.plot(lvs, 'o-', markersize=3, c=o2, label='spline median')

        if include_old_detrending:
            lv = preprocess.local_view(tic, ft, ff, period, duration, new_binning=True)
            plt.plot(lv, '-', markersize=3, c=b, label='original robust')
        lvs = preprocess.local_view(tic, sft, sff, period, duration, new_binning=True)
        plt.plot(lvs, 'o-', markersize=3, c=o, label='spline robust')

        plt.legend()
        plt.title('local view')


        if large:
            fig = plt.figure(figsize=fsize, dpi=dpi)
        else:
            plt.subplot(2, 3, 6)

        if include_old_detrending:
            lv = preprocess.secondary_view(tic, ft, ff, period, duration)
        lvs = preprocess.secondary_view(tic, sft, sff, period, duration)

        if include_old_detrending:
            plt.plot(lv, '-', markersize=3, color=b, label='original robust')
        plt.plot(lvs, 'o-', markersize=3, color=o, label='spline robust')
        plt.legend()
        _ = plt.title('secondary view')
        
        
    else:
    
        generate_input_records.FLAGS = generate_input_records.parser.parse_args([
          '--tess_data_dir', '/mnt/tess/lc',
          '--output_dir', '/dev/null',
          '--input_tce_csv_file', tces_file,
        ])


        row = list(tce.iterrows())[0][1]

        try:
            ex = generate_input_records._process_tce(row)
            title = 'preprocess'
        except:
            title = 'old detrending'
            ex = generate_input_records._process_tce(row, True)


        if large:
            plt.figure(figsize=fsize, dpi=dpi)
        else:
            plt.subplot(2, 3, 4)
        plt.plot(ex.features.feature['global_view'].float_list.value, color=n)
        plt.title(f'{title} global')

        if large:
            plt.figure(figsize=fsize, dpi=dpi)
        else:
            plt.subplot(2, 3, 5)
        plt.plot(ex.features.feature['local_view'].float_list.value, color=n)
        plt.title(f'{title} local')

        if large:
            plt.figure(figsize=fsize, dpi=dpi)
        else:
            plt.subplot(2, 3, 6)
        plt.plot(ex.features.feature['secondary_view'].float_list.value, color=n)
        plt.title(f'{title} secondary')
    
    
    plt.figure(figsize=(12, 8), dpi=dpi)
    plt.axis('off')
    im = plt.imread(f'{reports_dir}/TIC{tic}.png')
    _ = plt.imshow(im)