In [None]:
plot_all(121150643,
#          rpt_only=True,
#          large=True,
#          bkspace=5.0,
#          abs_xlim='3p',
#          abs_offset=6,
#          abs_xlim=(1300, 2500),
#          abs_ylim='3%',
#          abs_ylim=(0.975, 1.01),
#          item_no=5
        )

### Run at startup

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import warnings
warnings.filterwarnings('ignore')

In [None]:
from astropy.io import fits
from astronet.preprocess import generate_input_records
from astronet.preprocess import preprocess
from light_curve_util import keplersplinev2
from light_curve_util import tess_io
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd


tce_tables = {}

tce_files = [
    '../mnt/tess/astronet/tces-vetting-all.csv',
    '../mnt/tess/astronet/tces-vetting-v5-toi-train.csv',
    '../mnt/tess/astronet/tces-vetting-v5-toi-val.csv',
    '../mnt/tess/astronet/tces-vetting-v5-toi-test.csv',
]


def get_tce(tic, item_no):
    for f in tce_files:
        if f not in tce_tables:
            tce_tables[f] = pd.read_csv(f, header=0)
        tce_table = tce_tables[f]
        tce = tce_table[tce_table.tic_id == tic]
        if not len(tce):
            continue
        if len(tce) > 1:
            tce = tce[tce.index == tce.index.values[item_no]]
            if 'Source' in tce:
                print('Source:', 'BLS' if tce.Source.values.item() == 2 else 'TEV')
        return tce, f
    raise ValueError(f'no TCE data for {tic}')


