In [1]:
# General imports
from pathlib import Path
from dataclasses import dataclass
from typing import Optional, Iterable, Any
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

plt.rcParams["figure.figsize"] = (18, 10)
plt.rcParams["figure.facecolor"] = "white"

# ML imports
import torch
from torch.utils import data
from torch import nn
import torch.nn.functional as F
import einops
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import wandb
import sklearn.manifold

# power_perceiver imports
from power_perceiver.load_prepared_batches.prepared_dataset import PreparedDataset
from power_perceiver.consts import BatchKey
from power_perceiver.load_prepared_batches.data_loader import HRVSatellite, PV, Sun
from power_perceiver.xr_batch_processor import SelectPVSystemsNearCenterOfImage, ReduceNumPVSystems, ReduceNumTimesteps
from power_perceiver.np_batch_processor import EncodeSpaceTime, Topography
from power_perceiver.transforms.satellite import PatchSatellite
from power_perceiver.transforms.pv import PVPowerRollingWindow

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#DATA_PATH = Path("~/dev/ocf/power_perceiver/data_for_testing/").expanduser()

DATA_PATH = Path(
    "/mnt/storage_ssd_4tb/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/prepared_ML_training_data/v15/")
assert DATA_PATH.exists()

In [3]:
def get_dataloader(data_path: Path, tag: str) -> data.DataLoader:
    assert tag in ["train", "validation"]
    assert data_path.exists()
    
    xr_batch_processors = [
        SelectPVSystemsNearCenterOfImage(),
        ReduceNumPVSystems(requested_num_pv_systems=8),
        ]
    
    if tag == "train":
        xr_batch_processors.append(ReduceNumTimesteps(requested_timesteps=4))
    
    dataset = PreparedDataset(
        data_path=data_path,
        data_loaders=[
            HRVSatellite(
                transforms=[PatchSatellite()]
                ), 
            PV(
                transforms=[PVPowerRollingWindow()]
                ),
            Sun(),
        ],
        xr_batch_processors=xr_batch_processors,
        np_batch_processors=[
            EncodeSpaceTime(),
            Topography("/home/jack/europe_dem_2km_osgb.tif"),
            ]
    )

    dataloader = data.DataLoader(
        dataset,
        batch_size=None,
        num_workers=16,
        pin_memory=True,
    )
    
    return dataloader

#train_dataloader = get_dataloader(DATA_PATH)
train_dataloader = get_dataloader(DATA_PATH / "train", tag="train")
val_dataloader = get_dataloader(DATA_PATH / "test", tag="validation")

  return self._crs.to_proj4(version=version)
  return self._crs.to_proj4(version=version)
  return self._crs.to_proj4(version=version)
  return self._crs.to_proj4(version=version)


In [4]:
for batch in train_dataloader:
    break

In [5]:
batch[BatchKey.pv].shape

torch.Size([31, 4, 8])

In [6]:
batch[BatchKey.pv_time_utc].shape

torch.Size([31, 4])

In [None]:
batch[BatchKey.pv].shape

In [None]:
from power_perceiver.pytorch_modules.satellite_processor import HRVSatelliteProcessor
from power_perceiver.pytorch_modules.query_generator import QueryGenerator
from power_perceiver.pytorch_modules.self_attention import MultiLayerTransformerEncoder



