In [1]:
from nowcasting_dataset.datamodule import NowcastingDataModule
from pathlib import Path
import pandas as pd
import numpy as np
import xarray as xr
import numcodecs
import gcsfs
from typing import List

import logging
logging.basicConfig()
logger = logging.getLogger('nowcasting_dataset')
logger.setLevel(logging.DEBUG)

In [2]:
BUCKET = Path('solar-pv-nowcasting-data')

# Solar PV data
PV_PATH = BUCKET / 'PV/PVOutput.org'
PV_DATA_FILENAME = PV_PATH / 'UK_PV_timeseries_batch.nc'
PV_METADATA_FILENAME = PV_PATH / 'UK_PV_metadata.csv'

# SAT_FILENAME = BUCKET / 'satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep_quarter_geospatial.zarr'
SAT_FILENAME = BUCKET / 'satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr'

# Numerical weather predictions
#NWP_BASE_PATH = BUCKET / 'NWP/UK_Met_Office/UKV_zarr'
#NWP_BASE_PATH = BUCKET / 'NWP/UK_Met_Office/UKV_single_step_and_single_timestep_all_vars.zarr'
#NWP_BASE_PATH = BUCKET / 'NWP/UK_Met_Office/UKV_single_step_and_single_timestep_all_vars_full_spatial_2018_7-12_float32.zarr'
NWP_BASE_PATH = BUCKET / 'NWP/UK_Met_Office/UKV__2018-01_to_2019-12__chunks__variable10__init_time1__step1__x548__y704__.zarr'

In [3]:
filename = 'gs://solar-pv-nowcasting-data/prepared_ML_training_data/testing.zarr'

In [4]:
params = dict(
    batch_size=32,
    history_len=6,  #: Number of timesteps of history, not including t0.
    forecast_len=12,  #: Number of timesteps of forecast.
    image_size_pixels=32,
    nwp_channels=('t', 'dswrf', 'prate', 'r', 'sde', 'si10', 'vis', 'lcc', 'mcc', 'hcc'),
    sat_channels=(
        'HRV', 'IR_016', 'IR_039', 'IR_087', 'IR_097', 'IR_108', 'IR_120',
        'IR_134', 'VIS006', 'VIS008', 'WV_062', 'WV_073')
)

In [5]:
data_module = NowcastingDataModule(
    pv_power_filename=PV_DATA_FILENAME,
    pv_metadata_filename=f'gs://{PV_METADATA_FILENAME}',
    sat_filename = f'gs://{SAT_FILENAME}',
    nwp_base_path = f'gs://{NWP_BASE_PATH}',
    pin_memory = True,  #: Passed to DataLoader.
    num_workers = 16,  #: Passed to DataLoader.
    prefetch_factor = 8,  #: Passed to DataLoader.
    n_samples_per_timestep = 8,  #: Passed to NowcastingDataset
    n_training_batches_per_epoch = 50_000,
    collate_fn = lambda x: x,
    convert_to_numpy = False,  #: Leave data as Pandas / Xarray for pre-preparing.
    **params
)

In [6]:
data_module.prepare_data()

DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr


15 bad PV systems found and removed!
pv_power = 400.0 MB


In [7]:
data_module.setup()

DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening NWP data: gs://solar-pv-nowcasting-data/NWP/UK_Met_Office/UKV__2018-01_to_2019-12__chunks__variable10__init_time1__step1__x548__y704__.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
  unixtime = np.array(time.astype(np.int64)/10**9)
  unixtime = np.array(time.astype(np.int64)/10**9)
  unixtime = np.array(time.astype(np.int64)/10**9)
  unixtime = np.array(time.astype(np.int64)/10**9)
  a = a.astype(int)


In [8]:
data_module.train_t0_datetimes

DatetimeIndex(['2018-06-01 03:50:00', '2018-06-01 03:55:00',
               '2018-06-01 04:00:00', '2018-06-01 04:05:00',
               '2018-06-01 04:10:00', '2018-06-01 04:15:00',
               '2018-06-01 04:20:00', '2018-06-01 04:25:00',
               '2018-06-01 04:30:00', '2018-06-01 04:35:00',
               ...
               '2019-06-16 15:15:00', '2019-06-16 15:20:00',
               '2019-06-16 15:25:00', '2019-06-16 15:30:00',
               '2019-06-16 15:35:00', '2019-06-16 15:40:00',
               '2019-06-16 15:45:00', '2019-06-16 15:50:00',
               '2019-06-16 15:55:00', '2019-06-16 16:00:00'],
              dtype='datetime64[ns]', length=47620, freq=None)

In [9]:
data_module.val_t0_datetimes

DatetimeIndex(['2019-06-16 16:05:00', '2019-06-16 16:10:00',
               '2019-06-16 16:15:00', '2019-06-16 16:20:00',
               '2019-06-16 16:25:00', '2019-06-16 16:30:00',
               '2019-06-16 16:35:00', '2019-06-16 16:40:00',
               '2019-06-16 16:45:00', '2019-06-16 16:50:00',
               ...
               '2019-08-20 18:00:00', '2019-08-20 18:05:00',
               '2019-08-20 18:10:00', '2019-08-20 18:15:00',
               '2019-08-20 18:20:00', '2019-08-20 18:25:00',
               '2019-08-20 18:30:00', '2019-08-20 18:35:00',
               '2019-08-20 18:40:00', '2019-08-20 18:45:00'],
              dtype='datetime64[ns]', length=11904, freq=None)

