In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torchvision import transforms
import numpy as np
import tqdm
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import xarray as xr
import datetime

# custom
import common.loss_utils as loss_utils
import common.climatehack_dataset as climatehack_dataset

import sys
sys.path.append('metnet')
import metnet

In [3]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [4]:
# SATELLITE_ZARR_PATH = "gs://public-datasets-eumetsat-solar-forecasting/satellite/EUMETSAT/SEVIRI_RSS/v3/eumetsat_seviri_hrv_uk.zarr"
SATELLITE_ZARR_PATH = 'data/full/eumetsat_seviri_hrv_uk.zarr/'

dataset = xr.open_dataset(
    SATELLITE_ZARR_PATH, 
    engine="zarr",
    chunks="auto",  # Load the data as a Dask array
)

print(dataset)


<xarray.Dataset>
Dimensions:  (time: 173624, y: 891, x: 1843)
Coordinates:
  * time     (time) datetime64[ns] 2020-01-01T00:05:00 ... 2021-11-07T15:50:00
  * x        (x) float32 2.8e+04 2.7e+04 2.6e+04 ... -1.813e+06 -1.814e+06
    x_osgb   (y, x) float32 dask.array<chunksize=(891, 1843), meta=np.ndarray>
  * y        (y) float32 4.198e+06 4.199e+06 4.2e+06 ... 5.087e+06 5.088e+06
    y_osgb   (y, x) float32 dask.array<chunksize=(891, 1843), meta=np.ndarray>
Data variables:
    data     (time, y, x) int16 dask.array<chunksize=(22, 891, 1843), meta=np.ndarray>


In [5]:
BATCH_SIZE = 1
ds = climatehack_dataset.ClimatehackDataset(dataset, random_state=7)
ch_dataloader = torch.utils.data.DataLoader(
    ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
)


In [6]:
import torch.nn as nn
import torch.nn.init as init

def weight_init(m):
    '''
    Initializes a model's parameters.
    Credits to: https://gist.github.com/jeasinema
    Usage:
        model = Model()
        model.apply(weight_init)
    '''
    if isinstance(m, nn.Conv1d):
        init.normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.Conv2d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.Conv3d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose1d):
        init.normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose2d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose3d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.BatchNorm1d):
        init.normal_(m.weight.data, mean=0, std=1)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
        init.normal_(m.weight.data, mean=0, std=1)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm3d):
        init.normal_(m.weight.data, mean=0, std=1)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        init.xavier_normal_(m.weight.data)
        try:
            init.normal_(m.bias.data)
        except AttributeError:
            pass
    elif isinstance(m, nn.LSTM):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
    elif isinstance(m, nn.LSTMCell):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
    elif isinstance(m, nn.GRU):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
    elif isinstance(m, nn.GRUCell):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)

In [7]:
FORECAST = 10
# GROUP = 40
OUTPUTS = 4 # round(1023 / GROUP) + 1
model = metnet.MetNet(
        hidden_dim=32,
        forecast_steps=FORECAST, # should be 24 timesteps out
        input_channels=1, # 12 timeteps in
        output_channels=OUTPUTS, # 1 data channel in
        sat_channels=1, # 1 data channel in
        input_size=32, # =128/4, where 128 is the image dimensions
)
model.apply(weight_init)

# EXISTING = 20
# print(f"LOADING EPOCHS {EXISTING}")
# model.load_state_dict(torch.load(f'metnet_epochs={EXISTING}_loss=1.928.pt'))
model = model.to(DEVICE)


In [8]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model)} trainable parameters')

The model has 1822404 trainable parameters


In [9]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# criterion = loss_utils.MS_SSIMLoss(channels=FORECAST)
criterion = torch.nn.CrossEntropyLoss()


In [10]:
# _MEDIAN_PIXEL = 216.0
# _IQR = 201.0

# def transform(x):
#     return (x - _MEDIAN_PIXEL) / _IQR


# # predetermined
# from sklearn.cluster import KMeans
# _KM = KMeans(n_clusters=4, random_state=7)
# _KM.cluster_centers_ = np.array([
#     [81.25423],
#     [192.11592],
#     [310.74716],
#     [484.8715],
# ], dtype=np.float16
# )
# _KM._n_threads = 1

# def transform_y(y):
#     y_grouped = _KM.predict(y.reshape(-1,1))
#     y_grouped = y_grouped.reshape(y.shape)
#     return y_grouped
    

# # def inv_transform(x):
# #     return x * _IQR + _MEDIAN_PIXEL


In [11]:
def train_epoch(model, dl, optimizer, criterion):
    model.train()

    epoch_loss = 0
    total_count = 0
    pbar = tqdm.tqdm(dl)
    true_batch_size = 16
    optimizer.zero_grad()
    for i, (srcs, trgs, _) in enumerate(pbar):
#         x = transform(srcs).to(DEVICE)
        # turn y into a bucketized version
