In [1]:
from nowcasting_dataset import data_sources
from nowcasting_dataset import dataset
import nowcasting_dataset.time as nd_time

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl

In [3]:
FILENAME = 'gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr'

In [4]:
HISTORY_LEN = 0
FORECAST_LEN = 1
TOTAL_SEQ_LEN = HISTORY_LEN + FORECAST_LEN + 1

In [5]:
sat_data_source = data_sources.SatelliteDataSource(
    filename=FILENAME,
    consolidated=False,
    channels=('HRV',),
    image_size_pixels=128,
    history_len=HISTORY_LEN,
    forecast_len=FORECAST_LEN
)

In [6]:
%%time
t0_datetimes = nd_time.get_t0_datetimes(
    sat_data_source.datetime_index(),
    total_seq_len=TOTAL_SEQ_LEN,
    history_len=HISTORY_LEN)

CPU times: user 1.82 s, sys: 105 ms, total: 1.92 s
Wall time: 4.13 s


In [7]:
len(t0_datetimes)

162789

In [8]:
ds = dataset.NowcastingDataset(
    batch_size=32,
    n_samples_per_timestep=4,
    data_sources=[sat_data_source],
    t0_datetimes=t0_datetimes)

In [9]:
dataloader = torch.utils.data.DataLoader(
    ds,
    pin_memory=True,
    num_workers=16,
    worker_init_fn=dataset.worker_init_fn,
    prefetch_factor=256,
    batch_size=None,
    batch_sampler=None)

## Define very simple ML model

In [10]:
def normalise_images_in_model(images, device):    
    SAT_IMAGE_MEAN = torch.tensor(93.23458, dtype=torch.float, device=device)
    SAT_IMAGE_STD = torch.tensor(115.34247, dtype=torch.float, device=device)
    
    images = images.float()
    images -= SAT_IMAGE_MEAN
    images /= SAT_IMAGE_STD
    return images

In [11]:
CHANNELS = 32
KERNEL = 3


class LitAutoEncoder(pl.LightningModule):
    def __init__(
        self,
        history_len: int=1
    ):
        super().__init__()
        self.history_len = history_len
        
        self.encoder_conv1 = nn.Conv2d(in_channels=1, out_channels=CHANNELS//2, kernel_size=KERNEL)
        self.encoder_conv2 = nn.Conv2d(in_channels=CHANNELS//2, out_channels=CHANNELS, kernel_size=KERNEL)
        self.encoder_conv3 = nn.Conv2d(in_channels=CHANNELS, out_channels=CHANNELS, kernel_size=KERNEL)

        self.maxpool = nn.MaxPool2d(kernel_size=KERNEL)
        
        self.fc1 = nn.Linear(
            in_features=CHANNELS * 11 * 11, 
            out_features=256  # Minus 2 (2 for the NWP temperature above the PV system)
        )
        self.fc2 = nn.Linear(in_features=256, out_features=128)
        self.fc3 = nn.Linear(in_features=128, out_features=128)
        self.fc4 = nn.Linear(in_features=128, out_features=128)
        self.fc5 = nn.Linear(in_features=128, out_features=1)

        
    def forward(self, x):
        images = x['sat_data'][:, self.history_len:, :, :, 0]
        images = normalise_images_in_model(images, self.device)
        
        # Pass data through the network :)
        out = F.relu(self.encoder_conv1(images))
        out = self.maxpool(out)
        out = F.relu(self.encoder_conv2(out))
        out = self.maxpool(out)
        out = F.relu(self.encoder_conv3(out))
        
        out = out.view(-1, CHANNELS * 11 * 11)
        out = F.relu(self.fc1(out))
        
        out = torch.cat(
            (
                out, 
                #(x['nwp_above_pv'][:, 0] - 130) / 5,  # TODO fix horrible standardisation of temperature!
            ), dim=1)

        out = F.relu(self.fc2(out))
        out = F.relu(self.fc3(out))
        out = F.relu(self.fc4(out))
        out = self.fc5(out)

        return out
    
    def _training_or_validation_step(self, batch, is_train_step):
        y_hat = self(batch)
        #y = batch['pv_yield'][:, self.history_len:]
        y = torch.rand((32, 1), device=self.device)
        #mse_loss = F.mse_loss(y_hat, y)
        mae_loss = (y_hat - y).abs().mean()
        tag = "Train" if is_train_step else "Validation"
        #self.log_dict({'MSE/' + tag: mse_loss}, on_step=is_train_step, on_epoch=True)
        self.log_dict({'MAE/' + tag: mae_loss}, on_step=is_train_step, on_epoch=True)
        return mae_loss

    def training_step(self, batch, batch_idx):
        return self._training_or_validation_step(batch, is_train_step=True)
    
    def validation_step(self, batch, batch_idx):
        return self._training_or_validation_step(batch, is_train_step=False)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

In [12]:
model = LitAutoEncoder()

In [13]:
trainer = pl.Trainer(gpus=1)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(model, train_dataloader=dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type      | Params
--------------------------------------------
0 | encoder_conv1 | Conv2d    | 160   
1 | encoder_conv2 | Conv2d    | 4.6 K 
2 | encoder_conv3 | Conv2d    | 9.2 K 
3 | maxpool       | MaxPool2d | 0     
4 | fc1           | Linear    | 991 K 
5 | fc2           | Linear    | 32.9 K
6 | fc3           | Linear    | 16.5 K
7 | fc4           | Linear    | 16.5 K
8 | fc5           | Linear    | 129   
--------------------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params
4.286     Total estimated model params size (MB)


Epoch 0: : 0it [00:00, ?it/s]              

Exception ignored in: <finalize object at 0x7f9120c0a700; dead>
Traceback (most recent call last):
  File "/home/jack/miniconda3/envs/nowcasting_dataset/lib/python3.9/weakref.py", line 580, in __call__
    return info.func(*info.args, **(info.kwargs or {}))
  File "/home/jack/miniconda3/envs/nowcasting_dataset/lib/python3.9/site-packages/gcsfs/core.py", line 245, in close_session
    sync(loop, session.close, timeout=0.1)
  File "/home/jack/miniconda3/envs/nowcasting_dataset/lib/python3.9/site-packages/fsspec/asyn.py", line 63, in sync
    raise FSTimeoutError
fsspec.exceptions.FSTimeoutError: 
Exception ignored in: <finalize object at 0x7f9120c0a700; dead>
Traceback (most recent call last):
  File "/home/jack/miniconda3/envs/nowcasting_dataset/lib/python3.9/weakref.py", line 580, in __call__
    return info.func(*info.args, **(info.kwargs or {}))
  File "/home/jack/miniconda3/envs/nowcasting_dataset/lib/python3.9/site-packages/gcsfs/core.py", line 245, in close_session
    sync(loop, 

Epoch 0: : 5it [00:07,  1.51s/it, loss=0.468, v_num=77]

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Epoch 0: : 4753it [03:03, 25.87it/s, loss=0.244, v_num=77]