# Limited Reproduction of Google Weather Prediction Model

This notebook attempts a limited reproduction of the weather prediction model described in [here](https://cloud.google.com/blog/topics/sustainability/weather-prediction-with-ai). It also uses precipitation data from the GPM mission, however to make this model's predictions more directly comparable to the other models developed in this project, we will use the data already collected as shown in `data/process.ipynb`.

In [6]:
from data.data import process_gs_rainfall_daily
from data.data import get_files

n_images = 100

# Load files as needed to avoid using too much RAM
image_files = get_files(process_gs_rainfall_daily(), '*.npy', n_images)

In [9]:
%%time
import numpy as np

train_frac = 0.8
val_frac = 0.1

train_n = int(n_images * train_frac)
val_n = int(n_images * val_frac)

ref_image = np.load(image_files[0])
image_size = ref_image.shape
train_size = train_n * image_size[0] * image_size[1]

train_mean = np.sum([
                        np.load(file).sum() for file in image_files[:train_n]
                        ]) / train_size
train_std = np.sqrt(
                        np.sum([
                            ((np.load(file) - train_mean)**2).sum() for file in image_files[:train_n]
                            ]) / (train_size - 1)
                        )
print(f'Training mean: {train_mean:.4f}')
print(f'Training standard deviation: {train_std:.4f}')

Training mean: 0.2265
Training standard deviation: 0.2378
CPU times: user 166 ms, sys: 110 ms, total: 276 ms
Wall time: 280 ms


In [11]:
# See: https://github.com/GoogleCloudPlatform/python-docs-samples/blob/main/people-and-planet-ai/weather-forecasting/serving/weather-model/weather/model.py
from __future__ import annotations
from typing import Any as AnyType, Optional
from datasets.arrow_dataset import Dataset

import torch
from transformers import PretrainedConfig, PreTrainedModel

class WeatherConfig(PretrainedConfig):

    model_type = "weather"

    def __init__(
        self,
        mean: list = [],
        std: list = [],
        num_inputs: int = 5,
        num_hidden1: int = 64,
        num_hidden2: int = 128,
        num_outputs: int = 1,
        kernel_size: tuple[int, int] = (3, 3),
        **kwargs: AnyType,
    ) -> None:
        self.mean = mean
        self.std = std
        self.num_inputs = num_inputs
        self.num_hidden1 = num_hidden1
        self.num_hidden2 = num_hidden2
        self.num_outputs = num_outputs
        self.kernel_size = kernel_size
        super().__init__(**kwargs)

class WeatherModel():

    config_class = WeatherConfig

    def __init__(self, config: WeatherConfig) -> None:
        super().__init__(config)
        self.layers = torch.nn.Sequential(
            Normalization(config.mean, config.std),
            MoveDim(-1, 1),  # convert to channels-first
            torch.nn.Conv2d(config.num_inputs, config.num_hidden1, config.kernel_size),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(
                config.num_hidden1, config.num_hidden2, config.kernel_size
            ),
            torch.nn.ReLU(),
            MoveDim(1, -1),  # convert to channels-last
            torch.nn.Linear(config.num_hidden2, config.num_outputs),
            torch.nn.ReLU(),  # precipitation cannot be negative
        )

    def forward(
        self, inputs: torch.Tensor, labels: Optional[torch.Tensor] = None
    ) -> dict[str, torch.Tensor]:

        predictions = self.layers(inputs)
        if labels is None:
            return {"logits": predictions}

        loss_fn = torch.nn.SmoothL1Loss()
        loss = loss_fn(predictions, labels)
        return {"loss": loss, "logits": predictions}

    @staticmethod
    def create(inputs: Dataset, **kwargs: AnyType) -> WeatherModel:
        data = np.array(inputs, np.float32)
        mean = data.mean(axis=(0, 1, 2))[None, None, None, :]
        std = data.std(axis=(0, 1, 2))[None, None, None, :]
        config = WeatherConfig(mean.tolist(), std.tolist(), **kwargs)
        return WeatherModel(config)

    def predict(self, inputs: AnyType) -> np.ndarray:
        return self.predict_batch(torch.as_tensor([inputs]))[0]

    def predict_batch(self, inputs_batch: AnyType) -> np.ndarray:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = self.to(device)
        with torch.no_grad():
            outputs = model(torch.as_tensor(inputs_batch, device=device))
            predictions = outputs["logits"]
            return predictions.cpu().numpy()

class Normalization(torch.nn.Module):
    def __init__(self, mean: AnyType, std: AnyType) -> None:
        super().__init__()
        self.mean = torch.nn.Parameter(torch.as_tensor(mean))
        self.std = torch.nn.Parameter(torch.as_tensor(std))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return (x - self.mean) / self.std

class MoveDim(torch.nn.Module):
    def __init__(self, src: int, dest: int) -> None:
        super().__init__()
        self.src = src
        self.dest = dest

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.moveaxis(self.src, self.dest)