In [10]:
def get_dataloader():
    train_dl = data_module.train_dataloader()
    return train_dl

train_dl = get_dataloader()

In [11]:
import xarray as xr
import pandas as pd
from copy import deepcopy
import numcodecs
from concurrent import futures
import zarr
import gcsfs
from typing import List

In [12]:
def coord_to_arange(da: xr.DataArray, dim: str, prefix: str, dtype=np.int32) -> xr.DataArray:
    coord = da[dim]
    da[dim] = np.arange(len(coord), dtype=dtype)
    da[f'{prefix}_{dim}_coords'] = xr.DataArray(coord, coords=[da[dim]], dims=[dim])
    return da

In [13]:
def concat_examples(examples, start_example_index=0):
    datasets = []
    for i, example in enumerate(examples):
        individual_datasets = []
        example_dim = {'example': np.array([i+start_example_index], dtype=np.int32)}
        for name in ['sat_data', 'nwp']:
            ds = example[name].to_dataset(name=name)
            short_name = name.replace('_data', '')
            if name == 'nwp':
                ds = ds.rename({'target_time': 'time'})
            for dim in ['time', 'x', 'y']:
                ds = coord_to_arange(ds, dim, prefix=short_name)
            ds = ds.rename({
                'variable': f'{short_name}_variable', 
                'x': f'{short_name}_x', 
                'y': f'{short_name}_y',
            })
            individual_datasets.append(ds)

        # PV
        pv_yield = example['pv_yield'].rename('pv_yield').to_xarray().rename({'datetime': 'time'}).to_dataset()
        pv_yield = coord_to_arange(pv_yield, 'time', prefix='pv_yield')
        # This will expand all dataarrays to have an 'example' dim.
        for name in ['pv_system_id', 'pv_system_row_number']:
            pv_yield[name] = xr.DataArray([example[name]], coords=example_dim, dims=['example'])
        individual_datasets.append(pv_yield)

        # Merge
        merged_ds = xr.merge(individual_datasets)
        datasets.append(merged_ds)
    return xr.concat(datasets, dim='example')


def fix_dtypes(concat_ds):
    ds_dtypes = {
        'example': np.int32, 'sat_x_coords': np.int32, 'sat_y_coords': np.int32, 
        'nwp': np.float32, 'nwp_x_coords': np.float32, 'nwp_y_coords': np.float32,
        'pv_system_id': np.int32, 'pv_system_row_number': np.int32}

    for name, dtype in ds_dtypes.items():
        concat_ds[name] = concat_ds[name].astype(dtype)
    return concat_ds


def write_examples(examples, start_example_index=0):
    concat_ds = concat_examples(examples, start_example_index=start_example_index)
    concat_ds = fix_dtypes(concat_ds)
    target_chunks = {'example': 32}
    concat_ds = concat_ds.chunk(target_chunks)
    encoding = {
        name: {'compressor': numcodecs.Blosc(cname="zstd", clevel=5)}
        for name in concat_ds.data_vars}

    gcs = gcsfs.GCSFileSystem()
    if gcs.exists(filename):
        to_zarr_kwargs = dict(append_dim='example')
    else:
        to_zarr_kwargs = dict(encoding=encoding, mode='w')

    zarr_store = concat_ds.to_zarr(filename, **to_zarr_kwargs)

In [None]:
%%time
examples = []
start_example_index = 0
print('Getting first batch')
for batch_i, batch in enumerate(train_dl):
    print(f'Got batch {batch_i}')
    examples.extend(batch)
    if batch_i % 32 == 0:
        print('Writing!')
        write_examples(examples, start_example_index=start_example_index)
        start_example_index += len(examples)
        examples = []
    print('getting next batch...')

Getting first batch


DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite

Got batch 0
Writing!
getting next batch...
Got batch 1
getting next batch...
Got batch 2
getting next batch...
Got batch 3
getting next batch...
Got batch 4
getting next batch...
Got batch 5
getting next batch...
Got batch 6
getting next batch...
Got batch 7
getting next batch...
Got batch 8
getting next batch...
Got batch 9
getting next batch...
Got batch 10
getting next batch...
Got batch 11
getting next batch...
Got batch 12
getting next batch...
Got batch 13
getting next batch...
Got batch 14
getting next batch...
Got batch 15
getting next batch...
Got batch 16
getting next batch...
Got batch 17
getting next batch...
Got batch 18
getting next batch...
Got batch 19
getting next batch...
Got batch 20
getting next batch...
Got batch 21
getting next batch...
Got batch 22
getting next batch...
Got batch 23
getting next batch...
Got batch 24
getting next batch...
Got batch 25
getting next batch...
Got batch 26
getting next batch...
Got batch 27
getting next batch...
Got batch 28
getting 