### Climate downscaling using CDDLT package

In [1]:
# !pip install --index-url https://test.pypi.org/simple/ \
#              --extra-index-url https://pypi.org/simple \
#              cddlt --quiet

In [2]:
import os
import argparse
import torch
import torchmetrics
import cddlt

from cddlt.datasets.rekis_dataset import ReKIS
from cddlt.datasets.cordex_dataset import CORDEX
from cddlt.dataloaders.downscaling_transform import DownscalingTransform

from cddlt.models.bicubic import Bicubic
from cddlt.models.espcn import ESPCN
from cddlt.models.edsr import EDSR
from cddlt.models.srcnn import SRCNN
from cddlt.models.fno import FNO
from cddlt.models.swinir import SwinIR

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", default=8, type=int)
parser.add_argument("--epochs", default=100, type=int)
parser.add_argument("--seed", default=42, type=int)
parser.add_argument("--threads", default=4, type=int)
parser.add_argument("--upscale_factor", default=10, type=int)
parser.add_argument("--lr", default=0.001, type=float)
parser.add_argument("--logdir", default="logs", type=str)
parser.add_argument("--variables", default=["TM"], type=list)

_StoreAction(option_strings=['--variables'], dest='variables', nargs=None, const=None, default=['TM'], type=<class 'list'>, choices=None, required=False, help=None, metavar=None)

In [4]:
REKIS_DATA_PATH = "/mnt/personal/kostape4/data/climate/rekis/"
CORDEX_DATA_PATH = "/mnt/personal/kostape4/data/climate//cordex/"

In [5]:
args = parser.parse_args([])
cddlt.startup(args)

### ReKIS

In [6]:
rekis = ReKIS(
    data_path=REKIS_DATA_PATH,
    variables=args.variables,
    train_len=("1979-01-01", "2000-01-01"),
    dev_len=("2000-01-01", "2005-01-01"),
    test_len=("2005-01-01", "2012-01-01"), ## value framework input
    resampling="cubic_spline"
)

Loading 63 NetCDF file(s)...
Loaded data shape: {'time': 7670, 'northing': 400, 'easting': 400}
Time range: 1979-01-01 00:00:00 to 1999-12-31 00:00:00
Loading 63 NetCDF file(s)...
Loaded data shape: {'time': 1827, 'northing': 400, 'easting': 400}
Time range: 2000-01-01 00:00:00 to 2004-12-31 00:00:00
Loading 63 NetCDF file(s)...
Loaded data shape: {'time': 2556, 'northing': 400, 'easting': 400}
Time range: 2005-01-01 00:00:00 to 2011-12-31 00:00:00

ReKIS dataset initalized.
train size: (7670)
dev size: (1827)
test size: (2556)



In [7]:
rekis_train = DownscalingTransform(dataset=rekis.train).dataloader(args.batch_size, shuffle=True)
rekis_dev = DownscalingTransform(dataset=rekis.dev).dataloader(args.batch_size)
rekis_test = DownscalingTransform(dataset=rekis.test).dataloader(args.batch_size)

### Models

#### Bicubic RMSE baseline

In [14]:
bicubic = Bicubic(
    upscale_factor=args.upscale_factor
)

In [16]:
bicubic.configure(
    optimizer = None,
    scheduler = None,
    loss = torchmetrics.MeanSquaredError(squared=False),
    args = args
)

Bicubic(
  (loss): MeanSquaredError()
)

In [17]:
bicubic.evaluate(rekis_dev, print_loss=True, epochs=args.epochs)

Evaluation - dev_loss: 0.2569


#### Bicubic L1 baseline

In [8]:
bicubic = Bicubic(
    upscale_factor=args.upscale_factor
)

In [9]:
bicubic.configure(
    optimizer = None,
    scheduler = None,
    loss = torch.nn.L1Loss(),
    args = args
)

Bicubic(
  (loss): L1Loss()
)

In [10]:
bicubic.evaluate(rekis_dev, print_loss=True, epochs=args.epochs)

Evaluation - dev_loss: 0.1518


#### SRCNN

In [18]:
srcnn = SRCNN(
    n_channels=1,
    upscale_factor=args.upscale_factor
)

In [19]:
srcnn.configure(
    optimizer = torch.optim.AdamW(params=srcnn.parameters(), lr=args.lr),
    scheduler = None,
    loss = torchmetrics.MeanSquaredError(squared=False),
    args = args
)

