In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib qt
sys.path.append('/home/manu/TFG_repo/scripts')  # make modules in scripts folder callable from notebooks

In [2]:
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import mpl_interactions.ipyplot as iplt
from mpl_interactions import interactive_plot
from analysis_routines import * 
from plotting_routines_xr import *
from matplotlib.widgets import Slider, Button
from harmonic_analysis import *

In [3]:
# dates defining_periods
date_0 = datetime(2018, 11, 16, 11) # period_1
date_1 = datetime(2018, 11, 24)
period_1 = [date_0, date_1]

date_2 = datetime(2018, 11, 30) # period_2
date_3 = datetime(2018, 12, 9)
period_2 = [date_2, date_3]

date_4 = datetime(2019, 1, 5) # period_3
date_5 = datetime(2019, 1, 14)
period_3 = [date_4, date_5]

In [5]:
fit_chain = load_SHDR_fit('optimal_server_fit/AGL_20181116_fit_fci.csv')
data_chain = load_time_series_xr('processed/AGL_20181116_chain_xrcompatible.nc')
G005 = pd.read_csv(data_dir / 'SHDR_fit/aux/G05.csv', index_col='date', parse_dates=True)

In [6]:
D1_filt512 = lowpass_filter(fit_chain.D1, data_chain.date, 1/5, 1/512)
G005_filt512 = lowpass_filter(G005.x, data_chain.date, 1/5, 1/512)

D1_filt1024 = lowpass_filter(fit_chain.D1, data_chain.date, 1/5, 1/1024)
G005_filt1024 = lowpass_filter(G005.x, data_chain.date, 1/5, 1/1024)

low = period_to_freq(14, 'h')
high = period_to_freq(8, 'h')
D1_bandpass_M2 = bandpass_filter(fit_chain.D1, data_chain.date, 1/5, low, high)
G005_bandpass_M2 = bandpass_filter(G005.x, data_chain.date, 1/5, low, high)

D1_dn = detrend_normalize(D1_bandpass_M2)
G005_dn = detrend_normalize(G005_bandpass_M2)

In [12]:
date_7 = datetime(2018, 11, 21, 8)
date_8 = datetime(2018, 11, 21, 16)

