In [2]:
from torch.utils.data import Dataset
from dataclasses import dataclass
import pandas as pd
import os
import numpy as np
import matplotlib.pyplot as plt
import yaml
from pathlib import Path
from tqdm import tqdm
from dataclasses import dataclass, field
from typing import Dict, Tuple, Union, Any
import pickle as pkl
from PIL import Image, ImageOps

In [3]:
EC_PREDICTORS = ('DOY', 'TOD', 'TA', 'P', 'RH', 'VPD', 'PA', 'CO2', 'SW_IN', 'SW_OUT', 'LW_IN', 'LW_OUT', 'NETRAD', 'PPFD_IN', 'PPFD_OUT',
                 'TS_1', 'TS_2', 'TS_3', 'TS_4', 'TS_5', 'G', 'H', 'LE', 'WS', 'WD', 'USTAR')

EC_TARGETS = ('NEE', 'GPP_DT', 'GPP_NT', 'RECO_DT', 'RECO_NT', 'FCH4')

DEFAULT_NORM = {
    'DOY': {'cyclic': True, 'norm_max': 366.0, 'norm_min': 0.0},
    'TOD': {'cyclic': True, 'norm_max': 24.0, 'norm_min': 0.0},
    'TA': {'cyclic': False, 'norm_max': 80.0, 'norm_min': -80.0},
    'P': {'cyclic': False, 'norm_max': 50.0, 'norm_min': 0.0},
    'RH': {'cyclic': False, 'norm_max': 100.0, 'norm_min': 0.0},
    'VPD': {'cyclic': False, 'norm_max': 110.0, 'norm_min': 0.0},
    'PA': {'cyclic': False, 'norm_max': 110.0, 'norm_min': 0.0},
    'CO2': {'cyclic': False, 'norm_max': 750.0, 'norm_min': 0.0},
    'SW_IN': {'cyclic': False, 'norm_max': 1500.0, 'norm_min': -1500.0},
    'SW_OUT': {'cyclic': False, 'norm_max': 500.0, 'norm_min': -500.0},
    'LW_IN': {'cyclic': False, 'norm_max': 1000.0, 'norm_min': -1000.0},
    'LW_OUT': {'cyclic': False, 'norm_max': 1000.0, 'norm_min': -1000.0},
    'NETRAD': {'cyclic': False, 'norm_max': 1000.0, 'norm_min': -1000.0},
    'PPFD_IN': {'cyclic': False, 'norm_max': 2500.0, 'norm_min': -2500.0},
    'PPFD_OUT': {'cyclic': False, 'norm_max': 1000.0, 'norm_min': -1000.0},
    'TS_1': {'cyclic': False, 'norm_max': 40.0, 'norm_min': -40.0},
    'TS_2': {'cyclic': False, 'norm_max': 40.0, 'norm_min': -40.0},
    'TS_3': {'cyclic': False, 'norm_max': 40.0, 'norm_min': -40.0},
    'TS_4': {'cyclic': False, 'norm_max': 40.0, 'norm_min': -40.0},
    'TS_5': {'cyclic': False, 'norm_max': 40.0, 'norm_min': -40.0},
    'USTAR': {'cyclic': False, 'norm_max': 4.0, 'norm_min': -4.0},
    'G': {'cyclic': False, 'norm_max': 700.0, 'norm_min': -700.0},
    'H': {'cyclic': False, 'norm_max': 700.0, 'norm_min': -700.0},
    'LE': {'cyclic': False, 'norm_max': 700.0, 'norm_min': -700.0},
    'WD': {'cyclic': True, 'norm_max': 360.0, 'norm_min': 0.0},
    'WS': {'cyclic': False, 'norm_max': 100.0, 'norm_min': -100.0},

    'NEE': {'cyclic': False, 'norm_max': 50.0, 'norm_min': -50.0},
    'GPP_DT': {'cyclic': False, 'norm_max': 40.0, 'norm_min': -40.0},
    'GPP_NT': {'cyclic': False, 'norm_max': 40.0, 'norm_min': -40.0},
    'RECO_DT': {'cyclic': False, 'norm_max': 30.0, 'norm_min': -30.0},
    'RECO_NT': {'cyclic': False, 'norm_max': 30.0, 'norm_min': -30.0},
    'FCH4': {'cyclic': False, 'norm_max': 800.0, 'norm_min': -800.0}
}

In [27]:

