In [1]:
from nowcasting_dataset.datamodule import NowcastingDataModule
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import pandas as pd
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl

from neptune.new.integrations.pytorch_lightning import NeptuneLogger

import logging
logging.basicConfig()
logger = logging.getLogger('nowcasting_dataset')
logger.setLevel(logging.DEBUG)

In [2]:
BUCKET = Path('solar-pv-nowcasting-data')

# Solar PV data
PV_PATH = BUCKET / 'PV/PVOutput.org'
PV_DATA_FILENAME = PV_PATH / 'UK_PV_timeseries_batch.nc'
PV_METADATA_FILENAME = PV_PATH / 'UK_PV_metadata.csv'

# SAT_FILENAME = BUCKET / 'satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep_quarter_geospatial.zarr'
SAT_FILENAME = BUCKET / 'satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr'

# Numerical weather predictions
#NWP_BASE_PATH = BUCKET / 'NWP/UK_Met_Office/UKV_zarr'
#NWP_BASE_PATH = BUCKET / 'NWP/UK_Met_Office/UKV_single_step_and_single_timestep_all_vars.zarr'
#NWP_BASE_PATH = BUCKET / 'NWP/UK_Met_Office/UKV_single_step_and_single_timestep_all_vars_full_spatial_2018_7-12_float32.zarr'
NWP_BASE_PATH = BUCKET / 'NWP/UK_Met_Office/UKV__2018-01_to_2019-12__chunks__variable10__init_time1__step1__x548__y704__.zarr'

In [3]:
params = dict(
    batch_size=32,
    history_len=6,  #: Number of timesteps of history, not including t0.
    forecast_len=12,  #: Number of timesteps of forecast.
    image_size_pixels=32,
    nwp_channels=('t', 'dswrf', 'prate', 'r', 'sde', 'si10', 'vis', 'lcc', 'mcc', 'hcc'),
    sat_channels=(
        'HRV', 'IR_016', 'IR_039', 'IR_087', 'IR_097', 'IR_108', 'IR_120',
        'IR_134', 'VIS006', 'VIS008', 'WV_062', 'WV_073')
)

In [4]:
data_module = NowcastingDataModule(
    pv_power_filename=PV_DATA_FILENAME,
    pv_metadata_filename=f'gs://{PV_METADATA_FILENAME}',
    sat_filename = f'gs://{SAT_FILENAME}',
    nwp_base_path = f'gs://{NWP_BASE_PATH}',
    pin_memory = True,  #: Passed to DataLoader.
    num_workers = 22,  #: Passed to DataLoader.
    prefetch_factor = 256,  #: Passed to DataLoader.
    n_samples_per_timestep = 8,  #: Passed to NowcastingDataset
    **params
)

In [5]:
data_module.prepare_data()

DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr


15 bad PV systems found and removed!
pv_power = 400.0 MB


In [6]:
data_module.setup()

DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening NWP data: gs://solar-pv-nowcasting-data/NWP/UK_Met_Office/UKV__2018-01_to_2019-12__chunks__variable10__init_time1__step1__x548__y704__.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
  unixtime = np.array(time.astype(np.int64)/10**9)
  unixtime = np.array(time.astype(np.int64)/10**9)
  unixtime = np.array(time.astype(np.int64)/10**9)
  unixtime = np.array(time.astype(np.int64)/10**9)
  a = a.astype(int)


In [7]:
data_module.train_t0_datetimes

DatetimeIndex(['2018-06-01 03:50:00', '2018-06-01 03:55:00',
               '2018-06-01 04:00:00', '2018-06-01 04:05:00',
               '2018-06-01 04:10:00', '2018-06-01 04:15:00',
               '2018-06-01 04:20:00', '2018-06-01 04:25:00',
               '2018-06-01 04:30:00', '2018-06-01 04:35:00',
               ...
               '2019-06-16 15:15:00', '2019-06-16 15:20:00',
               '2019-06-16 15:25:00', '2019-06-16 15:30:00',
               '2019-06-16 15:35:00', '2019-06-16 15:40:00',
               '2019-06-16 15:45:00', '2019-06-16 15:50:00',
               '2019-06-16 15:55:00', '2019-06-16 16:00:00'],
              dtype='datetime64[ns]', length=47620, freq=None)