SRCNN(
  (model): Sequential(
    (0): Conv2d(1, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
    (1): ReLU()
    (2): Conv2d(64, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (3): ReLU()
    (4): Conv2d(32, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  )
  (loss): MeanSquaredError()
)

In [20]:
srcnn.fit(rekis_train, rekis_dev, args.epochs)

  return F.conv2d(input, weight, bias, self.stride,
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
                                                                                                                                                                                                                

Epoch 1/100 - train_loss: 0.8270, dev_loss: 0.3375 (47.87s)


                                                                                                                                                                                                                

Epoch 2/100 - train_loss: 0.4741, dev_loss: 0.3103 (47.20s)


                                                                                                                                                                                                                

Epoch 3/100 - train_loss: 0.4667, dev_loss: 0.3316 (47.13s)


                                                                                                                                                                                                                

Epoch 4/100 - train_loss: 0.3985, dev_loss: 0.5010 (47.13s)


                                                                                                                                                                                                                

Epoch 5/100 - train_loss: 0.3978, dev_loss: 0.2737 (47.11s)


                                                                                                                                                                                                                

Epoch 6/100 - train_loss: 0.3153, dev_loss: 0.4485 (47.13s)


                                                                                                                                                                                                                

Epoch 7/100 - train_loss: 0.2967, dev_loss: 0.2632 (47.13s)


                                                                                                                                                                                                                

Epoch 8/100 - train_loss: 0.3809, dev_loss: 0.3172 (47.14s)


                                                                                                                                                                                                                

Epoch 9/100 - train_loss: 0.4273, dev_loss: 0.3849 (47.12s)


                                                                                                                                                                                                                

Epoch 10/100 - train_loss: 0.3595, dev_loss: 0.2637 (47.11s)


                                                                                                                                                                                                                

Epoch 11/100 - train_loss: 0.2889, dev_loss: 0.2555 (47.11s)


                                                                                                                                                                                                                

Epoch 12/100 - train_loss: 0.2709, dev_loss: 0.2634 (47.13s)


                                                                                                                                                                                                                

Epoch 13/100 - train_loss: 0.3344, dev_loss: 0.2520 (47.13s)


                                                                                                                                                                                                                

Epoch 14/100 - train_loss: 0.2888, dev_loss: 0.2566 (47.13s)


                                                                                                                                                                                                                

Epoch 15/100 - train_loss: 0.3694, dev_loss: 0.4901 (47.11s)


                                                                                                                                                                                                                

Epoch 16/100 - train_loss: 0.5094, dev_loss: 0.3963 (47.14s)


                                                                                                                                                                                                                

Epoch 17/100 - train_loss: 0.4471, dev_loss: 0.3200 (47.12s)


                                                                                                                                                                                                                

Epoch 18/100 - train_loss: 0.3279, dev_loss: 0.2484 (47.14s)


                                                                                                                                                                                                                

Epoch 19/100 - train_loss: 0.2943, dev_loss: 0.4035 (47.11s)


                                                                                                                                                                                                                

Epoch 20/100 - train_loss: 0.3209, dev_loss: 0.2477 (47.14s)


                                                                                                                                                                                                                

Epoch 21/100 - train_loss: 0.2508, dev_loss: 0.2338 (47.12s)


                                                                                                                                                                                                                

Epoch 22/100 - train_loss: 0.2555, dev_loss: 0.2336 (47.13s)


                                                                                                                                                                                                                

Epoch 23/100 - train_loss: 0.2523, dev_loss: 0.2400 (47.13s)


                                                                                                                                                                                                                

Epoch 24/100 - train_loss: 0.5080, dev_loss: 0.3641 (47.13s)


                                                                                                                                                                                                                

Epoch 25/100 - train_loss: 0.3365, dev_loss: 0.2654 (47.11s)


                                                                                                                                                                                                                

Epoch 26/100 - train_loss: 0.2757, dev_loss: 0.2804 (47.11s)


                                                                                                                                                                                                                

Epoch 27/100 - train_loss: 0.4170, dev_loss: 0.4190 (47.12s)


                                                                                                                                                                                                                

Epoch 28/100 - train_loss: 0.3665, dev_loss: 0.4634 (47.12s)


                                                                                                                                                                                                                

Epoch 29/100 - train_loss: 0.2988, dev_loss: 0.2810 (47.14s)


                                                                                                                                                                                                                

Epoch 30/100 - train_loss: 0.3032, dev_loss: 0.2551 (47.14s)


                                                                                                                                                                                                                

Epoch 31/100 - train_loss: 0.2507, dev_loss: 0.2404 (47.13s)


                                                                                                                                                                                                                

Epoch 32/100 - train_loss: 0.2531, dev_loss: 0.2595 (47.13s)


                                                                                                                                                                                                                

Epoch 33/100 - train_loss: 0.2604, dev_loss: 0.2820 (47.14s)


                                                                                                                                                                                                                

Epoch 34/100 - train_loss: 0.2549, dev_loss: 0.2345 (47.12s)


                                                                                                                                                                                                                

Epoch 35/100 - train_loss: 0.2534, dev_loss: 0.4387 (47.12s)


                                                                                                                                                                                                                

Epoch 36/100 - train_loss: 0.3357, dev_loss: 0.2736 (47.13s)


                                                                                                                                                                                                                

Epoch 37/100 - train_loss: 0.2795, dev_loss: 0.2535 (47.12s)


                                                                                                                                                                                                                

Epoch 38/100 - train_loss: 0.3358, dev_loss: 0.3166 (47.13s)


                                                                                                                                                                                                                

Epoch 39/100 - train_loss: 0.2742, dev_loss: 0.2574 (47.13s)


                                                                                                                                                                                                                

Epoch 40/100 - train_loss: 0.4622, dev_loss: 0.2708 (47.14s)


                                                                                                                                                                                                                

Epoch 41/100 - train_loss: 0.2613, dev_loss: 0.2942 (47.12s)


                                                                                                                                                                                                                

Epoch 42/100 - train_loss: 0.2461, dev_loss: 0.2625 (47.15s)


                                                                                                                                                                                                                

Epoch 43/100 - train_loss: 0.2854, dev_loss: 0.2900 (47.14s)


                                                                                                                                                                                                                

Epoch 44/100 - train_loss: 0.2440, dev_loss: 0.2319 (47.12s)


                                                                                                                                                                                                                

Epoch 45/100 - train_loss: 0.2415, dev_loss: 0.2533 (47.12s)


                                                                                                                                                                                                                

Epoch 46/100 - train_loss: 0.3085, dev_loss: 0.3659 (47.10s)


                                                                                                                                                                                                                

Epoch 47/100 - train_loss: 0.3266, dev_loss: 0.2314 (47.12s)


                                                                                                                                                                                                                

Epoch 48/100 - train_loss: 0.5409, dev_loss: 0.3347 (47.11s)


                                                                                                                                                                                                                

Epoch 49/100 - train_loss: 0.3474, dev_loss: 0.2862 (47.11s)


                                                                                                                                                                                                                

Epoch 50/100 - train_loss: 0.2641, dev_loss: 0.3405 (47.12s)


                                                                                                                                                                                                                

Epoch 51/100 - train_loss: 0.2917, dev_loss: 0.2633 (47.11s)


                                                                                                                                                                                                                

Epoch 52/100 - train_loss: 0.3133, dev_loss: 0.2523 (47.13s)


                                                                                                                                                                                                                

Epoch 53/100 - train_loss: 0.2434, dev_loss: 0.2312 (47.14s)


                                                                                                                                                                                                                

Epoch 54/100 - train_loss: 0.2745, dev_loss: 0.2846 (47.13s)


                                                                                                                                                                                                                

Epoch 55/100 - train_loss: 0.2941, dev_loss: 0.2493 (47.11s)


                                                                                                                                                                                                                

Epoch 56/100 - train_loss: 0.2687, dev_loss: 0.2592 (47.13s)


                                                                                                                                                                                                                

Epoch 57/100 - train_loss: 0.2590, dev_loss: 0.2333 (47.11s)


                                                                                                                                                                                                                

Epoch 58/100 - train_loss: 0.2527, dev_loss: 0.2600 (47.11s)


                                                                                                                                                                                                                

Epoch 59/100 - train_loss: 0.2617, dev_loss: 0.2434 (47.10s)


                                                                                                                                                                                                                

Epoch 60/100 - train_loss: 0.2631, dev_loss: 0.5785 (47.14s)


                                                                                                                                                                                                                

Epoch 61/100 - train_loss: 0.2877, dev_loss: 0.2327 (47.11s)


                                                                                                                                                                                                                

Epoch 62/100 - train_loss: 0.2408, dev_loss: 0.2485 (47.12s)


                                                                                                                                                                                                                

Epoch 63/100 - train_loss: 0.2546, dev_loss: 0.2595 (47.11s)


                                                                                                                                                                                                                

Epoch 64/100 - train_loss: 0.2489, dev_loss: 0.2560 (47.12s)


                                                                                                                                                                                                                

Epoch 65/100 - train_loss: 0.2490, dev_loss: 0.2487 (47.11s)


                                                                                                                                                                                                                

Epoch 66/100 - train_loss: 0.3105, dev_loss: 0.2854 (47.11s)


                                                                                                                                                                                                                

Epoch 67/100 - train_loss: 0.3207, dev_loss: 0.2678 (47.12s)


                                                                                                                                                                                                                

Epoch 68/100 - train_loss: 0.3772, dev_loss: 0.3800 (47.11s)


                                                                                                                                                                                                                

Epoch 69/100 - train_loss: 0.3128, dev_loss: 0.2704 (47.10s)


                                                                                                                                                                                                                

Epoch 70/100 - train_loss: 0.2900, dev_loss: 0.2911 (47.12s)


                                                                                                                                                                                                                

Epoch 71/100 - train_loss: 0.2891, dev_loss: 0.2385 (47.12s)


                                                                                                                                                                                                                

Epoch 72/100 - train_loss: 0.2529, dev_loss: 0.2379 (47.11s)


                                                                                                                                                                                                                

Epoch 73/100 - train_loss: 0.2520, dev_loss: 0.2720 (47.11s)


                                                                                                                                                                                                                

Epoch 74/100 - train_loss: 0.2805, dev_loss: 0.3560 (47.14s)


                                                                                                                                                                                                                

Epoch 75/100 - train_loss: 0.3164, dev_loss: 0.3753 (47.12s)


                                                                                                                                                                                                                

Epoch 76/100 - train_loss: 0.3075, dev_loss: 0.2814 (47.10s)


                                                                                                                                                                                                                

Epoch 77/100 - train_loss: 0.2505, dev_loss: 0.2354 (47.12s)


                                                                                                                                                                                                                

Epoch 78/100 - train_loss: 0.2403, dev_loss: 0.2345 (47.12s)


                                                                                                                                                                                                                

Epoch 79/100 - train_loss: 0.3134, dev_loss: 0.3630 (47.10s)


                                                                                                                                                                                                                

Epoch 80/100 - train_loss: 0.3242, dev_loss: 0.3639 (47.11s)


                                                                                                                                                                                                                

Epoch 81/100 - train_loss: 0.2669, dev_loss: 0.2603 (47.12s)


                                                                                                                                                                                                                

Epoch 82/100 - train_loss: 0.2450, dev_loss: 0.2352 (47.12s)


                                                                                                                                                                                                                

Epoch 83/100 - train_loss: 0.2414, dev_loss: 0.3154 (47.13s)


                                                                                                                                                                                                                

Epoch 84/100 - train_loss: 0.2777, dev_loss: 0.2362 (47.13s)


                                                                                                                                                                                                                

Epoch 85/100 - train_loss: 0.3234, dev_loss: 0.3869 (47.12s)


                                                                                                                                                                                                                

Epoch 86/100 - train_loss: 0.3172, dev_loss: 0.2710 (47.12s)


                                                                                                                                                                                                                

Epoch 87/100 - train_loss: 0.3085, dev_loss: 0.2320 (47.12s)


                                                                                                                                                                                                                

Epoch 88/100 - train_loss: 0.2683, dev_loss: 0.2328 (47.12s)


                                                                                                                                                                                                                

Epoch 89/100 - train_loss: 0.2390, dev_loss: 0.2575 (47.13s)


                                                                                                                                                                                                                

Epoch 90/100 - train_loss: 0.2399, dev_loss: 0.2280 (47.14s)


                                                                                                                                                                                                                

Epoch 91/100 - train_loss: 0.2394, dev_loss: 0.2421 (47.11s)


                                                                                                                                                                                                                

Epoch 92/100 - train_loss: 0.2420, dev_loss: 0.2306 (47.14s)


                                                                                                                                                                                                                

Epoch 93/100 - train_loss: 0.2824, dev_loss: 0.2524 (47.12s)


                                                                                                                                                                                                                

Epoch 94/100 - train_loss: 0.2456, dev_loss: 0.2395 (47.12s)


                                                                                                                                                                                                                

Epoch 95/100 - train_loss: 0.3753, dev_loss: 0.2442 (47.11s)


                                                                                                                                                                                                                

Epoch 96/100 - train_loss: 0.2679, dev_loss: 0.2356 (47.11s)


                                                                                                                                                                                                                

Epoch 97/100 - train_loss: 0.2516, dev_loss: 0.2334 (47.12s)


                                                                                                                                                                                                                

Epoch 98/100 - train_loss: 0.2484, dev_loss: 0.2736 (47.13s)


                                                                                                                                                                                                                

Epoch 99/100 - train_loss: 0.2566, dev_loss: 0.2380 (47.13s)


                                                                                                                                                                                                                

Epoch 100/100 - train_loss: 0.2489, dev_loss: 0.2441 (47.11s)


#### ESPCN

In [None]:
args = parser.parse_args([])
cddlt.startup(args)

In [21]:
espcn = ESPCN(
    n_channels=1,
    upscale_factor=args.upscale_factor,
)

In [22]:
espcn.configure(
    optimizer = torch.optim.AdamW(params=espcn.parameters(), lr=args.lr),
    scheduler = None,
    loss = torchmetrics.MeanSquaredError(squared=False),
    args = args
)

ESPCN(
  (model): Sequential(
    (0): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): Tanh()
    (2): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): Tanh()
    (4): Conv2d(32, 100, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): PixelShuffle(upscale_factor=10)
  )
  (loss): MeanSquaredError()
)

In [23]:
espcn.fit(rekis_train, rekis_dev, args.epochs)

                                                                                                                                                                                                                

Epoch 1/100 - train_loss: 4.4526, dev_loss: 1.3290 (4.93s)


                                                                                                                                                                                                                

Epoch 2/100 - train_loss: 1.1119, dev_loss: 0.8996 (3.34s)


                                                                                                                                                                                                                

Epoch 3/100 - train_loss: 0.8910, dev_loss: 0.7705 (3.35s)


                                                                                                                                                                                                                

Epoch 4/100 - train_loss: 0.7864, dev_loss: 0.8812 (3.15s)


                                                                                                                                                                                                                

Epoch 5/100 - train_loss: 0.6399, dev_loss: 0.6064 (3.40s)


                                                                                                                                                                                                                

Epoch 6/100 - train_loss: 0.6217, dev_loss: 0.7238 (3.21s)


                                                                                                                                                                                                                

Epoch 7/100 - train_loss: 0.6492, dev_loss: 0.9036 (3.28s)


                                                                                                                                                                                                                

Epoch 8/100 - train_loss: 0.5758, dev_loss: 0.5785 (3.49s)


                                                                                                                                                                                                                

Epoch 9/100 - train_loss: 0.5345, dev_loss: 0.5463 (3.41s)


                                                                                                                                                                                                                

Epoch 10/100 - train_loss: 0.4547, dev_loss: 0.3970 (3.21s)


                                                                                                                                                                                                                

Epoch 11/100 - train_loss: 0.5833, dev_loss: 0.5709 (3.27s)


                                                                                                                                                                                                                

Epoch 12/100 - train_loss: 0.4858, dev_loss: 0.4380 (3.31s)


                                                                                                                                                                                                                

Epoch 13/100 - train_loss: 0.4112, dev_loss: 0.4045 (3.21s)


                                                                                                                                                                                                                

Epoch 14/100 - train_loss: 0.4093, dev_loss: 0.3982 (3.29s)


                                                                                                                                                                                                                

Epoch 15/100 - train_loss: 0.4860, dev_loss: 0.3811 (3.33s)


                                                                                                                                                                                                                

Epoch 16/100 - train_loss: 0.4435, dev_loss: 0.4518 (3.56s)


                                                                                                                                                                                                                

Epoch 17/100 - train_loss: 0.4044, dev_loss: 0.3629 (3.23s)


                                                                                                                                                                                                                

Epoch 18/100 - train_loss: 0.3855, dev_loss: 0.4199 (3.21s)


                                                                                                                                                                                                                

Epoch 19/100 - train_loss: 0.3863, dev_loss: 0.3526 (3.27s)


                                                                                                                                                                                                                

Epoch 20/100 - train_loss: 0.4814, dev_loss: 0.6997 (3.21s)


                                                                                                                                                                                                                

Epoch 21/100 - train_loss: 0.5641, dev_loss: 0.3741 (3.27s)


                                                                                                                                                                                                                

Epoch 22/100 - train_loss: 0.3527, dev_loss: 0.3519 (3.31s)


                                                                                                                                                                                                                

Epoch 23/100 - train_loss: 0.3589, dev_loss: 0.3186 (3.27s)


                                                                                                                                                                                                                

Epoch 24/100 - train_loss: 0.3694, dev_loss: 0.3328 (3.15s)


                                                                                                                                                                                                                

Epoch 25/100 - train_loss: 0.5135, dev_loss: 0.7281 (3.20s)


                                                                                                                                                                                                                

Epoch 26/100 - train_loss: 0.4422, dev_loss: 0.3489 (3.04s)


                                                                                                                                                                                                                

Epoch 27/100 - train_loss: 0.4626, dev_loss: 0.4888 (3.20s)


                                                                                                                                                                                                                

Epoch 28/100 - train_loss: 0.3786, dev_loss: 0.3233 (3.28s)


                                                                                                                                                                                                                

Epoch 29/100 - train_loss: 0.3490, dev_loss: 0.3085 (3.49s)


                                                                                                                                                                                                                

Epoch 30/100 - train_loss: 0.4903, dev_loss: 0.5512 (3.28s)


                                                                                                                                                                                                                

Epoch 31/100 - train_loss: 0.3674, dev_loss: 0.3309 (3.23s)


                                                                                                                                                                                                                

Epoch 32/100 - train_loss: 0.4576, dev_loss: 0.5465 (3.47s)


                                                                                                                                                                                                                

Epoch 33/100 - train_loss: 0.4090, dev_loss: 0.3563 (3.18s)


                                                                                                                                                                                                                

Epoch 34/100 - train_loss: 0.4389, dev_loss: 0.3607 (3.07s)


                                                                                                                                                                                                                

Epoch 35/100 - train_loss: 0.3276, dev_loss: 0.3073 (3.19s)


                                                                                                                                                                                                                

Epoch 36/100 - train_loss: 0.3431, dev_loss: 0.3237 (3.04s)


                                                                                                                                                                                                                

Epoch 37/100 - train_loss: 0.4916, dev_loss: 0.3595 (3.14s)


                                                                                                                                                                                                                

Epoch 38/100 - train_loss: 0.4283, dev_loss: 0.7920 (3.17s)


                                                                                                                                                                                                                

Epoch 39/100 - train_loss: 0.5306, dev_loss: 0.4154 (3.33s)


                                                                                                                                                                                                                

Epoch 40/100 - train_loss: 0.4631, dev_loss: 0.3551 (3.20s)


                                                                                                                                                                                                                

Epoch 41/100 - train_loss: 0.3647, dev_loss: 0.3145 (3.24s)


                                                                                                                                                                                                                

Epoch 42/100 - train_loss: 0.3643, dev_loss: 0.2980 (3.25s)


                                                                                                                                                                                                                

Epoch 43/100 - train_loss: 0.3840, dev_loss: 0.4921 (3.32s)


                                                                                                                                                                                                                

Epoch 44/100 - train_loss: 0.3300, dev_loss: 0.3576 (3.34s)


                                                                                                                                                                                                                

Epoch 45/100 - train_loss: 0.4704, dev_loss: 0.4243 (3.29s)


                                                                                                                                                                                                                

Epoch 46/100 - train_loss: 0.4124, dev_loss: 0.4201 (3.15s)


                                                                                                                                                                                                                

Epoch 47/100 - train_loss: 0.3323, dev_loss: 0.3204 (3.22s)


                                                                                                                                                                                                                

Epoch 48/100 - train_loss: 0.3860, dev_loss: 0.2970 (3.19s)


                                                                                                                                                                                                                

Epoch 49/100 - train_loss: 0.3139, dev_loss: 0.2903 (3.26s)


                                                                                                                                                                                                                

Epoch 50/100 - train_loss: 0.3228, dev_loss: 0.3049 (3.20s)


                                                                                                                                                                                                                

Epoch 51/100 - train_loss: 0.4173, dev_loss: 0.3578 (3.45s)


                                                                                                                                                                                                                

Epoch 52/100 - train_loss: 0.3271, dev_loss: 0.4408 (3.24s)


                                                                                                                                                                                                                

Epoch 53/100 - train_loss: 0.4700, dev_loss: 0.2874 (3.16s)


                                                                                                                                                                                                                

Epoch 54/100 - train_loss: 0.3238, dev_loss: 0.3579 (3.22s)


                                                                                                                                                                                                                

Epoch 55/100 - train_loss: 0.4467, dev_loss: 0.3079 (3.38s)


                                                                                                                                                                                                                

Epoch 56/100 - train_loss: 0.3347, dev_loss: 0.3477 (3.23s)


                                                                                                                                                                                                                

Epoch 57/100 - train_loss: 0.3157, dev_loss: 0.2830 (3.28s)


                                                                                                                                                                                                                

Epoch 58/100 - train_loss: 0.3262, dev_loss: 0.2806 (3.33s)


                                                                                                                                                                                                                

Epoch 59/100 - train_loss: 0.3593, dev_loss: 0.2937 (3.15s)


                                                                                                                                                                                                                

Epoch 60/100 - train_loss: 0.4028, dev_loss: 0.7070 (3.25s)


                                                                                                                                                                                                                

Epoch 61/100 - train_loss: 0.4327, dev_loss: 0.2843 (3.27s)


                                                                                                                                                                                                                

Epoch 62/100 - train_loss: 0.3049, dev_loss: 0.3081 (3.32s)


                                                                                                                                                                                                                

Epoch 63/100 - train_loss: 0.3678, dev_loss: 0.2852 (3.28s)


                                                                                                                                                                                                                

Epoch 64/100 - train_loss: 0.3736, dev_loss: 0.3355 (3.57s)


                                                                                                                                                                                                                

Epoch 65/100 - train_loss: 0.3928, dev_loss: 0.3737 (3.37s)


                                                                                                                                                                                                                

Epoch 66/100 - train_loss: 0.4870, dev_loss: 0.4936 (3.23s)


                                                                                                                                                                                                                

Epoch 67/100 - train_loss: 0.4654, dev_loss: 0.5691 (3.25s)


                                                                                                                                                                                                                

Epoch 68/100 - train_loss: 0.3063, dev_loss: 0.2730 (3.14s)


                                                                                                                                                                                                                

Epoch 69/100 - train_loss: 0.3386, dev_loss: 0.3634 (3.29s)


                                                                                                                                                                                                                

Epoch 70/100 - train_loss: 0.4535, dev_loss: 0.3232 (3.22s)


                                                                                                                                                                                                                

Epoch 71/100 - train_loss: 0.3387, dev_loss: 0.2720 (3.07s)


                                                                                                                                                                                                                

Epoch 72/100 - train_loss: 0.3362, dev_loss: 0.5020 (3.20s)


                                                                                                                                                                                                                

Epoch 73/100 - train_loss: 0.4147, dev_loss: 0.2971 (3.32s)


                                                                                                                                                                                                                

Epoch 74/100 - train_loss: 0.3358, dev_loss: 0.5713 (3.32s)


                                                                                                                                                                                                                

Epoch 75/100 - train_loss: 0.4605, dev_loss: 0.2861 (3.24s)


                                                                                                                                                                                                                

Epoch 76/100 - train_loss: 0.2901, dev_loss: 0.2796 (3.31s)


                                                                                                                                                                                                                

Epoch 77/100 - train_loss: 0.4734, dev_loss: 0.5401 (3.27s)


                                                                                                                                                                                                                

Epoch 78/100 - train_loss: 0.3061, dev_loss: 0.2706 (3.25s)


                                                                                                                                                                                                                

Epoch 79/100 - train_loss: 0.3183, dev_loss: 0.2964 (3.23s)


                                                                                                                                                                                                                

Epoch 80/100 - train_loss: 0.2977, dev_loss: 0.2805 (3.21s)


                                                                                                                                                                                                                

Epoch 81/100 - train_loss: 0.3085, dev_loss: 0.3077 (3.15s)


                                                                                                                                                                                                                

Epoch 82/100 - train_loss: 0.4239, dev_loss: 0.2830 (3.38s)


                                                                                                                                                                                                                

Epoch 83/100 - train_loss: 0.3410, dev_loss: 0.2736 (3.25s)


                                                                                                                                                                                                                

Epoch 84/100 - train_loss: 0.3363, dev_loss: 0.3855 (3.18s)


                                                                                                                                                                                                                

Epoch 85/100 - train_loss: 0.3752, dev_loss: 0.3020 (3.22s)


                                                                                                                                                                                                                

Epoch 86/100 - train_loss: 0.3104, dev_loss: 0.2749 (3.20s)


                                                                                                                                                                                                                

Epoch 87/100 - train_loss: 0.3129, dev_loss: 0.3037 (3.18s)


                                                                                                                                                                                                                

Epoch 88/100 - train_loss: 0.3090, dev_loss: 0.3340 (3.10s)


                                                                                                                                                                                                                

Epoch 89/100 - train_loss: 0.4893, dev_loss: 0.2993 (3.21s)


                                                                                                                                                                                                                

Epoch 90/100 - train_loss: 0.4326, dev_loss: 0.2969 (3.55s)


                                                                                                                                                                                                                

Epoch 91/100 - train_loss: 0.3242, dev_loss: 0.6143 (3.23s)


                                                                                                                                                                                                                

Epoch 92/100 - train_loss: 0.5161, dev_loss: 0.4811 (3.22s)


                                                                                                                                                                                                                

Epoch 93/100 - train_loss: 0.3760, dev_loss: 0.2729 (3.29s)


                                                                                                                                                                                                                

Epoch 94/100 - train_loss: 0.2870, dev_loss: 0.2932 (3.34s)


                                                                                                                                                                                                                

Epoch 95/100 - train_loss: 0.2890, dev_loss: 0.2741 (3.28s)


                                                                                                                                                                                                                

Epoch 96/100 - train_loss: 0.3456, dev_loss: 0.2957 (3.33s)


                                                                                                                                                                                                                

Epoch 97/100 - train_loss: 0.2956, dev_loss: 0.3610 (3.36s)


                                                                                                                                                                                                                

Epoch 98/100 - train_loss: 0.5555, dev_loss: 0.2966 (3.22s)


                                                                                                                                                                                                                

Epoch 99/100 - train_loss: 0.2868, dev_loss: 0.2881 (3.12s)


                                                                                                                                                                                                                

Epoch 100/100 - train_loss: 0.4074, dev_loss: 0.2717 (3.17s)


#### FNO

In [24]:
fno = FNO(
    n_channels=1,
    upscale_factor=args.upscale_factor
)

In [25]:
fno.configure(
    optimizer = torch.optim.AdamW(params=espcn.parameters(), lr=args.lr),
    scheduler = None,
    loss = torchmetrics.MeanSquaredError(squared=False),
    args = args
)

FNO(
  (P): Linear(in_features=1, out_features=32, bias=True)
  (spectral_convs): ModuleList(
    (0-3): 4 x SpectralConv2d()
  )
  (weights): ModuleList(
    (0-3): 4 x Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
  )
  (Q): Linear(in_features=32, out_features=1, bias=True)
  (loss): MeanSquaredError()
)

In [26]:
fno.fit(rekis_train, rekis_dev, args.epochs)

                                                                                                                                                                                                                

Epoch 1/100 - train_loss: 11.5460, dev_loss: 10.7557 (134.69s)


                                                                                                                                                                                                                

Epoch 2/100 - train_loss: 11.5400, dev_loss: 10.7557 (133.12s)


                                                                                                                                                                                                                

Epoch 3/100 - train_loss: 11.5461, dev_loss: 10.7557 (133.11s)


                                                                                                                                                                                                                

Epoch 4/100 - train_loss: 11.5436, dev_loss: 10.7557 (133.12s)


                                                                                                                                                                                                                

Epoch 5/100 - train_loss: 11.5472, dev_loss: 10.7557 (133.12s)


                                                                                                                                                                                                                

Epoch 6/100 - train_loss: 11.5418, dev_loss: 10.7557 (133.12s)


                                                                                                                                                                                                                

Epoch 7/100 - train_loss: 11.5486, dev_loss: 10.7557 (133.12s)


                                                                                                                                                                                                                

KeyboardInterrupt: 

#### EDSR

In [11]:
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", default=32, type=int)
parser.add_argument("--epochs", default=60, type=int)
parser.add_argument("--seed", default=42, type=int)
parser.add_argument("--threads", default=0, type=int)
parser.add_argument("--upscale_factor", default=10, type=int)
parser.add_argument("--lr", default=1e-4, type=float)
parser.add_argument("--wd", default=1e-5, type=float)
parser.add_argument("--logdir", default="logs", type=str)
parser.add_argument("--variables", default=["TM"], type=list)

_StoreAction(option_strings=['--variables'], dest='variables', nargs=None, const=None, default=['TM'], type=<class 'list'>, choices=None, required=False, help=None, metavar=None)

In [12]:
args = parser.parse_args([])
cddlt.startup(args)

In [13]:
edsr = EDSR(
    channels=1,
    scale_factor=10
)

In [15]:
optimizer = torch.optim.AdamW(params=edsr.parameters(), lr=args.lr, weight_decay=args.wd)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)

In [16]:
edsr.configure(
    optimizer = optimizer,
    scheduler = scheduler,
    loss = torch.nn.L1Loss(),
    args = args
)

EDSR(
  (loss): L1Loss()
  (head): Sequential(
    (0): Conv2d(1, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (body): Sequential(
    (0): ResidualBlock(
      (residual_block): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (1): ResidualBlock(
      (residual_block): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (2): ResidualBlock(
      (residual_block): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (3): ResidualBlock(
      (residual_block): Sequential(
        (0): Conv2d(256,

In [None]:
edsr.fit(rekis_train, rekis_dev, args.epochs)

  return F.conv2d(input, weight, bias, self.stride,
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
                                                                                                         

Epoch 1/60 - train_loss: 0.6536, dev_loss: 0.3353 (431.99s)


                                                                                                         

Epoch 2/60 - train_loss: 0.2577, dev_loss: 0.8569 (424.62s)


                                                                                                         

Epoch 3/60 - train_loss: 0.2534, dev_loss: 0.1974 (425.25s)


                                                                                                         

Epoch 4/60 - train_loss: 0.2158, dev_loss: 0.2896 (424.68s)


                                                                                                         

Epoch 5/60 - train_loss: 0.2095, dev_loss: 0.5226 (424.71s)


                                                                                                         

Epoch 6/60 - train_loss: 0.1779, dev_loss: 0.2876 (424.95s)


                                                                                     

Epoch 7/60 - train_loss: 0.1751, dev_loss: 0.3609 (425.05s)


                                                                                     

Epoch 8/60 - train_loss: 0.2013, dev_loss: 0.3146 (424.70s)


                                                                               

Epoch 9/60 - train_loss: 0.1875, dev_loss: 0.3413 (424.76s)


                                                                                  

Epoch 10/60 - train_loss: 0.1561, dev_loss: 0.1502 (425.79s)


                                                                                  

Epoch 11/60 - train_loss: 0.1870, dev_loss: 0.2298 (424.70s)


                                                                                  

Epoch 12/60 - train_loss: 0.1769, dev_loss: 0.2474 (424.47s)


                                                                                  

Epoch 13/60 - train_loss: 0.1474, dev_loss: 0.1678 (424.34s)


                                                                                  

Epoch 14/60 - train_loss: 0.1663, dev_loss: 0.2240 (423.61s)


                                                                                  

Epoch 15/60 - train_loss: 0.1451, dev_loss: 0.1790 (423.78s)


                                                                                  

Epoch 16/60 - train_loss: 0.1524, dev_loss: 0.1362 (424.13s)


                                                                                  

Epoch 17/60 - train_loss: 0.1304, dev_loss: 0.2522 (423.78s)


                                                                                  

Epoch 18/60 - train_loss: 0.1292, dev_loss: 0.1990 (423.86s)


                                                                                  

Epoch 19/60 - train_loss: 0.1203, dev_loss: 0.1206 (424.48s)


                                                                                  

Epoch 20/60 - train_loss: 0.1290, dev_loss: 0.1717 (423.61s)


                                                                                  

Epoch 21/60 - train_loss: 0.1156, dev_loss: 0.2201 (423.75s)


                                                                                  

Epoch 22/60 - train_loss: 0.0959, dev_loss: 0.1982 (423.83s)


                                                                                  

Epoch 23/60 - train_loss: 0.1241, dev_loss: 0.1206 (424.15s)


                                                                                  

Epoch 24/60 - train_loss: 0.0929, dev_loss: 0.1182 (424.43s)


                                                                                  

Epoch 25/60 - train_loss: 0.0876, dev_loss: 0.1262 (423.83s)


                                                                                  

Epoch 26/60 - train_loss: 0.0985, dev_loss: 0.1498 (423.78s)


                                                                                  

Epoch 27/60 - train_loss: 0.0995, dev_loss: 0.0896 (424.42s)


                                                                                  

Epoch 28/60 - train_loss: 0.1240, dev_loss: 0.1652 (423.84s)


                                                                                  

Epoch 29/60 - train_loss: 0.0880, dev_loss: 0.1204 (423.85s)


                                                                                  

Epoch 30/60 - train_loss: 0.0857, dev_loss: 0.1078 (423.77s)


                                                                                  

Epoch 31/60 - train_loss: 0.0750, dev_loss: 0.0554 (425.50s)


                                                                                  

Epoch 32/60 - train_loss: 0.0815, dev_loss: 0.1888 (424.85s)


33/60:  20%|███████▍                             | 48/240 [01:17<05:18,  1.66s/it]

#### DeepESDtas

In [10]:
class DeepESDtas(cddlt.DLModule):

    """
    DeepESD model as proposed in Baño-Medina et al. 2024 for temperature
    downscasling. This implementation allows for a deterministic (MSE-based)
    and stochastic (NLL-based) definition.

    Baño-Medina, J., Manzanas, R., Cimadevilla, E., Fernández, J., González-Abad,
    J., Cofiño, A. S., and Gutiérrez, J. M.: Downscaling multi-model climate projection
    ensembles with deep learning (DeepESD): contribution to CORDEX EUR-44, Geosci. Model
    Dev., 15, 6747–6758, https://doi.org/10.5194/gmd-15-6747-2022, 2022.

    Parameters
    ----------
    x_shape : tuple
        Shape of the data used as predictor. This must have dimension 4
        (time, channels/variables, lon, lat).

    y_shape : tuple
        Shape of the data used as predictand. This must have dimension 2
        (time, gridpoint)

    filters_last_conv : int
        Number of filters/kernels of the last convolutional layer

    stochastic: bool
        If set to True, the model is composed of two final dense layers computing
        the mean and log fo the variance. Otherwise, the models is composed of one
        final layer computing the values.
    """


    def __init__(self, x_shape: tuple, y_shape: tuple,
                 filters_last_conv: int, stochastic: bool):

        super(DeepESDtas, self).__init__()

        if (len(x_shape) != 4) or (len(y_shape) != 2):
            error_msg =\
            'X and Y data must have a dimension of length 4'
            'and 2, correspondingly'

            raise ValueError(error_msg)

        self.x_shape = x_shape
        self.y_shape = y_shape
        self.filters_last_conv = filters_last_conv
        self.stochastic = stochastic

        self.conv_1 = torch.nn.Conv2d(in_channels=self.x_shape[1],
                                      out_channels=50,
                                      kernel_size=3,
                                      padding=1)

        self.conv_2 = torch.nn.Conv2d(in_channels=50,
                                      out_channels=25,
                                      kernel_size=3,
                                      padding=1)

        self.conv_3 = torch.nn.Conv2d(in_channels=25,
                                      out_channels=self.filters_last_conv,
                                      kernel_size=3,
                                      padding=1)

        if self.stochastic:
            self.out_mean = torch.nn.Linear(in_features=\
                                            self.x_shape[2] * self.x_shape[3] * self.filters_last_conv,
                                            out_features=self.y_shape[1])

            self.out_log_var = torch.nn.Linear(in_features=\
                                               self.x_shape[2] * self.x_shape[3] * self.filters_last_conv,
                                               out_features=self.y_shape[1])

        else:
            self.out = torch.nn.Linear(in_features=\
                                       self.x_shape[2] * self.x_shape[3] * self.filters_last_conv,
                                       out_features=self.y_shape[1])

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        x = self.conv_1(x)
        x = torch.relu(x)

        x = self.conv_2(x)
        x = torch.relu(x)

        x = self.conv_3(x)
        x = torch.relu(x)

        x = torch.flatten(x, start_dim=1)

        if self.stochastic:
            mean = self.out_mean(x)
            log_var = self.out_log_var(x)
            out = torch.cat((mean, log_var), dim=1)
        else:
            out = self.out(x)
        
        return out

In [11]:
from typing import Dict, Tuple
from cddlt.models.deepesdtas import DeepESDtas

In [12]:
import torch
from typing import Type, Dict, Union, Any, List, Tuple, Optional, Callable

class DownscalingTransform:
    def __init__(
        self, 
        dataset: Type[torch.utils.data.Dataset],
        transform: Optional[Callable] = None,
        collate_fn: Optional[Callable] = None,
    ) -> None:
        self.dataset = dataset
        self.transform = transform if transform is not None else self.default_transform
        self.collate_fn = collate_fn if collate_fn is not None else self.default_collate_fn

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, index: int) -> Dict[str, Union[torch.Tensor, Any]]:
        sample = self.dataset[index]
        return self.transform(sample)
    
    @staticmethod
    def default_transform(sample: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
        return sample["input"], sample["target"]

    @staticmethod
    def default_collate_fn(batch: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
        inputs = torch.stack([item[0] for item in batch])
        targets = torch.stack([item[1] for item in batch])
        return inputs, targets

    def dataloader(self, batch_size: int, shuffle: bool = False, num_workers: int = 0) -> torch.utils.data.DataLoader:
        return torch.utils.data.DataLoader(
            self, 
            batch_size=batch_size, 
            shuffle=shuffle,
            collate_fn=self.collate_fn,
            num_workers=num_workers
        )

In [13]:
def sample_transform(sample: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
    inputs = sample[0]
    targets = sample[1]
    B, H, W = targets.shape
    targets = targets.view(B, H * W).squeeze(1)
    return inputs, targets

In [14]:
def collate_fn(batch: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
    inputs = torch.stack([item[0] for item in batch])
    targets = torch.stack([item[1] for item in batch]).squeeze(1)
    return inputs, targets.squeeze(1)

In [15]:
rekis_deepesd_train = DownscalingTransform(
    dataset=rekis.train, 
    transform=sample_transform,
    collate_fn=collate_fn
).dataloader(args.batch_size, shuffle=True)

rekis_deepesd_dev = DownscalingTransform(
    dataset=rekis.dev, 
    transform=sample_transform,
    collate_fn=collate_fn
).dataloader(args.batch_size)

In [16]:
deepesdtas = DeepESDtas(
    x_shape = (..., 1, 40, 40),
    y_shape = (..., 160000),
    filters_last_conv = 4,
    stochastic = False
)

In [17]:
deepesdtas.configure(
    optimizer = torch.optim.Adam(params=deepesdtas.parameters(), lr=args.lr),
    scheduler = None,
    loss = torchmetrics.MeanSquaredError(squared=False),
    args = args,
)

DeepESDtas(
  (conv_1): Conv2d(1, 50, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv_2): Conv2d(50, 25, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv_3): Conv2d(25, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (out): Linear(in_features=6400, out_features=160000, bias=True)
  (loss): MeanSquaredError()
)

In [None]:
deepesdtas.fit(rekis_deepesd_train, rekis_deepesd_dev, args.epochs)

                                                                                                                        

Epoch 1/100 - train_loss: 0.9945, dev_loss: 0.5351 (126.20s)


                                                                                                           

Epoch 2/100 - train_loss: 0.5234, dev_loss: 0.3506 (125.22s)


                                                                                                           

Epoch 3/100 - train_loss: 0.3756, dev_loss: 0.5112 (114.57s)


                                                                                                           

Epoch 4/100 - train_loss: 0.3325, dev_loss: 0.3784 (114.49s)


                                                                                                           

Epoch 5/100 - train_loss: 0.2986, dev_loss: 0.2090 (125.07s)


                                                                                                           

Epoch 6/100 - train_loss: 0.2641, dev_loss: 0.2760 (114.40s)


                                                                                                           

Epoch 7/100 - train_loss: 0.2417, dev_loss: 0.1945 (125.89s)


                                                                                                           

Epoch 8/100 - train_loss: 0.2230, dev_loss: 0.3382 (114.40s)


                                                                                                           

Epoch 9/100 - train_loss: 0.2026, dev_loss: 0.1816 (125.54s)


                                                                                                           

Epoch 10/100 - train_loss: 0.1927, dev_loss: 0.1530 (125.16s)


                                                                                                           

Epoch 11/100 - train_loss: 0.1949, dev_loss: 0.1611 (114.42s)


                                                                                                           

Epoch 12/100 - train_loss: 0.1877, dev_loss: 0.1698 (114.39s)


                                                                                                           

Epoch 13/100 - train_loss: 0.1805, dev_loss: 0.1693 (114.38s)


                                                                                                           

Epoch 14/100 - train_loss: 0.1835, dev_loss: 0.1527 (125.95s)


                                                                                                           

Epoch 15/100 - train_loss: 0.1616, dev_loss: 0.1323 (124.81s)


                                                                                                           

Epoch 16/100 - train_loss: 0.1775, dev_loss: 0.1513 (114.65s)


                                                                                                           

Epoch 17/100 - train_loss: 0.1579, dev_loss: 0.1130 (125.71s)


                                                                                                           

Epoch 18/100 - train_loss: 0.1663, dev_loss: 0.2585 (114.72s)


                                                                                                           

Epoch 19/100 - train_loss: 0.1609, dev_loss: 0.1998 (114.71s)


                                                                                                           

Epoch 20/100 - train_loss: 0.1613, dev_loss: 0.1008 (126.30s)


                                                                                                           

Epoch 21/100 - train_loss: 0.1745, dev_loss: 0.1184 (114.71s)


                                                                                                           

Epoch 22/100 - train_loss: 0.1447, dev_loss: 0.1238 (114.71s)


                                                                                                           

Epoch 23/100 - train_loss: 0.1509, dev_loss: 0.1582 (114.72s)


                                                                                                           

Epoch 24/100 - train_loss: 0.1476, dev_loss: 0.1810 (114.70s)


                                                                                                           

Epoch 25/100 - train_loss: 0.1401, dev_loss: 0.0949 (125.90s)


                                                                                                           

Epoch 26/100 - train_loss: 0.1420, dev_loss: 0.1146 (114.69s)


                                                                                                           

Epoch 27/100 - train_loss: 0.1431, dev_loss: 0.0982 (114.69s)


                                                                                                           

Epoch 28/100 - train_loss: 0.1298, dev_loss: 0.1323 (114.69s)


                                                                                                           

Epoch 29/100 - train_loss: 0.1272, dev_loss: 0.1202 (114.70s)


                                                                                                           

Epoch 30/100 - train_loss: 0.1381, dev_loss: 0.1330 (114.68s)


                                                                                                           

Epoch 31/100 - train_loss: 0.1320, dev_loss: 0.1014 (114.70s)


                                                                                                           

Epoch 32/100 - train_loss: 0.1309, dev_loss: 0.1425 (114.69s)


                                                                                                           

Epoch 33/100 - train_loss: 0.1263, dev_loss: 0.1621 (114.51s)


                                                                                                           

Epoch 34/100 - train_loss: 0.1239, dev_loss: 0.1679 (114.35s)


                                                                                                           

Epoch 35/100 - train_loss: 0.1260, dev_loss: 0.1217 (114.35s)


                                                                                                           

Epoch 36/100 - train_loss: 0.1168, dev_loss: 0.1247 (114.48s)


                                                                                                           

Epoch 37/100 - train_loss: 0.1193, dev_loss: 0.0878 (126.60s)


                                                                                                           

Epoch 38/100 - train_loss: 0.1177, dev_loss: 0.1150 (114.51s)


                                                                                                           

Epoch 39/100 - train_loss: 0.1314, dev_loss: 0.1365 (114.52s)


                                                                                                           

Epoch 40/100 - train_loss: 0.1199, dev_loss: 0.0984 (114.52s)


                                                                                                           

Epoch 41/100 - train_loss: 0.1088, dev_loss: 0.0884 (114.53s)


                                                                                                           

Epoch 42/100 - train_loss: 0.1086, dev_loss: 0.1089 (114.54s)


                                                                                                           

Epoch 43/100 - train_loss: 0.1201, dev_loss: 0.0828 (124.87s)


                                                                                                           

Epoch 44/100 - train_loss: 0.1154, dev_loss: 0.1131 (114.50s)


                                                                                                           

Epoch 45/100 - train_loss: 0.1054, dev_loss: 0.0911 (114.48s)


                                                                                                           

Epoch 46/100 - train_loss: 0.1209, dev_loss: 0.1696 (114.50s)


                                                                                                           

Epoch 47/100 - train_loss: 0.1158, dev_loss: 0.0908 (114.49s)


                                                                                                           

Epoch 48/100 - train_loss: 0.1212, dev_loss: 0.0893 (114.50s)


                                                                                                      

Epoch 49/100 - train_loss: 0.1101, dev_loss: 0.1170 (114.51s)


                                                                                                      

Epoch 50/100 - train_loss: 0.1114, dev_loss: 0.1030 (114.53s)


                                                                                                      

Epoch 51/100 - train_loss: 0.1121, dev_loss: 0.1085 (114.53s)


                                                                                                      

Epoch 52/100 - train_loss: 0.1148, dev_loss: 0.0981 (114.51s)


                                                                                                      

Epoch 53/100 - train_loss: 0.1071, dev_loss: 0.1041 (114.51s)


                                                                                                      

Epoch 54/100 - train_loss: 0.1103, dev_loss: 0.0811 (124.97s)


                                                                                                      

Epoch 55/100 - train_loss: 0.1148, dev_loss: 0.0815 (114.52s)


                                                                                                      

Epoch 56/100 - train_loss: 0.1088, dev_loss: 0.1403 (114.54s)


                                                                                                      

Epoch 57/100 - train_loss: 0.1038, dev_loss: 0.0877 (114.53s)


                                                                                                      

Epoch 58/100 - train_loss: 0.1007, dev_loss: 0.1173 (114.54s)


                                                                                                      

Epoch 59/100 - train_loss: 0.1029, dev_loss: 0.0962 (114.53s)


                                                                                                      

Epoch 60/100 - train_loss: 0.0994, dev_loss: 0.0995 (114.52s)


                                                                                                      

Epoch 61/100 - train_loss: 0.0982, dev_loss: 0.0943 (114.52s)


                                                                                                      

Epoch 62/100 - train_loss: 0.0985, dev_loss: 0.0960 (114.52s)


                                                                                                      

Epoch 63/100 - train_loss: 0.1035, dev_loss: 0.0907 (114.54s)


                                                                                                      

Epoch 64/100 - train_loss: 0.1097, dev_loss: 0.0973 (114.51s)


                                                                                                      

Epoch 65/100 - train_loss: 0.0987, dev_loss: 0.1575 (114.54s)


                                                                                                      

Epoch 66/100 - train_loss: 0.1045, dev_loss: 0.0908 (114.54s)


                                                                                                      

Epoch 67/100 - train_loss: 0.1069, dev_loss: 0.1049 (114.53s)


                                                                                                      

Epoch 68/100 - train_loss: 0.1072, dev_loss: 0.0970 (114.55s)


                                                                                                      

Epoch 69/100 - train_loss: 0.1007, dev_loss: 0.0800 (125.26s)


                                                                                                      

Epoch 70/100 - train_loss: 0.0959, dev_loss: 0.0749 (125.61s)


                                                                                                      

Epoch 71/100 - train_loss: 0.1009, dev_loss: 0.1400 (114.51s)


                                                                                                      

Epoch 72/100 - train_loss: 0.1031, dev_loss: 0.1888 (114.52s)


                                                                                                      

Epoch 73/100 - train_loss: 0.0989, dev_loss: 0.1164 (114.52s)


                                                                                                      

Epoch 74/100 - train_loss: 0.0979, dev_loss: 0.1053 (114.52s)


                                                                                                      

Epoch 75/100 - train_loss: 0.0951, dev_loss: 0.0843 (114.53s)


                                                                                                      

Epoch 76/100 - train_loss: 0.0879, dev_loss: 0.0950 (114.51s)


                                                                                                      

Epoch 77/100 - train_loss: 0.0948, dev_loss: 0.0754 (114.51s)


                                                                                                      

Epoch 78/100 - train_loss: 0.0963, dev_loss: 0.0857 (114.52s)


                                                                                                      

Epoch 79/100 - train_loss: 0.0942, dev_loss: 0.0761 (114.50s)


                                                                                                      

Epoch 80/100 - train_loss: 0.0970, dev_loss: 0.1050 (114.53s)


                                                                                                      

Epoch 81/100 - train_loss: 0.0959, dev_loss: 0.0830 (114.53s)


                                                                                                      

Epoch 82/100 - train_loss: 0.0943, dev_loss: 0.1034 (114.54s)


                                                                                                      

Epoch 83/100 - train_loss: 0.0892, dev_loss: 0.1208 (114.52s)


                                                                                                      

Epoch 84/100 - train_loss: 0.0916, dev_loss: 0.1387 (114.52s)


                                                                                                      

Epoch 85/100 - train_loss: 0.0967, dev_loss: 0.1074 (114.51s)


                                                                                                      

Epoch 86/100 - train_loss: 0.1010, dev_loss: 0.1006 (114.53s)


                                                                                                      

Epoch 87/100 - train_loss: 0.0903, dev_loss: 0.2227 (114.52s)


                                                                                                      

Epoch 88/100 - train_loss: 0.0935, dev_loss: 0.0950 (114.51s)


                                                                                                  

Epoch 89/100 - train_loss: 0.0873, dev_loss: 0.0809 (114.52s)


                                                                                                  

Epoch 90/100 - train_loss: 0.0936, dev_loss: 0.0847 (114.52s)


                                                                                                  

Epoch 91/100 - train_loss: 0.0917, dev_loss: 0.0964 (114.53s)


                                                                                                  

Epoch 92/100 - train_loss: 0.0908, dev_loss: 0.0821 (114.52s)


                                                                                                  

Epoch 93/100 - train_loss: 0.0895, dev_loss: 0.0769 (114.51s)


                                                                                                  

Epoch 94/100 - train_loss: 0.0877, dev_loss: 0.0764 (114.52s)


                                                                                                  

Epoch 95/100 - train_loss: 0.0957, dev_loss: 0.0876 (114.52s)


                                                                                                  

Epoch 96/100 - train_loss: 0.0910, dev_loss: 0.0877 (114.51s)


97/100:  98%|█████████████████████████████████████████████████▉ | 938/959 [01:50<00:02,  8.51it/s]

In [None]:
deepesdtas.load_weights(os.path.join(args.logdir, deepesdtas.model_name))
deepesdtas.evaluate(rekis_deepesd_dev, print_loss=True)

#### SwinIR

In [8]:
from typing import List

In [10]:
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", default=32, type=int)
parser.add_argument("--epochs", default=60, type=int)
parser.add_argument("--seed", default=42, type=int)
parser.add_argument("--threads", default=0, type=int)
parser.add_argument("--upscale_factor", default=10, type=int)
parser.add_argument("--embedding_dim", default=180, type=int)
parser.add_argument("--window_size", default=8, type=int)
parser.add_argument("--depths", default=[6, 6, 6, 6, 6, 6], type=List[int])
parser.add_argument("--num_heads", default=[6, 6, 6, 6, 6, 6], type=List[int])
parser.add_argument("--lr", default=1e-4, type=float)
parser.add_argument("--wd", default=1e-5, type=float)
parser.add_argument("--logdir", default="logs", type=str)
parser.add_argument("--variables", default=["TM"], type=list)

_StoreAction(option_strings=['--variables'], dest='variables', nargs=None, const=None, default=['TM'], type=<class 'list'>, choices=None, required=False, help=None, metavar=None)

In [11]:
args = parser.parse_args([])
cddlt.startup(args)

In [12]:
swin = SwinIR(
    img_size=40,
    in_chans=len(args.variables),
    window_size=args.window_size,
    embed_dim=args.embedding_dim,
    depths=args.depths,
    num_heads=args.num_heads,
    upsampler="pixelshuffle",
    upscale=args.upscale_factor
)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [13]:
optimizer = torch.optim.AdamW(params=swin.parameters(), lr=args.lr, weight_decay=args.wd)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)

In [14]:
swin.configure(
    optimizer = optimizer,
    scheduler = scheduler,
    loss = torch.nn.L1Loss(),
    args = args
)

SwinIR(
  (conv_first): Conv2d(1, 180, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (patch_embed): PatchEmbed(
    (norm): LayerNorm((180,), eps=1e-05, elementwise_affine=True)
  )
  (patch_unembed): PatchUnEmbed()
  (pos_drop): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0): RSTB(
      (residual_group): BasicLayer(
        dim=180, input_resolution=(40, 40), depth=6
        (blocks): ModuleList(
          (0): SwinTransformerBlock(
            dim=180, input_resolution=(40, 40), num_heads=6, window_size=8, shift_size=0, mlp_ratio=4.0
            (norm1): LayerNorm((180,), eps=1e-05, elementwise_affine=True)
            (attn): WindowAttention(
              dim=180, window_size=(8, 8), num_heads=6
              (qkv): Linear(in_features=180, out_features=540, bias=True)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=180, out_features=180, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
     

In [None]:
swin.fit(rekis_train, rekis_dev, args.epochs)

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
                                                                                     

Epoch 1/60 - train_loss: 0.6433, dev_loss: 0.2780 (287.35s)


                                                                                     

Epoch 2/60 - train_loss: 0.2490, dev_loss: 0.3682 (283.79s)


                                                                                     

Epoch 3/60 - train_loss: 0.2116, dev_loss: 0.3340 (283.93s)


                                                                                     

Epoch 4/60 - train_loss: 0.1877, dev_loss: 0.3196 (284.13s)


                                                                                     

Epoch 5/60 - train_loss: 0.1702, dev_loss: 0.1540 (284.43s)


                                                                                     

Epoch 6/60 - train_loss: 0.1742, dev_loss: 0.4562 (283.13s)


                                                                                     

Epoch 7/60 - train_loss: 0.1581, dev_loss: 0.1712 (284.17s)


                                                                                     

Epoch 8/60 - train_loss: 0.1582, dev_loss: 0.1492 (284.35s)


                                                                                     

Epoch 9/60 - train_loss: 0.1453, dev_loss: 0.1150 (284.51s)


                                                                                     

Epoch 10/60 - train_loss: 0.1484, dev_loss: 0.1537 (284.26s)


                                                                                     

Epoch 11/60 - train_loss: 0.1367, dev_loss: 0.1432 (284.26s)


                                                                                     

Epoch 12/60 - train_loss: 0.1245, dev_loss: 0.1691 (284.14s)


                                                                                     

Epoch 13/60 - train_loss: 0.1127, dev_loss: 0.1249 (284.36s)


                                                                                     

Epoch 14/60 - train_loss: 0.1150, dev_loss: 0.0851 (284.63s)


                                                                                     

Epoch 15/60 - train_loss: 0.1048, dev_loss: 0.0729 (284.64s)


                                                                                     

Epoch 16/60 - train_loss: 0.1016, dev_loss: 0.1224 (284.33s)


                                                                                     

Epoch 17/60 - train_loss: 0.0935, dev_loss: 0.0644 (284.49s)


                                                                                     

Epoch 18/60 - train_loss: 0.0893, dev_loss: 0.0589 (283.56s)


                                                                                     

Epoch 19/60 - train_loss: 0.0855, dev_loss: 0.1622 (284.34s)


                                                                                     

Epoch 20/60 - train_loss: 0.0842, dev_loss: 0.0598 (284.34s)


                                                                                     

Epoch 21/60 - train_loss: 0.0830, dev_loss: 0.0920 (284.41s)


                                                                                     

Epoch 22/60 - train_loss: 0.0746, dev_loss: 0.0495 (284.47s)


                                                                                     

Epoch 23/60 - train_loss: 0.0760, dev_loss: 0.0481 (284.62s)


                                                                                     

Epoch 24/60 - train_loss: 0.0759, dev_loss: 0.0447 (284.62s)


                                                                                     

Epoch 25/60 - train_loss: 0.0687, dev_loss: 0.0492 (284.30s)


                                                                                     

Epoch 26/60 - train_loss: 0.0796, dev_loss: 0.0521 (284.30s)


                                                                                     

Epoch 27/60 - train_loss: 0.0695, dev_loss: 0.0439 (284.51s)


                                                                                     

Epoch 28/60 - train_loss: 0.0689, dev_loss: 0.0471 (284.37s)


                                                                                     

Epoch 29/60 - train_loss: 0.0635, dev_loss: 0.0420 (284.65s)


                                                                                     

Epoch 30/60 - train_loss: 0.0642, dev_loss: 0.0391 (283.52s)


                                                                                     

Epoch 31/60 - train_loss: 0.0651, dev_loss: 0.0392 (284.25s)


                                                                                     

Epoch 32/60 - train_loss: 0.0599, dev_loss: 0.0401 (284.37s)


                                                                                     

Epoch 33/60 - train_loss: 0.0628, dev_loss: 0.0382 (284.65s)


                                                                                     

Epoch 34/60 - train_loss: 0.0614, dev_loss: 0.0375 (283.58s)


                                                                                     

Epoch 35/60 - train_loss: 0.0633, dev_loss: 0.0356 (284.70s)


                                                                                     

Epoch 36/60 - train_loss: 0.0569, dev_loss: 0.0360 (284.28s)


                                                                                     

Epoch 37/60 - train_loss: 0.0644, dev_loss: 0.0408 (284.44s)


                                                                                     

Epoch 38/60 - train_loss: 0.0575, dev_loss: 0.0348 (284.69s)


                                                                                     

Epoch 39/60 - train_loss: 0.0499, dev_loss: 0.0346 (284.68s)


                                                                                     

Epoch 40/60 - train_loss: 0.0549, dev_loss: 0.0339 (283.59s)


                                                                                     

Epoch 41/60 - train_loss: 0.0590, dev_loss: 0.0332 (283.49s)


                                                                                     

Epoch 42/60 - train_loss: 0.0562, dev_loss: 0.0339 (284.40s)


                                                                                     

Epoch 43/60 - train_loss: 0.0576, dev_loss: 0.0330 (284.68s)


                                                                                     

Epoch 44/60 - train_loss: 0.0542, dev_loss: 0.0331 (284.42s)


                                                                                     

Epoch 45/60 - train_loss: 0.0605, dev_loss: 0.0341 (284.44s)


                                                                                     

Epoch 46/60 - train_loss: 0.0510, dev_loss: 0.0317 (284.59s)


                                                                                     

Epoch 47/60 - train_loss: 0.0495, dev_loss: 0.0331 (284.45s)


48/60:   3%|█▏                                      | 28/959 [00:07<04:16,  3.64it/s]

In [None]:
swin.load_weights(os.path.join(args.logdir, swin.model_name))
swin.evaluate(rekis_dev, print_loss=True)

### CORDEX

In [34]:
### FIX
cordex = CORDEX(
    data_path=CORDEX_DATA_PATH,
    variables=args.variables,
    dev_len=("2000-01-01", "2001-01-01"),
    test_len=("2001-01-01", "2002-01-01"),
)

Loading 34 NetCDF file(s)...
Loaded data shape: {'time': 366, 'y': 40, 'x': 40}
Time range: 2000-01-01 12:00:00 to 2000-12-31 12:00:00
Loading 34 NetCDF file(s)...
Loaded data shape: {'time': 365, 'y': 40, 'x': 40}
Time range: 2001-01-01 12:00:00 to 2001-12-31 12:00:00

CORDEX dataset initalized.
dev size: (366)
test size: (365)

