In [1]:
from nowcasting_dataset.datamodule import NowcastingDataModule
from pathlib import Path

import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl

import logging
logging.basicConfig()
logger = logging.getLogger('nowcasting_dataset')
logger.setLevel(logging.DEBUG)

In [2]:
BUCKET = Path('solar-pv-nowcasting-data')

# Solar PV data
PV_PATH = BUCKET / 'PV/PVOutput.org'
PV_DATA_FILENAME = PV_PATH / 'UK_PV_timeseries_batch.nc'
PV_METADATA_FILENAME = PV_PATH / 'UK_PV_metadata.csv'

# SAT_FILENAME = BUCKET / 'satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep_quarter_geospatial.zarr'
SAT_FILENAME = BUCKET / 'satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr'

# Numerical weather predictions
#NWP_BASE_PATH = BUCKET / 'NWP/UK_Met_Office/UKV_zarr'
#NWP_BASE_PATH = BUCKET / 'NWP/UK_Met_Office/UKV_single_step_and_single_timestep_all_vars.zarr'
NWP_BASE_PATH = BUCKET / 'NWP/UK_Met_Office/UKV_single_step_and_single_timestep_all_vars_2018_7-12_float32.zarr'

In [3]:
%%time
data_module = NowcastingDataModule(
    pv_power_filename=PV_DATA_FILENAME,
    pv_metadata_filename=f'gs://{PV_METADATA_FILENAME}',
    batch_size = 32,
    history_len = 0,  #: Number of timesteps of history, not including t0.
    forecast_len = 1,  #: Number of timesteps of forecast.
    sat_filename = f'gs://{SAT_FILENAME}',
    # sat_channels =('HRV', 'WV_062', 'WV_073'),
    nwp_base_path = f'gs://{NWP_BASE_PATH}',
    pin_memory = True,  #: Passed to DataLoader.
    num_workers = 12,  #: Passed to DataLoader.
    prefetch_factor = 256,  #: Passed to DataLoader.
    n_samples_per_timestep = 8,  #: Passed to NowcastingDataset
)

CPU times: user 53 µs, sys: 9 µs, total: 62 µs
Wall time: 65.8 µs


In [4]:
%%time
data_module.prepare_data()

DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr


15 bad PV systems found and removed!
pv_power = 400.0 MB
CPU times: user 55 s, sys: 3.5 s, total: 58.5 s
Wall time: 59 s


In [5]:
%%time
data_module.setup()

DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening NWP data: gs://solar-pv-nowcasting-data/NWP/UK_Met_Office/UKV_single_step_and_single_timestep_all_vars_2018_7-12_float32.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
  unixtime = np.array(time.astype(np.int64)/10**9)
  unixtime = np.array(time.astype(np.int64)/10**9)
  unixtime = np.array(time.astype(np.int64)/10**9)
  unixtime = np.array(time.astype(np.int64)/10**9)


CPU times: user 5.79 s, sys: 150 ms, total: 5.94 s
Wall time: 6.89 s


  a = a.astype(int)


## Define very simple ML model

In [6]:
def normalise_images_in_model(images, device):
    SAT_IMAGE_MEAN = torch.tensor(
        [
            93.23458, 131.71373, 843.7779 , 736.6148 , 771.1189 , 589.66034,
            862.29816, 927.69586,  90.70885, 107.58985, 618.4583 , 532.47394
        ],
        dtype=torch.float, device=device)
    SAT_IMAGE_STD = torch.tensor(
        [
            115.34247 , 139.92636 ,  36.99538 ,  57.366386,  30.346825,
            149.68007 ,  51.70631 ,  35.872967, 115.77212 , 120.997154,
            98.57828 ,  99.76469
        ],
        dtype=torch.float, device=device)
    
    images = images.float()
    images = images - SAT_IMAGE_MEAN.unsqueeze(-1).unsqueeze(-1)
    images = images / SAT_IMAGE_STD.unsqueeze(-1).unsqueeze(-1)
    return images

