### Climate downscaling using CDDLT package

In [1]:
# !pip install --index-url https://test.pypi.org/simple/ \
#              --extra-index-url https://pypi.org/simple \
#              cddlt --quiet

In [1]:
import os
import argparse
import torch
import torchmetrics
import cddlt

from cddlt.datasets.rekis_dataset import ReKIS
from cddlt.datasets.cordex_dataset import CORDEX
from cddlt.dataloaders.downscaling_transform import DownscalingTransform

from cddlt.models.espcn import ESPCN
from cddlt.models.srcnn import SRCNN
from cddlt.models.bicubic import Bicubic
from cddlt.models.fno import FNO



In [2]:
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", default=64, type=int)
parser.add_argument("--epochs", default=10, type=int)
parser.add_argument("--seed", default=42, type=int)
parser.add_argument("--threads", default=4, type=int)
parser.add_argument("--upscale_factor", default=10, type=int)
parser.add_argument("--lr", default=0.001, type=float)
parser.add_argument("--logdir", default="logs", type=str)
parser.add_argument("--variables", default=["TM"], type=list)

_StoreAction(option_strings=['--variables'], dest='variables', nargs=None, const=None, default=['TM'], type=<class 'list'>, choices=None, required=False, help=None, metavar=None)

In [4]:
args = parser.parse_args([])
cddlt.startup(args, os.path.basename("notebook.ipynb"))

In [5]:
REKIS_DATA_PATH = "/home/kostape4/ReKIS/KlimRefDS_v3.1_1961-2023/Raster/Tag/GK4/TM/"
CORDEX_DATA_PATH = "/home/kostape4/CORDEX/tas/GERICS_REMO2015"

### ReKIS

In [10]:
rekis = ReKIS(
    data_path=REKIS_DATA_PATH,
    variables=args.variables,
    train_len=("2000-01-01", "2001-01-01"),
    dev_len=("2001-01-01", "2002-01-01"),
    test_len=("2002-01-01", "2003-01-01"), ## value framework input
    resampling="cubic_spline"
)

Loading 63 NetCDF file(s)...
Loaded data shape: {'time': 366, 'northing': 401, 'easting': 418}
Time range: 2000-01-01 00:00:00 to 2000-12-31 00:00:00
Variables in dataset: ['TM']
inputs
Shape of inputs: torch.Size([366, 1, 40, 40])
targets
Shape of targets: torch.Size([366, 1, 400, 400])
Loading 63 NetCDF file(s)...
Loaded data shape: {'time': 365, 'northing': 401, 'easting': 418}
Time range: 2001-01-01 00:00:00 to 2001-12-31 00:00:00
Variables in dataset: ['TM']
inputs
Shape of inputs: torch.Size([365, 1, 40, 40])
targets
Shape of targets: torch.Size([365, 1, 400, 400])
Loading 63 NetCDF file(s)...
Loaded data shape: {'time': 365, 'northing': 401, 'easting': 418}
Time range: 2002-01-01 00:00:00 to 2002-12-31 00:00:00
Variables in dataset: ['TM']
inputs
Shape of inputs: torch.Size([365, 1, 40, 40])
targets
Shape of targets: torch.Size([365, 1, 400, 400])
ReKIS dataset initalized.
train size: (366)
dev size: (365)
test size: (365)


In [22]:
rekis_train = DownscalingTransform(dataset=rekis.train).dataloader(args.batch_size, shuffle=True)
rekis_dev = DownscalingTransform(dataset=rekis.dev).dataloader(args.batch_size)
rekis_test = DownscalingTransform(dataset=rekis.test).dataloader(args.batch_size)

### Models

In [12]:
espcn = ESPCN(
    n_channels=1,
    upscale_factor=args.upscale_factor,
    stochastic=
)

In [13]:
espcn.configure(
    optimizer = torch.optim.AdamW(params=espcn.parameters(), lr=args.lr),
    scheduler = None,
    loss = torchmetrics.MeanSquaredError(squared=False),
    logdir = args.logdir
)

ESPCN(
  (model): Sequential(
    (0): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): Tanh()
    (2): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): Tanh()
    (4): Conv2d(32, 100, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): PixelShuffle(upscale_factor=10)
  )
  (loss): MeanSquaredError()
)