@dataclass(eq=False)  # See https://discuss.pytorch.org/t/typeerror-unhashable-type-for-my-torch-nn-module/109424/6
class Model(pl.LightningModule):
    # Params for Perceiver
    query_dim: int = 36  # byte_array and query will be automatically padded with zeros to get to this size.
    num_fourier_features: int = 16 # TOTAL for both x and y
    pv_system_id_embedding_dim: int = 16
    num_heads: int = 6
    dropout: float = 0.0
    share_weights_across_latent_transformer_layers: bool = False
    num_latent_transformer_encoders: int = 4
    
    # Other params:
    num_elements_query_padding: int = 0  # Probably keep this at zero while using MultiLayerTransformerEncoder

    def __post_init__(self):
        super().__init__()
        self.hrvsatellite_processor = HRVSatelliteProcessor()
        
        self.query_generator = QueryGenerator(
            num_fourier_features=self.num_fourier_features,  # TOTAL (for both x and y)
            pv_system_id_embedding_dim=self.pv_system_id_embedding_dim,
            num_elements_query_padding=self.num_elements_query_padding)
               
        self.transformer_encoder = MultiLayerTransformerEncoder(
            d_model=self.query_dim, 
            num_heads=self.num_heads,
            dropout=self.dropout,
            share_weights_across_latent_transformer_layers=self.share_weights_across_latent_transformer_layers,
            num_latent_transformer_encoders=self.num_latent_transformer_encoders,
            )
        
        """
        self.perceiver = Perceiver(
            query_dim=self.query_dim,
            byte_array_dim=self.byte_array_dim,
            num_heads=self.num_heads,
            dropout=self.dropout,
            share_weights_across_latent_transformer_layers=self.share_weights_across_latent_transformer_layers,
            num_latent_transformer_encoders=self.num_latent_transformer_encoders,
            )
        """
        
        self.output_module = nn.Sequential(
            nn.Linear(in_features=self.query_dim, out_features=self.query_dim),
            nn.ReLU(),
            nn.Linear(in_features=self.query_dim, out_features=1),
        )

        # Do this at the end of __post_init__ to capture model topology:
        self.save_hyperparameters()
        
    def forward(self, x: dict[BatchKey, torch.Tensor]) -> torch.Tensor:       
        original_batch_size = x[BatchKey.pv].shape[0]
        byte_array = self.hrvsatellite_processor(x)
        query = self.query_generator(x)
        
        # Pad with zeros if necessary to get up to self.query_dim:
        byte_array = self._maybe_pad_with_zeros(byte_array)
        query = self._maybe_pad_with_zeros(query)            
        
        # Prepare the attention input and run through the transformer_encoder:
        attn_input = torch.concat((byte_array, query), dim=1)            
        attn_output = self.transformer_encoder(attn_input)

        # Select the elements of the output which correspond to the query:
        out = attn_output[:, byte_array.shape[1]:]
        
        out = self.output_module(out)
        
        # Reshape back to (batch_size, n_timesteps, ...)
        return einops.rearrange(
            out, 
            "(batch_size n_timesteps) ... -> batch_size n_timesteps ...", 
            batch_size=original_batch_size)
        
    def _maybe_pad_with_zeros(self, tensor: torch.Tensor) -> torch.Tensor:
        num_zeros_to_pad = self.query_dim - tensor.shape[-1]        
        assert num_zeros_to_pad >= 0, f"{self.query_dim=}, {tensor.shape=}"
        if num_zeros_to_pad > 0:
            zero_padding_shape = tensor.shape[:2] + (num_zeros_to_pad,)
            zero_padding = torch.zeros(*zero_padding_shape, dtype=tensor.dtype, device=tensor.device)
            tensor = torch.concat((tensor, zero_padding), dim=2)
        return tensor
    
    def _training_or_validation_step(
            self, 
            batch: dict[BatchKey, torch.Tensor], 
            batch_idx: int, 
            tag: str
        ) -> dict[str, object]:
        """
        Args:
            batch: The training or validation batch.  A dictionary.
            tag: Either "train" or "validation"
            batch_idx: The index of the batch.
        """
        actual_pv_power = batch[BatchKey.pv]
        #actual_pv_power = torch.nan_to_num(actual_pv_power, nan=0.0)
        actual_pv_power = torch.where(
            batch[BatchKey.pv_mask].unsqueeze(1), 
            actual_pv_power, 
            torch.tensor(0.0, dtype=actual_pv_power.dtype, device=actual_pv_power.device))

        predicted_pv_power = self(batch).squeeze()
        #mse_loss = F.mse_loss(predicted_pv_power, actual_pv_power, reduction="none").mean(dim=1).float()
        #mse_loss = masked_mean(mse_loss, mask=batch[BatchKey.pv_mask])        
        mse_loss = F.mse_loss(predicted_pv_power, actual_pv_power)
        
        self.log(f"{tag}/mse", mse_loss)
        
        return {
            'loss': mse_loss,
            'predicted_pv_power': predicted_pv_power,
            }
    
    def training_step(self, batch: dict[BatchKey, torch.Tensor], batch_idx: int) -> dict[str, object]:
        return self._training_or_validation_step(batch=batch, batch_idx=batch_idx, tag="train")
    
    def validation_step(self, batch: dict[BatchKey, torch.Tensor], batch_idx: int) -> dict[str, object]:
        return self._training_or_validation_step(batch=batch, batch_idx=batch_idx, tag="validation")
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        return optimizer

In [None]:
model = Model()
#model = Model.load_from_checkpoint(
#    "~/dev/ocf/power_perceiver/notebooks/2022-04-04_train_ML_model/model.ckpt")


In [None]:
from typing import Callable


def forward_pre_hook(module, args) -> tuple:
    """A simple hook to set `need_weights` to True."""
    query, key, value, key_padding_mask, need_weights, attn_mask = args
    need_weights = True
    return query, key, value, key_padding_mask, need_weights, attn_mask

attn_weights = {}
def get_attn_weights(name: str) -> Callable:
    # Adapted from https://web.stanford.edu/~nanbhas/blog/forward-hooks-pytorch/
    def forward_hook(module, input, output):
        attn_output, attn_output_weights = output
        attn_weights[name] = attn_output_weights
    return forward_hook

