In [1]:
from nowcasting_dataset.datamodule import NowcastingDataModule
from pathlib import Path

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl

In [3]:
BUCKET = Path('solar-pv-nowcasting-data')

# Solar PV data
PV_PATH = BUCKET / 'PV/PVOutput.org'
PV_DATA_FILENAME = PV_PATH / 'UK_PV_timeseries_batch.nc'
PV_METADATA_FILENAME = PV_PATH / 'UK_PV_metadata.csv'

# SAT_FILENAME = BUCKET / 'satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep_quarter_geospatial.zarr'
SAT_FILENAME = BUCKET / 'satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr'

# Numerical weather predictions
NWP_BASE_PATH = BUCKET / 'NWP/UK_Met_Office/UKV_zarr'

In [4]:
%%time
data_module = NowcastingDataModule(
    #pv_power_filename=PV_DATA_FILENAME,
    #pv_metadata_filename=f'gs://{PV_METADATA_FILENAME}',
    batch_size = 32,
    history_len = 0,  #: Number of timesteps of history, not including t0.
    forecast_len = 1,  #: Number of timesteps of forecast.
    sat_filename = f'gs://{SAT_FILENAME}',
    # sat_channels =('HRV', 'WV_062', 'WV_073'),
    nwp_base_path = f'gs://{NWP_BASE_PATH}',
    pin_memory = True,  #: Passed to DataLoader.
    num_workers = 8,  #: Passed to DataLoader.
    prefetch_factor = 256,  #: Passed to DataLoader.
    n_samples_per_timestep = 8,  #: Passed to NowcastingDataset
)

CPU times: user 59 µs, sys: 10 µs, total: 69 µs
Wall time: 72.7 µs


In [5]:
%%time
data_module.prepare_data()

CPU times: user 117 µs, sys: 19 µs, total: 136 µs
Wall time: 139 µs


In [6]:
%%time
data_module.setup()

  unixtime = np.array(time.astype(np.int64)/10**9)
  unixtime = np.array(time.astype(np.int64)/10**9)
  unixtime = np.array(time.astype(np.int64)/10**9)
  unixtime = np.array(time.astype(np.int64)/10**9)


CPU times: user 16.2 s, sys: 506 ms, total: 16.7 s
Wall time: 18.6 s


  a = a.astype(int)


## Define very simple ML model

In [7]:
def normalise_images_in_model(images, device):
    SAT_IMAGE_MEAN = torch.tensor(
        [
            93.23458, 131.71373, 843.7779 , 736.6148 , 771.1189 , 589.66034,
            862.29816, 927.69586,  90.70885, 107.58985, 618.4583 , 532.47394
        ],
        dtype=torch.float, device=device)
    SAT_IMAGE_STD = torch.tensor(
        [
            115.34247 , 139.92636 ,  36.99538 ,  57.366386,  30.346825,
            149.68007 ,  51.70631 ,  35.872967, 115.77212 , 120.997154,
            98.57828 ,  99.76469
        ],
        dtype=torch.float, device=device)
    
    images = images.float()
    images = images - SAT_IMAGE_MEAN.unsqueeze(-1).unsqueeze(-1)
    images = images / SAT_IMAGE_STD.unsqueeze(-1).unsqueeze(-1)
    return images

In [15]:
CHANNELS = 144
KERNEL = 3
EMBEDDING_DIM = 0
NWP_SIZE = 10 * 2 * 2  # channels x width x height