In [14]:
# newer version of CUDA required
# torch.backends.cuda.matmul.allow_tf32 = False
# torch.backends.cudnn.allow_tf32 = False

In [17]:
espcn.fit(rekis_train, rekis_dev, args.epochs)

                                                                                                                                                                                                                

In [20]:
espcn.load_weights(args.logdir)

In [23]:
predict = espcn.predict(rekis_test)

In [24]:
predict

[tensor([[[[ 3.0604,  3.1791,  3.0490,  ..., -5.1304, -4.5075, -5.1123],
           [ 3.0697,  2.8275,  2.7292,  ..., -4.5778, -4.7703, -5.3395],
           [ 2.9577,  2.9833,  2.6916,  ..., -4.8499, -4.5563, -5.1420],
           ...,
           [-4.5775, -4.4216, -4.5465,  ..., -3.7267, -3.8545, -3.6929],
           [-4.0174, -4.2278, -3.8430,  ..., -3.6744, -3.5268, -3.3905],
           [-4.5355, -4.5837, -4.1477,  ..., -3.7071, -3.8272, -3.5612]]],
 
 
         [[[ 0.8157,  0.8946,  0.8117,  ..., -1.1485, -0.9807, -1.1794],
           [ 0.8512,  0.7060,  0.6916,  ..., -1.1049, -1.1098, -1.2791],
           [ 0.8053,  0.8152,  0.6394,  ..., -1.1099, -1.1141, -1.2247],
           ...,
           [-4.7602, -4.5932, -4.6953,  ..., -3.2838, -3.3346, -3.1987],
           [-4.1890, -4.3859, -4.0185,  ..., -3.2020, -3.0944, -2.9622],
           [-4.6977, -4.7552, -4.2996,  ..., -3.2252, -3.3416, -3.0572]]],
 
 
         [[[-5.0403, -5.1057, -4.9917,  ..., -6.9941, -6.1297, -6.6924],
       

In [33]:
srcnn = SRCNN(
    n_channels=1,
    upscale_factor=args.upscale_factor
)

In [34]:
srcnn.configure(
    optimizer = torch.optim.AdamW(params=espcn.parameters(), lr=args.lr),
    scheduler = None,
    loss = torchmetrics.MeanSquaredError(squared=False),
    logdir = args.logdir,
)

SRCNN(
  (model): Sequential(
    (0): Conv2d(1, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
    (1): ReLU()
    (2): Conv2d(64, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (3): ReLU()
    (4): Conv2d(32, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  )
  (loss): MeanSquaredError()
)

In [None]:
srcnn.fit(rekis_train, rekis_dev, args.epochs)

In [42]:
### perhaps option for non-configurable models?

In [37]:
bicubic = Bicubic(
    upscale_factor=args.upscale_factor
)

In [40]:
bicubic.configure(
    optimizer = torch.optim.AdamW(params=espcn.parameters(), lr=args.lr),
    scheduler = None,
    loss = torchmetrics.MeanSquaredError(squared=False),
    logdir = args.logdir,
)

Bicubic(
  (loss): MeanSquaredError()
)

In [None]:
bicubic.fit(rekis_train, rekis_dev, args.epochs)

In [45]:
fno = FNO(
    n_channels=1,
    upscale_factor=args.upscale_factor
)

In [46]:
fno.configure(
    optimizer = torch.optim.AdamW(params=espcn.parameters(), lr=args.lr),
    scheduler = None,
    loss = torchmetrics.MeanSquaredError(squared=False),
    logdir = args.logdir,
)

FNO(
  (P): Linear(in_features=1, out_features=32, bias=True)
  (spectral_convs): ModuleList(
    (0-3): 4 x SpectralConv2d()
  )
  (weights): ModuleList(
    (0-3): 4 x Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
  )
  (Q): Linear(in_features=32, out_features=1, bias=True)
  (loss): MeanSquaredError()
)

In [None]:
fno.fit(rekis_train, rekis_dev, args.epochs)

### CORDEX

In [25]:
### FIX
cordex = CORDEX(
    data_path=CORDEX_DATA_PATH,
    variables=args.variables,
    dev_len=("2000-01-01", "2001-01-01"),
    test_len=("2001-01-01", "2002-01-01"),
    resampling="cubic_spline"
)

Loading 8 NetCDF file(s)...


KeyError: 'TM'