#         y = torch.round(trgs[:,:FORECAST].float() / GROUP).long().to(DEVICE)
#         trgs[:,:FORECAST].detach().numpy()
        x = srcs.float().to(DEVICE)
        y = trgs[:,:FORECAST].long().to(DEVICE)

        # metnet expects a channel for satellite channel (like RGB). But we only have 1 of those
        x = torch.unsqueeze(x, dim=2)

        preds = model(x)
        # remove the satellite channel dimension from the prediction
        preds = torch.squeeze(preds, dim=2)
        
        # reshape so cross entropy works
        preds = preds.reshape(-1, OUTPUTS, 64, 64)
        y = y.reshape(-1, 64, 64)

        loss = criterion(preds, y)

        loss.backward()
        
        # do gradient accumulation
        if i % true_batch_size == true_batch_size - 1:
            optimizer.step()
            optimizer.zero_grad()
        
        l = loss.item()
        epoch_loss += l
        total_count += len(srcs)
        if i % 1 == 0:
            l = round(l, 4)
            avg_loss = round(epoch_loss / total_count, 4)
            pbar.set_description(f'MS-SSIM Avg Loss, Batch Loss: {avg_loss, l}')
            if i % (true_batch_size * 100) == 0:
                torch.save(model.state_dict(), f'metnet_epochs={i}_batch_{i}_loss={avg_loss}.pt')
    
    # do a final update
    optimizer.step()
    optimizer.zero_grad()
    return epoch_loss / total_count

In [None]:
EPOCHS = 40
# EXISTING = 0

for i in range(EXISTING + 1, EPOCHS + 1):
    print(f"Epoch {i}")
    avg_loss = train_epoch(model, ch_dataloader, optimizer, criterion)
    avg_loss = round(avg_loss, 4)
    torch.save(model.state_dict(), f'metnet_epochs={i}_loss={avg_loss}.pt')

Epoch 1


MS-SSIM Avg Loss, Batch Loss: (1.3802, 1.2945): 100%|██████████| 676/676 [16:20<00:00,  1.45s/it]


Epoch 2


MS-SSIM Avg Loss, Batch Loss: (1.1692, 0.8875): 100%|██████████| 676/676 [17:10<00:00,  1.52s/it]


Epoch 3


MS-SSIM Avg Loss, Batch Loss: (1.0479, 0.7668): 100%|██████████| 676/676 [13:07<00:00,  1.16s/it]


Epoch 4


MS-SSIM Avg Loss, Batch Loss: (0.9303, 0.6696): 100%|██████████| 676/676 [13:07<00:00,  1.16s/it]


Epoch 5


MS-SSIM Avg Loss, Batch Loss: (0.8714, 0.6347): 100%|██████████| 676/676 [13:06<00:00,  1.16s/it]


Epoch 6


MS-SSIM Avg Loss, Batch Loss: (0.9171, 0.5885): 100%|██████████| 676/676 [13:10<00:00,  1.17s/it]


Epoch 7


MS-SSIM Avg Loss, Batch Loss: (0.9027, 1.1334): 100%|██████████| 676/676 [13:07<00:00,  1.16s/it]


Epoch 8


MS-SSIM Avg Loss, Batch Loss: (0.8719, 0.4553): 100%|██████████| 676/676 [13:04<00:00,  1.16s/it]


Epoch 9


MS-SSIM Avg Loss, Batch Loss: (0.8046, 0.6362): 100%|██████████| 676/676 [13:01<00:00,  1.16s/it]


Epoch 10


MS-SSIM Avg Loss, Batch Loss: (0.8509, 1.0841): 100%|██████████| 676/676 [12:53<00:00,  1.14s/it]


Epoch 11


MS-SSIM Avg Loss, Batch Loss: (0.807, 1.2832): 100%|██████████| 676/676 [13:06<00:00,  1.16s/it] 


Epoch 12


MS-SSIM Avg Loss, Batch Loss: (0.8348, 1.6543): 100%|██████████| 676/676 [12:56<00:00,  1.15s/it]


Epoch 13


MS-SSIM Avg Loss, Batch Loss: (0.8688, 0.7841): 100%|██████████| 676/676 [13:06<00:00,  1.16s/it]


Epoch 14


MS-SSIM Avg Loss, Batch Loss: (0.7874, 1.1374): 100%|██████████| 676/676 [12:51<00:00,  1.14s/it]


Epoch 15


MS-SSIM Avg Loss, Batch Loss: (0.8063, 0.4837): 100%|██████████| 676/676 [12:59<00:00,  1.15s/it]


Epoch 16


MS-SSIM Avg Loss, Batch Loss: (0.7943, 1.0873): 100%|██████████| 676/676 [12:55<00:00,  1.15s/it]


Epoch 17


MS-SSIM Avg Loss, Batch Loss: (0.7679, 1.1072): 100%|██████████| 676/676 [12:56<00:00,  1.15s/it]


Epoch 18


MS-SSIM Avg Loss, Batch Loss: (0.802, 1.1828): 100%|██████████| 676/676 [13:01<00:00,  1.16s/it] 


Epoch 19


MS-SSIM Avg Loss, Batch Loss: (0.7745, 0.9202): 100%|██████████| 676/676 [13:00<00:00,  1.16s/it]


Epoch 20


MS-SSIM Avg Loss, Batch Loss: (0.767, 0.3734): 100%|██████████| 676/676 [12:53<00:00,  1.14s/it] 


Epoch 21


MS-SSIM Avg Loss, Batch Loss: (0.8202, 0.5115):   4%|▍         | 28/676 [00:36<10:00,  1.08it/s] 

In [None]:
# torch.save(model.state_dict(), f'metnet_epochs_20_bk.pt')