In [7]:
CHANNELS = 144
KERNEL = 3
EMBEDDING_DIM = 16
NWP_SIZE = 10 * 2 * 2  # channels x width x height


class LitAutoEncoder(pl.LightningModule):
    def __init__(
        self,
        history_len: int=1
    ):
        super().__init__()
        self.history_len = history_len
        
        self.sat_conv1 = nn.Conv2d(in_channels=12, out_channels=CHANNELS//2, kernel_size=KERNEL, groups=12)
        self.sat_conv2 = nn.Conv2d(in_channels=CHANNELS//2, out_channels=CHANNELS, kernel_size=KERNEL, groups=CHANNELS//2)
        self.sat_conv3 = nn.Conv2d(in_channels=CHANNELS, out_channels=CHANNELS, kernel_size=KERNEL, groups=CHANNELS)

        self.maxpool = nn.MaxPool2d(kernel_size=KERNEL)
        
        self.fc1 = nn.Linear(
            in_features=CHANNELS * 11 * 11, 
            out_features=256)
        
        self.fc2 = nn.Linear(in_features=256 + EMBEDDING_DIM + NWP_SIZE, out_features=128)
        #self.fc2 = nn.Linear(in_features=EMBEDDING_DIM + NWP_SIZE, out_features=128)
        self.fc3 = nn.Linear(in_features=128, out_features=128)
        self.fc4 = nn.Linear(in_features=128, out_features=128)
        self.fc5 = nn.Linear(in_features=128, out_features=1)
        
        if EMBEDDING_DIM:
            self.pv_system_id_embedding = nn.Embedding(
                num_embeddings=len(data_module.pv_data_source.pv_metadata),
                embedding_dim=EMBEDDING_DIM
            )

        
    def forward(self, x):
        sat_data = x['sat_data'][:, self.history_len]
        sat_data = sat_data.permute(0, 3, 2, 1)  # Conv2d expects channels to be the 2nd dim!
        sat_data = normalise_images_in_model(sat_data, self.device)
        
        # Pass data through the network :)
        out = F.relu(self.sat_conv1(sat_data))
        out = self.maxpool(out)
        out = F.relu(self.sat_conv2(out))
        out = self.maxpool(out)
        out = F.relu(self.sat_conv3(out))
        
        out = out.view(-1, CHANNELS * 11 * 11)
        out = F.relu(self.fc1(out))
        
        nwp_data = x['nwp'][:, :, self.history_len].float() # Shape: batch_size, channel, seq_length, width, height
        batch_size, n_nwp_chans, nwp_width, nwp_height = nwp_data.shape
        nwp_data = nwp_data.reshape(batch_size, n_nwp_chans * nwp_width * nwp_height)
        out = torch.cat((out, nwp_data), dim=1)
        #out = nwp_data
        
        if EMBEDDING_DIM:
            pv_embedding = self.pv_system_id_embedding(x['pv_system_row_number'])
            out = torch.cat(
                (
                    out,
                    pv_embedding
                ), dim=1)

        out = F.relu(self.fc2(out))
        out = F.relu(self.fc3(out))
        out = F.relu(self.fc4(out))
        out = self.fc5(out)

        return out
    
    def _training_or_validation_step(self, batch, is_train_step):
        y_hat = self(batch)
        y = batch['pv_yield'][:, self.history_len:]
        #y = torch.rand((32, 1), device=self.device)
        #mse_loss = F.mse_loss(y_hat, y)
        mae_loss = (y_hat - y).abs().mean()
        tag = "Train" if is_train_step else "Validation"
        #self.log_dict({'MSE/' + tag: mse_loss}, on_step=is_train_step, on_epoch=True)
        self.log_dict({'MAE/' + tag: mae_loss}, on_step=is_train_step, on_epoch=True)
        return mae_loss

    def training_step(self, batch, batch_idx):
        return self._training_or_validation_step(batch, is_train_step=True)
    
    def validation_step(self, batch, batch_idx):
        return self._training_or_validation_step(batch, is_train_step=False)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

In [8]:
model = LitAutoEncoder()

In [9]:
trainer = pl.Trainer(gpus=1, max_epochs=10_000)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(model, data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                   | Type      | Params
-----------------------------------------------------
0 | sat_conv1              | Conv2d    | 720   
1 | sat_conv2              | Conv2d    | 1.4 K 
2 | sat_conv3              | Conv2d    | 1.4 K 
3 | maxpool                | MaxPool2d | 0     
4 | fc1                    | Linear    | 4.5 M 
5 | fc2                    | Linear    | 40.1 K
6 | fc3                    | Linear    | 16.5 K
7 | fc4                    | Linear    | 16.5 K
8 | fc5                    | Linear    | 129   
9 | pv_system_id_embedding | Embedding | 15.0 K
-----------------------------------------------------
4.6 M     Trainable params
0         Non-trainable params
4.6 M     Total params
18.211    Total estimated model params size (MB)


Validation sanity check:   0%|                                                                  | 0/2 [00:00<?, ?it/s]

DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite

Epoch 0: : 0it [00:00, ?it/s]                                                                                         

DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite

Epoch 0: : 408it [01:06,  6.17it/s, loss=0.118, v_num=205] 

ERROR:gcsfs:_request non-retriable exception: <html><title>Error 400 (Bad Request)!!1</title></html>, 400
Traceback (most recent call last):
  File "/home/jack/miniconda3/envs/nowcasting_dataset/lib/python3.9/site-packages/gcsfs/retry.py", line 110, in retry_request
    return await func(*args, **kwargs)
  File "/home/jack/miniconda3/envs/nowcasting_dataset/lib/python3.9/site-packages/gcsfs/core.py", line 332, in _request
    validate_response(status, contents, path)
  File "/home/jack/miniconda3/envs/nowcasting_dataset/lib/python3.9/site-packages/gcsfs/retry.py", line 99, in validate_response
    raise HttpError({"code": status, "message": msg})  # text-like
gcsfs.retry.HttpError: <html><title>Error 400 (Bad Request)!!1</title></html>, 400


Epoch 0: : 1020it [02:30,  6.77it/s, loss=0.0961, v_num=205]
Validating: 0it [00:00, ?it/s][A
Validating: 0it [00:00, ?it/s][A

DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite


Epoch 0: : 1023it [02:39,  6.43it/s, loss=0.0961, v_num=205]
Epoch 0: : 1028it [02:39,  6.44it/s, loss=0.0961, v_num=205]
Epoch 0: : 1033it [02:40,  6.45it/s, loss=0.0961, v_num=205]
Validating: 13it [00:09,  2.34it/s][A
Epoch 0: : 1038it [02:41,  6.44it/s, loss=0.0961, v_num=205]
Epoch 0: : 1045it [02:41,  6.47it/s, loss=0.0961, v_num=205]
Epoch 1: : 0it [00:00, ?it/s, loss=0.0961, v_num=205]       

DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite

Epoch 1: : 1020it [02:30,  6.79it/s, loss=0.103, v_num=205] 
Validating: 0it [00:00, ?it/s][A
Validating: 0it [00:00, ?it/s][A

DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite


Validating: 1it [00:08,  8.33s/it][A
Epoch 1: : 1025it [02:38,  6.45it/s, loss=0.103, v_num=205]
Epoch 1: : 1030it [02:40,  6.43it/s, loss=0.103, v_num=205]
Epoch 1: : 1039it [02:41,  6.45it/s, loss=0.103, v_num=205]
Epoch 1: : 1045it [02:41,  6.48it/s, loss=0.103, v_num=205]
Epoch 2: : 0it [00:00, ?it/s, loss=0.103, v_num=205]       

DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
DEBUG:nowcasting_dataset:Opening satellite data: gs://solar-pv-nowcasting-data/satellite

Epoch 2: : 190it [00:39,  4.77it/s, loss=0.101, v_num=205] 