# Explore the CREDIT data pipeline

In [1]:
import glob

## The `ERA5_and_Forcing_Dataset` class

In [2]:
from credit.data import ERA5_and_Forcing_Dataset

### Build the Pytorch dataset

**Get ERA5 file names**

In [3]:
ERA5_save_loc = '/glade/derecho/scratch/schreck/STAGING/TOTAL_*'
all_ERA_files = sorted(glob.glob(ERA5_save_loc))

In [4]:
# hourly ERA5 data save as yearly *.zarr files  
all_ERA_files[:2]

['/glade/derecho/scratch/schreck/STAGING/TOTAL_1979-01-01_1979-12-31_staged.zarr',
 '/glade/derecho/scratch/schreck/STAGING/TOTAL_1980-01-01_1980-12-31_staged.zarr']

**Get forcing and static filenames**

In [5]:
forcing_name = '/glade/campaign/cisl/aiml/ksha/CREDIT/forcing.nc'
static_name = '/glade/campaign/cisl/aiml/ksha/CREDIT/static.nc'

**Get the dataset**

In [6]:
ERA5_dataset = ERA5_and_Forcing_Dataset(
    filenames=all_ERA_files,
    filename_forcing=forcing_name,
    filename_static=static_name,
    history_len=2, # the number of input forcast lead times
    forecast_len=2, # The targeted forecast lead time, e.g. 0 for the next-hour training
    transform=None,
    skip_periods=None, # works like array[::skip] 
    one_shot=True, # True: returns the last forecast lead time target, None: returns all forecast lead times
    max_forecast_len=None)

### Explore the rollout samples

In [7]:
samples = next(iter(ERA5_dataset))

In [8]:
samples.keys()

dict_keys(['historical_ERA5_images', 'target_ERA5_images', 'datetime_index', 'index'])

**Input dataset**

In [9]:
samples['historical_ERA5_images']

**Target dataset**

In [10]:
samples['target_ERA5_images']

**Datetime info**

In [11]:
samples['datetime_index']

array([283996800, 284000400, 284004000, 284007600, 284011200])