def plot_all(tic,
             large=False,
             dpi=100,
             bkspace=None,
             abs_xlim=None,
             abs_offset=0,
             abs_ylim=None,
             item_no=0):
    tess_data_dir = '../mnt/tess/lc-v'
    reports_dir = f'../mnt/tess/rpt/png-v'

    tce, tces_file = get_tce(tic, item_no)
    generate_input_records.FLAGS = generate_input_records.parser.parse_args([
      '--tess_data_dir', tess_data_dir,
      '--output_dir', '/dev/null',
      '--input_tce_csv_file', tces_file,
      '--vetting_features', 'y',
    ])


    b = plt.get_cmap('tab20')(0)
    b2 = plt.get_cmap('tab20')(1)
    o = plt.get_cmap('tab20')(2)
    o2 = plt.get_cmap('tab20')(3)
    g = plt.get_cmap('tab20')(4)
    r = plt.get_cmap('tab20')(6)
    n = plt.get_cmap('tab20')(8)
    k = plt.get_cmap('tab20')(14)

    plotrows = 7
    plotcols = 3
    
    if large:
        fsize = (16, 9)
    else:
        fsize = (16, 4 * (plotrows + 1))
        plt.figure(figsize=fsize, dpi=dpi)
    
    period = tce.Period.values.item()
    epoc = tce.Epoc.values.item()
    duration = tce.Duration.values.item()

    print(f'Epoc: {epoc}\nPeriod: {period}\nDuration: {duration}')
    
    def config_abs_plot(title):
        plt.legend()
        plt.title(title)
        if abs_xlim:
            if abs_xlim == '3p':
                minx = min(td) + abs_offset * period
                maxx = minx + 3.5 * period
                plt.xlim(minx, maxx)
            else:
                plt.xlim(*abs_xlim)
        if abs_ylim:
            if abs_ylim == '3%':
                miny = np.percentile(fs[~np.isnan(fs)], 3)
                maxy = np.percentile(fs[~np.isnan(fs)], 97)
                plt.ylim(miny, maxy)
            else:
                plt.ylim(*abs_ylim)
                
    nplotted = 0
    def splt(c=1):
        nonlocal nplotted
        if large:
            plt.figure(figsize=fsize, dpi=dpi)
        else:
            plt.subplot(plotrows, plotcols // c, (nplotted // c) + 1)
        nplotted += c
    
    
    file_names = tess_io.tess_filenames(tic, tess_data_dir)
    f = fits.open(file_names)
    
    td = f[1].data["TIME"]
    if "KSPSAP_FLUX" in f[1].data:
      fd = f[1].data["KSPSAP_FLUX"]
    else:
      fd = []
    fs = f[1].data["SAP_FLUX"]
    
    splt()
    plt.plot(td, fs, '-', alpha=0.6, color=g, label='SAP')
    config_abs_plot('fits data')
    
    
    splt()
    ut, uf, ap = preprocess.read_and_process_light_curve(
        tic, tess_data_dir, 'SAP_FLUX',
        {
            's': 'SAP_FLUX_SML',
            'm': 'SAP_FLUX_MID',
            'l': 'SAP_FLUX_LAG',
        },
    )
    input_mask = preprocess.get_spline_mask(ut, period, epoc, duration)
    sf, mdata = keplersplinev2.choosekeplersplinev2(
        ut, uf, input_mask=input_mask, return_metadata=True,
        fixed_bkspace=bkspace,
    )

    plt.plot(ut, uf, '-', alpha=0.6, color=g, label='SAP')
    plt.plot(ut[input_mask], sf[input_mask], 'x', markersize=3, color=k, label='spline (OOT)')
    plt.plot(ut[~input_mask], sf[~input_mask], 'o', markersize=3, color=o, label='spline')
    config_abs_plot(f'raw | bkspace: {mdata.bkspace}')
    
    
    ut, nf, fm = preprocess.detrend_and_filter(tic, ut, uf, period, epoc, duration, bkspace)
    sft, sff, sfn, sftm = preprocess.phase_fold_and_sort_light_curve(ut, nf, input_mask, period, epoc)

    splt()
    plt.plot(sft, sff, 'o', markersize=3, alpha=0.6, c=o, label='spline')
    sff_filtered = np.where((sff > 1.5) | (sff < -0.5), 0, sff)
    if len(sff_filtered):
        mask = np.where(sftm, 1, min(sff_filtered))
        plt.plot(sft, mask, '-', markersize=1, alpha=0.6, c=r, label='OOT')
        title = f'phase folded ({int(max(sfn) + 1)} folds) / filtered'
    else:
        if len(sfn):
            title = f'phase folded ({int(max(sfn) + 1)} folds) | WARNING: filtering removed all data'
        else:
            title = f'phase folded | WARNING: filtering removed all data'
    plt.legend()
    plt.title(title)


    splt()
    if ap:
        uf_s = ap['s'][1]
        uf_m = ap['m'][1]
        uf_l = ap['l'][1]

        _, nf_s, _ = preprocess.detrend_and_filter(tic, ut, uf_s, period, epoc, duration, bkspace)
        sft_s, sff_s, _, _ = preprocess.phase_fold_and_sort_light_curve(ut, nf_s, input_mask, period, epoc)
        _, nf_m, _ = preprocess.detrend_and_filter(tic, ut, uf_m, period, epoc, duration, bkspace)
        sft_m, sff_m, _, _ = preprocess.phase_fold_and_sort_light_curve(ut, nf_m, input_mask, period, epoc)
        _, nf_l, _ = preprocess.detrend_and_filter(tic, ut, uf_m, period, epoc, duration, bkspace)
        sft_l, sff_l, _, _ = preprocess.phase_fold_and_sort_light_curve(ut, nf_m, input_mask, period, epoc)

        plt.plot(sft_s, sff_s, '.', markersize=1, alpha=0.6, label='sm')
        plt.plot(sft_m, sff_m, '.', markersize=1, alpha=0.6, label='med')
        plt.plot(sft_l, sff_l, '.', markersize=1, alpha=0.6, label='lg')
        plt.legend()
        plt.title('multi-aperture')


    row = list(tce.iterrows())[0][1]

    ex = generate_input_records._process_tce(row, bkspace)
    
    if ex.features.feature['star_rad_present'].int64_list.value[0] > 0:
        print(f'star_rad: {ex.features.feature["star_rad"].float_list.value[0]}')
    else:
        print('no star_rad')
    print(f'Transit_Depth: {ex.features.feature["Transit_Depth"].float_list.value[0]}')
    print(f'local_scale: {ex.features.feature["local_scale"].float_list.value[0]}')
    
    
    def plot_w_scatter(view, std, title, mask=None, mask_is_filter=False):
        splt()
        view = np.array(ex.features.feature[view].float_list.value)
        std = np.array(ex.features.feature[std].float_list.value)
        if mask:
            msk = -np.array(ex.features.feature[mask].float_list.value)
            if mask_is_filter:
                view = view[msk < 0]
                std = std[msk < 0]
        plt.plot(std, color=o)
        plt.plot(view, color=n)
        if mask:
            plt.plot(msk, color=r, alpha=0.5, linestyle='--')
        plt.title(title)
        
    def plot_norm(view, title):
        splt()
        plt.plot(ex.features.feature[view].float_list.value, color=n)
        plt.ylim(-1, 1)
        plt.title(title)    
        
        
    def plot_segments(view, tag, num_chans=14, min_chans=None, max_chans=None):
        splt(c=3)
        img = ex.features.feature[view].float_list.value
        img = np.reshape(img, (-1, num_chans))[:, min_chans:max_chans]
        n_transits = img.shape[1] // 2
        for i in range(n_transits):
            view = img[:, 2 * i]
            mask = img[:, 2 * i + 1] > 0
            plt.plot(np.where(mask, view, np.nan), marker='.')
        plt.title(f'{n_transits} sample segments, densest first, ties broken at random {tag}')
        
    splt()
    plt.plot(ex.features.feature['local_aperture_s'].float_list.value, label='small')
    plt.plot(ex.features.feature['local_aperture_m'].float_list.value, label='med')
    plt.plot(ex.features.feature['local_aperture_l'].float_list.value, label='large')
    plt.title('multi-aperture')
    plt.legend()
    
    splt()

    plot_w_scatter('global_view_0.3', 'global_std_0.3', 'global @ 0.3')
    plot_w_scatter('global_view_5.0', 'global_std_5.0', 'global @ 5.0')

    sec_phase = ex.features.feature['secondary_phase'].float_list.value[0]
    loc_scale = ex.features.feature['local_scale'].float_list.value[0]
    sec_scale = ex.features.feature['secondary_scale'].float_list.value[0]
    plot_w_scatter('global_view', 'global_std', 'global', 'global_transit_mask')
    plot_w_scatter('local_view', 'local_std', f'local / {loc_scale:0.3}')
    plot_w_scatter('secondary_view', 'secondary_std', f'secondary ({sec_phase:0.2}) / {sec_scale:0.5}')
    plot_w_scatter('local_view_half_period', 'local_view_half_period_std', 'local half period')
    plot_w_scatter('local_view_even', 'local_std_even', 'local even', 'local_mask_even', True)
    plot_w_scatter('local_view_odd', 'local_std_odd', 'local odd', 'local_mask_odd', True)
    
    splt()

    plot_segments('sample_segments_local_view', 'odd', num_chans=16, min_chans=0, max_chans=8)
    plot_segments('sample_segments_local_view', 'even', num_chans=16, min_chans=8, max_chans=16)

    try:
        im = plt.imread(f'{reports_dir}/{tic}.page1.png')
        plt.figure(figsize=(12, 8), dpi=dpi)
        plt.axis('off')
        _ = plt.imshow(im)
        try:
            im = plt.imread(f'{reports_dir}/{tic}.page3.png')
            plt.figure(figsize=(12, 8), dpi=dpi)
            plt.axis('off')
            _ = plt.imshow(im)
        except FileNotFoundError:
            print('-- no page 3 --')
    except FileNotFoundError:
        print('-- no report file --')
