# Explore the CREDIT data pipeline

In [1]:
import glob
import yaml

## Read the example config file

In [2]:
config_name = '/glade/u/home/ksha/miles-credit/config/example_for_data_checks.yml'
# Read YAML file
with open(config_name, 'r') as stream:
    conf = yaml.safe_load(stream)

## The `ERA5_and_Forcing_Dataset` class

```python
from credit.data import ERA5_and_Forcing_Dataset
```

`ERA5_and_Forcing_Dataset` reads ERA5, forcing, and static variables from storage and convert them to `xarray.Dataset`.

In [3]:
from credit.data import ERA5_and_Forcing_Dataset

### Build the Pytorch dataset

**Get ERA5 file names**

In [4]:
ERA5_save_loc = conf['data']['save_loc']
all_ERA_files = sorted(glob.glob(ERA5_save_loc))

In [5]:
# hourly ERA5 data save as yearly *.zarr files  
all_ERA_files[:2]

['/glade/derecho/scratch/schreck/STAGING/TOTAL_1979-01-01_1979-12-31_staged.zarr',
 '/glade/derecho/scratch/schreck/STAGING/TOTAL_1980-01-01_1980-12-31_staged.zarr']

**Get forcing and static filenames**

In [6]:
forcing_name = conf['data']['save_loc_forcing']
static_name = conf['data']['save_loc_static']

**Get the dataset**

In [7]:
ERA5_dataset = ERA5_and_Forcing_Dataset(
    filenames=all_ERA_files,
    filename_forcing=forcing_name,
    filename_static=static_name,
    history_len=2, # the number of input forcast lead times
    forecast_len=2, # The targeted forecast lead time, e.g. 0 for the next-hour training
    transform=None, # transform
    skip_periods=None, # works like array[::skip] 
    one_shot=True, # True: returns the last forecast lead time target, None: returns all forecast lead times
    max_forecast_len=None)

### Explore the produced samples

In [8]:
samples = next(iter(ERA5_dataset))

In [9]:
samples.keys()

dict_keys(['historical_ERA5_images', 'target_ERA5_images', 'datetime_index', 'index'])

**Input dataset**

In [10]:
samples['historical_ERA5_images']

**Target dataset**

In [11]:
samples['target_ERA5_images']

**Datetime info**

In [12]:
samples['datetime_index']

array([283996800, 284000400, 284004000, 284007600, 284011200])

## The `Normalize_ERA5_and_Forcing` and `ToTensor_ERA5_and_Forcing`

In [13]:
from credit.transforms import Normalize_ERA5_and_Forcing, ToTensor_ERA5_and_Forcing

### Z-score normalization

In [14]:
transform_scaler = Normalize_ERA5_and_Forcing(conf)

# let dataset roll out a sample
samples = next(iter(ERA5_dataset))

# use transforms to do z-score
samples_norm = transform_scaler(samples)

In [15]:
samples_norm.keys()

dict_keys(['historical_ERA5_images', 'target_ERA5_images'])

In [16]:
samples_norm['historical_ERA5_images']

### Convert xarray to tensor

In [17]:
to_tensor_scaler = ToTensor_ERA5_and_Forcing(conf)

# get the normalized dataset from above
tensor_norm = to_tensor_scaler(samples_norm)

In [18]:
tensor_norm.keys()

dict_keys(['x_surf', 'x', 'forcing_static', 'y_surf', 'y'])

In [19]:
print(tensor_norm['x'].shape)
print(tensor_norm['x_surf'].shape)
print(tensor_norm['forcing_static'].shape)
print(tensor_norm['y_surf'].shape)
print(tensor_norm['y'].shape)

torch.Size([2, 4, 15, 640, 1280])
torch.Size([2, 7, 640, 1280])
torch.Size([2, 3, 640, 1280])
torch.Size([1, 7, 640, 1280])
torch.Size([1, 4, 15, 640, 1280])


## Build the DataLoader & rollout an example batch

In [20]:
import torch
from torchvision import transforms as tforms
from torch.utils.data.distributed import DistributedSampler

In [21]:
transform_scaler = Normalize_ERA5_and_Forcing(conf)
to_tensor_scaler = ToTensor_ERA5_and_Forcing(conf)

transforms = tforms.Compose([
    transform_scaler,
    to_tensor_scaler,
])

In [22]:
# shufle dataloader if training
shuffle = False
# 1 GPU scenario
rank = 0
world_size = 1

# dataset
ERA5_dataset = ERA5_and_Forcing_Dataset(
    filenames=all_ERA_files,
    filename_forcing=forcing_name,
    filename_static=static_name,
    history_len=2,
    forecast_len=2,
    transform=transforms, # <--------- add transforms to the Dataset
    skip_periods=None,
    one_shot=True,
    max_forecast_len=None)

# Pytorch sampler
sampler = DistributedSampler(
    ERA5_dataset,
    num_replicas=world_size,
    rank=rank,
    seed=42,
    shuffle=shuffle,
    drop_last=True
)

dataloader = torch.utils.data.DataLoader(
    ERA5_dataset, # <------------------------- our dataset
    batch_size=32, # <------------------------ 32 samples per batch
    shuffle=False,
    sampler=sampler, # <----------------------- our sampler
)

In [23]:
example_batch = next(iter(dataloader))

In [24]:
example_batch.keys()

dict_keys(['x_surf', 'x', 'forcing_static', 'y_surf', 'y', 'index'])

In [25]:
print(example_batch['x'].shape)
print(example_batch['x_surf'].shape)
print(example_batch['forcing_static'].shape)
print(example_batch['y_surf'].shape)
print(example_batch['y'].shape)

torch.Size([32, 2, 4, 15, 640, 1280])
torch.Size([32, 2, 7, 640, 1280])
torch.Size([32, 2, 3, 640, 1280])
torch.Size([32, 1, 7, 640, 1280])
torch.Size([32, 1, 4, 15, 640, 1280])
