In [11]:
import warnings
import numpy as np
import pandas as pd
import joblib
import math
import seaborn as sns
import matplotlib.pyplot as plt

from scipy.ndimage import gaussian_filter
from data_loading import create_xarr, mad, create_label_df
from utils import DateIter
from threshold_edge_detection import lowess_smooth, measure_thresholds
from IPython.display import clear_output

In [2]:
parent_dir = 'A:/Gits/Ionosphere/raw_data/multisource_data_NA_20m_T0_24hr/'
label_csv_path = '../ionospheric_disturbance_segmentation/labels/official_labels.csv'
data_out_path = 'processed_data/full_data.joblib'
label_out_path = 'labels/labels.joblib'

In [3]:
full_xarr = create_xarr(
    parent_dir=parent_dir,
    expected_shape=(720, 300),
    dtype=(np.uint16, np.float32),
    apply_fn=mad,
    plot=False,
)

label_df = create_label_df(
    csv_path=label_csv_path,
)

2641

In [4]:
joblib.dump(full_xarr, data_out_path)
joblib.dump(label_df, label_out_path)

# %pip install netCDF4
# full_xarr.to_netcdf('full_data.nc')

['labels/labels.joblib']

In [5]:
date_iter = DateIter(data_out_path) #, label_df=label_out_path)

In [12]:
def save_wrap(save_dir, fmt='%Y-%m-%d', ext='.png', **kwargs):
    os.makedirs(save_dir, exist_ok=True)
    def wrapped(date):
        date_str = pd.to_datetime(date).strftime(fmt)
        file_path = os.path.join(save_dir, date_str + ext)
        plt.savefig(file_path, **kwargs)
        return
    return wrapped

In [13]:
def run_edge_detect(
    dates,
    x_trim=.08333,
    y_trim=.08,
    sigma=4.2, # 3.8 was good
    qs=[.4, .5, .6],
    occurence_n = 60,
    i_max=30,
    plot=True,
    clear_every=100,
    plt_save_path=None,
    csv_save_path=None,
    thresh=None,
):
    processed_dates = list()
    if plt_save_path is not None:
        save_plt = save_wrap(plt_save_path)
    else:
        save_plt = None
        
#     final_edge_dict = dict()
    final_edge_list = list()
    if dates == 'all':
        date_gen = date_iter.iter_all()
    else:
        date_gen = date_iter.iter_dates(dates, raise_missing=False)

    for i, (date, arr) in enumerate(date_gen):
        if arr is None:
            warnings.warn(f'Date {date} has no input')
            continue
            
        if not i % clear_every:
            clear_output()

        xl_trim, xr_trim = x_trim if isinstance(x_trim, (tuple, list)) else (x_trim, x_trim)
        yl_trim, yr_trim = x_trim if isinstance(y_trim, (tuple, list)) else (y_trim, y_trim)
        xr, xl = math.floor(xl_trim * arr.shape[0]), math.floor(xr_trim * arr.shape[0])
        yr, yl = math.floor(yl_trim * arr.shape[1]), math.floor(yr_trim * arr.shape[1])

        arr = arr[xr:-xl, yr:-yl]

        heights = arr.coords['height']
        times = arr.coords['time'] # .to_pandas().dt.strftime('%H:%M')
        
        arr = np.nan_to_num(arr, nan=0)

        arr = gaussian_filter(arr.T, sigma=(sigma, sigma))  # [::-1,:]
        med_lines, min_line, minz_line = measure_thresholds(
            arr, # [::-1]
            qs=qs, 
            occurrence_n=occurence_n, 
            i_max=i_max
        )

        data = pd.DataFrame(
            np.array(med_lines).T,
            index=(date + times).dt.strftime('%H:%M'),
            columns=qs,
        ).reset_index(
            names='Time',
        )
        if thresh is None:
            edge_line = pd.DataFrame(
                min_line, 
                index=(date + times).dt.strftime('%H:%M'), 
                columns=['Height'],
            ).reset_index(
                names='Time'
            )
        elif isinstance(thresh, dict):
            edge_line = (
                data[['Time', thresh[date]]]
                .rename(columns={thresh[date] : 'Height'})
            )
        elif isinstance(thresh, float):
            edge_line = (
                data[['Time', thresh]]
                .rename(columns={thresh : 'Height'})
            )
        else:
            raise ValueError(f'Threshold {thresh} of type {type(thresh)} is invalid')

