In [1]:
import os
import json
import numpy as np
import pandas as pd
import pickle as pkl
from torch.utils.data import Dataset, DataLoader
import torch
from time import time

In [2]:
DATA_DIR = os.path.join('data', 'processed', 'v2')
SITES = os.listdir(DATA_DIR)

In [3]:
class FluxDataset(Dataset):
    def __init__(self, data_dir, sites, context_length=48, targets=['GPP_NT_VUT_REF'], device='cpu'):
        self.data_dir = data_dir
        self.sites = sites
        self.data = []
        self.context_length = context_length
        self.targets = targets
        self.remove_columns = ['timestamp', 'NEE_VUT_REF', 'GPP_NT_VUT_REF', 'RECO_NT_VUT_REF']
        self.device = device
        
        for root, _, files in os.walk(self.data_dir):
            in_sites = False
            for site in sites:
                if site in root:
                    in_sites = True
            if not in_sites:
                continue

            if 'data.csv' in files:
                df = pd.read_csv(os.path.join(root, 'data.csv'))
                df['timestamp'] = pd.to_datetime(df['timestamp'])
                with open(os.path.join(root, 'modis.pkl'), 'rb') as f:
                    modis_data = pkl.load(f)
                with open(os.path.join(root, 'meta.json'), 'r') as f:
                    meta = json.load(f)

                self.data.append((meta, df, modis_data))
        
        self.lookup_table = []
        for i, d in enumerate(self.data):
            _, df, _ = d
            for r in range(self.context_length, len(df)+1):
                self.lookup_table.append((i,r))


    def __len__(self):
        return len(self.lookup_table)

    def __getitem__(self, idx):
        site_num, row_max = self.lookup_table[idx]
        row_min = row_max - (self.context_length)

        _, df, modis = self.data[site_num]
        rows = df.iloc[row_min:row_max]
        rows = rows.reset_index(drop=True)
        modis_data = []
        timestamps = list(rows['timestamp'])
        for i, ts in enumerate(timestamps):
            pixels = modis.get(ts, None)
            if pixels is not None:
                modis_data.append((i, torch.tensor(pixels[:,1:9,1:9]).to(self.device)))
        
        targets = torch.tensor(rows[self.targets].values).to(self.device)
        row_values = torch.tensor(rows.drop(columns=self.remove_columns).values)
        mask = row_values.isnan().to(self.device)
        row_values = row_values.nan_to_num(-1.0).to(self.device) # just needs a numeric value, doesn't matter what

        return row_values, mask, modis_data, targets



def custom_collate_fn(batch):
    row_values, mask, modis_data, targets = zip(*batch)

    # imgs are tensors with the same dim, can be stacked
    row_values = torch.stack(row_values, dim=0)
    mask = torch.stack(mask, dim=0)
    targets = torch.stack(targets, dim=0)

    # masks and classes have variable size per sample, so they get returned as a list
    modis_data = [m for m in modis_data]

    return row_values, mask, modis_data, targets

def FluxDataLoader(data_dir, sites, context_length = 48, targets=['GPP_NT_VUT_REF'], device='cpu', **kwargs):
    ds = FluxDataset(data_dir, sites, context_length=context_length, targets=targets, device=device)
    return DataLoader(ds, collate_fn=custom_collate_fn, **kwargs)
    

dl = FluxDataLoader(DATA_DIR, SITES, batch_size=32, shuffle=True, num_workers=8)


In [6]:
r, m, a, t = next(iter(dl))