pre_hook = model.transformer_encoder.transformer_encoder.layers[0].self_attn.register_forward_pre_hook(forward_pre_hook)
f_hook = model.transformer_encoder.transformer_encoder.layers[0].self_attn.register_forward_hook(
    get_attn_weights(name="layer0"))

In [None]:
model.transformer_encoder.transformer_encoder.layers[0]

In [None]:
model_output = model(batch)
model_output.dtype

In [None]:
pre_hook.remove()
f_hook.remove()


In [None]:
attn_weights["layer0"].shape

In [None]:
batch[BatchKey.hrvsatellite_x_osgb].shape

In [None]:
import cartopy.crs as ccrs

In [None]:
BATCH_IDX = 11
TIMESTEP_IDX = 0
PV_SYSTEM_IDX = 3

projection = ccrs.OSGB(approx=False)

nrows = 1
ncols = 3
shape = (nrows, ncols)

ax1 = plt.subplot2grid(shape, loc=(0, 0), projection=projection)
ax2 = plt.subplot2grid(shape, loc=(0, 1), projection=projection)
ax3 = plt.subplot2grid(shape, loc=(0, 2))

ax1.set_title("Attention")
ax1.pcolormesh(
    batch[BatchKey.hrvsatellite_x_osgb][BATCH_IDX],
    batch[BatchKey.hrvsatellite_y_osgb][BATCH_IDX],
    attn_weights["layer0"][(4*BATCH_IDX)+TIMESTEP_IDX, 256+PV_SYSTEM_IDX][:256].detach().numpy().reshape(16, 16),
    )

date = pd.to_datetime(batch[BatchKey.hrvsatellite_time_utc][BATCH_IDX, TIMESTEP_IDX], unit="s")

ax2.set_title(f"Satellite {date}")

ax2.pcolormesh(
    batch[BatchKey.hrvsatellite_x_osgb][BATCH_IDX].numpy().repeat(4, axis=0).repeat(4, axis=1),
    batch[BatchKey.hrvsatellite_y_osgb][BATCH_IDX].numpy().repeat(4, axis=0).repeat(4, axis=1),
    einops.rearrange(
        batch[BatchKey.hrvsatellite][BATCH_IDX, TIMESTEP_IDX, 0], 
        "y x (patch_size_y patch_size_x) -> (y patch_size_y) (x patch_size_x)",
        patch_size_y=4,
        patch_size_x=4
        ),
)

ax3.set_title("Satellite unprojected")
ax3.imshow(
    einops.rearrange(
        batch[BatchKey.hrvsatellite][BATCH_IDX, TIMESTEP_IDX, 0], 
        "y x (patch_size_y patch_size_x) -> (y patch_size_y) (x patch_size_x)",
        patch_size_y=4,
        patch_size_x=4
        ),
    extent=(
        batch[BatchKey.hrvsatellite_x_osgb][BATCH_IDX][-1, 0],
        batch[BatchKey.hrvsatellite_x_osgb][BATCH_IDX][0, -1],
        batch[BatchKey.hrvsatellite_y_osgb][BATCH_IDX][0, -1],
        batch[BatchKey.hrvsatellite_y_osgb][BATCH_IDX][-1, 0],
    ), # left, right, bottom, top
    origin="lower",
)


for ax in (ax1, ax2, ax3):
    ax.scatter(
        batch[BatchKey.pv_x_osgb][BATCH_IDX, PV_SYSTEM_IDX],
        batch[BatchKey.pv_y_osgb][BATCH_IDX, PV_SYSTEM_IDX],
        color="white",
    )
    
for ax in (ax1, ax2):
    ylim = ax.get_ylim()
    xlim = ax.get_xlim()
    BORDER_METERS = 50_000
    ax.set_ylim(ylim[0]-BORDER_METERS, ylim[1]+BORDER_METERS)
    ax.set_xlim(xlim[0]-BORDER_METERS, xlim[1]+BORDER_METERS)

    ax.coastlines()

# TODO:
# Plot with OSGB coords
# Overlay coastline so we can see what's what
# Plot location of PV system

In [None]:
wandb_logger = WandbLogger(
    project="power_perceiver", 
    entity="openclimatefix",
    log_model="all",
    )

# log gradients, parameter histogram and model topology
wandb_logger.watch(model, log="all")

In [None]:
trainer = pl.Trainer(
    gpus=[3],
    max_epochs=-1,
    logger=wandb_logger,
    callbacks=[
        LogTimeseriesPlots(),
        LogTSNEPlot(),
    ]
    )

In [None]:
trainer.fit(
    model=model, 
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
    )