In [1]:
from nowcasting_dataset.datamodule import NowcastingDataModule
from pathlib import Path

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl

In [3]:
BUCKET = Path('solar-pv-nowcasting-data')

# Solar PV data
PV_PATH = BUCKET / 'PV/PVOutput.org'
PV_DATA_FILENAME = PV_PATH / 'UK_PV_timeseries_batch.nc'
PV_METADATA_FILENAME = PV_PATH / 'UK_PV_metadata.csv'

# SAT_FILENAME = BUCKET / 'satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep_quarter_geospatial.zarr'
SAT_FILENAME = BUCKET / 'satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr'

In [4]:
%%time
data_module = NowcastingDataModule(
    pv_power_filename=PV_DATA_FILENAME,
    pv_metadata_filename=f'gs://{PV_METADATA_FILENAME}',
    batch_size = 32,
    history_len = 0,  #: Number of timesteps of history, not including t0.
    forecast_len = 1,  #: Number of timesteps of forecast.
    sat_filename = f'gs://{SAT_FILENAME}',
    sat_channels = None, #('HRV', 'WV_062', 'WV_073'),
    pin_memory = True,  #: Passed to DataLoader.
    num_workers = 8,  #: Passed to DataLoader.
    prefetch_factor = 256,  #: Passed to DataLoader.
    n_samples_per_timestep = 8,  #: Passed to NowcastingDataset
)

CPU times: user 69 µs, sys: 10 µs, total: 79 µs
Wall time: 82.3 µs


In [5]:
data_module.prepare_data()

15 bad PV systems found and removed!
pv_power = 400.0 MB


In [6]:
data_module.setup()

  unixtime = np.array(time.astype(np.int64)/10**9)
  unixtime = np.array(time.astype(np.int64)/10**9)
  unixtime = np.array(time.astype(np.int64)/10**9)
  unixtime = np.array(time.astype(np.int64)/10**9)
  a = a.astype(int)


In [7]:
dl = data_module.train_dataloader()

In [8]:
%%time
dl.dataset.per_worker_init(worker_id=0)

CPU times: user 1.56 s, sys: 35.8 ms, total: 1.6 s
Wall time: 1.73 s


In [9]:
%%time
batch = dl.dataset._get_batch()

CPU times: user 476 ms, sys: 56.9 ms, total: 533 ms
Wall time: 339 ms


## Define very simple ML model

In [10]:
def normalise_images_in_model(images, device):
    SAT_IMAGE_MEAN = torch.tensor(
        [
            93.23458, 131.71373, 843.7779 , 736.6148 , 771.1189 , 589.66034,
            862.29816, 927.69586,  90.70885, 107.58985, 618.4583 , 532.47394
        ],
        dtype=torch.float, device=device)
    SAT_IMAGE_STD = torch.tensor(
        [
            115.34247 , 139.92636 ,  36.99538 ,  57.366386,  30.346825,
            149.68007 ,  51.70631 ,  35.872967, 115.77212 , 120.997154,
            98.57828 ,  99.76469
        ],
        dtype=torch.float, device=device)
    
    images = images.float()
    images = images - SAT_IMAGE_MEAN.unsqueeze(-1).unsqueeze(-1)
    images = images / SAT_IMAGE_STD.unsqueeze(-1).unsqueeze(-1)
    return images

In [11]:
CHANNELS = 144
KERNEL = 3


