In [1]:
from nowcasting_dataset.datamodule import NowcastingDataModule
from pathlib import Path
import matplotlib.pyplot as plt

import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl

import logging
logging.basicConfig()
logger = logging.getLogger('nowcasting_dataset')
logger.setLevel(logging.DEBUG)

In [2]:
BUCKET = Path('solar-pv-nowcasting-data')

# Solar PV data
PV_PATH = BUCKET / 'PV/PVOutput.org'
PV_DATA_FILENAME = PV_PATH / 'UK_PV_timeseries_batch.nc'
PV_METADATA_FILENAME = PV_PATH / 'UK_PV_metadata.csv'

# SAT_FILENAME = BUCKET / 'satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep_quarter_geospatial.zarr'
SAT_FILENAME = BUCKET / 'satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr'

# Numerical weather predictions
#NWP_BASE_PATH = BUCKET / 'NWP/UK_Met_Office/UKV_zarr'
#NWP_BASE_PATH = BUCKET / 'NWP/UK_Met_Office/UKV_single_step_and_single_timestep_all_vars.zarr'
NWP_BASE_PATH = BUCKET / 'NWP/UK_Met_Office/UKV_single_step_and_single_timestep_all_vars_full_spatial_2018_7-12_float32.zarr'

In [66]:
params = dict(
    batch_size=32,
    history_len=6,  #: Number of timesteps of history, not including t0.
    forecast_len=12,  #: Number of timesteps of forecast.
    nwp_channels=(
        't', 'dswrf', 'prate', 'r', 'sde', 'si10', 'vis', 'lcc', 'mcc', 'hcc')
)

In [4]:
%%time
data_module = NowcastingDataModule(
    pv_power_filename=PV_DATA_FILENAME,
    pv_metadata_filename=f'gs://{PV_METADATA_FILENAME}',
    sat_filename = f'gs://{SAT_FILENAME}',
    # sat_channels =('HRV', 'WV_062', 'WV_073'),
    nwp_base_path = f'gs://{NWP_BASE_PATH}',
    pin_memory = True,  #: Passed to DataLoader.
    num_workers = 22,  #: Passed to DataLoader.
    prefetch_factor = 256,  #: Passed to DataLoader.
    n_samples_per_timestep = 8,  #: Passed to NowcastingDataset
    **params
)

CPU times: user 54 µs, sys: 8 µs, total: 62 µs
Wall time: 65.8 µs


In [5]:
%%time
data_module.prepare_data()

DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr


15 bad PV systems found and removed!
pv_power = 400.0 MB
CPU times: user 53.3 s, sys: 3.37 s, total: 56.7 s
Wall time: 58 s


In [6]:
%%time
data_module.setup()

DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening NWP data: gs://solar-pv-nowcasting-data/NWP/UK_Met_Office/UKV_single_step_and_single_timestep_all_vars_full_spatial_2018_7-12_float32.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
  unixtime = np.array(time.astype(np.int64)/10**9)
  unixtime = np.array(time.astype(np.int64)/10**9)
  unixtime = np.array(time.astype(np.int64)/10**9)
  unixtime = np.array(time.astype(np.int64)/10**9)


CPU times: user 5.47 s, sys: 121 ms, total: 5.59 s
Wall time: 6.67 s


  a = a.astype(int)


## Define very simple ML model

In [167]:
def plot_example(batch, model_output, example_i: int=0):
    fig, axes = plt.subplots(ncols=3, nrows=2, figsize=(20, 20))
    
    # Satellite data
    sat_data = batch['sat_data'][example_i, :, :, :, 0].cpu()
    axes[0, 0].imshow(sat_data[0])
    axes[0, 0].set_title('t = -{}'.format(params['history_len']))
    axes[0, 1].imshow(sat_data[params['history_len']])
    axes[0, 1].set_title('t = 0')
    axes[0, 2].imshow(sat_data[-1])
    axes[0, 2].set_title('t = {}'.format(params['forecast_len']))

    # NWP
    pd.DataFrame(batch['nwp'][example_i, :, :, 0, 0].T.cpu().numpy(), columns=params['nwp_channels']).plot(figsize=(20, 10), ax=axes[1, 0])
    axes[1, 0].set_title('NWP')

    # datetime features
    ax = axes[1, 1]
    ax.set_title('datetime features')
    for key in ['hour_of_day_sin', 'hour_of_day_cos', 'day_of_year_sin', 'day_of_year_cos']:
        ax.plot(batch[key][example_i].cpu(), label=key)
    ax.legend()

    # PV yield
    ax = axes[1, 2]
    ax.set_title('PV yield')
    ax.plot(batch['pv_yield'][example_i].cpu(), label='actual')
    ax.plot(range(params['history_len'] + 1, TOTAL_SEQ_LEN), model_output[example_i].detach().cpu(), label='prediction')
    ax.legend()
    
    return fig, axes

