In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torchvision import transforms
import numpy as np
import tqdm
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import xarray as xr
import datetime

# custom
import common.loss_utils as loss_utils
import common.climatehack_dataset as climatehack_dataset

import sys
sys.path.append('metnet')
import metnet

In [4]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [5]:
# SATELLITE_ZARR_PATH = "gs://public-datasets-eumetsat-solar-forecasting/satellite/EUMETSAT/SEVIRI_RSS/v3/eumetsat_seviri_hrv_uk.zarr"

SATELLITE_ZARR_PATH = 'data/full/eumetsat_seviri_hrv_uk.zarr/'

dataset = xr.open_dataset(
    SATELLITE_ZARR_PATH, 
    engine="zarr",
    chunks="auto",  # Load the data as a Dask array
)

print(dataset)


<xarray.Dataset>
Dimensions:  (time: 173624, y: 891, x: 1843)
Coordinates:
  * time     (time) datetime64[ns] 2020-01-01T00:05:00 ... 2021-11-07T15:50:00
  * x        (x) float32 2.8e+04 2.7e+04 2.6e+04 ... -1.813e+06 -1.814e+06
    x_osgb   (y, x) float32 dask.array<chunksize=(891, 1843), meta=np.ndarray>
  * y        (y) float32 4.198e+06 4.199e+06 4.2e+06 ... 5.087e+06 5.088e+06
    y_osgb   (y, x) float32 dask.array<chunksize=(891, 1843), meta=np.ndarray>
Data variables:
    data     (time, y, x) int16 dask.array<chunksize=(22, 891, 1843), meta=np.ndarray>


In [7]:
FORECAST = 10

model = metnet.MetNet(
        hidden_dim=32,
        forecast_steps=FORECAST, # 24 timesteps out
        input_channels=1, # 12 timeteps in
        output_channels=1, # 1 data channel in
        sat_channels=1, # 1 data channel in
        input_size=32, # =128/4, where 128 is the image dimensions
)
model = model.to(DEVICE)
model.load_state_dict(torch.load('metnet_epochs_1_batch_1750.pt'))


<All keys matched successfully>

In [10]:
BATCH_SIZE = 1
ds = climatehack_dataset.ClimatehackDataset(
    dataset, random_state=7,
    crops_per_slice=5,
    start_date=datetime.date(2020, 1, 1),
    end_date=datetime.date(2020, 1, 4),
)
ch_dataloader = DataLoader(ds, batch_size=BATCH_SIZE)


In [11]:
criterion = loss_utils.MS_SSIMLoss(channels=FORECAST)

In [12]:
_MAX_PIXEL = 1023
_MEAN = 0.1787

def transform(x):
    return (x / _MAX_PIXEL) - _MEAN

def inv_transform(x):
    return (x + _MEAN) * _MAX_PIXEL


In [13]:
def test_epoch(model, dl, criterion):
    model.train()

    epoch_loss = 0
    total_count = 0
    pbar = tqdm.tqdm(dl)
    for i, (srcs, trgs) in enumerate(pbar):

        x = transform(srcs).to(DEVICE)
        # we don't need to transform y
        y = trgs[:,:FORECAST].float().to(DEVICE)

        # metnet expects a channel for satellite channel (like RGB). But we only have 1 of those
        x = torch.unsqueeze(x, dim=2)
        
        with torch.no_grad():
            preds = model(x)
            
        # remove the satellite channel dimension from the prediction
        preds = torch.squeeze(preds, dim=2)
        preds = inv_transform(preds)

        loss = criterion(preds, y)
        
        l = loss.item()
        epoch_loss += l
        total_count += len(srcs)
        if i % 1 == 0:
            l = round(l, 4)
            pbar.set_description(f'Current MS-SSIM Loss: {l}')
            
    return epoch_loss / total_count

In [14]:
EPOCHS = 1

for i in range(EPOCHS):
    print(f"Epoch {i + 1}")
    test_epoch(model, ch_dataloader, criterion)

Epoch 1


Current MS-SSIM Loss: 0.105:  50%|███████████████████████████████████████████▌                                           | 15/30 [00:26<00:26,  1.74s/it]
