# Introduction

Our ultimate aim is to predict solar electricity power generation over the next few hours.

## Loading terabytes of data efficiently from cloud storage

We have several TB of satellite data. To keep the GPU fed with data during training, we need to read chunks of data quickly from the Zarr store; and we also want to load data asynchronously.  That is, while the GPU is training on the current batch, the data loader should simultaneously load the _next_ batch from disk.

PyTorch makes this easy!  PyTorch's `DataLoader` spawns multiple worker processes when constructed with `num_workers` set to more than 1.  Each worker process receives a copy of the `SatelliteDataset` object.

There is a small challenge: The code hangs when it gets to `enumerate(dataloader)` if we open the `xarray.DataArray` in the main process and copy that opened `DataArray` to the child processes.  Our solution is to delay the creation of the `DataArray` until _after_ the worker processes have been created.  PyTorch makes this easy by allowing us to pass a `worker_init_fn` to `DataLoader`.  `worker_init_fn` is called on each worker process.  Our `worker_init_fn` just has one job: to call `SatelliteDataset.per_worker_init()` which, in turn, opens the `DataArray`.

This approach achieves read speeds of 600 MB/s from Google Cloud Storage to a single GCP VM with 12 vCPUs (as measured by `nethogs`).

We use `IterableDataset` instead of `Dataset` so `SatelliteDataset` can pre-load the next example from disk and then block (on the `yield`) waiting for PyTorch to read that data.  This allows the worker processes to be processing the next data samples while the main process is training the current batch on the GPU.

We can't pin the memory in each worker process because pinned memory can't be shared across processes.  Instead we ask `DataLoader` to pin the collated batch so that pytorch-lightning can asynchronously load the next batch from pinned CPU memory into GPU memory.

The satellite data is stored on disk as `int16`.  To speed up the movement of the satellite data between processes and from the CPU memory to the GPU memory, we keep the data as `int16` until the data gets to the GPU, where it is converted to `float32` and normalised.

### Loading data from disk into memory in chunks

Cloud storage buckets can't seek into files like 'proper' POSIX filesystems can.  So, even if we just want 1 byte from a 1 GB file, we have to load the entire 1 GB file from the bucket.

Zarr is designed with this challenge in mind.  Zarr gets round the inability of cloud storage buckets to seek by chunking the data into lots of small files.  But, still, we have to load entire Zarr chunks at once, even if we only want a part of a chunk.  And, even though we can pull 600 MB/s from a cloud storage bucket, the reads from the storage bucket are still the rate-limiting-step.  (GPUs are very fast and have a voracious appetite for data!)

To get the most out of each disk read, our worker processes load several contiguous chunks of Zarr data from disk into memory at once.  We then randomly sample from the in-memory data multiple times, before loading another set of chunks from disk into memory.   This trick increases training speed by about 10x.

Each Zarr chunk is 36 timesteps long and contains the entire geographical extent.  Each timestep is about 5 minutes apart, so each Zarr chunk spans 1.5 hours, assuming the timesteps are contiguous (more on contiguous chunks later).

### Loading only daylight data

We're interested in forecasting solar power generation, so we don't care about nighttime data :)