@dataclass
class CarbonSenseConfig:
    '''Configuration for CarbonSenseV2 dataloader and preprocessor

    targets - variable selection for targets. Must be a subset of EC_TARGETS
    targets_max_qc - maximum QC flag (inclusive) to allow for target values. A lower value will result
                     in fewer usable samples, but they will be of higher quality
    predictors - variable selection for predictors. Must be a subset of EC_PREDICTORS
    predictors_max_qc - similar to targets_max_qc, but applied to predictor variables
    normalization_config - dictionary object used for normalizing variables. Custom dictionaries can
                           be supplied, but should be based on the DEFAULT_NORM template

    suffix - file suffix for preprocessed tabular data files. This is used by the proprocessing script to
             normalize and filter variables based on max QC values. If the data loader is passed a config with
             a suffix string and does not locate the appropriate files, it will run the proprocessor automatically.
             For example, if suffix = '_proc', then the preprocessor will make a 'data_proc.csv' for each site.
             **NOTE** if a suffix file is already present, it will effectively ignore all max QC, normalization,
             and variable selection parameters, as it will assume these operations are already complete. This can
             be circumvented by passing force_overwrite = True.
    force_overwrite - whether or not to rerun the preprocessing even if the suffix file is found.

    context_window_length - how many timesteps should be included with each example. This is the 'context window'.
                            For example, if context_window_length = 4, then every sample will contain 4
                            consecutive timesteps worth of data, for both predictors and targets.
    image_timestep_threshold - allows the dataloader to look back past the beginning of the context window for any
                               MODIS / phenocam imagery. For example, if image_timestep_threshold = 48, then the
                               data loader will look back up to 48 timesteps (24 hours) for images if none are found
                               in the context window. This is useful for models which only have a context window of 1,
                               but still want to have imagery with most samples.
                               **NOTE** the batch provided by the dataloader may lie about when the images were taken.
    '''

    # Preprocessor arguments
    targets: Tuple[str] = EC_TARGETS
    targets_max_qc: int = 1
    predictors: Tuple[str] = EC_PREDICTORS
    predictors_max_qc: int = 1

    normalize_predictors: bool = True
    normalize_targets: bool = False
    normalization_config: Dict = field(default_factory = lambda: (DEFAULT_NORM))
    suffix: str = '_proc'
    force_overwrite: bool = False

    # Dataloader arguments
    context_window_length: int = 16
    image_timestep_threshold: int = 0
    use_modis: bool = True
    use_phenocam: bool = True
    phenocam_resolution: Tuple = (512, 512)




def preprocess_data(
        data_dir: Union[str, os.PathLike],
        config: CarbonSenseConfig):
    if config.suffix == '':
        print('ERROR: Cannot overwrite original data')
        return
    data_path = Path(data_dir)
    sites = os.listdir(data_path / 'site_data')

    print('Preprocessing site data...')
    for site in tqdm(sites):
        site_path = data_path / 'site_data' / site
        outfile = site_path / f'data{config.suffix}.csv'
        if os.path.exists(outfile) and not config.force_overwrite:
            continue
        df = pd.read_csv(site_path / 'data.csv')

        # Delete values with a QC flag higher than the maximum specified in the config
        for pred in config.predictors:
            if pred == 'DOY' or pred == 'TOD':
                continue
            df.loc[df[f'{pred}_QC'] > config.predictors_max_qc, pred] = np.nan
        for targ in config.targets:
            df.loc[df[f'{targ}_QC'] > config.targets_max_qc, targ] = np.nan

        # Filter variables (and get rid of QC columns)
        df = df[['timestamp'] + list(config.predictors) + list(config.targets)]

        # Min-max normalization
        if config.normalize_predictors:
            for pred in config.predictors:
                vmax = config.normalization_config[pred]['norm_max']
                vmin = config.normalization_config[pred]['norm_min']
                vmid = (vmax + vmin) / 2
                vrange = vmax - vmin
                cyclic = config.normalization_config[pred]['cyclic']
                if cyclic:
                    vrange /= 2

                df.loc[~df[pred].between(vmin, vmax), pred] = np.nan
                df[pred] = (df[pred] - vmid) / vrange

        if config.normalize_targets:
            for targ in config.targets:
                vmax = config.normalization_config[targ]['norm_max']
                vmin = config.normalization_config[targ]['norm_min']
                vmid = (vmax + vmin) / 2
                vrange = vmax - vmin
                cyclic = config.normalization_config[targ]['cyclic']
                if cyclic:
                    vrange /= 2

                df.loc[~df[targ].between(vmin, vmax), targ] = np.nan
                df[targ] = (df[targ] - vmid) / vrange
            
        df.to_csv(outfile, index=False)
    return


@dataclass
class CarbonSenseBatch:
    sites: Tuple[str] # one value for each sample
    columns: Tuple[str] # common mapping for all samples in the batch
    timestamps: Union[Tuple, np.ndarray]
    ec_values: np.ndarray # all eddy covariance data: (batch, context_window, values)
    modis: Tuple # all modis data: (batch, (timestamp, ndarray))
    phenocam: Tuple # all phenocam data: (batch, (timestamp, ndarray))

    def to(self, device: Any):
        '''
        .to(device) is provided with this dataclass as a shortcut to individually moving
        every piece of data in the class.
        '''
        raise NotImplementedError()


