# Prepare data for transformers

**Purpose:** We want to compare metrics by the RNN maps to predictions by the transformer model. Our transformer mdoel currenlty takes in inputs of size `(rho, num observations, num inputs)`. To get the daily prediction at each site, we have to prepare this matrix.

**Date:** July 30, 2024


This notebook follows the `models/cherry_pick/data.py` and `models/all_pick/data.py` files.

In [2]:
# hydroDL module by Kuai Fang
import hydroDL.data.dbVeg
from hydroDL.data import dbVeg
from hydroDL.data import DataModel
from hydroDL.master import dataTs2Range
from hydroDL import kPath

import numpy as np
import torch

loading package hydroDL


In [47]:
dataName = "singleDaily-nadgrid"
rho = 45

In [80]:
# Load data with custom DataFrameVeg and DataModel classes
df = dbVeg.DataFrameVeg(dataName)
dm = DataModel(X=df.x, XC=df.xc, Y=df.y)
dm.trans(mtdDefault='minmax') # Min-max normalization
dataTup = dm.getData()

Here, we fill the `y` matrix with all ones. This is quick temporary and hacky way to quickly format the data for the transformer. 

The `dataTs2Range` method in the following line prepares the data by reformatting it to shape `(rho, num observations, num inputs)`. In our case, we do not want to limit just to `num observations`. Rathe, we want to keep ALL the remote sensing data in order to retrieve daily LFMC prediction from the study period.

In [81]:
x, xc, y, _ = dataTup
dataTup = (x, xc, np.ones(y.shape), _)

In [91]:
# To convert data to shape (Number of observations, rho, number of input features)
dataEnd, (iInd, jInd) = dataTs2Range(dataTup, rho, returnInd=True) # iInd: day, jInd: site
x, xc, _, yc = dataEnd 

iInd = np.array(iInd) # TODO: Temporary fix
jInd = np.array(jInd) # TODO: emporary fix

In [92]:
# satellite variable names
varS = ['VV', 'VH', 'vh_vv']
varL = ['SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'ndvi', 'ndwi', 'nirv']
varM = ["MCD43A4_b{}".format(x) for x in range(1, 8)]

In [93]:
iInd = np.array(iInd) # TODO: Temporary fix
jInd = np.array(jInd) # TODO: emporary fix

iS = [df.varX.index(var) for var in varS]
iL = [df.varX.index(var) for var in varL]
iM = [df.varX.index(var) for var in varM]

# For each remote sensing source (i.e. Sentinel, MODIS), for each LFMC observaton,
# create a list of days in the rho-window that have data 
# nMat: Number of days each satellite has data for, of shape (# obsevations, # satellites)
pSLst, pLLst, pMLst = list(), list(), list()
nMat = np.zeros([yc.shape[0], 3])
for k in range(nMat.shape[0]):
    tempS = x[:, k, iS]
    pS = np.where(~np.isnan(tempS).any(axis=1))[0]
    tempL = x[:, k, iL]
    pL = np.where(~np.isnan(tempL).any(axis=1))[0]
    tempM = x[:, k, iM]
    pM = np.where(~np.isnan(tempM).any(axis=1))[0]
    pSLst.append(pS)
    pLLst.append(pL)
    pMLst.append(pM)
    nMat[k, :] = [len(pS), len(pL), len(pM)]

# only keep if data if there is at least 1 day of data for each remote sensing source
indKeep = np.where((nMat > 0).all(axis=1))[0]
x = x[:, indKeep, :]
xc = xc[indKeep, :]
yc = yc[indKeep, :]
nMat = nMat[indKeep, :]
pSLst = [pSLst[k] for k in indKeep]
pLLst = [pLLst[k] for k in indKeep]
pMLst = [pMLst[k] for k in indKeep]

jInd = jInd[indKeep]

In [3]:
a = np.ones(10)
b = np.zeros(10)
np.concatenate([a, b])

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0.])

In [5]:
a = np.random.rand(10).shape()
b = np.random.rand(10)

In [15]:
np.column_stack([a, b])

array([[0.04216492, 0.38305902],
       [0.66129926, 0.12741888],
       [0.6074255 , 0.56743243],
       [0.01837205, 0.88331768],
       [0.21977481, 0.37993062],
       [0.32463266, 0.95190576],
       [0.05082871, 0.43167692],
       [0.72797715, 0.49819248],
       [0.65345978, 0.59968182],
       [0.69760439, 0.08436677]])

In [1]:
import os

In [4]:
lfmc_daily = np.load(os.path.join(kPath.dirVeg, "transformer_lfmc_daily.npy"))

In [14]:
site = 45
len(lfmc_daily[lfmc_daily[:, 2] == site])

333

In [12]:
np.unique(lfmc_daily[:, 2])

array([ 45.,  46.,  54.,  58.,  61.,  66.,  68.,  72.,  74.,  76.,  81.,
        82.,  83.,  86.,  90.,  91.,  92.,  93.,  95.,  96.,  97.,  98.,
       104., 106., 109., 112., 117., 118., 119., 120., 121., 122., 123.,
       124., 125., 127., 129., 131., 133., 135., 138., 140., 142., 144.,
       145., 146., 147., 148., 149., 151., 152., 154., 157., 158., 159.,
       161., 162., 163., 164., 165., 166., 167., 176., 177., 178., 179.,
       180., 181., 182., 185., 186., 187., 188., 189., 191., 192., 193.,
       194., 198., 199., 200., 201., 202., 203., 204., 205., 206., 207.,
       208., 209., 210., 211., 213., 215., 217., 223., 225., 228., 236.,
       237., 243., 244., 245., 246., 251., 252., 253., 254., 255., 256.,
       257., 269., 270., 271., 272., 274., 275., 276., 280., 282., 286.,
       287., 289., 290., 291., 292., 293., 294., 295., 296., 297., 298.,
       299., 300., 305., 307., 308., 309., 310., 311., 312., 313., 314.,
       315., 316., 317., 318., 319., 320., 322., 32

In [15]:
np.unique(lfmc_daily[:, 1])

array([  0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,
        11.,  12.,  13.,  14.,  15.,  16.,  17.,  18.,  19.,  20.,  21.,
        22.,  23.,  25.,  26.,  27.,  28.,  29.,  30.,  31.,  32.,  33.,
        34.,  35.,  36.,  37.,  38.,  39.,  40.,  41.,  42.,  43.,  44.,
        45.,  46.,  47.,  48.,  49.,  50.,  51.,  52.,  53.,  54.,  55.,
        56.,  57.,  58.,  59.,  60.,  61.,  62.,  63.,  64.,  65.,  66.,
        67.,  68.,  69.,  70.,  71.,  72.,  73.,  74.,  75.,  76.,  77.,
        78.,  79.,  80.,  81.,  82.,  83.,  84.,  85.,  86.,  87.,  88.,
        89.,  90.,  91.,  92.,  93.,  94.,  95.,  96.,  97.,  98.,  99.,
       100., 101., 102., 103., 104., 105., 106., 107., 108., 109., 110.,
       111., 112., 113., 114., 115., 116., 117., 118., 119., 120., 121.,
       122., 123., 124., 125., 126., 127., 128., 129., 130., 131., 132.,
       133., 134., 135., 136., 137., 138., 139., 140., 141., 142., 143.,
       144., 145., 146., 147., 148., 149., 150., 15