### Climate downscaling using CDDLT package

In [None]:
# !pip install --index-url https://test.pypi.org/simple/ \
#              --extra-index-url https://pypi.org/simple \
#              cddlt --quiet

In [36]:
import os
import argparse
import torch
import torchmetrics
import cddlt

from cddlt.datasets.rekis_dataset import ReKIS
from cddlt.datasets.cordex_dataset import CORDEX
from cddlt.dataloaders.downscaling_transform import DownscalingTransform

from cddlt.models.espcn import ESPCN
from cddlt.models.srcnn import SRCNN
from cddlt.models.bicubic import Bicubic
from cddlt.models.fno import FNO

In [29]:
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", default=8, type=int)
parser.add_argument("--epochs", default=2, type=int)
parser.add_argument("--seed", default=42, type=int)
parser.add_argument("--threads", default=4, type=int)
parser.add_argument("--upscale_factor", default=10, type=int)
parser.add_argument("--lr", default=0.001, type=float)
parser.add_argument("--logdir", default="logs", type=str)
parser.add_argument("--variables", default=["TM"], type=list)

_StoreAction(option_strings=['--variables'], dest='variables', nargs=None, const=None, default=['TM'], type=<class 'list'>, choices=None, required=False, help=None, metavar=None)

In [18]:
args = parser.parse_args([])
cddlt.startup(args, os.path.basename("notebook.ipynb"))

In [30]:
REKIS_DATA_PATH = "/home/kostape4/ReKIS/KlimRefDS_v3.1_1961-2023/Raster/Tag/GK4/TM/"
CORDEX_DATA_PATH = "/home/kostape4/CORDEX/tas/GERICS_REMO2015"

### ReKIS

In [7]:
rekis = ReKIS(
    data_path=REKIS_DATA_PATH,
    variables=args.variables,
    train_len=("2000-01-01", "2000-02-01"),
    dev_len=("2000-02-01", "2000-03-01"),
    test_len=("2000-03-01", "2000-04-01"), ## value framework input
    resampling="cubic_spline"
)

Loading 63 NetCDF file(s)...
Loaded data shape: {'easting': 418, 'northing': 401, 'time': 31}
Time range: 2000-01-01 00:00:00 to 2000-01-31 00:00:00
Variables in dataset: ['TM']
Pre reproject shape:
TM: (31, 400, 400)
Inputs shape:
TM: (31, 40, 40)
Targets shape:
TM: (31, 400, 400)
Loading 63 NetCDF file(s)...
Loaded data shape: {'easting': 418, 'northing': 401, 'time': 29}
Time range: 2000-02-01 00:00:00 to 2000-02-29 00:00:00
Variables in dataset: ['TM']
Pre reproject shape:
TM: (29, 400, 400)
Inputs shape:
TM: (29, 40, 40)
Targets shape:
TM: (29, 400, 400)
Loading 63 NetCDF file(s)...
Loaded data shape: {'easting': 418, 'northing': 401, 'time': 31}
Time range: 2000-03-01 00:00:00 to 2000-03-31 00:00:00
Variables in dataset: ['TM']
Pre reproject shape:
TM: (31, 400, 400)
Inputs shape:
TM: (31, 40, 40)
Targets shape:
TM: (31, 400, 400)
ReKIS dataset initalized.
train size: (29)
dev size: (31)
test size: (31)


In [14]:
rekis_train = DownscalingTransform(dataset=rekis.train).dataloader(args.batch_size, shuffle=True)
rekis_dev = DownscalingTransform(dataset=rekis.dev).dataloader(args.batch_size)

### Models

In [19]:
espcn = ESPCN(
    n_channels=1,
    upscale_factor=args.upscale_factor
)

In [24]:
espcn.configure(
    optimizer = torch.optim.AdamW(params=espcn.parameters(), lr=args.lr),
    scheduler = None,
    loss = torchmetrics.MeanSquaredError(squared=False),
    logdir = args.logdir,
)

ESPCN(
  (model): Sequential(
    (0): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): Tanh()
    (2): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): Tanh()
    (4): Conv2d(32, 100, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): PixelShuffle(upscale_factor=10)
  )
  (loss): MeanSquaredError()
)

In [None]:
espcn.fit(rekis_train, rekis_dev, args.epochs)

In [33]:
srcnn = SRCNN(
    n_channels=1,
    upscale_factor=args.upscale_factor
)

In [34]:
srcnn.configure(
    optimizer = torch.optim.AdamW(params=espcn.parameters(), lr=args.lr),
    scheduler = None,
    loss = torchmetrics.MeanSquaredError(squared=False),
    logdir = args.logdir,
)

SRCNN(
  (model): Sequential(
    (0): Conv2d(1, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
    (1): ReLU()
    (2): Conv2d(64, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (3): ReLU()
    (4): Conv2d(32, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  )
  (loss): MeanSquaredError()
)

In [None]:
srcnn.fit(rekis_train, rekis_dev, args.epochs)

In [42]:
### perhaps option for non-configurable models?

In [37]:
bicubic = Bicubic(
    upscale_factor=args.upscale_factor
)

In [40]:
bicubic.configure(
    optimizer = torch.optim.AdamW(params=espcn.parameters(), lr=args.lr),
    scheduler = None,
    loss = torchmetrics.MeanSquaredError(squared=False),
    logdir = args.logdir,
)

Bicubic(
  (loss): MeanSquaredError()
)

In [None]:
bicubic.fit(rekis_train, rekis_dev, args.epochs)

In [45]:
fno = FNO(
    n_channels=1,
    upscale_factor=args.upscale_factor
)

In [46]:
fno.configure(
    optimizer = torch.optim.AdamW(params=espcn.parameters(), lr=args.lr),
    scheduler = None,
    loss = torchmetrics.MeanSquaredError(squared=False),
    logdir = args.logdir,
)

FNO(
  (P): Linear(in_features=1, out_features=32, bias=True)
  (spectral_convs): ModuleList(
    (0-3): 4 x SpectralConv2d()
  )
  (weights): ModuleList(
    (0-3): 4 x Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
  )
  (Q): Linear(in_features=32, out_features=1, bias=True)
  (loss): MeanSquaredError()
)

In [None]:
fno.fit(rekis_train, rekis_dev, args.epochs)

### CORDEX

In [48]:
### FIX
cordex = CORDEX(
    data_path=CORDEX_DATA_PATH,
    variables=args.variables,
    dev_len=("2000-03-01", "2000-04-01"),
    test_len=("2000-04-01", "2000-06-01"),
    resampling="cubic_spline"
)

2000-03-01
2023-12-31
2000-04-01
2023-12-31
Loading 8 NetCDF file(s)...
Loaded data shape: {'time': 31, 'bnds': 2, 'rlat': 412, 'rlon': 424, 'vertices': 4}
Time range: 2000-03-01 12:00:00 to 2000-03-31 12:00:00
Variables in dataset: ['tas']
Pre reproject shape:
tas: (31, 412, 424)


MissingSpatialDimensionError: x dimension (easting) not found.