In [1]:
from pathlib import Path
import pandas as pd
import numpy as np
import xarray as xr
import gcsfs
from typing import List
import io
import hashlib
import os
import matplotlib.pyplot as plt

import torch

from nowcasting_dataset.example import Example, to_numpy
import nowcasting_dataset.time as nd_time
from nowcasting_dataset.dataset import worker_init_fn

import logging
logging.basicConfig()
logger = logging.getLogger('nowcasting_dataset')
logger.setLevel(logging.DEBUG)

In [2]:
def get_netcdf_filename(batch_idx: int) -> Path:
    """Generate full filename, excluding path.
    
    Filename includes the first 6 digits of the MD5 hash of the filename,
    as recommended by Google Cloud in order to distribute data across
    multiple back-end servers.
    """
    filename = f'{batch_idx}.nc'
    hash_of_filename = hashlib.md5(filename.encode()).hexdigest()
    return f'{hash_of_filename[:6]}_{filename}'


# Test
assert get_netcdf_filename(10) == '77eb6f_10.nc'

In [36]:
class NetCDFDataset(torch.utils.data.Dataset):
    def __init__(self, n_batches: int, src_path: str, tmp_path: str):
        self.n_batches = n_batches
        self.src_path = src_path
        self.tmp_path = tmp_path
        
    def per_worker_init(self, worker_id: int):
        self.gcs = gcsfs.GCSFileSystem()
        
    def __len__(self):
        return self.n_batches
    
    def __getitem__(self, batch_idx: int) -> Example:
        """ 
        Returns a whole batch at once.
        """
        netcdf_filename = get_netcdf_filename(batch_idx)
        remote_netcdf_filename = os.path.join(self.src_path, netcdf_filename)
        local_netcdf_filename = os.path.join(self.tmp_path, netcdf_filename)
        self.gcs.get(remote_netcdf_filename, local_netcdf_filename)
        batch = xr.load_dataset(local_netcdf_filename)
        os.remove(local_netcdf_filename)
        
        example = Example(
            sat_datetime_index=batch.sat_time_coords,
            nwp_target_time=batch.nwp_time_coords)
        for key in [
            'nwp', 'nwp_x_coords', 'nwp_y_coords', 
            'sat_data', 'sat_x_coords', 'sat_y_coords', 
            'pv_yield', 'pv_system_id', 'pv_system_row_number']:
            example[key] = batch[key]
        example = to_numpy(example)
        return example
            
        # Add datetime features
        dt_features = []
        for dt_index_for_example in batch.sat_time_coords.values:
            dt_index_for_example = pd.DatetimeIndex(dt_index_for_example)
            dt_features_for_example = nd_time.datetime_features_in_example(dt_index_for_example)
            dt_features_for_example = to_numpy(dt_features_for_example)
            dt_features.append(dt_features_for_example)
        dt_features = torch.utils.data._utils.collate.default_collate(dt_features)
        example.update(dt_features)
        
        return example

In [49]:
%%time
ds = NetCDFDataset(30_465, 'gs://solar-pv-nowcasting-data/prepared_ML_training_data/netcdf4/', '/home/jack/temp2/')
#ds.per_worker_init()
#batch = ds[1]

CPU times: user 12 µs, sys: 0 ns, total: 12 µs
Wall time: 15 µs


In [50]:
dl = torch.utils.data.DataLoader(
    ds,
    pin_memory=True,
    num_workers=30,
    prefetch_factor=1,
    worker_init_fn=worker_init_fn,
    
    # Disable automatic batching because NowcastingDataset.__iter__
    # returns complete batches
    batch_size=1,
    #batch_sampler=None
)

In [51]:
import time

In [54]:
%%time
durations = []
t0 = None
for i, batch in enumerate(dl):
    t1 = time.time()
    if t0:
        durations.append(t1 - t0)
    t0 = t1
    if i > 2000:
        break

CPU times: user 5.81 s, sys: 14.9 s, total: 20.7 s
Wall time: 25.8 s


In [55]:
np.mean(durations)

0.012244936467885137

In [56]:
gcs = gcsfs.GCSFileSystem()

In [57]:
%%time
n_bytes = gcs.du('gs://solar-pv-nowcasting-data/prepared_ML_training_data/netcdf4/')

CPU times: user 46.9 s, sys: 53 ms, total: 47 s
Wall time: 49.8 s


In [59]:
n_bytes / 1E9

468.720893878

In [None]:
%%time
gcs.du('gs://solar-pv-nowcasting-data/prepared_ML_training_data/testing.zarr') / 1E9