class LitAutoEncoder(pl.LightningModule):
    def __init__(
        self,
        history_len: int=1
    ):
        super().__init__()
        self.history_len = history_len
        
        self.sat_conv1 = nn.Conv2d(in_channels=12, out_channels=CHANNELS//2, kernel_size=KERNEL, groups=12)
        self.sat_conv2 = nn.Conv2d(in_channels=CHANNELS//2, out_channels=CHANNELS, kernel_size=KERNEL, groups=CHANNELS//2)
        self.sat_conv3 = nn.Conv2d(in_channels=CHANNELS, out_channels=CHANNELS, kernel_size=KERNEL, groups=CHANNELS)

        self.maxpool = nn.MaxPool2d(kernel_size=KERNEL)
        
        self.fc1 = nn.Linear(
            in_features=CHANNELS * 11 * 11, 
            out_features=256)
        self.fc2 = nn.Linear(in_features=256 + EMBEDDING_DIM + NWP_SIZE, out_features=128)
        self.fc3 = nn.Linear(in_features=128, out_features=128)
        self.fc4 = nn.Linear(in_features=128, out_features=128)
        self.fc5 = nn.Linear(in_features=128, out_features=1)
        
        if EMBEDDING_DIM:
            self.pv_system_id_embedding = nn.Embedding(
                num_embeddings=len(data_module.pv_data_source.pv_metadata),
                embedding_dim=EMBEDDING_DIM
            )

        
    def forward(self, x):
        sat_data = x['sat_data'][:, self.history_len]
        sat_data = sat_data.permute(0, 3, 2, 1)  # Conv2d expects channels to be the 2nd dim!
        sat_data = normalise_images_in_model(sat_data, self.device)
        
        # Pass data through the network :)
        out = F.relu(self.sat_conv1(sat_data))
        out = self.maxpool(out)
        out = F.relu(self.sat_conv2(out))
        out = self.maxpool(out)
        out = F.relu(self.sat_conv3(out))
        
        out = out.view(-1, CHANNELS * 11 * 11)
        out = F.relu(self.fc1(out))
        
        nwp_data = x['nwp'][:, :, self.history_len] # Shape: batch_size, channel, seq_length, width, height
        batch_size, n_nwp_chans, nwp_width, nwp_height = nwp_data.shape
        nwp_data = nwp_data.reshape(batch_size, n_nwp_chans * nwp_width * nwp_height)
        out = torch.cat((out, nwp_data), dim=1)
        
        if EMBEDDING_DIM:
            pv_embedding = self.pv_system_id_embedding(x['pv_system_row_number'])
            out = torch.cat(
                (
                    out,
                    pv_embedding
                    #(x['nwp_above_pv'][:, 0] - 130) / 5,  # TODO fix horrible standardisation of temperature!
                ), dim=1)

        out = F.relu(self.fc2(out))
        out = F.relu(self.fc3(out))
        out = F.relu(self.fc4(out))
        out = self.fc5(out)

        return out
    
    def _training_or_validation_step(self, batch, is_train_step):
        y_hat = self(batch)
        #y = batch['pv_yield'][:, self.history_len:]
        y = torch.rand((32, 1), device=self.device)
        #mse_loss = F.mse_loss(y_hat, y)
        mae_loss = (y_hat - y).abs().mean()
        tag = "Train" if is_train_step else "Validation"
        #self.log_dict({'MSE/' + tag: mse_loss}, on_step=is_train_step, on_epoch=True)
        self.log_dict({'MAE/' + tag: mae_loss}, on_step=is_train_step, on_epoch=True)
        return mae_loss

    def training_step(self, batch, batch_idx):
        return self._training_or_validation_step(batch, is_train_step=True)
    
    def validation_step(self, batch, batch_idx):
        return self._training_or_validation_step(batch, is_train_step=False)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

In [16]:
model = LitAutoEncoder()

In [10]:
trainer = pl.Trainer(gpus=1, max_epochs=10_000)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [17]:
trainer.fit(model, data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type      | Params
----------------------------------------
0 | sat_conv1 | Conv2d    | 720   
1 | sat_conv2 | Conv2d    | 1.4 K 
2 | sat_conv3 | Conv2d    | 1.4 K 
3 | maxpool   | MaxPool2d | 0     
4 | fc1       | Linear    | 4.5 M 
5 | fc2       | Linear    | 38.0 K
6 | fc3       | Linear    | 16.5 K
7 | fc4       | Linear    | 16.5 K
8 | fc5       | Linear    | 129   
----------------------------------------
4.5 M     Trainable params
0         Non-trainable params
4.5 M     Total params
18.142    Total estimated model params size (MB)


Validation sanity check:   0%|                                                                  | 0/2 [00:00<?, ?it/s]



                                                                                                                      





                                                                                                                      



                                                                                                                      



                                                                                                                      



                                                                                                                      



                                                                                                                      



Epoch 0: : 0it [00:00, ?it/s]                                                                                         



                                                                                                                      



                                                                                                                      







                                                                                                                      



                                                                                                                      

InvalidIndexError: Caught InvalidIndexError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/jack/dev/ocf/nowcasting_dataset/nowcasting_dataset/data_sources/data_source.py", line 64, in _get_cached_time_slice
    return self._cache[t0_dt]
KeyError: Timestamp('2018-11-04 13:50:00')

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/jack/miniconda3/envs/nowcasting_dataset/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/jack/miniconda3/envs/nowcasting_dataset/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 34, in fetch
    data = next(self.dataset_iter)
  File "/home/jack/dev/ocf/nowcasting_dataset/nowcasting_dataset/dataset.py", line 61, in __iter__
    yield self._get_batch()
  File "/home/jack/dev/ocf/nowcasting_dataset/nowcasting_dataset/dataset.py", line 81, in _get_batch
    examples = [
  File "/home/jack/dev/ocf/nowcasting_dataset/nowcasting_dataset/dataset.py", line 82, in <listcomp>
    future_example.result() for future_example in future_examples]
  File "/home/jack/miniconda3/envs/nowcasting_dataset/lib/python3.9/concurrent/futures/_base.py", line 438, in result
    return self.__get_result()
  File "/home/jack/miniconda3/envs/nowcasting_dataset/lib/python3.9/concurrent/futures/_base.py", line 390, in __get_result
    raise self._exception
  File "/home/jack/miniconda3/envs/nowcasting_dataset/lib/python3.9/concurrent/futures/thread.py", line 52, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/home/jack/dev/ocf/nowcasting_dataset/nowcasting_dataset/dataset.py", line 118, in _get_example
    example_from_source = data_source.get_example(
  File "/home/jack/dev/ocf/nowcasting_dataset/nowcasting_dataset/data_sources/data_source.py", line 148, in get_example
    selected_data = self._get_cached_time_slice(t0_dt)
  File "/home/jack/dev/ocf/nowcasting_dataset/nowcasting_dataset/data_sources/data_source.py", line 66, in _get_cached_time_slice
    data = self._get_time_slice(t0_dt)
  File "/home/jack/dev/ocf/nowcasting_dataset/nowcasting_dataset/data_sources/nwp_data_source.py", line 103, in _get_time_slice
    init_times = self.data.sel(
  File "/home/jack/miniconda3/envs/nowcasting_dataset/lib/python3.9/site-packages/xarray/core/dataarray.py", line 1271, in sel
    ds = self._to_temp_dataset().sel(
  File "/home/jack/miniconda3/envs/nowcasting_dataset/lib/python3.9/site-packages/xarray/core/dataset.py", line 2365, in sel
    pos_indexers, new_indexes = remap_label_indexers(
  File "/home/jack/miniconda3/envs/nowcasting_dataset/lib/python3.9/site-packages/xarray/core/coordinates.py", line 421, in remap_label_indexers
    pos_indexers, new_indexes = indexing.remap_label_indexers(
  File "/home/jack/miniconda3/envs/nowcasting_dataset/lib/python3.9/site-packages/xarray/core/indexing.py", line 274, in remap_label_indexers
    idxr, new_idx = convert_label_indexer(index, label, dim, method, tolerance)
  File "/home/jack/miniconda3/envs/nowcasting_dataset/lib/python3.9/site-packages/xarray/core/indexing.py", line 200, in convert_label_indexer
    indexer = get_indexer_nd(index, label, method, tolerance)
  File "/home/jack/miniconda3/envs/nowcasting_dataset/lib/python3.9/site-packages/xarray/core/indexing.py", line 101, in get_indexer_nd
    flat_indexer = index.get_indexer(flat_labels, method=method, tolerance=tolerance)
  File "/home/jack/miniconda3/envs/nowcasting_dataset/lib/python3.9/site-packages/pandas/core/indexes/base.py", line 3442, in get_indexer
    raise InvalidIndexError(self._requires_unique_msg)
pandas.errors.InvalidIndexError: Reindexing only valid with uniquely valued Index objects


                                                                                                                      

In [None]:
# torch.save(model.state_dict(), 'model_state_dict.pt')