class CarbonSenseDataset(Dataset):
    def __init__(self,
                 data_dir: Union[str, os.PathLike],
                 config: CarbonSenseConfig):
        self.data_path = Path(data_dir)
        self.config = config

        preprocess_data(data_dir, self.config)

        self.window_len = self.config.context_window_length
        self.datafile = f'data{self.config.suffix}.csv'
        self.sites = sorted(os.listdir(self.data_path / 'site_data'))

        self.data = []
        print('Indexing sites...')
        for site in tqdm(self.sites):
            df = pd.read_csv(self.data_path / 'site_data' / site / self.datafile, skiprows=range(1,self.window_len))
            idx = df[df[list(self.config.targets)].notnull().any(axis=1)].index.to_list()
            self.data.extend([(site, i) for i in idx])

    
    def __len__(self):
        return len(self.data)

    def _load_image(self, filename):
        with Image.open(self.data_path / 'phenocam' / filename) as img:
            return img.convert("RGB")
    
    def __getitem__(self, idx):
        # Find the relevant file
        site, index = self.data[idx]
        ec_file = self.data_path / 'site_data' / site / self.datafile
        df = pd.read_csv(ec_file, skiprows=range(1,1+index), nrows=self.window_len)

        # Get timestamps
        ec_timestamps = df['timestamp'].to_list()
        if self.config.image_timestep_threshold > 0:
            img_idx_range = range(1,max(2,1+index-self.config.image_timestep_threshold))
            df_ts = pd.read_csv(ec_file, skiprows=img_idx_range, nrows=self.window_len+self.config.image_timestep_threshold)
            ec_timestamps = df_ts['timestamp'].to_list()
        
        # EC data
        ec_data = df.drop(columns='timestamp').values
        ec_cols = tuple(df.drop(columns='timestamp').columns)

        # MODIS data
        modis_data = []
        modis_file = self.data_path / 'site_data' / site / 'modis.pkl'
        if self.config.use_modis and os.path.exists(modis_file):
            with open(modis_file, 'rb') as f:
                all_modis_data = pkl.load(f)
            modis_data = [(ts, im) for ts, im in all_modis_data.items() if ts in ec_timestamps]

        # Phenocam data
        phenocam_data = []
        phenocam_file = self.data_path / 'site_data' / site / 'phenocam.pkl'
        if self.config.use_phenocam and os.path.exists(phenocam_file):
            with open(phenocam_file, 'rb') as f:
                phenocam_map = pkl.load(f)
            filtered_phenocam_map = [(ts, f) for ts, f in phenocam_map.items() if ts in ec_timestamps]
            for timestamp, files in filtered_phenocam_map:
                images = []
                for file in files:
                    resized_img = self._load_image(file)
                    images.append(np.array(resized_img, dtype=np.float32)/255.0)
                    # TODO: Resize images?
                phenocam_data.append((timestamp, images))


        return site, ec_cols, ec_timestamps, ec_data, modis_data, phenocam_data


config = CarbonSenseConfig(targets=['NEE'])
ds = CarbonSenseDataset('data/carbonsense_v2', config)


Preprocessing site data...


100%|██████████| 417/417 [00:00<00:00, 91275.69it/s]


Indexing sites...


100%|██████████| 417/417 [01:28<00:00,  4.73it/s]


In [28]:
site, ec_cols, ec_timestamps, ec_data, modis_data, phenocam_data = ds.__getitem__(40343385)

<PIL.Image.Image image mode=RGB size=1296x960 at 0x7FF6AEBB2790>
<PIL.Image.Image image mode=RGB size=2592x1944 at 0x7FF6AEBB2520>


In [32]:
phenocam_data

[(201903061200,
  [array([[[0.21176471, 0.26666668, 0.4117647 ],
           [0.21176471, 0.26666668, 0.4117647 ],
           [0.21176471, 0.26666668, 0.4117647 ],
           ...,
           [0.41568628, 0.54901963, 0.59607846],
           [0.4117647 , 0.54509807, 0.5921569 ],
           [0.40784314, 0.5411765 , 0.5882353 ]],
   
          [[0.21176471, 0.26666668, 0.4117647 ],
           [0.21176471, 0.26666668, 0.4117647 ],
           [0.21176471, 0.26666668, 0.4117647 ],
           ...,
           [0.41568628, 0.54901963, 0.59607846],
           [0.4117647 , 0.54509807, 0.5921569 ],
           [0.40784314, 0.5411765 , 0.5882353 ]],
   
          [[0.21176471, 0.26666668, 0.4117647 ],
           [0.21176471, 0.26666668, 0.4117647 ],
           [0.21176471, 0.26666668, 0.4117647 ],
           ...,
           [0.4117647 , 0.54509807, 0.5921569 ],
           [0.4117647 , 0.54509807, 0.5921569 ],
           [0.4117647 , 0.54509807, 0.5921569 ]],
   
          ...,
   
          [[0.164705

In [6]:
for i in range(len(ds)):
    if ds.data[i][0] == 'US-xDJ':
        print(i)
        break

40343385


In [None]:
# For Tuesday
# I finished basic EC data loading, now I need to work on:
# - MODIS image loading
# - Phenocam loading
# - collate function and general structure (dataclass for data delivery?)