# Prepare data for transformers

**Purpose:** We want to compare metrics by the RNN maps to predictions by the transformer model. Our transformer mdoel currenlty takes in inputs of size `(rho, num observations, num inputs)`. To get the daily prediction at each site, we have to prepare this matrix.

**Date:** July 30, 2024


This notebook follows the `models/cherry_pick/data.py` and `models/all_pick/data.py` files.

In [46]:
# hydroDL module by Kuai Fang
import hydroDL.data.dbVeg
from hydroDL.data import dbVeg
from hydroDL.data import DataModel
from hydroDL.master import dataTs2Range
from hydroDL import kPath

import numpy as np
import torch

In [47]:
dataName = "singleDaily-nadgrid"
rho = 45

In [80]:
# Load data with custom DataFrameVeg and DataModel classes
df = dbVeg.DataFrameVeg(dataName)
dm = DataModel(X=df.x, XC=df.xc, Y=df.y)
dm.trans(mtdDefault='minmax') # Min-max normalization
dataTup = dm.getData()

Here, we fill the `y` matrix with all ones. This is quick temporary and hacky way to quickly format the data for the transformer. 

The `dataTs2Range` method in the following line prepares the data by reformatting it to shape `(rho, num observations, num inputs)`. In our case, we do not want to limit just to `num observations`. Rathe, we want to keep ALL the remote sensing data in order to retrieve daily LFMC prediction from the study period.

In [81]:
x, xc, y, _ = dataTup
dataTup = (x, xc, np.ones(y.shape), _)

In [83]:
# To convert data to shape (Number of observations, rho, number of input features)
dataEnd, (iInd, jInd) = (dataTup, rho, returnInd=True) # iInd: day, jInd: site
x, xc, _, yc = dataEnd 

iInd = np.array(iInd) # TODO: Temporary fix
jInd = np.array(jInd) # TODO: emporary fix

[   0    0    0 ... 1887 1887 1887]


In [77]:
x.shape

(91, 602330, 27)

In [30]:
# satellite variable names
varS = ['VV', 'VH', 'vh_vv']
varL = ['SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'ndvi', 'ndwi', 'nirv']
varM = ["MCD43A4_b{}".format(x) for x in range(1, 8)]

In [31]:
iInd = np.array(iInd) # TODO: Temporary fix
jInd = np.array(jInd) # TODO: emporary fix

iS = [df.varX.index(var) for var in varS]
iL = [df.varX.index(var) for var in varL]
iM = [df.varX.index(var) for var in varM]

# For each remote sensing source (i.e. Sentinel, MODIS), for each LFMC observaton,
# create a list of days in the rho-window that have data 
# nMat: Number of days each satellite has data for, of shape (# obsevations, # satellites)
pSLst, pLLst, pMLst = list(), list(), list()
nMat = np.zeros([yc.shape[0], 3])
for k in range(nMat.shape[0]):
    tempS = x[:, k, iS]
    pS = np.where(~np.isnan(tempS).any(axis=1))[0]
    tempL = x[:, k, iL]
    pL = np.where(~np.isnan(tempL).any(axis=1))[0]
    tempM = x[:, k, iM]
    pM = np.where(~np.isnan(tempM).any(axis=1))[0]
    pSLst.append(pS)
    pLLst.append(pL)
    pMLst.append(pM)
    nMat[k, :] = [len(pS), len(pL), len(pM)]

# only keep if data if there is at least 1 day of data for each remote sensing source
indKeep = np.where((nMat > 0).all(axis=1))[0]
x = x[:, indKeep, :]
xc = xc[indKeep, :]
yc = yc[indKeep, :]
nMat = nMat[indKeep, :]
pSLst = [pSLst[k] for k in indKeep]
pLLst = [pLLst[k] for k in indKeep]
pMLst = [pMLst[k] for k in indKeep]

jInd = jInd[indKeep]

In [32]:
x.shape

(91, 575602, 27)

In [35]:
(602330 - 575602)

26728