In [8]:
data_module.val_t0_datetimes

DatetimeIndex(['2019-06-16 16:05:00', '2019-06-16 16:10:00',
               '2019-06-16 16:15:00', '2019-06-16 16:20:00',
               '2019-06-16 16:25:00', '2019-06-16 16:30:00',
               '2019-06-16 16:35:00', '2019-06-16 16:40:00',
               '2019-06-16 16:45:00', '2019-06-16 16:50:00',
               ...
               '2019-08-20 18:00:00', '2019-08-20 18:05:00',
               '2019-08-20 18:10:00', '2019-08-20 18:15:00',
               '2019-08-20 18:20:00', '2019-08-20 18:25:00',
               '2019-08-20 18:30:00', '2019-08-20 18:35:00',
               '2019-08-20 18:40:00', '2019-08-20 18:45:00'],
              dtype='datetime64[ns]', length=11904, freq=None)

In [9]:
len(data_module.pv_data_source.pv_metadata)

940

## Define very simple ML model

In [None]:
import tilemapbase
from nowcasting_dataset.geospatial import osgb_to_lat_lon

In [None]:
tilemapbase.init(create=True)

In [None]:
def plot_example(batch, model_output, example_i: int=0, border: int=0):
    fig = plt.figure(figsize=(20, 20))
    ncols=4
    nrows=2
    
    # Satellite data
    extent = (
        float(batch['sat_x_coords'][example_i, 0].cpu().numpy()), 
        float(batch['sat_x_coords'][example_i, -1].cpu().numpy()), 
        float(batch['sat_y_coords'][example_i, -1].cpu().numpy()), 
        float(batch['sat_y_coords'][example_i, 0].cpu().numpy()))  # left, right, bottom, top
    
    def _format_ax(ax):
        ax.scatter(
            batch['x_meters_center'][example_i].cpu(), 
            batch['y_meters_center'][example_i].cpu(), 
            s=500, color='white', marker='x')

    ax = fig.add_subplot(nrows, ncols, 1) #, projection=ccrs.OSGB(approx=False))
    sat_data = batch['sat_data'][example_i, :, :, :, 0].cpu().numpy()
    sat_min = np.min(sat_data)
    sat_max = np.max(sat_data)
    ax.imshow(sat_data[0], extent=extent, interpolation='none', vmin=sat_min, vmax=sat_max)
    ax.set_title('t = -{}'.format(params['history_len']))
    _format_ax(ax)

    ax = fig.add_subplot(nrows, ncols, 2)
    ax.imshow(sat_data[params['history_len']+1], extent=extent, interpolation='none', vmin=sat_min, vmax=sat_max)
    ax.set_title('t = 0')
    _format_ax(ax)
    
    ax = fig.add_subplot(nrows, ncols, 3)
    ax.imshow(sat_data[-1], extent=extent, interpolation='none', vmin=sat_min, vmax=sat_max)
    ax.set_title('t = {}'.format(params['forecast_len']))
    _format_ax(ax)
    
    ax = fig.add_subplot(nrows, ncols, 4)
    lat_lon_bottom_left = osgb_to_lat_lon(extent[0], extent[2])
    lat_lon_top_right = osgb_to_lat_lon(extent[1], extent[3])
    tiles = tilemapbase.tiles.build_OSM()
    lat_lon_extent = tilemapbase.Extent.from_lonlat(
        longitude_min=lat_lon_bottom_left[1],
        longitude_max=lat_lon_top_right[1],
        latitude_min=lat_lon_bottom_left[0],
        latitude_max=lat_lon_top_right[0])
    plotter = tilemapbase.Plotter(lat_lon_extent, tile_provider=tiles, zoom=6)
    plotter.plot(ax, tiles)

    ############## TIMESERIES ##################
    # NWP
    ax = fig.add_subplot(nrows, ncols, 5)
    nwp_dt_index = pd.to_datetime(batch['nwp_target_time'][example_i].cpu().numpy(), unit='s')
    pd.DataFrame(
        batch['nwp'][example_i, :, :, 0, 0].T.cpu().numpy(), 
        index=nwp_dt_index,
        columns=params['nwp_channels']).plot(ax=ax)
    ax.set_title('NWP')

    # datetime features
    ax = fig.add_subplot(nrows, ncols, 6)
    ax.set_title('datetime features')
    datetime_feature_cols = ['hour_of_day_sin', 'hour_of_day_cos', 'day_of_year_sin', 'day_of_year_cos']
    datetime_features_df = pd.DataFrame(index=nwp_dt_index, columns=datetime_feature_cols)
    for key in datetime_feature_cols:
        datetime_features_df[key] = batch[key][example_i].cpu().numpy()
    datetime_features_df.plot(ax=ax)
    ax.legend()
    ax.set_xlabel(nwp_dt_index[0].date())

    # PV yield
    ax = fig.add_subplot(nrows, ncols, 7)
    ax.set_title('PV yield for PV ID {:,d}'.format(batch['pv_system_id'][example_i].cpu()))
    pv_actual = pd.Series(
        batch['pv_yield'][example_i].cpu().numpy(),
        index=nwp_dt_index,
        name='actual')
    pv_pred = pd.Series(
        model_output[example_i].detach().cpu().numpy(),
        index=nwp_dt_index[params['history_len']+1:],
        name='prediction')
    pd.concat([pv_actual, pv_pred], axis='columns').plot(ax=ax)
    ax.legend()

    # fig.tight_layout()
    
    return fig

