In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torchvision import transforms
import numpy as np
import tqdm
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import xarray as xr
import datetime

# custom
import common.loss_utils as loss_utils
import common.climatehack_dataset as climatehack_dataset

import sys
sys.path.append('metnet')
import metnet

In [3]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [4]:
# SATELLITE_ZARR_PATH = "gs://public-datasets-eumetsat-solar-forecasting/satellite/EUMETSAT/SEVIRI_RSS/v3/eumetsat_seviri_hrv_uk.zarr"
SATELLITE_ZARR_PATH = 'data/full/eumetsat_seviri_hrv_uk.zarr/'

dataset = xr.open_dataset(
    SATELLITE_ZARR_PATH, 
    engine="zarr",
    chunks="auto",  # Load the data as a Dask array
)

print(dataset)


<xarray.Dataset>
Dimensions:  (time: 173624, y: 891, x: 1843)
Coordinates:
  * time     (time) datetime64[ns] 2020-01-01T00:05:00 ... 2021-11-07T15:50:00
  * x        (x) float32 2.8e+04 2.7e+04 2.6e+04 ... -1.813e+06 -1.814e+06
    x_osgb   (y, x) float32 dask.array<chunksize=(891, 1843), meta=np.ndarray>
  * y        (y) float32 4.198e+06 4.199e+06 4.2e+06 ... 5.087e+06 5.088e+06
    y_osgb   (y, x) float32 dask.array<chunksize=(891, 1843), meta=np.ndarray>
Data variables:
    data     (time, y, x) int16 dask.array<chunksize=(22, 891, 1843), meta=np.ndarray>


In [5]:
BATCH_SIZE = 1
ds = climatehack_dataset.ClimatehackDataset(dataset, random_state=7)
ch_dataloader = torch.utils.data.DataLoader(
    ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
)


In [6]:
FORECAST = 10

model = metnet.MetNet(
        hidden_dim=32,
        forecast_steps=FORECAST, # 24 timesteps out
        input_channels=1, # 12 timeteps in
        output_channels=1, # 1 data channel in
        sat_channels=1, # 1 data channel in
        input_size=32, # =128/4, where 128 is the image dimensions
)
EXISTING = 1
print(f"LOADING EPOCHS {EXISTING}")
model.load_state_dict(torch.load(f'metnet_epochs_{EXISTING}.pt'))
model = model.to(DEVICE)


LOADING EPOCHS 1


In [7]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model)} trainable parameters')

The model has 1823649 trainable parameters


In [8]:
optimizer = optim.Adam(model.parameters())
criterion = loss_utils.MS_SSIMLoss(channels=FORECAST)

In [9]:
_MAX_PIXEL = 1023
_MEAN = 0.1787

def transform(x):
    return (x / _MAX_PIXEL) - _MEAN

def inv_transform(x):
    return (x + _MEAN) * _MAX_PIXEL


In [10]:
def train_epoch(model, dl, optimizer, criterion):
    model.train()

    epoch_loss = 0
    total_count = 0
    pbar = tqdm.tqdm(dl)
    for i, (srcs, trgs) in enumerate(pbar):
        optimizer.zero_grad()

        x = transform(srcs).to(DEVICE)
        # we don't need to transform y
        y = trgs[:,:FORECAST].float().to(DEVICE)

        # metnet expects a channel for satellite channel (like RGB). But we only have 1 of those
        x = torch.unsqueeze(x, dim=2)

        preds = model(x)
        # remove the satellite channel dimension from the prediction
        preds = torch.squeeze(preds, dim=2)
        preds = inv_transform(preds)

        loss = criterion(preds, y)

        loss.backward()
        optimizer.step()
        
        l = loss.item()
        epoch_loss += l
        total_count += len(srcs)
        if i % 1 == 0:
            l = round(l, 4)
            avg_loss = epoch_loss / total_count
            pbar.set_description(f'MS-SSIM Avg Loss, Batch Loss: {avg_loss, l}')

    return epoch_loss / total_count

In [None]:
EPOCHS = 20 + EXISTING + 1

for i in range(EXISTING + 1, EPOCHS):
    print(f"Epoch {i}")
    train_epoch(model, ch_dataloader, optimizer, criterion)
    torch.save(model.state_dict(), f'metnet_epochs_{EPOCHS}.pt')

Epoch 2


MS-SSIM Avg Loss, Batch Loss: (0.35108789537079943, 0.2627): 100%|█████████████████████████████████████████████████████| 676/676 [13:32<00:00,  1.20s/it]


Epoch 3


MS-SSIM Avg Loss, Batch Loss: (0.3342696348061928, 0.2368): 100%|██████████████████████████████████████████████████████| 676/676 [13:31<00:00,  1.20s/it]


Epoch 4


MS-SSIM Avg Loss, Batch Loss: (0.34877828230871954, 0.1787): 100%|█████████████████████████████████████████████████████| 676/676 [13:34<00:00,  1.20s/it]


Epoch 5


MS-SSIM Avg Loss, Batch Loss: (0.34351057546025904, 0.1039): 100%|█████████████████████████████████████████████████████| 676/676 [14:00<00:00,  1.24s/it]


Epoch 6


MS-SSIM Avg Loss, Batch Loss: (0.33726178453518796, 0.379): 100%|██████████████████████████████████████████████████████| 676/676 [13:47<00:00,  1.22s/it]


Epoch 7


MS-SSIM Avg Loss, Batch Loss: (0.3368464523165889, 0.2002): 100%|██████████████████████████████████████████████████████| 676/676 [13:19<00:00,  1.18s/it]


Epoch 8


MS-SSIM Avg Loss, Batch Loss: (0.3356718748102527, 0.0172): 100%|██████████████████████████████████████████████████████| 676/676 [13:57<00:00,  1.24s/it]


Epoch 9


MS-SSIM Avg Loss, Batch Loss: (0.33697222076224154, 0.6793): 100%|█████████████████████████████████████████████████████| 676/676 [15:40<00:00,  1.39s/it]


Epoch 10


MS-SSIM Avg Loss, Batch Loss: (0.3337864677405216, 0.545): 100%|███████████████████████████████████████████████████████| 676/676 [14:53<00:00,  1.32s/it]


Epoch 11


MS-SSIM Avg Loss, Batch Loss: (0.3366006387937704, 0.2564): 100%|██████████████████████████████████████████████████████| 676/676 [14:20<00:00,  1.27s/it]


Epoch 12


MS-SSIM Avg Loss, Batch Loss: (0.3360196580371913, 0.2516): 100%|██████████████████████████████████████████████████████| 676/676 [15:48<00:00,  1.40s/it]


Epoch 13


MS-SSIM Avg Loss, Batch Loss: (0.34347571882270495, 0.3102): 100%|█████████████████████████████████████████████████████| 676/676 [15:33<00:00,  1.38s/it]


Epoch 14


MS-SSIM Avg Loss, Batch Loss: (0.3323899324123676, 0.3376): 100%|██████████████████████████████████████████████████████| 676/676 [15:14<00:00,  1.35s/it]


Epoch 15


MS-SSIM Avg Loss, Batch Loss: (0.34053814622777456, 0.5606): 100%|█████████████████████████████████████████████████████| 676/676 [14:02<00:00,  1.25s/it]


Epoch 16


MS-SSIM Avg Loss, Batch Loss: (0.3391773364760659, 0.3246):  65%|███████████████████████████████████                   | 439/676 [09:51<06:39,  1.69s/it]

In [None]:
torch.save(model.state_dict(), f'metnet_epochs_{EPOCHS}.pt')