In [1]:
from nowcasting_dataset.datamodule import NowcastingDataModule
from pathlib import Path

  rank_zero_deprecation(


In [2]:
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl

In [3]:
BUCKET = Path('solar-pv-nowcasting-data')

# Solar PV data
PV_PATH = BUCKET / 'PV/PVOutput.org'
PV_DATA_FILENAME = PV_PATH / 'UK_PV_timeseries_batch.nc'
PV_METADATA_FILENAME = PV_PATH / 'UK_PV_metadata.csv'

SAT_FILENAME = BUCKET / 'satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep_quarter_geospatial.zarr'

In [4]:
%%time
data_module = NowcastingDataModule(
    pv_power_filename=PV_DATA_FILENAME,
    pv_metadata_filename=f'gs://{PV_METADATA_FILENAME}',
    batch_size = 32,
    history_len = 0,  #: Number of timesteps of history, not including t0.
    forecast_len = 1,  #: Number of timesteps of forecast.
    sat_filename = f'gs://{SAT_FILENAME}',
    sat_channels = None, #('HRV', 'WV_062', 'WV_073'),
    pin_memory = True,  #: Passed to DataLoader.
    num_workers = 16,  #: Passed to DataLoader.
    prefetch_factor = 256,  #: Passed to DataLoader.
    n_samples_per_timestep = 2,  #: Passed to NowcastingDataset
)

CPU times: user 51 µs, sys: 7 µs, total: 58 µs
Wall time: 61.8 µs


## Define very simple ML model

In [38]:
def normalise_images_in_model(images, device):    
    # HRV, WV_062, WV_073
    SAT_IMAGE_MEAN = torch.tensor(
        [
            93.23458, 131.71373, 843.7779 , 736.6148 , 771.1189 , 589.66034,
            862.29816, 927.69586,  90.70885, 107.58985, 618.4583 , 532.47394
        ],
        dtype=torch.float, device=device)
    SAT_IMAGE_STD = torch.tensor(
        [
            115.34247 , 139.92636 ,  36.99538 ,  57.366386,  30.346825,
            149.68007 ,  51.70631 ,  35.872967, 115.77212 , 120.997154,
            98.57828 ,  99.76469
        ],
        dtype=torch.float, device=device)
    
    images = images.float()
    images = images - SAT_IMAGE_MEAN.unsqueeze(-1).unsqueeze(-1)
    images = images / SAT_IMAGE_STD.unsqueeze(-1).unsqueeze(-1)
    return images

In [39]:
CHANNELS = 32
KERNEL = 3


class LitAutoEncoder(pl.LightningModule):
    def __init__(
        self,
        history_len: int=1
    ):
        super().__init__()
        self.history_len = history_len
        
        self.encoder_conv1 = nn.Conv2d(in_channels=12, out_channels=CHANNELS//2, kernel_size=KERNEL)
        self.encoder_conv2 = nn.Conv2d(in_channels=CHANNELS//2, out_channels=CHANNELS, kernel_size=KERNEL)
        self.encoder_conv3 = nn.Conv2d(in_channels=CHANNELS, out_channels=CHANNELS, kernel_size=KERNEL)

        self.maxpool = nn.MaxPool2d(kernel_size=KERNEL)
        
        self.fc1 = nn.Linear(
            in_features=CHANNELS * 11 * 11, 
            out_features=256  # Minus 2 (2 for the NWP temperature above the PV system)
        )
        self.fc2 = nn.Linear(in_features=256, out_features=128)
        self.fc3 = nn.Linear(in_features=128, out_features=128)
        self.fc4 = nn.Linear(in_features=128, out_features=128)
        self.fc5 = nn.Linear(in_features=128, out_features=1)

        
    def forward(self, x):
        images = x['sat_data'][:, self.history_len, :, :, :]
        images = images.permute(0, 3, 2, 1)  # Conv2d expects channels to be the 2nd dim!
        images = normalise_images_in_model(images, self.device)
        
        # Pass data through the network :)
        out = F.relu(self.encoder_conv1(images))
        out = self.maxpool(out)
        out = F.relu(self.encoder_conv2(out))
        out = self.maxpool(out)
        out = F.relu(self.encoder_conv3(out))
        
        out = out.view(-1, CHANNELS * 11 * 11)
        out = F.relu(self.fc1(out))
        
        out = torch.cat(
            (
                out, 
                #(x['nwp_above_pv'][:, 0] - 130) / 5,  # TODO fix horrible standardisation of temperature!
            ), dim=1)

        out = F.relu(self.fc2(out))
        out = F.relu(self.fc3(out))
        out = F.relu(self.fc4(out))
        out = self.fc5(out)

        return out
    
    def _training_or_validation_step(self, batch, is_train_step):
        y_hat = self(batch)
        y = batch['pv_yield'][:, self.history_len:]
        #y = torch.rand((32, 1), device=self.device)
        #mse_loss = F.mse_loss(y_hat, y)
        mae_loss = (y_hat - y).abs().mean()
        tag = "Train" if is_train_step else "Validation"
        #self.log_dict({'MSE/' + tag: mse_loss}, on_step=is_train_step, on_epoch=True)
        self.log_dict({'MAE/' + tag: mae_loss}, on_step=is_train_step, on_epoch=True)
        return mae_loss

    def training_step(self, batch, batch_idx):
        return self._training_or_validation_step(batch, is_train_step=True)
    
    def validation_step(self, batch, batch_idx):
        return self._training_or_validation_step(batch, is_train_step=False)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

In [40]:
model = LitAutoEncoder()

In [41]:
trainer = pl.Trainer(gpus=1)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(model, data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type      | Params
--------------------------------------------
0 | encoder_conv1 | Conv2d    | 448   
1 | encoder_conv2 | Conv2d    | 4.6 K 
2 | encoder_conv3 | Conv2d    | 9.2 K 
3 | maxpool       | MaxPool2d | 0     
4 | fc1           | Linear    | 991 K 
5 | fc2           | Linear    | 32.9 K
6 | fc3           | Linear    | 16.5 K
7 | fc4           | Linear    | 16.5 K
8 | fc5           | Linear    | 129   
--------------------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params
4.287     Total estimated model params size (MB)


Epoch 0: : 7568it [18:32,  6.80it/s, loss=0.104, v_num=120]   
Epoch 0: : 7568it [18:32,  6.80it/s, loss=0.104, v_num=120]   
Epoch 0: : 7568it [18:33,  6.80it/s, loss=0.104, v_num=120]   
Epoch 0: : 7568it [18:33,  6.80it/s, loss=0.104, v_num=120]   
Epoch 0: : 7568it [18:33,  6.80it/s, loss=0.104, v_num=120]   
Epoch 0: : 7568it [18:33,  6.80it/s, loss=0.104, v_num=120]   
Epoch 0: : 7568it [18:33,  6.80it/s, loss=0.104, v_num=120]   
Epoch 0: : 7568it [18:33,  6.80it/s, loss=0.104, v_num=120]   
Epoch 0: : 7568it [18:33,  6.80it/s, loss=0.104, v_num=120]   
Epoch 0: : 7568it [18:33,  6.80it/s, loss=0.104, v_num=120]   
Epoch 0: : 7568it [18:33,  6.80it/s, loss=0.104, v_num=120]   
Epoch 0: : 7568it [18:33,  6.80it/s, loss=0.104, v_num=120]   
Epoch 0: : 7568it [18:33,  6.80it/s, loss=0.104, v_num=120]   
Epoch 0: : 7568it [18:33,  6.80it/s, loss=0.104, v_num=120]   
Epoch 0: : 7568it [18:33,  6.80it/s, loss=0.104, v_num=120]   
Epoch 0: : 7568it [18:33,  6.80it/s, loss=0.104, v_num=