In [15]:
def interactive_plot(data_chain, fit_chain, period=[None, None], extra_variable=None):
    
    slice_ = slice(*period)

    fit_period = fit_chain.loc[slice_]
    mld = fit_period.D1.to_numpy()
    temp_period = data_chain.temp.loc[slice_].data
    date_str = np.datetime_as_string(data_chain.date.data, unit='s')
    date_period = data_chain.date.loc[slice_].data
    date_str = np.datetime_as_string(date_period, unit='s')

    z = data_chain.depth.data

    N = len(fit_period.D1)

    zz_ = np.linspace(0, 200, 300)
    ii = range(N)

    def f_yy(zz, i):
        return fit_function(zz, fit_period, i)
    
    if extra_variable is None:
        fig, ax = plt.subplots(figsize=(5, 5.75))
        
    else:
        fig = plt.figure(figsize=(10, 10))
        if type(extra_variable) == list and len(extra_variable) == 2:
            gs = plt.GridSpec(3, 3, height_ratios=(0.5, 0.5, 2), width_ratios = (1,4,1))
            ax_extra1 = fig.add_subplot(gs[0, :])
            ax_extra2 = fig.add_subplot(gs[1, :])
            ax = fig.add_subplot(gs[2, 1])

            
            locator = mdates.AutoDateLocator(minticks=4, maxticks=None)
            formatter = mdates.ConciseDateFormatter(locator)
            minor_locator = mdates.AutoDateLocator(minticks=6)
            ax_extra1.xaxis.set_major_locator(locator)
            ax_extra1.xaxis.set_major_formatter(formatter)
            ax_extra1.xaxis.set_minor_locator(minor_locator)           
            locator = mdates.AutoDateLocator(minticks=4, maxticks=None)
            formatter = mdates.ConciseDateFormatter(locator)
            minor_locator = mdates.AutoDateLocator(minticks=6)
            ax_extra2.xaxis.set_major_locator(locator)
            ax_extra2.xaxis.set_major_formatter(formatter)
            ax_extra2.xaxis.set_minor_locator(minor_locator)
            
        else:
            gs = plt.GridSpec(2, 3, height_ratios=(1, 2), width_ratios = (1,4,1))
            ax_extra = fig.add_subplot(gs[0, :])
            ax = fig.add_subplot(gs[1, 1])
            
            locator = mdates.AutoDateLocator(minticks=4, maxticks=None)
            formatter = mdates.ConciseDateFormatter(locator)
            minor_locator = mdates.AutoDateLocator(minticks=6)
            ax_extra.xaxis.set_major_locator(locator)
            ax_extra.xaxis.set_major_formatter(formatter)
            ax_extra.xaxis.set_minor_locator(minor_locator)

    
        
    xlim = [np.nanmin(temp_period) - 0.1, np.nanmax(temp_period) + 0.1]
    ax.set_xlim(*xlim)

    line, = ax.plot(f_yy(zz_, 0), zz_)
    points, = ax.plot(temp_period[0], z, 'o', mfc='None', mec='tab:red')
    mld_line, = ax.plot(xlim, [mld[0], mld[0]], c='grey', ls='--')
    title = ax.text(0.7, 0.1, date_str[0], bbox={'facecolor': 'w', 'alpha': 0.5,
                                         'pad': 5}, transform=ax.transAxes, ha='center')

    if extra_variable is not None:
        if type(extra_variable) is list:
            extra1 = extra_variable[0].loc[slice_]
            ax_extra1.plot(date_period, extra1)
            ax_extra1_ylim = ax_extra1.get_ylim()
            vline1, = ax_extra1.plot([date_period[0], date_period[0]], ax_extra1_ylim)
            
            extra2 = np.asarray(extra_variable[1].loc[slice_])
            ax_extra2.plot(date_period, extra2)
            ax_extra2_ylim = ax_extra2.get_ylim()
            vline2, = ax_extra2.plot([date_period[0], date_period[0]], ax_extra2_ylim)

        else:
            extra = extra_variable.loc[slice_]
            ax_extra.plot(date_period, extra)
            ax_extra_ylim = ax_extra.get_ylim()
            vline, = ax_extra.plot([date_period[0], date_period[0]], ax_extra_ylim)
            
    fig.subplots_adjust(bottom=0.15)
    fig.subplots_adjust(top=0.95)
    fig.subplots_adjust(right=0.90)
    fig.subplots_adjust(left=0.15)



    axii = fig.add_axes([0.15, 0.05, 0.8, 0.03])
    ax.set_ylim(200, 0)
    ax.set_ylabel('Profundidad (dbar)')
    ax.set_xlabel('Temperatura (ºC)')
    
    ii_slider = Slider(
        ax=axii,
        label='Index',
        valmin=0,
        valmax=N - 1,
        valstep=1,
        valinit=0,
    )

    def update(i):
        line.set_data(f_yy(zz_, i), zz_)
        points.set_data(temp_period[i], z)
        mld_line.set_data(xlim, [mld[i], mld[i]])
        title.set_text('{}'.format(date_str[i]))
        if extra_variable is not None:
            if type(extra_variable) is list:
                vline1.set_data([date_period[i], date_period[i]], ax_extra1_ylim)
                vline2.set_data([date_period[i], date_period[i]], ax_extra2_ylim)
            else:
                vline.set_data([date_period[i], date_period[i]], ax_extra_ylim)

    
    ii_slider.on_changed(update)

    nextax = fig.add_axes([0.8, 0.525, 0.1, 0.04])
    button_1 = Button(nextax, 'Next', hovercolor='0.975')
    
    prevax = fig.add_axes([0.8, 0.425, 0.1, 0.04])
    button_2 = Button(prevax, 'Prev', hovercolor='0.975')
    
    def prev(event):
        actual_i = ii_slider.val
        next_i = actual_i - 1
        ii_slider.set_val(next_i)

    def next(event):
        actual_i = ii_slider.val
        next_i = actual_i + 1
        ii_slider.set_val(next_i)
    
    
    button_1.on_clicked(next)
    button_2.on_clicked(prev)

    plt.show()
    return ii_slider, button_1, button_2

In [19]:
period = [datetime(2018, 11, 19, 20), datetime(2018, 11, 21, 20)]
slider, button_1, button_2 = interactive_plot(data_chain, fit_chain, period=period_3, 
                                  extra_variable=[fit_chain.D1, G005_filt1024])


In [None]:
date_7_precise = datetime(2018, 11, 21, 10)
date_8_precise = datetime(2018, 11, 21, 12, 15)

In [None]:
print(G005_filt1024.loc[slice(*period_1)])