In [182]:
from neptune.new.types import File

In [190]:
TOTAL_SEQ_LEN = params['history_len'] + params['forecast_len'] + 1
CHANNELS = 144
KERNEL = 3
EMBEDDING_DIM = 16
NWP_SIZE = 10 * 2 * 2 * TOTAL_SEQ_LEN  # channels x width x height
N_DATETIME_FEATURES = 4 * TOTAL_SEQ_LEN

class LitAutoEncoder(pl.LightningModule):
    def __init__(
        self,
        history_len = params['history_len'],
        forecast_len = params['forecast_len'],
        
    ):
        super().__init__()
        self.history_len = history_len
        self.forecast_len = forecast_len
        
        self.sat_conv1 = nn.Conv2d(in_channels=history_len, out_channels=CHANNELS//2, kernel_size=KERNEL, groups=history_len)
        self.sat_conv2 = nn.Conv2d(in_channels=CHANNELS//2, out_channels=CHANNELS, kernel_size=KERNEL, groups=CHANNELS//2)
        self.sat_conv3 = nn.Conv2d(in_channels=CHANNELS, out_channels=CHANNELS, kernel_size=KERNEL, groups=CHANNELS)

        self.maxpool = nn.MaxPool2d(kernel_size=KERNEL)
        
        self.fc1 = nn.Linear(
            in_features=CHANNELS * 11 * 11, 
            out_features=256)
        
        self.fc2 = nn.Linear(in_features=256 + EMBEDDING_DIM + NWP_SIZE + N_DATETIME_FEATURES + params['history_len'], out_features=128)
        #self.fc2 = nn.Linear(in_features=EMBEDDING_DIM + N_DATETIME_FEATURES, out_features=128)
        self.fc3 = nn.Linear(in_features=128, out_features=128)
        self.fc4 = nn.Linear(in_features=128, out_features=128)
        self.fc5 = nn.Linear(in_features=128, out_features=params['forecast_len'])
        
        if EMBEDDING_DIM:
            self.pv_system_id_embedding = nn.Embedding(
                num_embeddings=len(data_module.pv_data_source.pv_metadata),
                embedding_dim=EMBEDDING_DIM)
        
    def forward(self, x):
        # ******************* Satellite imagery *************************
        # Shape: batch_size, seq_length, width, height, channel
        sat_data = x['sat_data'][:, :self.history_len]
        batch_size, seq_len, width, height, n_chans = sat_data.shape
        
        # Move seq_length to be the last dim, ready for changing the shape
        sat_data = sat_data.permute(0, 2, 3, 4, 1)
        
        # Stack timesteps into the channel dimension
        sat_data = sat_data.view(batch_size, width, height, seq_len * n_chans)
        
        sat_data = sat_data.permute(0, 3, 1, 2)  # Conv2d expects channels to be the 2nd dim!
        
        # Pass data through the network :)
        out = F.relu(self.sat_conv1(sat_data))
        out = self.maxpool(out)
        out = F.relu(self.sat_conv2(out))
        out = self.maxpool(out)
        out = F.relu(self.sat_conv3(out))
        
        out = out.view(-1, CHANNELS * 11 * 11)
        out = F.relu(self.fc1(out))
        
        # *********************** NWP Data **************************************
        nwp_data = x['nwp'].float() # Shape: batch_size, channel, seq_length, width, height
        batch_size, n_nwp_chans, nwp_seq_len, nwp_width, nwp_height = nwp_data.shape
        nwp_data = nwp_data.reshape(batch_size, n_nwp_chans * nwp_seq_len * nwp_width * nwp_height)
        
        # Concat
        out = torch.cat(
            (
                out,
                x['pv_yield'][:, :self.history_len],
                nwp_data,
                x['hour_of_day_sin'],
                x['hour_of_day_cos'],
                x['day_of_year_sin'],
                x['day_of_year_cos'],
            ),
            dim=1)
        
        # Embedding of PV system ID
        if EMBEDDING_DIM:
            pv_embedding = self.pv_system_id_embedding(x['pv_system_row_number'])
            out = torch.cat(
                (
                    out,
                    pv_embedding
                ), 
                dim=1)

        # Fully connected layers.
        out = F.relu(self.fc2(out))
        out = F.relu(self.fc3(out))
        out = F.relu(self.fc4(out))
        out = self.fc5(out)

        return out
    
    def _training_or_validation_step(self, batch, is_train_step):
        y_hat = self(batch)
        y = batch['pv_yield'][:, -self.forecast_len:]
        #y = torch.rand((32, 1), device=self.device)
        mse_loss = F.mse_loss(y_hat, y)
        nmae_loss = (y_hat - y).abs().mean()
        # TODO: Compute correlation coef using np.corrcoef(tensor with shape (2, num_timesteps))[0, 1]
        # on each example, and taking the mean across the batch?
        tag = "Train" if is_train_step else "Validation"
        self.log_dict({f'MSE/{tag}': mse_loss}, on_step=is_train_step, on_epoch=True)
        self.log_dict({f'NMAE/{tag}': nmae_loss}, on_step=is_train_step, on_epoch=True)
        
        return nmae_loss

    def training_step(self, batch, batch_idx):
        return self._training_or_validation_step(batch, is_train_step=True)
    
    def validation_step(self, batch, batch_idx):
        if batch_idx == 0:
            # Plot example
            model_output = self(batch)
            fig, axes = plot_example(batch, model_output)
            self.logger.experiment['validation/plot'].log(File.as_image(fig))
            
        return self._training_or_validation_step(batch, is_train_step=False)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

In [191]:
#train_dl = data_module.train_dataloader()
#for batch in train_dl:
#    break

# model_output = model(batch)
# plot_example(batch, model_output);

In [192]:
model = LitAutoEncoder()

In [198]:
from neptune.new.integrations.pytorch_lightning import NeptuneLogger

In [201]:
logger = NeptuneLogger(
    project='OpenClimateFix/predict-pv-yield',
    #params=params,
    #experiment_name='climatology',
    #experiment_id='PRED-1'
)

In [202]:
logger.version

https://app.neptune.ai/OpenClimateFix/predict-pv-yield/e/PRED-16
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api-reference/run#stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.


'PRED-16'

In [203]:
trainer = pl.Trainer(gpus=1, max_epochs=10_000, logger=logger)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(model, data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                   | Type      | Params
-----------------------------------------------------
0 | sat_conv1              | Conv2d    | 720   
1 | sat_conv2              | Conv2d    | 1.4 K 
2 | sat_conv3              | Conv2d    | 1.4 K 
3 | maxpool                | MaxPool2d | 0     
4 | fc1                    | Linear    | 4.5 M 
5 | fc2                    | Linear    | 142 K 
6 | fc3                    | Linear    | 16.5 K
7 | fc4                    | Linear    | 16.5 K
8 | fc5                    | Linear    | 1.5 K 
9 | pv_system_id_embedding | Embedding | 15.0 K
-----------------------------------------------------
4.7 M     Trainable params
0         Non-trainable params
4.7 M     Total params
18.627    Total estimated model params size (MB)


Validation sanity check:   0%|                                                                     | 0/2 [00:00<?, ?it/s]

DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite

                                                                                                                         

DEBUG:nowcasting_dataset:Opening NWP data: gs://solar-pv-nowcasting-data/NWP/UK_Met_Office/UKV_single_step_and_single_timestep_all_vars_full_spatial_2018_7-12_float32.zarr
DEBUG:nowcasting_dataset:Opening NWP data: gs://solar-pv-nowcasting-data/NWP/UK_Met_Office/UKV_single_step_and_single_timestep_all_vars_full_spatial_2018_7-12_float32.zarr
DEBUG:nowcasting_dataset:Opening NWP data: gs://solar-pv-nowcasting-data/NWP/UK_Met_Office/UKV_single_step_and_single_timestep_all_vars_full_spatial_2018_7-12_float32.zarr
DEBUG:nowcasting_dataset:Opening NWP data: gs://solar-pv-nowcasting-data/NWP/UK_Met_Office/UKV_single_step_and_single_timestep_all_vars_full_spatial_2018_7-12_float32.zarr
DEBUG:nowcasting_dataset:Opening NWP data: gs://solar-pv-nowcasting-data/NWP/UK_Met_Office/UKV_single_step_and_single_timestep_all_vars_full_spatial_2018_7-12_float32.zarr
DEBUG:nowcasting_dataset:Opening NWP data: gs://solar-pv-nowcasting-data/NWP/UK_Met_Office/UKV_single_step_and_single_timestep_all_vars_full

Epoch 0: : 0it [00:00, ?it/s]                                                                                            

DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite

Epoch 0: : 5945it [18:09,  5.46it/s, loss=0.085, v_num=D-16] 