In [1]:
import torch
import gcsfs
import xarray as xr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [2]:
plt.rcParams['figure.figsize'] = (10, 10)
plt.rcParams['image.interpolation'] = 'none'

In [3]:
torch.cuda.is_available()

True

# Load satellite data

In [7]:
#ZARR = 'solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/full_extent_TM_int16'
ZARR = 'solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16'

In [98]:
def get_sat_data(filename=ZARR):
    gcs = gcsfs.GCSFileSystem()
    store = gcsfs.GCSMap(root=filename, gcs=gcs)
    dataset = xr.open_zarr(store, consolidated=True)
    # dataset['time'] = pd.to_datetime(dataset.time.values, unit='s')
    return dataset['stacked_eumetsat_data'].sel(variable='HRV')

In [99]:
sat_data = get_sat_data()

In [100]:
len(sat_data)

163079

# Simple PyTorch Dataset

In [101]:
class SatelliteDataset(torch.utils.data.IterableDataset):
    def __init__(self, length, history_len=24, forecast_len=12, transform=None):
        self.history_len = history_len
        self.forecast_len = forecast_len
        self.total_seq_len = history_len + forecast_len
        self.length = length - self.total_seq_len
        self.transform = transform
        self.n_samples_per_epoch_total = 1024
        self.base_seed = 42  # random number generator seed
        
    def per_worker_init(self, worker_id=0, n_workers=1):
        self.data_array = get_sat_data()
        self.n_samples_per_epoch_per_worker = self.n_samples_per_epoch_total // n_workers
        # Each worker must have a different initial random number generator seed.
        seed = self.base_seed + worker_id
        self.rng = np.random.default_rng(seed=seed)
    
    def __iter__(self):
        for _ in range(self.n_samples_per_epoch_per_worker):
            start_idx = self.rng.integers(low=0, high=self.length, dtype=np.uint32)
            end_idx = start_idx + self.total_seq_len
            sample = self.data_array.isel(time=slice(start_idx, end_idx))
            if self.transform:
                sample = self.transform(sample)
            yield sample.values


class CropSquare():
    def __init__(self, size=128):
        self.size = size
        
    def __call__(self, sample):
        return sample[:, :self.size, :self.size]
    
    
def worker_init_fn(worker_id):
    """Configures each dataset worker process."""
    worker_info = torch.utils.data.get_worker_info()
    if worker_info is None:
        print('worker_info is None!')
    else:
        dataset_obj = worker_info.dataset  # The Dataset copy in this worker process.
        dataset_obj.per_worker_init(worker_id=worker_info.id, n_workers=worker_info.num_workers)


dataset = SatelliteDataset(
    length=len(sat_data),
    transform=CropSquare(),
)

In [102]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=8,
    num_workers=10,  # 4=13.8s; 8=11.6; 10=11.3s; 11=11.5s; 12=12.6s
    worker_init_fn=worker_init_fn,
    pin_memory=False
)

In [103]:
%%time
for i, batch in enumerate(dataloader):
    print(i, batch.shape)
    #if i > 24:
    #    break

0 torch.Size([8, 36, 128, 128])
1 torch.Size([8, 36, 128, 128])
2 torch.Size([8, 36, 128, 128])
3 torch.Size([8, 36, 128, 128])
4 torch.Size([8, 36, 128, 128])
5 torch.Size([8, 36, 128, 128])
6 torch.Size([8, 36, 128, 128])
7 torch.Size([8, 36, 128, 128])
8 torch.Size([8, 36, 128, 128])
9 torch.Size([8, 36, 128, 128])
10 torch.Size([8, 36, 128, 128])
11 torch.Size([8, 36, 128, 128])
12 torch.Size([8, 36, 128, 128])
13 torch.Size([8, 36, 128, 128])
14 torch.Size([8, 36, 128, 128])
15 torch.Size([8, 36, 128, 128])
16 torch.Size([8, 36, 128, 128])
17 torch.Size([8, 36, 128, 128])
18 torch.Size([8, 36, 128, 128])
19 torch.Size([8, 36, 128, 128])
20 torch.Size([8, 36, 128, 128])
21 torch.Size([8, 36, 128, 128])
22 torch.Size([8, 36, 128, 128])
23 torch.Size([8, 36, 128, 128])
24 torch.Size([8, 36, 128, 128])
25 torch.Size([8, 36, 128, 128])
26 torch.Size([8, 36, 128, 128])
27 torch.Size([8, 36, 128, 128])
28 torch.Size([8, 36, 128, 128])
29 torch.Size([8, 36, 128, 128])
30 torch.Size([8, 36

In [82]:
batch.is_pinned()

True