In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torchvision import transforms
import numpy as np
import tqdm
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import xarray as xr
import datetime

# custom
import common.loss_utils as loss_utils
import common.climatehack_dataset as climatehack_dataset

import sys
sys.path.append('metnet')
import metnet

In [3]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [4]:
# SATELLITE_ZARR_PATH = "gs://public-datasets-eumetsat-solar-forecasting/satellite/EUMETSAT/SEVIRI_RSS/v3/eumetsat_seviri_hrv_uk.zarr"
SATELLITE_ZARR_PATH = 'data/full/eumetsat_seviri_hrv_uk.zarr/'

dataset = xr.open_dataset(
    SATELLITE_ZARR_PATH, 
    engine="zarr",
    chunks="auto",  # Load the data as a Dask array
)

print(dataset)


<xarray.Dataset>
Dimensions:  (time: 173624, y: 891, x: 1843)
Coordinates:
  * time     (time) datetime64[ns] 2020-01-01T00:05:00 ... 2021-11-07T15:50:00
  * x        (x) float32 2.8e+04 2.7e+04 2.6e+04 ... -1.813e+06 -1.814e+06
    x_osgb   (y, x) float32 dask.array<chunksize=(891, 1843), meta=np.ndarray>
  * y        (y) float32 4.198e+06 4.199e+06 4.2e+06 ... 5.087e+06 5.088e+06
    y_osgb   (y, x) float32 dask.array<chunksize=(891, 1843), meta=np.ndarray>
Data variables:
    data     (time, y, x) int16 dask.array<chunksize=(22, 891, 1843), meta=np.ndarray>


In [5]:
BATCH_SIZE = 1
ds = climatehack_dataset.ClimatehackDataset(dataset, random_state=7)
ch_dataloader = torch.utils.data.DataLoader(
    ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=8,
)


In [6]:
FORECAST = 10

model = metnet.MetNet(
        hidden_dim=32,
        forecast_steps=FORECAST, # 24 timesteps out
        input_channels=1, # 12 timeteps in
        output_channels=1, # 1 data channel in
        sat_channels=1, # 1 data channel in
        input_size=32, # =128/4, where 128 is the image dimensions
)
# EXISTING = 33
# print(f"LOADING EPOCHS {EXISTING}")
# model.load_state_dict(torch.load(f'metnet_epochs_{EXISTING}.pt'))
model = model.to(DEVICE)


In [7]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model)} trainable parameters')

The model has 1822305 trainable parameters


In [8]:
optimizer = optim.Adam(model.parameters())
criterion = loss_utils.MS_SSIMLoss(channels=FORECAST)

In [9]:
_MEDIAN_PIXEL = 216.0
_IQR = 201.0

def transform(x):
    return torch.tanh((x - _MEDIAN_PIXEL) / _IQR)

def inv_transform(x):
    return torch.atanh(x) * _IQR + _MEDIAN_PIXEL


In [10]:
def train_epoch(model, dl, optimizer, criterion):
    model.train()

    epoch_loss = 0
    total_count = 0
    pbar = tqdm.tqdm(dl)
    true_batch_size = 32
    optimizer.zero_grad()
    for i, (srcs, trgs) in enumerate(pbar):
        x = transform(srcs).to(DEVICE)
        # we don't need to transform y
        y = trgs[:,:FORECAST].float().to(DEVICE)

        # metnet expects a channel for satellite channel (like RGB). But we only have 1 of those
        x = torch.unsqueeze(x, dim=2)

        preds = model(x)
        # remove the satellite channel dimension from the prediction
        preds = torch.squeeze(preds, dim=2)
        preds = inv_transform(preds)

        loss = criterion(preds, y)

        loss.backward()
        
        # do gradient accumulation
        if i % true_batch_size == true_batch_size - 1:
            optimizer.step()
            optimizer.zero_grad()
        
        l = loss.item()
        epoch_loss += l
        total_count += len(srcs)
        if i % 1 == 0:
            l = round(l, 4)
            avg_loss = round(epoch_loss / total_count, 4)
            pbar.set_description(f'MS-SSIM Avg Loss, Batch Loss: {avg_loss, l}')
    
    # do a final update
    optimizer.step()
    optimizer.zero_grad()
    return epoch_loss / total_count

In [12]:
EPOCHS = 200
# EXISTING = 0

for i in range(EXISTING + 1, EPOCHS + 1):
    print(f"Epoch {i}")
    avg_loss = train_epoch(model, ch_dataloader, optimizer, criterion)
    avg_loss = round(avg_loss, 4)
    torch.save(model.state_dict(), f'metnet_epochs={i}_loss={avg_loss}.pt')

Epoch 1


MS-SSIM Avg Loss, Batch Loss: (0.3608, 0.3213): 100%|██████████| 676/676 [09:44<00:00,  1.16it/s]


Epoch 2


MS-SSIM Avg Loss, Batch Loss: (0.3331, 0.276): 100%|██████████| 676/676 [09:41<00:00,  1.16it/s] 


Epoch 3


MS-SSIM Avg Loss, Batch Loss: (0.3398, 0.1301): 100%|██████████| 676/676 [09:37<00:00,  1.17it/s]


Epoch 4


MS-SSIM Avg Loss, Batch Loss: (0.3414, 0.2352): 100%|██████████| 676/676 [09:40<00:00,  1.16it/s]


Epoch 5


MS-SSIM Avg Loss, Batch Loss: (0.3336, 0.3152): 100%|██████████| 676/676 [09:39<00:00,  1.17it/s]


Epoch 6


MS-SSIM Avg Loss, Batch Loss: (0.3478, 0.4486): 100%|██████████| 676/676 [09:35<00:00,  1.17it/s]


Epoch 7


MS-SSIM Avg Loss, Batch Loss: (0.3393, 0.5224): 100%|██████████| 676/676 [09:36<00:00,  1.17it/s]


Epoch 8


MS-SSIM Avg Loss, Batch Loss: (0.3402, 0.3734): 100%|██████████| 676/676 [09:37<00:00,  1.17it/s]


Epoch 9


MS-SSIM Avg Loss, Batch Loss: (0.3399, 0.0924): 100%|██████████| 676/676 [09:36<00:00,  1.17it/s]


Epoch 10


MS-SSIM Avg Loss, Batch Loss: (0.3394, 0.5061): 100%|██████████| 676/676 [09:54<00:00,  1.14it/s]


Epoch 11


MS-SSIM Avg Loss, Batch Loss: (0.3371, 0.603): 100%|██████████| 676/676 [09:48<00:00,  1.15it/s] 


Epoch 12


MS-SSIM Avg Loss, Batch Loss: (0.3403, 0.4604): 100%|██████████| 676/676 [09:53<00:00,  1.14it/s]


Epoch 13


MS-SSIM Avg Loss, Batch Loss: (0.3312, 0.1738): 100%|██████████| 676/676 [09:53<00:00,  1.14it/s]


Epoch 14


MS-SSIM Avg Loss, Batch Loss: (0.3318, 0.2012): 100%|██████████| 676/676 [09:58<00:00,  1.13it/s]


Epoch 15


MS-SSIM Avg Loss, Batch Loss: (0.3407, 0.2188):  49%|████▊     | 329/676 [04:58<05:14,  1.10it/s]


KeyboardInterrupt: 

In [12]:
# torch.save(model.state_dict(), f'metnet_epochs_{EPOCHS}_bk.pt')