#         final_edge_dict[date] = min_line.squeeze()
        final_edge_list.append(
            pd.Series(min_line.squeeze(), index=times, name=date)
        )

        if plot or save_plt is not None:
            fig, ax = plt.subplots(1, 1, figsize=(15,8))
            plt.title(f'| {date} |')
            sns.heatmap(
                pd.DataFrame(
                    arr, # [::-1],
                    index=heights,
                    columns=times,
                ), 
                robust=True, 
                cbar=False, 
                ax=ax,
            )
            ax.invert_yaxis()

            sns.lineplot(
                data=data, 
                alpha=.75, 
                dashes=False, 
                ax=ax, 
                palette='light:b',
            )
            sns.lineplot(
                data=edge_line, 
                x='Time', 
                y='Height', 
                color='white', 
                alpha=1, 
                dashes=False, 
                ax=ax,
            )

            if save_plt is not None:
                save_plt(date)
            
            processed_dates.append(date)
            plt.show() if plot else plt.clear()
                
#     final_edge_df = pd.DataFrame(
#         final_edge_dict, 
#         index=list(range(60,660,1))
#     )
    final_edge_df = pd.concat(final_edge_list, axis=1)
    if csv_save_path:
        final_edge_df.to_csv(csv_save_path)
    
    return final_edge_df

In [15]:
# run_edge_detect(
#     [('2022-10-27','2022-10-31')], 
#     csv_save_path=None,
#     plot=True,
#     plt_save_path='A:/Gits/Ionosphere/ionospheric_edge_detection/outputs/October2022Edges'
# ).head()

In [16]:
today = pd.to_datetime('today').strftime('%Y-%m-%d')
full_edge_path = f'processed_data/Nov2018-May2019_run_{today}.csv'

final_edge_df = run_edge_detect(
#     'all',
    [('2018-10-31','2019-05-31')],
    qs=[.4],
    clear_every=10,
#     csv_save_path=full_edge_path,
    plot=False,
    thresh=.4,
)
final_edge_df.index

TimedeltaIndex(['0 days 12:59:00', '0 days 13:00:00', '0 days 13:01:00',
                '0 days 13:02:00', '0 days 13:03:00', '0 days 13:04:00',
                '0 days 13:05:00', '0 days 13:06:00', '0 days 13:07:00',
                '0 days 13:08:00',
                ...
                '0 days 22:51:00', '0 days 22:52:00', '0 days 22:53:00',
                '0 days 22:54:00', '0 days 22:55:00', '0 days 22:56:00',
                '0 days 22:57:00', '0 days 22:58:00', '0 days 22:59:00',
                '0 days 23:00:00'],
               dtype='timedelta64[ns]', length=602, freq=None)

In [17]:
nov_to_may_df = (
    final_edge_df
    .reindex(
        pd.timedelta_range(start='12:00:00', end='23:59:00', freq='1min'),
        axis=0,
    )
    .fillna(0)
)

nov_to_may_df.to_csv('Nov2018_to_May2019_edges.csv')
nov_to_may_df

Unnamed: 0,2018-10-31,2018-11-01,2018-11-02,2018-11-03,2018-11-04,2018-11-05,2018-11-06,2018-11-07,2018-11-08,2018-11-09,...,2019-05-22,2019-05-23,2019-05-24,2019-05-25,2019-05-26,2019-05-27,2019-05-28,2019-05-29,2019-05-30,2019-05-31
0 days 12:00:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0 days 12:01:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0 days 12:02:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0 days 12:03:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0 days 12:04:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0 days 23:55:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0 days 23:56:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0 days 23:57:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0 days 23:58:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [18]:
thresh_df = pd.read_excel('Edges_2020_MLW.xlsx', sheet_name='Sheet2')
thresh_dict = thresh_df.set_index('Date').squeeze().to_dict()

# today = pd.to_datetime('today').strftime('%Y-%m-%d')
# full_edge_path = f'processed_data/final_edges_{today}_run.csv'

final_edge_df = run_edge_detect(
#     'all',
    sorted(thresh_dict.keys()),
    qs=[.4,.5,.6],
    clear_every=10,
#     csv_save_path=full_edge_path,
    plot=False,
    thresh=thresh_dict,
)

NameError: name 'thresh_dict' is not defined

In [None]:
def plot_date(date_iter, edge_df, date, side_trim=.08):
    if isinstance(edge_df, str):
        edge_df = pd.read_csv(edge_df, index_col=0)
        edge_df.columns = pd.to_datetime(edge_df.columns)
    
    date = pd.to_datetime(date)
    arr = date_iter.get_date(date)
    x_trim = math.floor(side_trim * arr.shape[0])
    y_trim = math.floor(side_trim * arr.shape[1])
    arr = arr[x_trim:-x_trim, y_trim:-y_trim]
    
    edge = edge_df[date]
    
    fig, ax = plt.subplots(1, 1, figsize=(15,8))
    plt.title(date)
    ax.axis('off')
    sns.heatmap(arr.T, robust=True, cbar=False, ax=ax)
    ax.invert_yaxis()

    sns.lineplot(data=edge, color='white', alpha=1, dashes=False, ax=ax)    
    plt.show()
    return

plot_date(date_iter, final_edge_df, '2020-01-20', side_trim=side_trim)