In [None]:
# plot_example(batch, model_output, example_i=20);  

In [None]:
SAT_X_MEAN = np.float32(309000)
SAT_X_STD = np.float32(316387.42073603)
SAT_Y_MEAN = np.float32(519000)
SAT_Y_STD = np.float32(406454.17945938)

In [None]:
from neptune.new.types import File

In [None]:
TOTAL_SEQ_LEN = params['history_len'] + params['forecast_len'] + 1
CHANNELS = 32
N_CHANNELS_LAST_CONV = 4
KERNEL = 3
EMBEDDING_DIM = 16
NWP_SIZE = 10 * 2 * 2  # channels x width x height
N_DATETIME_FEATURES = 4
CNN_OUTPUT_SIZE = N_CHANNELS_LAST_CONV * ((params['image_size_pixels'] - 6) ** 2)
FC_OUTPUT_SIZE = 8
RNN_HIDDEN_SIZE = 16

class LitAutoEncoder(pl.LightningModule):
    def __init__(
        self,
        history_len = params['history_len'],
        forecast_len = params['forecast_len'],
    ):
        super().__init__()
        self.history_len = history_len
        self.forecast_len = forecast_len

        self.sat_conv1 = nn.Conv2d(in_channels=len(params['sat_channels'])+5, out_channels=CHANNELS, kernel_size=KERNEL)#, groups=history_len+1)
        self.sat_conv2 = nn.Conv2d(in_channels=CHANNELS, out_channels=CHANNELS, kernel_size=KERNEL) #, groups=CHANNELS//2)
        self.sat_conv3 = nn.Conv2d(in_channels=CHANNELS, out_channels=N_CHANNELS_LAST_CONV, kernel_size=KERNEL) #, groups=CHANNELS)

        #self.maxpool = nn.MaxPool2d(kernel_size=KERNEL)

        self.fc1 = nn.Linear(
            in_features=CNN_OUTPUT_SIZE, 
            out_features=256)

        self.fc2 = nn.Linear(
            in_features=256 + EMBEDDING_DIM,
            out_features=128)
        #self.fc2 = nn.Linear(in_features=EMBEDDING_DIM + N_DATETIME_FEATURES, out_features=128)
        self.fc3 = nn.Linear(in_features=128, out_features=64)
        self.fc4 = nn.Linear(in_features=64, out_features=32)
        self.fc5 = nn.Linear(in_features=32, out_features=FC_OUTPUT_SIZE)

        if EMBEDDING_DIM:
            self.pv_system_id_embedding = nn.Embedding(
                num_embeddings=len(data_module.pv_data_source.pv_metadata),
                embedding_dim=EMBEDDING_DIM)
            
            
        self.encoder_rnn = nn.GRU(
            input_size=FC_OUTPUT_SIZE + N_DATETIME_FEATURES + 1 + NWP_SIZE,  # plus 1 for history
            hidden_size=RNN_HIDDEN_SIZE,
            num_layers=2,
            batch_first=True)
        self.decoder_rnn = nn.GRU(
            input_size=FC_OUTPUT_SIZE + N_DATETIME_FEATURES + NWP_SIZE,
            hidden_size=RNN_HIDDEN_SIZE,
            num_layers=2,
            batch_first=True)
        
        self.decoder_fc1 = nn.Linear(
            in_features=RNN_HIDDEN_SIZE,
            out_features=8)
        self.decoder_fc2 = nn.Linear(
            in_features=8,
            out_features=1)
        
        ### EXTRA CHANNELS
        # Center marker
        new_batch_size = params['batch_size'] * TOTAL_SEQ_LEN
        self.center_marker = torch.zeros(
            (
                new_batch_size, 
                1, 
                params['image_size_pixels'], 
                params['image_size_pixels']
            ),
            dtype=torch.float32, device=self.device)
        half_width = params['image_size_pixels'] // 2
        self.center_marker[..., half_width-2:half_width+2, half_width-2:half_width+2] = 1
        
        # pixel x & y
        pixel_range = (torch.arange(params['image_size_pixels'], device=self.device) - 64) / 37
        pixel_range = pixel_range.unsqueeze(0).unsqueeze(0)
        self.pixel_x = pixel_range.unsqueeze(-2).expand(new_batch_size, 1, params['image_size_pixels'], -1)
        self.pixel_y = pixel_range.unsqueeze(-1).expand(new_batch_size, 1, -1, params['image_size_pixels'])
        

    def forward(self, x):
        # ******************* Satellite imagery *************************
        # Shape: batch_size, seq_length, width, height, channel
        # TODO: Use optical flow, not actual sat images of the future!
        sat_data = x['sat_data']
        batch_size, seq_len, width, height, n_chans = sat_data.shape

        # Stack timesteps as extra examples
        new_batch_size = batch_size * seq_len
        #                                 0           1       2      3
        sat_data = sat_data.reshape(new_batch_size, width, height, n_chans)

        # Conv2d expects channels to be the 2nd dim!
        sat_data = sat_data.permute(0, 3, 1, 2)
        # Now shape: new_batch_size, n_chans, width, height

        ### EXTRA CHANNELS
        # geo-spatial x
        x_coords = x['sat_x_coords']  # shape:  batch_size, image_size_pixels
        x_coords = x_coords - SAT_X_MEAN
        x_coords = x_coords / SAT_X_STD
        x_coords = x_coords.unsqueeze(1).expand(-1, width, -1).unsqueeze(1).repeat_interleave(repeats=TOTAL_SEQ_LEN, dim=0)
        
        # geo-spatial y
        y_coords = x['sat_y_coords']  # shape:  batch_size, image_size_pixels
        y_coords = y_coords - SAT_Y_MEAN
        y_coords = y_coords / SAT_Y_STD
        y_coords = y_coords.unsqueeze(-1).expand(-1, -1, height).unsqueeze(1).repeat_interleave(repeats=TOTAL_SEQ_LEN, dim=0)
        
        # Concat
        if sat_data.device != self.center_marker.device:
            self.center_marker = self.center_marker.to(sat_data.device)
            self.pixel_x = self.pixel_x.to(sat_data.device)
            self.pixel_y = self.pixel_y.to(sat_data.device)
        
        sat_data = torch.cat((sat_data, self.center_marker, x_coords, y_coords, self.pixel_x, self.pixel_y), dim=1)
        
        del x_coords, y_coords

        
        # Pass data through the network :)
        out = F.relu(self.sat_conv1(sat_data))
        #out = self.maxpool(out)
        out = F.relu(self.sat_conv2(out))
        #out = self.maxpool(out)
        out = F.relu(self.sat_conv3(out))

        out = out.reshape(new_batch_size, CNN_OUTPUT_SIZE)
        out = F.relu(self.fc1(out))
        
        # ********************** Embedding of PV system ID *********************
        if EMBEDDING_DIM:
            pv_embedding = self.pv_system_id_embedding(x['pv_system_row_number'].repeat_interleave(TOTAL_SEQ_LEN))
            out = torch.cat(
                (
                    out,
                    pv_embedding
                ), 
                dim=1)

        # Fully connected layers.
        out = F.relu(self.fc2(out))
        out = F.relu(self.fc3(out))
        out = F.relu(self.fc4(out))
        out = F.relu(self.fc5(out))

        # ******************* PREP DATA FOR RNN *****************************************
        out = out.reshape(batch_size, TOTAL_SEQ_LEN, FC_OUTPUT_SIZE) # TODO: Double-check this does what we expect!
        
        # The RNN encoder gets recent history: satellite, NWP, datetime features, and recent PV history.
        # The RNN decoder gets what we know about the future: satellite, NWP, and datetime features.

        # *********************** NWP Data **************************************
        nwp_data = x['nwp'].float() # Shape: batch_size, channel, seq_length, width, height
        nwp_data = nwp_data.permute(0, 2, 1, 3, 4)  # RNN expects seq_len to be dim 1.
        batch_size, nwp_seq_len, n_nwp_chans, nwp_width, nwp_height = nwp_data.shape
        nwp_data = nwp_data.reshape(batch_size, nwp_seq_len, n_nwp_chans * nwp_width * nwp_height)

        # Concat
        rnn_input = torch.cat(
            (
                out,
                nwp_data,
                x['hour_of_day_sin'].unsqueeze(-1),
                x['hour_of_day_cos'].unsqueeze(-1),
                x['day_of_year_sin'].unsqueeze(-1),
                x['day_of_year_cos'].unsqueeze(-1),
            ),
            dim=2)
        
        pv_yield_history = x['pv_yield'][:, :self.history_len+1].unsqueeze(-1)
        encoder_input = torch.cat(
            (
                rnn_input[:, :self.history_len+1],
                pv_yield_history
            ),
            dim=2)
        
        encoder_output, encoder_hidden = self.encoder_rnn(encoder_input)
        decoder_output, _ = self.decoder_rnn(rnn_input[:, -self.forecast_len:], encoder_hidden)
        # decoder_output is shape batch_size, seq_len, rnn_hidden_size
        
        decoder_output = F.relu(self.decoder_fc1(decoder_output))
        decoder_output = self.decoder_fc2(decoder_output)
        
        return decoder_output.squeeze()
    
    def _training_or_validation_step(self, batch, is_train_step):
        y_hat = self(batch)
        y = batch['pv_yield'][:, -self.forecast_len:]
        #y = torch.rand((32, 1), device=self.device)
        mse_loss = F.mse_loss(y_hat, y)
        nmae_loss = (y_hat - y).abs().mean()
        # TODO: Compute correlation coef using np.corrcoef(tensor with shape (2, num_timesteps))[0, 1]
        # on each example, and taking the mean across the batch?
        tag = "Train" if is_train_step else "Validation"
        self.log_dict({f'MSE/{tag}': mse_loss}, on_step=is_train_step, on_epoch=True)
        self.log_dict({f'NMAE/{tag}': nmae_loss}, on_step=is_train_step, on_epoch=True)
        
        return nmae_loss

    def training_step(self, batch, batch_idx):
        return self._training_or_validation_step(batch, is_train_step=True)
    
    def validation_step(self, batch, batch_idx):
        if batch_idx == 0:
            # Plot example
            model_output = self(batch)
            fig = plot_example(batch, model_output)
            self.logger.experiment['validation/plot'].log(File.as_image(fig))
            
        return self._training_or_validation_step(batch, is_train_step=False)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

In [None]:
model = LitAutoEncoder()

In [None]:
def get_batch():
    train_ds = data_module.train_dataset
    train_ds.per_worker_init(0)
    for batch in train_ds:
        break
    return batch

In [None]:
# batch = get_batch()

In [None]:
#model_output = model(batch)

In [None]:
#model_output.shape

In [None]:
#model_output.shape

In [None]:
#plot_example(batch, model_output, example_i=2);

In [None]:
logger = NeptuneLogger(
    project='OpenClimateFix/predict-pv-yield',
    #params=params,
    #experiment_name='climatology',
    #experiment_id='PRED-1'
)

In [None]:
logger.log_hyperparams(params)

In [None]:
print('logger.version =', logger.version)

In [None]:
trainer = pl.Trainer(gpus=1, max_epochs=10_000, logger=logger)

In [None]:
trainer.fit(model, data_module)