In the UK in summer, the sun rises first in the north east, and sets last in the north west (see [video of June 2019](https://www.youtube.com/watch?v=IOp-tj-IJpk&t=0s)).  In summer, the north gets more hours of sunshine per day.

In the UK in winter, the sun rises first in the south east, and sets last in the south west (see [video of Jan 2019](https://www.youtube.com/watch?v=CJ4prUVa2nQ)).  In winter, the south gets more hours of sunshine per day.

|                        | Summer | Winter |
|           ---:         |  :---: |  :---: |
| Sun rises first in     | N.E.   | S.E.   |
| Sun sets last in       | N.W.   | S.W.   |
| Most hours of sunlight | North  | South  |

We always load a pre-defined number of Zarr chunks from disk every disk load (defined by `n_chunks_per_disk_load`).

Before training, we select timesteps which have at least some sunlight.  We do this by computing the clearsky global horizontal irradiance (GHI) for the four corners of the satellite imagery, and for all the timesteps in the dataset.  We only use timesteps where the maximum global horizontal irradiance across all four corners is above some threshold.

(The 'clearsky [solar irradiance](https://en.wikipedia.org/wiki/Solar_irradiance)' is the amount of sunlight we'd expect on a clear day at a specific time and location. The SI unit of irradiance is watt per square meter.  The 'global horizontal irradiance' is the total sunlight that would hit a horizontal surface on the surface of the Earth.  The GHI is the sum of the direct irradiance (sunlight which takes a direct path from the Sun to the Earth's surface) and the diffuse horizontal irradiance (the sunlight scattered from the atmosphere)).

### Finding contiguous sequences

Once we have a list of 'lit' timesteps, we then find contiguous sequences (timeseries without any gaps).  And we then compute a list of contiguous Zarr chunks that we'll load at once during training.

### Loading data during training

During training, each worker process randomly picks multiple contiguous Zarr chunk sequences from the list of contiguous sequences pre-computed before training started.  The worker loads that data into memory and then randomly samples many samples from that in-memory data before loading more data from disk.

#### Ensuring each batch contains a random sample of the dataset

When PyTorch's `DataLoader` constructs a batch, it reads from just one worker process.  (This is not how I had _assumed_ it would work:  I assumed PyTorch would construct each batch by randomly sampling from all workers.)  This is an issue because, for stochastic gradient descent to work correctly, each batch must contain random samples of the dataset.  So it's not sufficient for each worker to load just one contiguous Zarr chunk (because then each batch would be made up entirely of samples from roughly the same time of day).  So, instead, each worker process loads multiple contiguous Zarr sequences into memory.  This also means that each worker must load quite a lot of data from disk.  To avoid training pausing while a worker process loads more data from disk, the data loading is done asynchronously using a separate thread within each worker process.

## Timestep numbering:

* t<sub>0</sub> is 'now': the most recent observation.
* t<sub>1</sub> is the first timestep of the forecast.

In [1]:
# Python core
from typing import Optional, Callable, TypedDict, Union, Iterable, Tuple, NamedTuple, List
from dataclasses import dataclass
import datetime
from itertools import product
from concurrent import futures
from pathlib import Path
import numbers

# Scientific python
import numpy as np
import pandas as pd
import xarray as xr
import numcodecs
import matplotlib.pyplot as plt
import dask

# Cloud compute
import gcsfs

# PyTorch
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl

# PV & geospatial
import pvlib
import pyproj

In [2]:
xr.__version__

'0.18.0'

In [3]:
gcsfs.__version__

'0.7.2'

In [4]:
torch.__version__

'1.8.1'

In [5]:
torch.cuda.is_available()

True

In [6]:
pyproj.__version__

'3.0.1'

## Consts & config

The [Zarr docs](https://zarr.readthedocs.io/en/stable/tutorial.html#configuring-blosc) say we should tell the Blosc compression library not to use threads because we're using multiple processes to read from our Zarr store:

In [7]:
numcodecs.blosc.use_threads = False

plt.rcParams['figure.figsize'] = (25, 10)
plt.rcParams['image.interpolation'] = 'none'

In [8]:
BUCKET = Path('solar-pv-nowcasting-data')

# Satellite data
SAT_DATA_ZARR = BUCKET / 'satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16'

# Solar PV data
PV_PATH = BUCKET / 'PV/PVOutput.org'
PV_DATA_FILENAME = PV_PATH / 'UK_PV_timeseries_batch.nc'
PV_METADATA_FILENAME = PV_PATH / 'UK_PV_metadata.csv'

# Numerical weather predictions
NWP_ZARR = BUCKET / 'NWP/UK_Met_Office/UKV_zarr'

In [9]:
PV_DATA_FILENAME

PosixPath('solar-pv-nowcasting-data/PV/PVOutput.org/UK_PV_timeseries_batch.nc')

# Load satellite data

In [10]:
def get_sat_data(filename: Union[str, Path]=SAT_DATA_ZARR) -> xr.DataArray:
    """Lazily opens the Zarr store on Google Cloud Storage (GCS).
    
    Selects the High Resolution Visible (HRV) satellite channel.
    """
    #gcs = gcsfs.GCSFileSystem(access='read_only')
    #store = gcsfs.GCSMap(root=filename, gcs=gcs)
    store = 'gs://' + str(filename)
    dataset = xr.open_zarr(store, consolidated=True)
    data_array = dataset['stacked_eumetsat_data'].sel(variable='HRV')
    #gcs.clear_instance_cache()  # See https://github.com/dask/gcsfs/issues/379#issuecomment-826930203
    return data_array

In [11]:
#sat_data.close(); del sat_data

In [None]:
%%time
sat_data = get_sat_data()

Caution: Wierdly, plotting `sat_data` at this point causes the code to hang (with no errors messages) when it gets to `for batch in dataloader:`.  The code hangs even if we first do `sat_data.close(); del sat_data`.  See https://github.com/dask/gcsfs/issues/379

In [None]:
# sat_data.variable

Get the datetime boundaries of each Zarr chunk.  Later, we will use these boundaries to ensure we load complete chunks at a time.

In [None]:
# First, get the integer indicies of the chunk boundaries.
zarr_chunk_boundaries = np.concatenate(([0], np.cumsum(sat_data.chunks[0])))

# Then convert to datetimes.  Except the last one, because that's one-out-of-range.
zarr_chunk_boundaries = pd.DatetimeIndex(
    np.concatenate((sat_data['time'].values[zarr_chunk_boundaries[:-1]], sat_data['time'].values[-1:])), 
    tz='UTC')

assert zarr_chunk_boundaries[-2] != zarr_chunk_boundaries[-1]

zarr_chunk_boundaries

## Select daylight hours

In [None]:
# OSGB is also called "OSGB 1936 / British National Grid -- United Kingdom Ordnance Survey".
# OSGB is used in many UK electricity system maps, and is used by the UK Met Office UKV model.
# OSGB is a Transverse Mercator projection, using 'easting' and 'northing' coordinates
# which are in meters.
OSGB = 27700

# WGS84 is short for "World Geodetic System 1984", used in GPS. Uses latitude and longitude.
WGS84 = 4326

# osgb_to_wgs84.transform(x, y) returns latitude (north-south), longitude (east-west)
osgb_to_wgs84 = pyproj.Transformer.from_crs(crs_from=OSGB, crs_to=WGS84)

# wgs84_to_osgb.transform(lat, lon) returns x, y
wgs84_to_osgb = pyproj.Transformer.from_crs(crs_from=WGS84, crs_to=OSGB)

def select_daylight_timestamps(
    dt_index: pd.DatetimeIndex, 
    locations: Iterable[Tuple[float, float]],
    ghi_threshold: float = 10
    ) -> pd.DatetimeIndex:
    """Returns datetimes for which the global horizontal irradiance
    (GHI) is above ghi_threshold across all locations.

    Args:
      dt_index: DatetimeIndex to filter.  Must be UTC.
      locations: List of Tuples of x, y coordinates in OSGB projection.
      ghi_threshold: Global horizontal irradiance threshold.  (Watts per square meter?)
    """
    assert dt_index.tz.zone == 'UTC'
    ghi_for_all_locations = []
    for x, y in locations:
        lat, lon = osgb_to_wgs84.transform(x, y)
        location = pvlib.location.Location(latitude=lat, longitude=lon)
        clearsky = location.get_clearsky(dt_index)
        ghi = clearsky['ghi']
        ghi_for_all_locations.append(ghi)
        
    ghi_for_all_locations = pd.concat(ghi_for_all_locations, axis='columns')
    max_ghi = ghi_for_all_locations.max(axis='columns')
    mask = max_ghi > ghi_threshold
    return dt_index[mask]

In [None]:
lat, lon = osgb_to_wgs84.transform(100_000, 50_000)
lat, lon

In [None]:
x, y = wgs84_to_osgb.transform(lat, lon)
x, y

In [None]:
# Get 'corner' coordinates for a rectange GEO_BORDER within the actual boundary of the satellite imagery.
GEO_BORDER: int = 64  #: In same geo projection and units as sat_data.
corners = [
    (sat_data.x.values[x], sat_data.y.values[y]) 
    for x, y in product(
        [GEO_BORDER, -GEO_BORDER], 
        [GEO_BORDER, -GEO_BORDER])]

In [None]:
%%time
datetimes = select_daylight_timestamps(
    dt_index=pd.DatetimeIndex(sat_data.time.values, tz='UTC'), 
    locations=corners)

In [None]:
len(datetimes)

In [None]:
#sat_data.close()
#del sat_data

# Load PV power timeseries data

In [None]:
%%time
pv_metadata = pd.read_csv(f'gs://{PV_METADATA_FILENAME}', index_col='system_id')
pv_metadata.dropna(subset=['longitude', 'latitude'], how='any', inplace=True)

In [None]:
pv_metadata

In [None]:
pv_metadata['location_x'], pv_metadata['location_y'] = wgs84_to_osgb.transform(pv_metadata['latitude'], pv_metadata['longitude'])

In [None]:
# Remove PV systems outside of the geospatial extent of the satellite data
GEOSPATIAL_BOUNDARY = {
    'WEST': sat_data.x.values[0],
    'EAST': sat_data.x.values[-1],
    'NORTH': sat_data.y.values[0],
    'SOUTH': sat_data.y.values[-1]}

In [None]:
GEOSPATIAL_BOUNDARY

In [None]:
pv_metadata = pv_metadata[
    (pv_metadata.location_x >= GEOSPATIAL_BOUNDARY['WEST']) &
    (pv_metadata.location_x <= GEOSPATIAL_BOUNDARY['EAST']) &
    (pv_metadata.location_y <= GEOSPATIAL_BOUNDARY['NORTH']) &
    (pv_metadata.location_y >= GEOSPATIAL_BOUNDARY['SOUTH'])]

In [None]:
import io

In [None]:
def load_solar_pv_data(start_datetime, end_datetime) -> pd.DataFrame:
    gcs = gcsfs.GCSFileSystem(access='read_only')

    # It is possible to simplify the code below and do the xr.open_dataset(file) in
    # the first 'with' block, and delete the second 'with' block.  But that takes 6 minutes to load the data, 
    # where as loading into memory first and then loading from there takes 23 seconds!
    with gcs.open(PV_DATA_FILENAME, mode='rb') as file:
        file_bytes = file.read()

    with io.BytesIO(file_bytes) as file:
        pv_power = xr.open_dataset(file)
        pv_power_df = pv_power.sel(datetime=slice(start_datetime, end_datetime)).to_dataframe()

    # Save memory
    del file_bytes    
    del pv_power
    
    # Tidy up
    gcs.clear_instance_cache()  # See https://github.com/dask/gcsfs/issues/379#issuecomment-826930203
    
    # Process the data a little
    pv_power_df = pv_power_df.dropna(axis='columns', how='all')
    pv_power_df = pv_power_df.clip(lower=0, upper=5E7)
    pv_power_df.columns = [np.int32(col) for col in pv_power_df.columns]    
    return pv_power_df.tz_localize('Europe/London').tz_convert('UTC')

In [None]:
%%time
# Clip to start and end date of satellite data.
# Note that the data in the net CDF file is in the Europe/London timezone,
# but xarray can only handle timezone-naive timestamps.
datetimes_naive = datetimes.tz_convert('Europe/London').tz_convert(None)
pv_power_df = load_solar_pv_data(
    start_datetime=datetimes_naive[0],
    end_datetime=datetimes_naive[-1])

In [None]:
pv_power_df

In [None]:
pv_power_df[10003].plot()

In [None]:
pv_power_df.values.nbytes / 1E9

In [None]:
# A bit of hand-crafted cleaning
pv_power_df[30248]['2018-10-29':'2019-01-03'] = np.NaN

In [None]:
# Only pick PV systems for which we have metadata.
def align_pv_system_ids(pv_metadata, pv_power_df):
    pv_system_ids = pv_metadata.index.intersection(pv_power_df.columns)
    pv_system_ids = np.sort(pv_system_ids)

    pv_power_df = pv_power_df[pv_system_ids]
    pv_metadata = pv_metadata.loc[pv_system_ids]
    return pv_metadata, pv_power_df
    
pv_metadata, pv_power_df = align_pv_system_ids(pv_metadata, pv_power_df)

In [None]:
len(pv_metadata)

In [None]:
# Scale to the range [0, 1]
pv_power_min = pv_power_df.min()
pv_power_df -= pv_power_min
pv_power_max = pv_power_df.max()
pv_power_df /= pv_power_max

In [None]:
pv_power_df.min().min()

In [None]:
# Drop systems which produce power over night
NIGHT_YIELD_THRESHOLD = 0.4
night_hours = [22, 23, 0, 1, 2]
pv_data_at_night = pv_power_df.loc[pv_power_df.index.hour.isin(night_hours)]
pv_above_threshold_at_night = (pv_data_at_night > NIGHT_YIELD_THRESHOLD).any()
bad_systems = pv_power_df.columns[pv_above_threshold_at_night]
print(len(bad_systems), 'bad systems found.')

In [None]:
# TODO: Of these bad systems, 24647, 42656, 42807, 43081, 51247, 59919 might have some salvagable data?

In [None]:
pv_power_df = pv_power_df.drop(columns=bad_systems)

In [None]:
%%time
# Resample to 5-minutely and interpolate up to 15 minutes ahead.
pv_power_df = pv_power_df.resample('5T').interpolate(method='time', limit=3)

In [None]:
pv_power_df[pv_power_df.columns[5]]['2019-01'].plot()

In [None]:
# Align again, after removing dud PV systems
pv_metadata, pv_power_df = align_pv_system_ids(pv_metadata, pv_power_df)
len(pv_power_df.columns)

In [None]:
len(pv_power_df)

In [None]:
# Select datetimes where there is data for at least 1 PV system
pv_power_df = pv_power_df.reindex(datetimes + pd.Timedelta('1 minute')).dropna(how='all')

In [None]:
len(pv_power_df)

In [None]:
len(datetimes)

In [None]:
datetimes = pv_power_df.index - pd.Timedelta('1 minute')

In [None]:
len(datetimes)

In [None]:
sat_data.sel(time=datetimes[20]).plot()

In [None]:
pv_power_df.values.nbytes / 1E6

In [None]:
# Minimum number of PV systems available for one timestep:
(~pv_power_df.isna()).sum(axis='columns').min()

# Load numerical weather predictions

In [None]:
def load_nwp(base_path: Path = NWP_ZARR):
    nwp_datasets = []
    for zarr_store in ['2018_1-6', '2018_7-12', '2019_1-6', '2019_7-12']:
        full_dir = base_path / zarr_store
        full_dir = 'gs://' + str(full_dir)
        print(full_dir)
        dataset = xr.open_zarr(full_dir, consolidated=True)
        dataset = dataset.rename({'time': 'init_time'})
        nwp_datasets.append(dataset)

    # The isobaricInhPa coordinates look messed up, especially in the 2018_7-12 and 2019_7-12 Zarr stores.
    # So let's drop all the variables with multiple vertical levels for now:
    for ds in nwp_datasets:
        del ds['isobaricInhPa'], ds['gh_p'], ds['r_p'], ds['t_p'], ds['wdir_p'], ds['ws_p']

    # Concat.
    dask.config.set({"array.slicing.split_large_chunks": False})  # Silence warning about large chunks
    nwp = xr.concat(nwp_datasets, dim='init_time')
    
    # There are a lot of doubled-up indicies from 2018-07-18 00:00 to 2018-08-27 09:00.
    # De-duplicate the index. Code adapted from https://stackoverflow.com/a/51077784/732596
    _, unique_index = np.unique(nwp.init_time, return_index=True)
    return nwp.isel(init_time=unique_index)

In [None]:
%%time
nwp = load_nwp()

In [None]:
# Get names of the NWP variables

nwp_var_names = [var_name for var_name, _ in nwp.variables.items()]

nwp_var_details = pd.DataFrame(columns=['name', 'units'], index=nwp_var_names)
for var_name, var in nwp.variables.items():
    attrs = var.attrs
    if 'stanard_name' in attrs:
        name = attrs['standard_name']
    elif 'long_name' in attrs:
        name = attrs['long_name']
    else:
        name = ''
        
    try:
        units = var.attrs['units']
    except:
        units = ''
    nwp_var_details.loc[var_name] = {'name': name, 'units': units}
    
nwp_var_details

In [None]:
fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(20, 20))

ax = axes[0][0]
nwp['lcc'].sel(init_time='2018-06-01T12:00').isel(step=0).clip(0).plot(ax=ax)

ax = axes[0][1]
nwp['mcc'].sel(init_time='2018-06-01T12:00').isel(step=0).clip(0).plot(ax=ax)

ax = axes[1][0]
nwp['hcc'].sel(init_time='2018-06-01T12:00').isel(step=0).clip(0).plot(ax=ax)

ax = axes[1][1]
sat_data.sel(time='2018-06-01T12:04').plot(ax=ax)

ax = axes[2][0]
nwp['dswrf'].sel(init_time='2018-06-01T12:00').isel(step=0).plot(ax=ax)

ax = axes[2][1]
nwp['dlwrf'].sel(init_time='2018-06-01T12:00').isel(step=0).plot(ax=ax);

## Get contiguous segments

In [None]:
class Segment(NamedTuple):
    """Represents the start and end datetimes of a segment of contiguous samples.
    
    The Segment covers the range [start, end].
    """
    start: pd.Timestamp
    end: pd.Timestamp
        
    def to_naive(self):
        start = self.start.tz_convert('UTC').tz_convert(None)
        end = self.end.tz_convert('UTC').tz_convert(None)
        return Segment(start, end)


def get_contiguous_segments(dt_index: pd.DatetimeIndex, min_timesteps: int, max_gap: pd.Timedelta) -> List[Segment]:
    """Chunk datetime index into contiguous segments, each at least min_timesteps long.
    
    max_gap defines the threshold for what constitutes a 'gap' between contiguous segments.
    
    Throw away any timesteps in a sequence shorter than min_timesteps long.
    """
    gap_mask = np.diff(dt_index) > max_gap
    gap_indices = np.argwhere(gap_mask)[:, 0]

    # gap_indicies are the indices into dt_index for the timestep immediately before the gap.
    # e.g. if the datetimes at 12:00, 12:05, 18:00, 18:05 then gap_indicies will be [1].
    segment_boundaries = gap_indices + 1

    # Capture the last segment of dt_index.
    segment_boundaries = np.concatenate((segment_boundaries, [len(dt_index)]))

    assert segment_boundaries[-2] != segment_boundaries[-1]
    
    segments = []
    start_i = 0
    for end_i in segment_boundaries:
        n_timesteps = end_i - start_i
        if n_timesteps >= min_timesteps:
            segment = Segment(start=dt_index[start_i], end=dt_index[end_i-1])
            segments.append(segment)
        start_i = end_i
        
    return segments

In [None]:
%%time
contiguous_segments = get_contiguous_segments(
    dt_index = datetimes,
    min_timesteps = 36 * 1.5,
    max_gap = pd.Timedelta('5 minutes'))

contiguous_segments[:5]

In [None]:
len(contiguous_segments)

## Turn the contiguous segments into sequences of Zarr chunks, which will be loaded together during training

In [None]:
def get_zarr_chunk_sequences(
    n_chunks_per_disk_load: int, 
    zarr_chunk_boundaries: List[datetime.datetime], 
    contiguous_segments: List[Segment]) -> List[Segment]:
    """
    
    Args:
      n_chunks_per_disk_load: Maximum number of Zarr chunks to load from disk in one go.
      zarr_chunk_boundaries: The datetime indicies into the Zarr store's time dimension which define the Zarr chunk boundaries.
        Must be sorted.
      contiguous_segments: Datetime indicies into the Zarr store's time dimension that define contiguous timeseries.
        That is, timeseries with no gaps.
    
    Returns zarr_chunk_sequences: a list of Segments representing the start and end datetimes of contiguous sequences of multiple Zarr chunks,
    all exactly n_chunks_per_disk_load long (for contiguous segments at least as long as n_chunks_per_disk_load zarr chunks),
    and at least one side of the boundary will lie on a 'natural' Zarr chunk boundary.
    
    For example, say that n_chunks_per_disk_load = 3, and the Zarr chunks sizes are all 5:
    
    
                  0    5   10   15   20   25   30   35 
                  |....|....|....|....|....|....|....|

    INPUTS:
                     |------CONTIGUOUS SEGMENT----|
                     
    zarr_chunk_boundaries:
                  |----|----|----|----|----|----|----|
    
    OUTPUT:
    zarr_chunk_sequences:
           3 to 15:  |-|----|----|
           5 to 20:    |----|----|----|
          10 to 25:         |----|----|----|
          15 to 30:              |----|----|----|
          20 to 32:                   |----|----|-|
    
    """
    assert n_chunks_per_disk_load > 0
    
    zarr_chunk_sequences = []

    for contig_segment in contiguous_segments:
        # searchsorted() returns the index into zarr_chunk_boundaries at which contig_segment.start
        # should be inserted into zarr_chunk_boundaries to maintain a sorted list.
        # i_of_first_zarr_chunk is the index to the element in zarr_chunk_boundaries which defines
        # the start of the current contig chunk.
        i_of_first_zarr_chunk = np.searchsorted(zarr_chunk_boundaries, contig_segment.start)
        
        # i_of_first_zarr_chunk will be too large by 1 unless contig_segment.start lies
        # exactly on a Zarr chunk boundary.  Hence we must subtract 1, or else we'll
        # end up with the first contig_chunk being 1 + n_chunks_per_disk_load chunks long.
        if zarr_chunk_boundaries[i_of_first_zarr_chunk] > contig_segment.start:
            i_of_first_zarr_chunk -= 1
            
        # Prepare for looping to create multiple Zarr chunk sequences for the current contig_segment.
        zarr_chunk_seq_start = contig_segment.start
        zarr_chunk_seq_end = None  # Just a convenience to allow us to break the while loop by checking if zarr_chunk_seq_end != contig_segment.end.
        while zarr_chunk_seq_end != contig_segment.end:
            zarr_chunk_seq_end = zarr_chunk_boundaries[i_of_first_zarr_chunk + n_chunks_per_disk_load]
            zarr_chunk_seq_end = min(zarr_chunk_seq_end, contig_segment.end)
            zarr_chunk_sequences.append(Segment(start=zarr_chunk_seq_start, end=zarr_chunk_seq_end))
            i_of_first_zarr_chunk += 1
            zarr_chunk_seq_start = zarr_chunk_boundaries[i_of_first_zarr_chunk]
            
    return zarr_chunk_sequences

In [None]:
zarr_chunk_sequences = get_zarr_chunk_sequences(
    n_chunks_per_disk_load=3,
    zarr_chunk_boundaries=zarr_chunk_boundaries,
    contiguous_segments=contiguous_segments)

zarr_chunk_sequences[:10]

In [None]:
# Durations of zarr chunk sequences
[segment.end - segment.start for segment in zarr_chunk_sequences[:10]]

## PyTorch data storage & processing

In [None]:
Array = Union[np.ndarray, xr.DataArray]

IMAGE_ATTR_NAMES = ('historical_sat_images', 'target_sat_images')

class Sample(TypedDict):
    """Simple class for structuring data for the ML model.
    
    Using typing.TypedDict gives us several advantages:
      1. Single 'source of truth' for the type and documentation of each example.
      2. A static type checker can check the types are correct.

    Instead of TypedDict, we could use typing.NamedTuple,
    which would provide runtime checks, but the deal-breaker with Tuples is that they're immutable
    so we cannot change the values in the transforms.
    """
    # IMAGES
    # Shape: batch_size, seq_length, width, height
    historical_sat_images: Array
    target_sat_images: Array
        
    # PV yield time series
    historical_pv_yield: pd.Series
    target_pv_yield: pd.Series
        
    # Numerical weather predictions (NWPs)
    nwp_above_pv: Array  #: The NWP at a single point nearest to the PV system.
    
    # METADATA
    forecast_len: int
    history_len: int
    pv_system_id: int
    pv_system_row_number: int  #: Guaranteed to be in the range [0, len(pv_metadata)]
    pv_location_x: float
    pv_location_y: float
    datetime_index: Array  #: At 5-minute timings like 00, 05, 10, ...; *not* the 04, 09, ... sequence of the satellite imagery.


class BadData(Exception):
    pass


@dataclass
class RandomSquareCrop():
    size: int = 128  #: Size (in pixels) of the cropped image.

    def __call__(self, sample: Sample) -> Sample:
        for attr_name in IMAGE_ATTR_NAMES:
            image = sample[attr_name]
            # TODO: Random crop!
            cropped_image = image[..., :self.size, :self.size]
            sample[attr_name] = cropped_image
        return sample


def crop_square(data_array: xr.DataArray, centre_x_osgb: float, centre_y_osgb: float, size_pixels: int):
    half_size_pixels = size_pixels // 2

    # centre_y_osgb and centre_x_osgb are in OSGB-space; but size_pixels is number of pixels!
    # Need to convert to integer index into image pixels.
    # The y array is in _descending_ order.
    centre_x_index = np.searchsorted(data_array.x, centre_x_osgb)
    centre_y_index = np.searchsorted(data_array.y[::-1], centre_y_osgb)
    centre_y_index = len(data_array.y) - centre_y_index
    
    # Get coordinates for boundaries of the cropped image.
    north = centre_y_index - half_size_pixels
    south = centre_y_index + half_size_pixels
    east = centre_x_index + half_size_pixels
    west = centre_x_index - half_size_pixels

    cropped = data_array.isel(
        x=slice(west, east), 
        y=slice(north, south))
    
    assert len(cropped.x) == size_pixels, len(cropped.x)
    assert len(cropped.y) == size_pixels, len(cropped.y)
    
    return cropped
    

@dataclass
class CropCentredOnPv():
    size_pixels: int = 128  #: Size (in pixels) of the cropped squaure image.  Must be an even number.
        
    def __post_init__(self):
        assert self.size_pixels % 2 == 0
    
    def __call__(self, sample: Sample) -> Sample:
        x = sample['pv_location_x']
        y = sample['pv_location_y']
        
        for attr_name in IMAGE_ATTR_NAMES:
            image = sample[attr_name]
            cropped_image = crop_square(image, x, y, self.size_pixels)
            sample[attr_name] = cropped_image
        return sample


class CheckForBadData():
    def __call__(self, sample: Sample) -> Sample:
        for attr_name in IMAGE_ATTR_NAMES:
            image = sample[attr_name]
            if np.any(image < 0):
                raise BadData(f'\n{attr_name} has negative values at {image.time.values}')
        return sample

        
class ToTensor():
    def __call__(self, sample: Sample) -> Sample:
        for key, value in sample.items():
            original_type = type(value)  # Helpful for debugging.
            if isinstance(value, (xr.DataArray, pd.Series, pd.DataFrame)):
                value = value.values
            elif isinstance(value, pd.DatetimeIndex):
                value = value.values.astype('datetime64[s]').astype(np.int64)
            elif isinstance(value, numbers.Number):
                value = np.asanyarray(value)
            
            try:
                sample[key] = torch.from_numpy(value)
            except:
                print(f'Failed to convert {key}, with value of type {original_type} = {value}')
                raise
        return sample
    
    
class Compose():
    # Copied from https://pytorch.org/vision/stable/_modules/torchvision/transforms/transforms.html#Compose
    # But not using torchvision, because it appears to create conda package conflicts
    # with opencv?  But need to explore more!
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

## PyTorch dataset

In [None]:
def get_nwp_example(
    nwp: xr.Dataset, 
    start: pd.Timestamp, 
    end: pd.Timestamp,
    t0: pd.Timestamp,
    freq: str='5T',
    interpolation_method: str='cubic'
    ) -> xr.Dataset:
    """Select the numerical weather predictions (NWP) for a single ML example.
    
    The NWP for each example covers a contiguous timespan running from `start` to `end`.
    The first part of the timeseries [`start`, `t0`] is the 'recent history'.
    The second part of the timeseries (`t0`, `end`] is the 'future'.
    
    For each timestep in the recent history [`start`, `t0`], get predictions 
    produced by the freshest NWP run to each timestep.
    
    For the future (`t0`, `end`], use the NWP initialised most recently to t0.
    
    This function will also resample the NWP using `freq` and `interpolation_method`.
    
    Args:
      nwp: Numerical weather prediction Dataset.  Assumed to be hourly.
      start, end: The start and end datetimes of the entire example.
      t0_datetime: The datetime that represents 'now'.  That is, any datetimes <= 'now'
        are considered recent history; and datetimes > 'now' are considered forecast.
      freq: The frequency to resample to (e.g. '5T' for 5-minutely).
      interpolation_method: The interpolation method to use when resampling.
    """
    # First we get the hourly NWPs; then we resample to `freq` at the end of the function.
    
    # Extend the start and end of the timespan by 1 hour, so that the
    # cubic interpolation has more 'context' to work with. 
    BUFFER = pd.Timedelta('1H')
    start_hourly = start.floor('H') - BUFFER
    t0_hourly = t0.ceil('H')
    end_hourly = end.ceil('H') + BUFFER
    
    target_times_hourly = pd.date_range(start_hourly, end_hourly, freq='H')
    
    # Get the most recent NWP initialisation time for each target_time_hourly.
    init_times = nwp.sel(init_time=target_times_hourly, method='ffill').init_time.values
    
    # Find the NWP init time for just the 'future' portion of the example.
    init_time_future = init_times[target_times_hourly == t0_hourly]
    
    # For the 'future' portion of the example, replace all the NWP init times with the
    # NWP init time most recent to t0.
    init_times[target_times_hourly > t0_hourly] = init_time_future
    
    # steps is the number of hourly timesteps beyond the NWP initialisation time.
    steps = target_times_hourly - init_times

    def _get_data_array_indexer(index):
        # We want one timestep for each target_time_hourly (obviously!)
        # If we simply do nwp.sel(init_time=init_times, step=steps)
        # then we'll get the *product* of init_times and steps,
        # which is not what we want!  Instead, we use xarray's vectorized-indexing mode
        # by using a DataArray indexer.  See the last example here:
        # http://xarray.pydata.org/en/stable/user-guide/indexing.html#vectorized-indexing
        return xr.DataArray(index, dims='target_time', coords={'target_time': target_times_hourly})
    
    init_time_indexer = _get_data_array_indexer(init_times)
    step_indexer = _get_data_array_indexer(steps)
    nwp_selected = nwp.sel(init_time=init_time_indexer, step=step_indexer)
    nwp_selected = nwp_selected.resample({'target_time': freq}).interpolate(interpolation_method)
    return nwp_selected.sel(target_time=slice(start, end))

In [None]:
selected_nwp = get_nwp_example(
    nwp,
    start=pd.Timestamp('2018-04-01T05:00'),
    end=pd.Timestamp('2018-04-01T07:00'),
    t0=pd.Timestamp('2018-04-01T06:00')
)

In [None]:
selected_nwp

In [None]:
%%time
_ = selected_nwp['t'].sel(x=100, y=100, method='nearest').plot()

In [None]:
class DataSource():
    """Base class for additional data sources to be added into the Satellite data."""
    def __init__(self):
        self.rng = np.random.default_rng()  # Will be replaced if per_worker_init is called.
        
    def per_worker_init(self, worker_id: int=0) -> None:
        # Each worker must have a different seed for its random number generator.
        # Otherwise all the workers will output exactly the same data!
        seed = torch.initial_seed()
        self.rng = np.random.default_rng(seed=seed)
    
    def __getitem__(self, sample: Sample) -> Sample:
        """The method that sub-classes should override.
        
        Typically, this method should set one or more attributes in the `sample` dictionary.
        """
        raise NotImplemented()

    
@dataclass
class PVDataSource(DataSource):
    pv_power_df: pd.DataFrame = pv_power_df  #: PV yield, resampled to 5-minutely
    pv_metadata: pd.DataFrame = pv_metadata
        
        
    def __post_init__(self):
        super().__init__()
        

    def __getitem__(self, sample: Sample) -> Sample:
        datetime_index = sample['datetime_index']
        history_len = sample['history_len']
        
        selected_pv_power_df = self.pv_power_df.loc[datetime_index]
        
        # Select just one PV system
        selected_pv_power_df = selected_pv_power_df.dropna(axis='columns', how='any')
        pv_system_ids = selected_pv_power_df.columns
        pv_system_id = self.rng.choice(pv_system_ids)
        selected_pv_power = selected_pv_power_df[pv_system_id]
        
        # Get metadata for PV system
        metadata_for_pv_system = self.pv_metadata.loc[pv_system_id]
        
        # Save data into the Sample dict...
        sample['pv_system_id'] = pv_system_id
        sample['pv_system_row_number'] = self.pv_metadata.index.get_loc(pv_system_id)
        sample['historical_pv_yield'] = selected_pv_power.iloc[:history_len]
        sample['target_pv_yield'] = selected_pv_power.iloc[history_len:]
        sample['pv_location_x'] = metadata_for_pv_system.location_x
        sample['pv_location_y'] = metadata_for_pv_system.location_y
        return sample


@dataclass
class NWPDataSource(DataSource):
    nwp: xr.Dataset = nwp
    params: Iterable[str] = (
        't',  # Temperature in Kelvin.
        'dswrf',  # Downward short-wave radiation flux in W/m2 (irradiance).
        'prate',  # Precipitation rate in kg m^-2 s^-1.
        'r',  # Relative humidty in %.
        'sde',  # Snow depth in meters.
        'si10',  # 10-meter wind speed in m/s.
        'vis',  # Visibility in meters.
        'lcc',  # Low-level cloud cover in %.
        'mcc',  # Medium-level cloud cover in %.
        'hcc',  # High-level cloud cover in %.
    )

    def __post_init__(self):
        super().__init__()

    def __getitem__(self, sample: Sample) -> Sample:
        datetime_index = sample['datetime_index'].tz_convert('UTC').tz_convert(None)
        start = datetime_index[0]
        end = datetime_index[-1]
        history_len = sample['history_len']
        t0 = datetime_index[history_len]
        x = sample['pv_location_x']
        y = sample['pv_location_y']
        params = list(self.params)
        
        selected_nwp = get_nwp_example(self.nwp, start=start, end=end, t0=t0)
        
        # Now select data for nearest to PV system
        selected_nwp = selected_nwp.sel(x=x, y=y, method='nearest')[params].to_array()
        sample['nwp_above_pv'] = selected_nwp
        return sample
        
    

@dataclass
class SatelliteDataset(torch.utils.data.IterableDataset):
    zarr_chunk_sequences: Iterable[Segment]  #: Defines multiple Zarr chunks to be loaded from disk at once.
    history_len: int = 12  #: The number of timesteps of 'history' to load.
    forecast_len: int = 1  #: The number of timesteps of 'forecast' to load.
        
    #: Append additional data, such as PV power or NWPs.
    #: Additional data is added before the transforms are applied.
    additional_data_sources: Iterable[DataSource] = (PVDataSource(), NWPDataSource())    
    transform: Optional[Callable] = None

    n_disk_loads_per_epoch: int = 10_000  #: Number of disk loads per worker process per epoch.
    min_n_samples_per_disk_load: int = 1_000  #: Number of samples each worker will load for each disk load.
    max_n_samples_per_disk_load: int = 2_000  #: Max number of disk loader.  Actual number is chosen randomly between min & max.
    n_zarr_chunk_sequences_to_load_at_once: int = 8  #: Number of chunk seqs to load at once.  These are sampled at random.
    
    def __post_init__(self):
        #: Total sequence length of each sample.
        self.total_seq_len = self.history_len + self.forecast_len

    def per_worker_init(self, worker_id: int=0) -> None:
        """Called by worker_init_fn on each copy of SatelliteDataset after the worker process has been spawned."""
        self.worker_id = worker_id
        self.data_array = get_sat_data()
        # Each worker must have a different seed for its random number generator.
        # Otherwise all the workers will output exactly the same data!
        seed = torch.initial_seed()
        self.rng = np.random.default_rng(seed=seed)
        for data_source in self.additional_data_sources:
            data_source.per_worker_init(worker_id)
    
    def __iter__(self):
        """
        Asynchronously loads next data from disk while sampling from data_in_mem.
        """
        with futures.ThreadPoolExecutor(max_workers=1) as executor:
            future_data = executor.submit(self._load_data_from_disk)
            for _ in range(self.n_disk_loads_per_epoch):
                data_in_mem = future_data.result()
                future_data = executor.submit(self._load_data_from_disk)
                n_samples = self.rng.integers(self.min_n_samples_per_disk_load, self.max_n_samples_per_disk_load)
                for _ in range(n_samples):
                    sample = self._get_sample(data_in_mem)
                    
                    # Add in additional data.
                    for data_source in self.additional_data_sources:
                        sample = data_source[sample]

                    # Transform.
                    if self.transform:
                        try:
                            sample = self.transform(sample)
                        except BadData as e:
                            # print(e)
                            continue

                    yield sample

    def _load_data_from_disk(self) -> List[xr.DataArray]:
        """Loads data from contiguous Zarr chunks from disk into memory."""
        sat_images_list = []
        for _ in range(self.n_zarr_chunk_sequences_to_load_at_once):
            zarr_chunk_sequence = self.rng.choice(self.zarr_chunk_sequences)
            # rng.choice, weirdly, converts the Segment to a 2-element ndarray.  So let's convert back to
            # a Segment and convert to naive datetime:
            zarr_chunk_sequence = Segment(*zarr_chunk_sequence)
            zarr_chunk_sequence = zarr_chunk_sequence.to_naive()
            sat_images = self.data_array.sel(time=slice(*zarr_chunk_sequence))

            # Sanity checks
            n_timesteps_available = len(sat_images)
            if n_timesteps_available < self.total_seq_len:
                raise RuntimeError(f'Not enough timesteps in loaded data!  Need at least {self.total_seq_len}.  Got {n_timesteps_available}!')

            sat_images_list.append(sat_images.load())
        return sat_images_list

    def _get_sample(self, sat_data_in_mem_list: List[xr.DataArray]) -> Sample:
        # Select a random Zarr chunk sequence from the Zarr chunk sequences loaded into memory
        i = self.rng.integers(0, len(sat_data_in_mem_list))
        sat_data_in_mem = sat_data_in_mem_list[i]
        
        # Select random start index
        n_timesteps_available = len(sat_data_in_mem)
        max_start_idx = n_timesteps_available - self.total_seq_len
        start_idx = self.rng.integers(low=0, high=max_start_idx, dtype=np.uint32)
        end_idx = start_idx + self.total_seq_len
        selected_sat_images = sat_data_in_mem.isel(time=slice(start_idx, end_idx))
        datetime_index = pd.DatetimeIndex(selected_sat_images.time.values, tz='UTC') + pd.Timedelta('1 minute')
        return Sample(
            history_len=self.history_len,
            forecast_len=self.forecast_len,
            historical_sat_images=selected_sat_images[:self.history_len],
            target_sat_images=selected_sat_images[self.history_len:],
            datetime_index=datetime_index
        )


def worker_init_fn(worker_id):
    """Configures each dataset worker process.
    
    Just has one job!  To call SatelliteDataset.per_worker_init().
    """
    # get_worker_info() returns information specific to each worker process.
    worker_info = torch.utils.data.get_worker_info()
    if worker_info is None:
        print('worker_info is None!')
    else:
        dataset_obj = worker_info.dataset  # The Dataset copy in this worker process.
        dataset_obj.per_worker_init(worker_id=worker_info.id)

### Testing

In [None]:
def test_pv_data_source():
    pv_data_source = PVDataSource()
    sample = Sample(history_len=1, datetime_index=pd.DatetimeIndex(["2018-06-01T11:45", "2018-06-01T11:50"], tz='UTC'))
    pv_data_source[sample]
    return sample

sample = test_pv_data_source()
sample

In [None]:
dataset_without_to_tensor = SatelliteDataset(
    zarr_chunk_sequences=zarr_chunk_sequences,
    transform=Compose([
        CropCentredOnPv(),
        CheckForBadData(),
    ]),
)

In [None]:
dataset_without_to_tensor.per_worker_init()

In [None]:
dataset_iter = dataset_without_to_tensor.__iter__()

In [None]:
sample = next(dataset_iter)

In [None]:
fig, axes = plt.subplots(nrows=3, figsize=(10, 20))

ax = axes[0]
sample['historical_sat_images'].isel(time=0).plot(ax=ax)
ax.scatter(sample['pv_location_x'], sample['pv_location_y'], c='black')

ax = axes[1]
sample['target_sat_images'].isel(time=0).plot(ax=ax)
ax.scatter(sample['pv_location_x'], sample['pv_location_y'], c='black')

ax = axes[2]
ax.plot(pd.concat((sample['historical_pv_yield'], sample['target_pv_yield'])))

In [None]:
sample['nwp_above_pv'].sel(variable='t').plot()

In [None]:
num_workers = 8

torch.manual_seed(42)

dataset = SatelliteDataset(
    zarr_chunk_sequences=zarr_chunk_sequences,
    transform=Compose([
        CropCentredOnPv(),
        CheckForBadData(),
        ToTensor(),
    ]),
)

if num_workers == 0:
    dataset.per_worker_init()

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=8,
    num_workers=num_workers,  # timings:  4=13.8s; 8=11.6; 10=11.3s; 11=11.5s; 12=12.6s.  10=3it/s
    worker_init_fn=worker_init_fn,
    pin_memory=True,
    #multiprocessing_context='spawn'
    #persistent_workers=True
)

In [None]:
%%time
for batch in dataloader:
    print(batch['historical_sat_images'].shape)
    break

In [None]:
pd.to_datetime(batch['datetime_index'].numpy().flatten(), unit='s').values.reshape(-1, 2).astype('datetime64[s]')

In [None]:
batch['historical_sat_images'].shape

In [None]:
batch['target_sat_images'].shape

In [None]:
batch['historical_sat_images'].dtype

In [None]:
plt.imshow(batch['historical_sat_images'][0, 0])

In [None]:
plt.imshow(batch['target_sat_images'][0, 0])

In [None]:
batch['historical_pv_yield']

In [None]:
batch['target_pv_yield']

# Simple ML model

In [None]:
def normalise_images_in_model(images, device):    
    SAT_IMAGE_MEAN = torch.tensor(93.23458, dtype=torch.float, device=device)
    SAT_IMAGE_STD = torch.tensor(115.34247, dtype=torch.float, device=device)
    
    images = images.float()
    images -= SAT_IMAGE_MEAN
    images /= SAT_IMAGE_STD
    return images

In [None]:
CHANNELS = 32
KERNEL = 3


class LitAutoEncoder(pl.LightningModule):
    def __init__(self):
        super().__init__()
        
        self.encoder_conv1 = nn.Conv2d(in_channels=1, out_channels=CHANNELS//2, kernel_size=KERNEL)
        self.encoder_conv2 = nn.Conv2d(in_channels=CHANNELS//2, out_channels=CHANNELS, kernel_size=KERNEL)
        self.encoder_conv3 = nn.Conv2d(in_channels=CHANNELS, out_channels=CHANNELS, kernel_size=KERNEL)

        self.maxpool = nn.MaxPool2d(kernel_size=KERNEL)
        
        self.fc1 = nn.Linear(
            in_features=CHANNELS * 11 * 11, 
            out_features=256 - 4  # Minus 4 (1 for the historical PV data; 4 for the PV system embedding)
        )
        #self.fc_just_prev_yield = nn.Linear(in_features=1, out_features=256)
        self.fc2 = nn.Linear(in_features=256, out_features=128)
        self.fc3 = nn.Linear(in_features=128, out_features=128)
        self.fc4 = nn.Linear(in_features=128, out_features=128)
        self.fc5 = nn.Linear(in_features=128, out_features=1)
        
        self.pv_system_id_embedding = nn.Embedding(
            num_embeddings=len(pv_metadata),
            embedding_dim=4
        )
        
    def forward(self, x):
        images = x['target_sat_images']
        images = normalise_images_in_model(images, self.device)
        
        # Pass data through the network :)
        out = F.relu(self.encoder_conv1(images))
        out = self.maxpool(out)
        out = F.relu(self.encoder_conv2(out))
        out = self.maxpool(out)
        out = F.relu(self.encoder_conv3(out))
        
        out = out.view(-1, CHANNELS * 11 * 11)
        out = F.relu(self.fc1(out))
        
        pv_embedding = self.pv_system_id_embedding(x['pv_system_row_number'])
        
        out = torch.cat(
            (
                out, 
                #x['historical_pv_yield'][:, :1], 
                pv_embedding
            ), dim=1)

        #out = F.relu(self.fc_just_prev_yield(x['historical_pv_yield'][:, :1]))
        out = F.relu(self.fc2(out))
        out = F.relu(self.fc3(out))
        out = F.relu(self.fc4(out))
        out = self.fc5(out)

        return out
    
    def _training_or_validation_step(self, batch, is_train_step):
        y_hat = self(batch)
        y = batch['target_pv_yield']
        #mse_loss = F.mse_loss(y_hat, y)
        mae_loss = (y_hat - y).abs().mean()
        tag = "Train" if is_train_step else "Validation"
        #self.log_dict({'MSE/' + tag: mse_loss}, on_step=is_train_step, on_epoch=True)
        self.log_dict({'MAE/' + tag: mae_loss}, on_step=is_train_step, on_epoch=True)
        return mae_loss

    def training_step(self, batch, batch_idx):
        return self._training_or_validation_step(batch, is_train_step=True)
    
    def validation_step(self, batch, batch_idx):
        return self._training_or_validation_step(batch, is_train_step=False)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

In [None]:
model = LitAutoEncoder()

In [None]:
model(batch).shape

In [None]:
trainer = pl.Trainer(gpus=1, max_epochs=400, terminate_on_nan=False)

In [None]:
%%time
trainer.fit(model, train_dataloader=dataloader)