In [None]:
import yaml
import pickle
import os.path as osp
import subprocess
from datetime import timedelta
from urllib.parse import urlparse
import numpy as np
import matplotlib.pyplot as plt
from utils import time_intp, str2time, filter_nan_values, read_pkl, read_yml, retrieve_url

## Setup

In [None]:
retrieve_url("https://demo.openwfm.org/web/data/fmda/dicts/fmda_nw_202401-05_f05.pkl", "data/fmda_nw_202401-05_f05.pkl")

In [None]:
data_params = read_yml("params_data.yaml")
data_params

In [None]:
# dat = read_pkl("data/test_CA_202401.pkl")
dat = read_pkl("data/fmda_nw_202401-05_f05.pkl")

## Interp

Any filters applied after interpolation.

In [None]:
cases = list([*dat.keys()])
flags = np.zeros(len(cases))
for i, case in enumerate(cases):
    time_raws=str2time(dat[case]['RAWS']['time_raws'])
    time_hrrr=str2time(dat[case]["HRRR"]['time'])
    fm = dat[case]['RAWS']['fm']
    ynew = time_intp(time_raws,fm,time_hrrr)
    dat[case]['y'] = ynew

## Divide into Hours

In [None]:
hours = 720 # 1 month
cases = list([*dat.keys()])
flags = np.zeros(len(cases))
dat2={}
for key, data in dat.items():
    print(key)
    time_hrrr=str2time(data["HRRR"]['time'])
    X_array = data['HRRR']['f01']['temp']
    y_array = data['y']
    
    # Determine the start and end time for the 720-hour portions
    start_time = time_hrrr[0]
    end_time = time_hrrr[-1]
    current_time = start_time
    portion_index = 1
    while current_time < end_time:
        next_time = current_time + timedelta(hours=720)
        
        # Create a mask for the time range
        mask = (time_hrrr >= current_time) & (time_hrrr < next_time)
        
        # Apply the mask to extract the portions
        new_time = time_hrrr[mask]
        new_X = X_array[mask]
        new_y = y_array[mask]
        
        # Save the portions in the new dictionary
        new_key = f"{key}_portion_{portion_index}"
        dat2[new_key] = {'time': new_time, 'X': new_X, 'y': new_y}
        
        # Move to the next portion
        current_time = next_time
        portion_index += 1    

In [None]:
dat2.keys()

## Filters

In [None]:
# Useful Cases:
    # NV040_202401: more raws observations than HRRR, interp should shorten
    # NV026_202401: raws 10min obs, interp should shorten
    # CGVC1_202401: missing only a few observations, interp should lengthen
    # YNWC1_202401: only 2 observations, should be filtered entirely

In [None]:
def flag_lag_stretches(x, lag = 1, threshold = data_params['zero_lag_threshold']):
    lags = np.round(np.diff(x, n=lag), 8)
    zero_lag_indices = np.where(lags == 0)[0]
    current_run_length = 1
    for i in range(1, len(zero_lag_indices)):
        if zero_lag_indices[i] == zero_lag_indices[i-1] + 1:
            current_run_length += 1
            if current_run_length > threshold:
                return True
        else:
            current_run_length = 1
    else:
        return False    

In [None]:
cases = list([*dat.keys()])
flags = np.zeros(len(cases))
data_params['max_intp_time'] = 48
data_params['zero_lag_threshold'] = 48
for i, case in enumerate(cases):
    print("~"*50)
    print(f"Case: {case}")
    time_raws=str2time(dat[case]['RAWS']['time_raws'])
    time_hrrr=str2time(dat[case]["HRRR"]['time'])
    fm = dat[case]['RAWS']['fm']
    ynew = time_intp(time_raws,fm,time_hrrr)
    dat[case]['y'] = ynew
    if flag_lag_stretches(ynew):
        print(f"Flagging case {case} for zero lag stretches greater than `zero_lag_threshold` param {data_params['zero_lag_threshold']}")
        flags[i]=1
    if flag_lag_stretches(ynew, lag=2):
        print(f"Flagging case {case} for constant linear stretches greater than `max_intp_time` param {data_params['max_intp_time']}")
        flags[i]=1
    if np.any(ynew>=data_params['max_fm']) or np.any(ynew<=data_params['min_fm']):
        print(f"Flagging case {case} for FMC outside param range {data_params['min_fm'],data_params['max_fm']}. FMC range for {case}: {ynew.min(),ynew.max()}")
        flags[i]=1

In [None]:
flagged_cases = [element for element, flag in zip(cases, flags) if flag == 1]
print(flagged_cases)

In [None]:
len(flagged_cases)

In [None]:
len(cases)

In [None]:
168 / 235

In [None]:
# Try partitioned dict
len(dat2.keys())

In [None]:
def discard_keys_with_short_y(input_dict):
    filtered_dict = {key: value for key, value in input_dict.items() if len(value['y']) >= 720}
    return filtered_dict

# Discard shorter stretches
dat2 = discard_keys_with_short_y(dat2)

In [None]:
len(dat2.keys())

In [None]:
cases = list([*dat2.keys()])
flags = np.zeros(len(cases))
data_params['max_intp_time'] = 8
data_params['zero_lag_threshold'] = 8
for i, case in enumerate(cases):
    print("~"*50)
    print(f"Case: {case}")
    ynew = dat2[case]['y']
    if flag_lag_stretches(ynew):
        print(f"Flagging case {case} for zero lag stretches greater than `zero_lag_threshold` param {data_params['zero_lag_threshold']}")
        flags[i]=1
    if flag_lag_stretches(ynew, lag=2):
        print(f"Flagging case {case} for constant linear stretches greater than `max_intp_time` param {data_params['max_intp_time']}")
        flags[i]=1
    if np.any(ynew>=data_params['max_fm']) or np.any(ynew<=data_params['min_fm']):
        print(f"Flagging case {case} for FMC outside param range {data_params['min_fm'],data_params['max_fm']}. FMC range for {case}: {ynew.min(),ynew.max()}")
        flags[i]=1

In [None]:
len(dat2.keys())

In [None]:
flagged_cases = [element for element, flag in zip(cases, flags) if flag == 1]
len(flagged_cases)

In [None]:
477 / 1175