In [1]:
import numpy as np
import xarray as xr
from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe
import torch
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# test notebook for pretrained cnn feature extraction

In [2]:
from typing import Callable, Iterable, Optional, Union
from scipy.ndimage import zoom
import datetime as dt


PRETRAINED_OUTPUT_DIMS = 1_000
NWP_VARIABLE_NUM = 17

def _downsample_and_process_for_pretrained_input(nwp_data_time_slice: xr.DataArray) -> torch.Tensor:
    nwp_data_time_slice = nwp_data_time_slice.as_numpy().values
    nwp_data_time_slice = np.nan_to_num(nwp_data_time_slice)
    nwp_data_time_slice = np.tile(zoom(nwp_data_time_slice, (1, 64/704, 64/548), order=1), (3, 1, 1, 1)).reshape(17, 3, 64, 64)
    return torch.from_numpy(nwp_data_time_slice)

def _downsample_pretrained_output(model_output: torch.Tensor) -> np.ndarray:
    with torch.no_grad():
        output = torch.softmax(model_output, 1)
        output = np.split(output, range(200, PRETRAINED_OUTPUT_DIMS, 200), axis=1)
        dsampled_output = torch.concat([torch.linalg.norm(x, axis=1).reshape(1, NWP_VARIABLE_NUM) for x in output]).T
        assert dsampled_output.shape == (NWP_VARIABLE_NUM, 5)
        return dsampled_output.numpy().flatten()


@functional_datapipe("process_nwp_pretrained")
class ProcessNWPPretrainedIterDataPipe(IterDataPipe):
    def __init__(self, base_nwp_datapipe: IterDataPipe, step: int, pretrained_model: Callable[[torch.Tensor], torch.Tensor], interpolate: bool = False, interpolation_timepoints = Optional[Iterable[Union[dt.datetime, np.datetime64]]]) -> None:
        if interpolate:
            assert interpolation_timepoints is not None, "Must provide points for interpolation."
        self.source_datapipe = base_nwp_datapipe
        self.step = step
        self.pretrained_model = pretrained_model
        self.interpolation_timepoints = interpolation_timepoints
    def __iter__(self):
        for nwp in self.source_datapipe:
            nwp = nwp.isel(step=self.step)  # select the horizon we want
            nwp.interp(init_time_utc=self.interpolation_timepoints, method="cubic")  # interpolate to perscribed points
            for time, nwp_by_init_time in nwp.groupby("init_time_utc"):
                # at each time point, pass the nwp data to pretrained model
                data = _downsample_and_process_for_pretrained_input(nwp_by_init_time)
                data = self.pretrained_model(data)
                data = _downsample_pretrained_output(data)
                yield (time, data)

In [4]:
from torchvision.models import resnet101
from ocf_datapipes.load.nwp.nwp import OpenNWPIterDataPipe
import pandas as pd

nwp_path = "gs://solar-pv-nowcasting-data/NWP/UK_Met_Office/UKV_intermediate_version_3.zarr/"
base_nwp_dpipe =  OpenNWPIterDataPipe(nwp_path)
model = resnet101(pretrained=True)
process_nwp = ProcessNWPPretrainedIterDataPipe(base_nwp_dpipe, 4, model, interpolate=True, interpolation_timepoints=pd.date_range("2020-01-01 00:00", "2022-01-01 00:00", freq="0.5H"))



In [10]:
output_size = 5*17
obs = 1_000
results = []

count = 0
for x in process_nwp:
    results.append(x)
    count +=1
    
    if count == 1_000:
        break

In [15]:
results

[(numpy.datetime64('2020-01-01T00:00:00.000000000'),
  array([1.67538300e-01, 3.08745682e-01, 2.89069000e-03, 2.38098484e-03,
         8.51276983e-03, 3.99928467e-05, 1.27211772e-02, 9.16136801e-01,
         1.62498429e-02, 3.88147198e-02, 9.43456497e-03, 2.62956936e-02,
         8.40119347e-02, 1.34962067e-01, 3.65997329e-02, 1.19008645e-02,
         2.44912179e-03, 3.22482251e-02, 2.45866373e-01, 1.46334887e-01,
         4.13943827e-01, 2.14200228e-01, 1.45370672e-02, 2.20285840e-02,
         1.53110707e-02, 4.58092615e-03, 3.27864569e-03, 5.53700440e-02,
         3.44276249e-01, 1.35193303e-01, 4.04100865e-02, 5.40920794e-02,
         3.28413844e-02, 4.25724834e-02, 3.89460064e-02, 6.08137883e-02,
         3.66210610e-01, 1.53056374e-02, 9.67598986e-03, 3.66750136e-02,
         3.85019407e-02, 2.48547923e-02, 5.84433153e-02, 2.50712391e-02,
         3.81160714e-02, 4.54318635e-02, 7.83828873e-05, 2.51499772e-01,
         9.58697870e-02, 4.21592444e-01, 4.73523542e-05, 4.58744034e-05

In [None]:
import pickle 

with open('/home/tom/local_data/pretrained_nwp_processing_step_0.pkl', 'rb') as f:
    results = pickle.load(f)