class LitAutoEncoder(pl.LightningModule):
    def __init__(
        self,
        history_len: int=1
    ):
        super().__init__()
        self.history_len = history_len
        
        self.encoder_conv1 = nn.Conv2d(in_channels=12, out_channels=CHANNELS//2, kernel_size=KERNEL, groups=12)
        self.encoder_conv2 = nn.Conv2d(in_channels=CHANNELS//2, out_channels=CHANNELS, kernel_size=KERNEL, groups=CHANNELS//2)
        self.encoder_conv3 = nn.Conv2d(in_channels=CHANNELS, out_channels=CHANNELS, kernel_size=KERNEL, groups=CHANNELS)

        self.maxpool = nn.MaxPool2d(kernel_size=KERNEL)
        
        self.fc1 = nn.Linear(
            in_features=CHANNELS * 11 * 11, 
            out_features=256  # Minus 2 (2 for the NWP temperature above the PV system)
        )
        self.fc2 = nn.Linear(in_features=256, out_features=128)
        self.fc3 = nn.Linear(in_features=128, out_features=128)
        self.fc4 = nn.Linear(in_features=128, out_features=128)
        self.fc5 = nn.Linear(in_features=128, out_features=1)

        
    def forward(self, x):
        images = x['sat_data'][:, self.history_len, :, :, :]
        images = images.permute(0, 3, 2, 1)  # Conv2d expects channels to be the 2nd dim!
        images = normalise_images_in_model(images, self.device)
        
        # Pass data through the network :)
        out = F.relu(self.encoder_conv1(images))
        out = self.maxpool(out)
        out = F.relu(self.encoder_conv2(out))
        out = self.maxpool(out)
        out = F.relu(self.encoder_conv3(out))
        
        out = out.view(-1, CHANNELS * 11 * 11)
        out = F.relu(self.fc1(out))
        
        out = torch.cat(
            (
                out, 
                #(x['nwp_above_pv'][:, 0] - 130) / 5,  # TODO fix horrible standardisation of temperature!
            ), dim=1)

        out = F.relu(self.fc2(out))
        out = F.relu(self.fc3(out))
        out = F.relu(self.fc4(out))
        out = self.fc5(out)

        return out
    
    def _training_or_validation_step(self, batch, is_train_step):
        y_hat = self(batch)
        y = batch['pv_yield'][:, self.history_len:]
        #y = torch.rand((32, 1), device=self.device)
        #mse_loss = F.mse_loss(y_hat, y)
        mae_loss = (y_hat - y).abs().mean()
        tag = "Train" if is_train_step else "Validation"
        #self.log_dict({'MSE/' + tag: mse_loss}, on_step=is_train_step, on_epoch=True)
        self.log_dict({'MAE/' + tag: mae_loss}, on_step=is_train_step, on_epoch=True)
        return mae_loss

    def training_step(self, batch, batch_idx):
        return self._training_or_validation_step(batch, is_train_step=True)
    
    def validation_step(self, batch, batch_idx):
        return self._training_or_validation_step(batch, is_train_step=False)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

In [12]:
model = LitAutoEncoder()

In [13]:
trainer = pl.Trainer(gpus=1)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(model, data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type      | Params
--------------------------------------------
0 | encoder_conv1 | Conv2d    | 720   
1 | encoder_conv2 | Conv2d    | 1.4 K 
2 | encoder_conv3 | Conv2d    | 1.4 K 
3 | maxpool       | MaxPool2d | 0     
4 | fc1           | Linear    | 4.5 M 
5 | fc2           | Linear    | 32.9 K
6 | fc3           | Linear    | 16.5 K
7 | fc4           | Linear    | 16.5 K
8 | fc5           | Linear    | 129   
--------------------------------------------
4.5 M     Trainable params
0         Non-trainable params
4.5 M     Total params
18.122    Total estimated model params size (MB)


Validation sanity check:   0%|                                                                  | 0/2 [00:00<?, ?it/s]

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Epoch 0: : 0it [00:00, ?it/s]                                                                                         



Epoch 0: : 1024it [01:10, 14.43it/s, loss=0.105, v_num=147]
Validating: 0it [00:00, ?it/s][A
Validating: 0it [00:00, ?it/s][A




Epoch 0: : 1026it [01:15, 13.65it/s, loss=0.105, v_num=147]
Validating: 2it [00:04,  1.79s/it][A
Epoch 0: : 1030it [01:15, 13.60it/s, loss=0.105, v_num=147]
Epoch 0: : 1034it [01:16, 13.57it/s, loss=0.105, v_num=147]
Validating: 10it [00:05,  3.01it/s][A
Epoch 0: : 1038it [01:17, 13.43it/s, loss=0.105, v_num=147]
Epoch 0: : 1042it [01:17, 13.40it/s, loss=0.105, v_num=147]
Epoch 0: : 1046it [01:17, 13.43it/s, loss=0.105, v_num=147]
Epoch 0: : 1050it [01:18, 13.43it/s, loss=0.105, v_num=147]
Epoch 0: : 1057it [01:18, 13.46it/s, loss=0.105, v_num=147]
Epoch 1: : 0it [00:00, ?it/s, loss=0.105, v_num=147]       



Epoch 1: : 1015it [01:09, 14.51it/s, loss=0.0904, v_num=147]



Epoch 1: : 1024it [01:10, 14.54it/s, loss=0.0975, v_num=147]
Validating: 0it [00:00, ?it/s][A
Validating: 0it [00:00, ?it/s][A




Validating: 1it [00:04,  4.18s/it][A
Epoch 1: : 1029it [01:15, 13.64it/s, loss=0.0975, v_num=147]
Validating: 8it [00:05,  2.28it/s][A
Epoch 1: : 1036it [01:16, 13.62it/s, loss=0.0975, v_num=147]




Epoch 1: : 1043it [01:16, 13.62it/s, loss=0.0975, v_num=147]




Epoch 1: : 1050it [01:16, 13.64it/s, loss=0.0975, v_num=147]
Epoch 1: : 1057it [01:17, 13.67it/s, loss=0.0975, v_num=147]
Epoch 2: : 0it [00:00, ?it/s, loss=0.0975, v_num=147]       



Epoch 2: : 1000it [01:09, 14.36it/s, loss=0.093, v_num=147]



Epoch 2: : 1024it [01:10, 14.51it/s, loss=0.0826, v_num=147]
Validating: 0it [00:00, ?it/s][A
Validating: 0it [00:00, ?it/s][A




Validating: 1it [00:04,  4.08s/it][A
Epoch 2: : 1029it [01:15, 13.69it/s, loss=0.0826, v_num=147]
Validating: 5it [00:04,  1.79it/s][A
Epoch 2: : 1036it [01:16, 13.62it/s, loss=0.0826, v_num=147]
Epoch 2: : 1043it [01:16, 13.62it/s, loss=0.0826, v_num=147]
Validating: 19it [00:05,  6.81it/s][A
Epoch 2: : 1050it [01:17, 13.64it/s, loss=0.0826, v_num=147]
Epoch 2: : 1057it [01:17, 13.67it/s, loss=0.0826, v_num=147]
Epoch 3: : 0it [00:00, ?it/s, loss=0.0826, v_num=147]       



Epoch 3: : 1003it [01:09, 14.42it/s, loss=0.102, v_num=147]



Epoch 3: : 1024it [01:10, 14.54it/s, loss=0.092, v_num=147] 
Validating: 0it [00:00, ?it/s][A
Validating: 0it [00:00, ?it/s][A




Epoch 3: : 1029it [01:15, 13.66it/s, loss=0.092, v_num=147]
Epoch 3: : 1036it [01:15, 13.65it/s, loss=0.092, v_num=147]
Validating: 13it [00:05,  3.56it/s][A
Epoch 3: : 1043it [01:16, 13.65it/s, loss=0.092, v_num=147]




Epoch 3: : 1057it [01:16, 13.74it/s, loss=0.092, v_num=147]
Epoch 4: : 0it [00:00, ?it/s, loss=0.092, v_num=147]       



Epoch 4: : 1023it [01:11, 14.36it/s, loss=0.0937, v_num=147]



Epoch 4: : 1024it [01:11, 14.35it/s, loss=0.0911, v_num=147]
Validating: 0it [00:00, ?it/s][A
Validating: 0it [00:00, ?it/s][A




Validating: 1it [00:03,  3.91s/it][A
Epoch 4: : 1029it [01:15, 13.58it/s, loss=0.0911, v_num=147]
Validating: 5it [00:04,  1.59it/s][A
Epoch 4: : 1036it [01:16, 13.58it/s, loss=0.0911, v_num=147]
Epoch 4: : 1043it [01:16, 13.57it/s, loss=0.0911, v_num=147]
Validating: 19it [00:05,  8.63it/s][A
Epoch 4: : 1050it [01:17, 13.57it/s, loss=0.0911, v_num=147]
Epoch 4: : 1057it [01:17, 13.59it/s, loss=0.0911, v_num=147]
Epoch 5: : 0it [00:00, ?it/s, loss=0.0911, v_num=147]       



Epoch 5: : 134it [00:14,  9.31it/s, loss=0.0907, v_num=147]