In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from bsd_dataset import regions
from bsd_dataset.common.transforms import LogTransformPrecipitation
import pandas as pd
import xarray as xr
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class Dataset:
    input_ds = None
    target_ds = None
    
    def __init__(self, input_dir, target_dir, region, period, transform=None, target_transform=None):
        if Dataset.input_ds is None:
            Dataset.input_ds = xr.open_mfdataset(f'{input_dir}/*.nc')
        if Dataset.target_ds is None:
            Dataset.target_ds = xr.open_mfdataset(f'{target_dir}/*.nc')
        self.lats = sorted(region.get_latitudes())
        self.lons = sorted(region.get_longitudes(180))
        self.period = pd.date_range(*period)
        self.transform = transform
        self.target_transform = target_transform
        
    def __getitem__(self, idx):
        date = str(self.period[idx].date())
        
        xdata = Dataset.input_ds.precip.sel(time=date, latitude=slice(*self.lats), longitude=slice(*self.lons))
        ydata = Dataset.target_ds.precip.sel(time=date, latitude=slice(*self.lats), longitude=slice(*self.lons))        
        
        x = torch.tensor(xdata.values)
        if self.transform:
            x = self.transform(x)
            
        y = torch.tensor(ydata.values)
        if self.target_transform:
            y = self.target_transform(y)
            
        info = {
            'x_mask': torch.isnan(x),
            'y_mask': torch.isnan(y),
            'x_lat': torch.tensor(xdata.latitude.values),
            'x_lon': torch.tensor(xdata.longitude.values),
            'y_lat': torch.tensor(ydata.latitude.values),
            'y_lon': torch.tensor(ydata.longitude.values)
        }
        
        return x, y, info

In [4]:
input_dir = '/home/data/BSDD/experiment-chirps/chirps-input'
target_dir = '/home/data/BSDD/experiment-chirps/chirps-target'

train_period = ('1983-01-01', '2010-12-31')
val_period = ('2011-01-01', '2012-12-31')
test_period = ('2013-01-01', '2014-12-31')

region = regions.Europe
transform = LogTransformPrecipitation()

In [5]:
train_set = Dataset(input_dir, target_dir, region, train_period, transform, transform)
val_set = Dataset(input_dir, target_dir, region, val_period, transform, transform)
test_set = Dataset(input_dir, target_dir, region, test_period, transform, transform)

In [6]:
x, y, info = train_set[0]

In [7]:
print(x)
print(x.shape)

tensor([[nan, 0., 0.,  ..., 0., 0., 0.],
        [nan, 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [nan, nan, nan,  ..., 0., 0., 0.],
        [nan, nan, nan,  ..., 0., 0., 0.],
        [nan, nan, nan,  ..., 0., 0., 0.]])
torch.Size([80, 200])


In [8]:
print(y)
print(y.shape)

tensor([[nan, nan, nan,  ..., 0., 0., 0.],
        [nan, nan, nan,  ..., 0., 0., 0.],
        [nan, nan, nan,  ..., 0., 0., 0.],
        ...,
        [nan, nan, nan,  ..., 0., 0., 0.],
        [nan, nan, nan,  ..., 0., 0., 0.],
        [nan, nan, nan,  ..., 0., 0., 0.]])
torch.Size([400, 1000])


In [9]:
for k, v in info.items():
    print(f'{k}: {v.shape}')

x_mask: torch.Size([80, 200])
y_mask: torch.Size([400, 1000])
x_lat: torch.Size([80])
x_lon: torch.Size([200])
y_lat: torch.Size([400])
y_lon: torch.Size([1000])
