In [None]:
from nowcasting_dataset import data_sources
from nowcasting_dataset import dataset
import nowcasting_dataset.time as nd_time

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl

In [None]:
FILENAME = 'gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr'

In [None]:
HISTORY_LEN = 0
FORECAST_LEN = 1
TOTAL_SEQ_LEN = HISTORY_LEN + FORECAST_LEN + 1

In [None]:
sat_data_source = data_sources.SatelliteDataSource(
    filename=FILENAME,
    consolidated=False,
    channels=('HRV',),
    image_size_pixels=128,
    history_len=HISTORY_LEN,
    forecast_len=FORECAST_LEN
)

In [None]:
%%time
t0_datetimes = nd_time.get_t0_datetimes(
    sat_data_source.datetime_index(),
    total_seq_len=TOTAL_SEQ_LEN,
    history_len=HISTORY_LEN)

In [None]:
len(t0_datetimes)

In [None]:
ds = dataset.NowcastingDataset(
    batch_size=32,
    n_samples_per_timestep=4,
    data_sources=[sat_data_source],
    t0_datetimes=t0_datetimes)

In [None]:
dataloader = torch.utils.data.DataLoader(
    ds,
    pin_memory=True,
    num_workers=16,
    worker_init_fn=dataset.worker_init_fn,
    prefetch_factor=256,
    batch_size=None,
    batch_sampler=None)

## Define very simple ML model

In [None]:
def normalise_images_in_model(images, device):    
    SAT_IMAGE_MEAN = torch.tensor(93.23458, dtype=torch.float, device=device)
    SAT_IMAGE_STD = torch.tensor(115.34247, dtype=torch.float, device=device)
    
    images = images.float()
    images -= SAT_IMAGE_MEAN
    images /= SAT_IMAGE_STD
    return images

In [None]:
CHANNELS = 32
KERNEL = 3


class LitAutoEncoder(pl.LightningModule):
    def __init__(
        self,
        history_len: int=1
    ):
        super().__init__()
        self.history_len = history_len
        
        self.encoder_conv1 = nn.Conv2d(in_channels=1, out_channels=CHANNELS//2, kernel_size=KERNEL)
        self.encoder_conv2 = nn.Conv2d(in_channels=CHANNELS//2, out_channels=CHANNELS, kernel_size=KERNEL)
        self.encoder_conv3 = nn.Conv2d(in_channels=CHANNELS, out_channels=CHANNELS, kernel_size=KERNEL)

        self.maxpool = nn.MaxPool2d(kernel_size=KERNEL)
        
        self.fc1 = nn.Linear(
            in_features=CHANNELS * 11 * 11, 
            out_features=256  # Minus 2 (2 for the NWP temperature above the PV system)
        )
        self.fc2 = nn.Linear(in_features=256, out_features=128)
        self.fc3 = nn.Linear(in_features=128, out_features=128)
        self.fc4 = nn.Linear(in_features=128, out_features=128)
        self.fc5 = nn.Linear(in_features=128, out_features=1)

        
    def forward(self, x):
        images = x['sat_data'][:, self.history_len:, :, :, 0]
        images = normalise_images_in_model(images, self.device)
        
        # Pass data through the network :)
        out = F.relu(self.encoder_conv1(images))
        out = self.maxpool(out)
        out = F.relu(self.encoder_conv2(out))
        out = self.maxpool(out)
        out = F.relu(self.encoder_conv3(out))
        
        out = out.view(-1, CHANNELS * 11 * 11)
        out = F.relu(self.fc1(out))
        
        out = torch.cat(
            (
                out, 
                #(x['nwp_above_pv'][:, 0] - 130) / 5,  # TODO fix horrible standardisation of temperature!
            ), dim=1)

        out = F.relu(self.fc2(out))
        out = F.relu(self.fc3(out))
        out = F.relu(self.fc4(out))
        out = self.fc5(out)

        return out
    
    def _training_or_validation_step(self, batch, is_train_step):
        y_hat = self(batch)
        #y = batch['pv_yield'][:, self.history_len:]
        y = torch.rand((32, 1), device=self.device)
        #mse_loss = F.mse_loss(y_hat, y)
        mae_loss = (y_hat - y).abs().mean()
        tag = "Train" if is_train_step else "Validation"
        #self.log_dict({'MSE/' + tag: mse_loss}, on_step=is_train_step, on_epoch=True)
        self.log_dict({'MAE/' + tag: mae_loss}, on_step=is_train_step, on_epoch=True)
        return mae_loss

    def training_step(self, batch, batch_idx):
        return self._training_or_validation_step(batch, is_train_step=True)
    
    def validation_step(self, batch, batch_idx):
        return self._training_or_validation_step(batch, is_train_step=False)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

In [None]:
model = LitAutoEncoder()

In [None]:
trainer = pl.Trainer(gpus=1)

In [None]:
trainer.fit(model, train_dataloader=dataloader)