In [1]:
from hydroDL import kPath # package by Kuai Fang, kPath contains req paths

import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import pandas as pd
import json
import os

loading package hydroDL


### Set Up

In [2]:
# module by Kuai Fang
from hydroDL.data import dbVeg
from hydroDL.data import DataModel
from hydroDL.master import dataTs2Range

import numpy as np
import torch

# satellite variable names
varS = ['VV', 'VH', 'vh_vv']
varL = ['SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'ndvi', 'ndwi', 'nirv']
varM = ["MCD43A4_b{}".format(x) for x in range(1, 8)]

# Days per satellite
bS = 8
bL = 6
bM = 10

def prepare_data(dataName, rho):
    """
    Loads the data, normalizes it, and processes it into the required shape and format 
    for further analysis. It identifies the available days for each observation from various remote 
    sensing sources (Sentinel, Landsat, MODIS) and filters the observations to retain only those with 
    data available from all sources.

    Args:
        dataName (str): The name or path of the dataset to be loaded.
        rho (int): The time window size for each observation.

    Returns:
        tuple: A tuple containing the following elements:
            - df (DataFrameVeg): A custom DataFrameVeg object containing the loaded data.
            - dm (DataModel): A custom DataModel object containing the normalized data.
            - iInd (np.array): Array of day indices for the observations.
            - jInd (np.array): Array of site indices for the observations.
            - nMat (np.array): Matrix indicating the number of available days with data for each satellite 
                               (Sentinel, Landsat, MODIS) for each observation.
            - pSLst (list): List of arrays indicating the days with available data for Sentinel for each observation.
            - pLLst (list): List of arrays indicating the days with available data for Landsat for each observation.
            - pMLst (list): List of arrays indicating the days with available data for MODIS for each observation.
            - x (np.array): Array of raw values for each observation in the shape (rho, number of observations, number of input features).
            - rho (int): The time window size for each observation.
            - xc (np.array): Array of constant variables for the observations.
            - yc (np.array): Array of LFMC data for the observations.
    """
    # Load data with custom DataFrameVeg and DataModel classes
    df = dbVeg.DataFrameVeg(dataName)
    dm = DataModel(X=df.x, XC=df.xc, Y=df.y)
    dm.trans(mtdDefault='minmax') # Min-max normalization
    dataTup = dm.getData()

    # To convert data to shape (Number of observations, rho, number of input features)
    dataEnd, (iInd, jInd) = dataTs2Range(dataTup, rho, returnInd=True) # iInd: day, jInd: site
    x, xc, _, yc = dataEnd 
   
    iInd = np.array(iInd) # TODO: Temporary fix
    jInd = np.array(jInd) # TODO: emporary fix
    
    iS = [df.varX.index(var) for var in varS]
    iM = [df.varX.index(var) for var in varM]
    
    # For each remote sensing source (i.e. Sentinel, MODIS), for each LFMC observaton,
    # create a list of days in the rho-window that have data 
    # nMat: Number of days each satellite has data for, of shape (# obsevations, # satellites)
    pSLst, pMLst = list(), list()
    nMat = np.zeros([yc.shape[0], 2])
    for k in range(nMat.shape[0]):
        tempS = x[:, k, iS]
        pS = np.where(~np.isnan(tempS).any(axis=1))[0]
        pSLst.append(pS)
        tempM = x[:, k, iM]
        pM = np.where(~np.isnan(tempM).any(axis=1))[0]
        pMLst.append(pM)
        nMat[k, :] = [len(pS), len(pM)]
    
    # only keep if data if there is at least 1 day of data for each remote sensing source
    indKeep = np.where((nMat > 0).all(axis=1))[0]
    x = x[:, indKeep, :]
    xc = xc[indKeep, :]
    yc = yc[indKeep, :]
    nMat = nMat[indKeep, :]
    pSLst = [pSLst[k] for k in indKeep]
    pMLst = [pMLst[k] for k in indKeep]
    
    jInd = jInd[indKeep]
    iInd = iInd[indKeep]

    data = {
        'dataFrame' : df,
        'dataModel' : dm,
        'day_indices' : iInd,
        'site_indices' : jInd,
        'sat_avail_per_obs' : nMat, 
        's_days_per_obs' : pSLst,
        'm_days_per_obs': pMLst,
        'rho' : rho,
        'x' : x,
        'xc' : xc,
        'yc' : yc
    }

    return data

### Tests functions

In [3]:
def test_total_sites(splits_dict):
    total_sites = []
    for fold in range(5):
        num_sites = len(splits_dict[f'trainSite_k{fold}5']) + len(splits_dict[f'testSite_k{fold}5'])
        total_sites.append(num_sites)
        
    assert all(x == total_sites[0] for x in total_sites)

def test_obs_duplicates(splits_dict):
    for fold in range(5):
        train_ind = set(splits_dict[f'trainInd_k{fold}5'])
        test_qual_ind = set(splits_dict[f'testInd_k{fold}5'])
        test_poor_ind = set(splits_dict[f'testInd_underThresh'])

        assert len(train_ind.intersection(test_qual_ind)) == 0
        assert len(train_ind.intersection(test_poor_ind)) == 0
        assert len(test_qual_ind.intersection(test_poor_ind)) == 0

def test_site_duplicates(splits_dict):
    for fold in range(5):
        train_sites = set(splits_dict[f'trainSite_k{fold}5'])
        test_qual_sites = set(splits_dict[f'testSite_k{fold}5'])
        test_poor_sites = set(splits_dict[f'testSite_underThresh'])

        assert len(train_sites.intersection(test_qual_sites)) == 0
        assert len(train_sites.intersection(test_poor_sites)) == 0
        assert len(test_qual_sites.intersection(test_poor_sites)) == 0

def test_site_from_obs_duplicates(splits_dict, jInd):
    for fold in range(5):
        train_obs = splits_dict[f'trainInd_k{fold}5']
        test_qual_obs = splits_dict[f'testInd_k{fold}5']
        test_poor_obs = splits_dict[f'testInd_underThresh']

        train_sites = set(jInd[train_obs])
        test_qual_sites = set(jInd[test_qual_obs])
        test_poor_sites = set(jInd[test_poor_obs])

        assert len(train_sites.intersection(test_qual_sites)) == 0
        assert len(train_sites.intersection(test_poor_sites)) == 0
        assert len(test_qual_sites.intersection(test_poor_sites)) == 0

def sanity_check_splits(dataset, rho, split_version):
    data = prepare_data(dataset, rho)
    jInd = data['site_indices']

    splits_path = os.path.join(kPath.dirVeg, 'model', 'attention', split_version, 'subset.json')
    with open(splits_path) as f:
        splits_dict = json.load(f)

    test_total_sites(splits_dict)
    test_obs_duplicates(splits_dict)
    test_site_duplicates(splits_dict)
    test_site_from_obs_duplicates(splits_dict, jInd)

### Tests

In [4]:
dataset = 'singleDaily-modisgrid-new-const'
rho = 45

In [5]:
split_version = 'dataset'
sanity_check_splits(dataset, rho, split_version)

AssertionError: 

In [None]:
split_version = 'stratified'
sanity_check_splits(dataset, rho, split_version)

In [None]:
split_version = 'stratified_s0'
sanity_check_splits(dataset, rho, split_version)

split_version = 'stratified_s1'
sanity_check_splits(dataset, rho, split_version)

In [None]:
split_version = 'random_s1'
sanity_check_splits(dataset, rho, split_version)

split_version = 'random_s1'
sanity_check_splits(dataset, rho, split_version)

In [6]:
split_version = 'sherlock_stratified'
sanity_check_splits(dataset, rho, split_version)

In [7]:
split_version = 'sherlock_dataset'
sanity_check_splits(dataset, rho, split_version)

AssertionError: 