In [13]:
import torch
import torchvision.models as models
from torch.utils.data import DataLoader
import xarray as xr

# custom
from utils.loss import MS_SSIMLoss
from utils.data import ClimatehackDataset, CustomDataset

In [14]:
from fastai.vision.all import *

## Data

In [15]:
# SATELLITE_ZARR_PATH = "gs://public-datasets-eumetsat-solar-forecasting/satellite/EUMETSAT/SEVIRI_RSS/v3/eumetsat_seviri_hrv_uk.zarr"
SATELLITE_ZARR_PATH = 'data/eumetsat_seviri_hrv_uk.zarr/'

dataset = xr.open_dataset(
    SATELLITE_ZARR_PATH, 
    engine="zarr",
    chunks="auto",  # Load the data as a Dask array
)

print(dataset)

<xarray.Dataset>
Dimensions:  (time: 173624, y: 891, x: 1843)
Coordinates:
  * time     (time) datetime64[ns] 2020-01-01T00:05:00 ... 2021-11-07T15:50:00
  * x        (x) float32 2.8e+04 2.7e+04 2.6e+04 ... -1.813e+06 -1.814e+06
    x_osgb   (y, x) float32 dask.array<chunksize=(891, 1843), meta=np.ndarray>
  * y        (y) float32 4.198e+06 4.199e+06 4.2e+06 ... 5.087e+06 5.088e+06
    y_osgb   (y, x) float32 dask.array<chunksize=(891, 1843), meta=np.ndarray>
Data variables:
    data     (time, y, x) int16 dask.array<chunksize=(22, 891, 1843), meta=np.ndarray>


In [17]:
BATCH_SIZE = 32
train_ds = ClimatehackDataset(dataset, random_state=7)
valid_ds = ClimatehackDataset(dataset, random_state=3)

train_dl = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE)
valid_dl = torch.utils.data.DataLoader(valid_ds, batch_size=BATCH_SIZE)

In [20]:
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, pin_memory=True)
valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, pin_memory=True)
dls = DataLoaders(train_loader, valid_loader)

In [21]:
next(iter(train_loader))

[tensor([[[[157., 157., 157.,  ..., 132., 125., 121.],
           [157., 156., 159.,  ..., 114., 109., 107.],
           [157., 159., 157.,  ..., 100.,  96.,  95.],
           ...,
           [ 95.,  95.,  92.,  ..., 132., 136., 141.],
           [ 95.,  94.,  94.,  ..., 127., 129., 138.],
           [ 95.,  91.,  92.,  ..., 124., 125., 135.]],
 
          [[157., 156., 157.,  ..., 132., 129., 128.],
           [156., 156., 154.,  ..., 123., 120., 118.],
           [156., 159., 157.,  ..., 102.,  98.,  96.],
           ...,
           [ 88.,  95.,  96.,  ..., 136., 136., 139.],
           [ 87.,  92.,  95.,  ..., 136., 138., 139.],
           [ 88.,  91.,  94.,  ..., 136., 139., 142.]],
 
          [[152., 153., 153.,  ..., 145., 129., 116.],
           [152., 152., 153.,  ..., 131., 121., 112.],
           [153., 153., 154.,  ..., 106., 102.,  99.],
           ...,
           [ 91.,  92.,  95.,  ..., 136., 135., 134.],
           [ 89.,  95.,  95.,  ..., 135., 136., 134.],
           

## Training

In [6]:
FORECAST = 24
criterion = MS_SSIMLoss(channels=FORECAST)

In [7]:
data = DataLoaders(train_dl, valid_dl)

In [8]:
model = create_unet_model(
    arch=models.resnet50, img_size=(128, 128), n_out=24, pretrained=True, n_in=12, self_attention=True, 
)

In [9]:
learn = Learner(data, model, loss_func=criterion)

In [10]:
# learn.lr_find()

In [11]:
learn.fine_tune(30, 1e-3)

epoch,train_loss,valid_loss,time
0,0.461904,0.36767,44:56


epoch,train_loss,valid_loss,time
0,0.364078,0.370544,44:49
1,0.363126,0.362556,44:47
2,0.357845,0.335379,44:52
3,0.344939,0.331677,44:48
4,0.338443,0.328063,44:46
5,0.334038,0.325128,44:41
6,0.331881,0.329352,44:46
7,0.326995,0.324047,44:46
8,0.324831,0.329879,44:47
9,0.32344,0.322895,44:46


In [63]:
torch.save(learn.model.state_dict(), 'checkpoints/dynamic_unet.pth')

In [13]:
learn.fit_flat_cos(50, 5e-4)

epoch,train_loss,valid_loss,time
0,0.248178,0.253115,44:54
1,0.250437,0.254625,44:53
2,0.248723,0.251835,45:12
3,0.2489,0.253866,44:55
4,0.247123,0.244329,44:52
5,0.244221,0.246293,44:54
6,0.240939,0.23982,44:53
7,0.241679,0.249434,44:52
8,0.240875,0.246027,44:52
9,0.241449,0.245839,44:54


In [39]:
learn.export("checkpoints/unet50_learner.pkl")

In [30]:
learn.dls = dls

In [37]:
# dls.new_empty()

for l in dls.loaders:
    print(l.new)

<bound method DataLoader.new of <fastai.data.load.DataLoader object at 0x7f364e45eeb0>>
<bound method DataLoader.new of <fastai.data.load.DataLoader object at 0x7f364e45e250>>


In [38]:
@patch
def new_empty(self: climatehack_dataset.ClimatehackDataset):
    return self

In [27]:
train_loader.new

<bound method DataLoader.new of <fastai.data.load.DataLoader object at 0x7f364e6ea910>>

In [42]:
preds = learn.get_preds()

In [59]:
x, y = preds
y.unsqueeze_(0)
x.unsqueeze_(0)
x.shape, y.shape

(torch.Size([1, 676, 24, 128, 128]), torch.Size([1, 676, 24, 128, 128]))