In [1]:
import torch
import gcsfs
import xarray as xr
import pandas as pd
import numpy as np
from typing import Optional, Callable
import matplotlib.pyplot as plt

## Consts & config

In [2]:
ZARR = 'solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16'

In [3]:
plt.rcParams['figure.figsize'] = (10, 10)
plt.rcParams['image.interpolation'] = 'none'

In [4]:
torch.cuda.is_available()

True

# Load satellite data

In [5]:
def get_sat_data(filename: str=ZARR) -> xr.DataArray:
    """Lazily opens the Zarr store on Google Cloud Storage (GCS).
    
    Selects the High Resolution Visible (HRV) satellite channel.
    """
    gcs = gcsfs.GCSFileSystem()
    store = gcsfs.GCSMap(root=filename, gcs=gcs)
    dataset = xr.open_zarr(store, consolidated=True)
    return dataset['stacked_eumetsat_data'].sel(variable='HRV')

In [6]:
%%time
sat_data = get_sat_data()

CPU times: user 1.62 s, sys: 72.1 ms, total: 1.69 s
Wall time: 1.93 s


Caution: Wierdly, plotting `sat_data` at this point causes the code to hang (with no errors messages) when it gets to `enumerate(dataloader)`.  The code hangs even if we first do `sat_data.close(); del sat_data`

In [7]:
sat_data

Unnamed: 0,Array,Chunk
Bytes,125.83 GB,27.78 MB
Shape,"(163079, 704, 548)","(36, 704, 548)"
Count,58891 Tasks,4530 Chunks
Type,int16,numpy.ndarray
"Array Chunk Bytes 125.83 GB 27.78 MB Shape (163079, 704, 548) (36, 704, 548) Count 58891 Tasks 4530 Chunks Type int16 numpy.ndarray",548  704  163079,

Unnamed: 0,Array,Chunk
Bytes,125.83 GB,27.78 MB
Shape,"(163079, 704, 548)","(36, 704, 548)"
Count,58891 Tasks,4530 Chunks
Type,int16,numpy.ndarray


## Simple PyTorch Dataset

We have several TB of satellite data. To keep the GPU fed with data during training, we need to read chunks of data quickly from the Zarr store; and we also want to load data asynchronously.  That is, while the GPU is training on the current batch, the data loader should be simultaneously loading the next batch from disk.

PyTorch makes this easy!  PyTorch's `DataLoader` spawns multiple worker processes when constructed with `num_workers` set to more than 1.  Each worker process receives a copy of the `SatelliteDataset` object.

There is a small challenge: The code hangs when it gets to `enumerate(dataloader)` if we open the `xarray.DataArray` in the main process and copy that opened `DataArray` to the child processes.  Our solution is to delay the creation of the `DataArray` until _after_ the worker processes have been created.  PyTorch makes this easy by allowing us to pass a `worker_init_fn` to `DataLoader`.  `worker_init_fn` is called on each worker process.  Our `worker_init_fn` just has one job: to call `SatelliteDataset.per_worker_init()` which, in turn, opens the `DataArray`.

This approach achieves read speeds of 600 MB/s from Google Cloud Storage to a single GCP VM with 12 vCPUs (as measured by `nethogs`).

We use `IterableDataset` instead of `Dataset` so `SatelliteDataset` can pre-load the next example from disk and then block (on the `yield`) waiting for PyTorch to read that data.  This allows the worker process to load the next batch from disk while the main process is training the current batch on the GPU.

We can't pin the memory in each worker process because pinned memory can't be shared across processes.  Instead we ask `DataLoader` to pin the collated batch so that pytorch-lightning can asynchronously load the next batch from pinned CPU memory into GPU memory.

In [8]:
class SatelliteDataset(torch.utils.data.IterableDataset):
    def __init__(self, length: int, history_len: int=24, forecast_len: int=12, transform: Optional[Callable]=None):
        self.history_len = history_len
        self.forecast_len = forecast_len
        self.total_seq_len = history_len + forecast_len
        self.length = length - self.total_seq_len
        self.transform = transform
        self.n_samples_per_epoch_total = 256
        self.base_seed = 42  # random number generator seed
        
    def per_worker_init(self, worker_id: int=0, n_workers: int=1) -> None:
        """Called by worker_init_fn on each copy of SatelliteDataset within each worker process."""
        self.data_array = get_sat_data()
        self.n_samples_per_epoch_per_worker = self.n_samples_per_epoch_total // n_workers
        # Each worker must have a different seed for its random number generator.
        # Otherwise all the workers will output exactly the same data!
        seed = self.base_seed + worker_id
        self.rng = np.random.default_rng(seed=seed)
    
    def __iter__(self):
        for _ in range(self.n_samples_per_epoch_per_worker):
            start_idx = self.rng.integers(low=0, high=self.length, dtype=np.uint32)
            end_idx = start_idx + self.total_seq_len
            sample = self.data_array.isel(time=slice(start_idx, end_idx))
            if self.transform:
                sample = self.transform(sample)
            yield sample.values


class CropSquare():
    def __init__(self, size=128):
        self.size = size
        
    def __call__(self, sample):
        return sample[:, :self.size, :self.size]
    

def worker_init_fn(worker_id):
    """Configures each dataset worker process.
    
    Just has one job!  To call SatelliteDataset.per_worker_init().
    """
    worker_info = torch.utils.data.get_worker_info()
    if worker_info is None:
        print('worker_info is None!')
    else:
        dataset_obj = worker_info.dataset  # The Dataset copy in this worker process.
        dataset_obj.per_worker_init(worker_id=worker_info.id, n_workers=worker_info.num_workers)


dataset = SatelliteDataset(
    length=len(sat_data),
    transform=CropSquare(),
)

In [9]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=8,
    num_workers=10,  # timings:  4=13.8s; 8=11.6; 10=11.3s; 11=11.5s; 12=12.6s
    worker_init_fn=worker_init_fn,
    pin_memory=True
)

In [10]:
%%time
for i, batch in enumerate(dataloader):
    print(i, batch.shape)

0 torch.Size([8, 36, 128, 128])
1 torch.Size([8, 36, 128, 128])
2 torch.Size([8, 36, 128, 128])
3 torch.Size([8, 36, 128, 128])
4 torch.Size([8, 36, 128, 128])
5 torch.Size([8, 36, 128, 128])
6 torch.Size([8, 36, 128, 128])
7 torch.Size([8, 36, 128, 128])
8 torch.Size([8, 36, 128, 128])
9 torch.Size([8, 36, 128, 128])
10 torch.Size([8, 36, 128, 128])
11 torch.Size([8, 36, 128, 128])
12 torch.Size([8, 36, 128, 128])
13 torch.Size([8, 36, 128, 128])
14 torch.Size([8, 36, 128, 128])
15 torch.Size([8, 36, 128, 128])
16 torch.Size([8, 36, 128, 128])
17 torch.Size([8, 36, 128, 128])
18 torch.Size([8, 36, 128, 128])
19 torch.Size([8, 36, 128, 128])
20 torch.Size([8, 36, 128, 128])
21 torch.Size([8, 36, 128, 128])
22 torch.Size([8, 36, 128, 128])
23 torch.Size([8, 36, 128, 128])
24 torch.Size([8, 36, 128, 128])
25 torch.Size([8, 36, 128, 128])
26 torch.Size([8, 36, 128, 128])
27 torch.Size([8, 36, 128, 128])
28 torch.Size([8, 36, 128, 128])
29 torch.Size([8, 36, 128, 128])
30 torch.Size([1, 36

In [11]:
batch.is_pinned()

True