In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from bsd_dataset import get_dataset, regions, DatasetRequest
import bsd_dataset.common.metrics as metrics
import bsd_dataset.common.transforms as transforms
from bsd_dataset.datasets.check_dataset import Interpolator

import torch

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
variable = [
    'daily_maximum_near_surface_air_temperature',
    'daily_minimum_near_surface_air_temperature',
    'near_surface_air_temperature',
    'near_surface_specific_humidity',
    'near_surface_wind_speed',
    'precipitation',
    'sea_level_pressure'
]

input_datasets = [
    DatasetRequest(
        dataset='projections-cmip6',
        model='gfdl_esm4',
        variable=variable
    )
]

target_dataset = DatasetRequest(dataset='chirps', resolution=0.25)

root = f'/home/data/BSDD/experiment-0.1'
study_region = regions.SouthAmerica

In [4]:
dataset = get_dataset(
    input_datasets,
    target_dataset,
    train_region=study_region,
    val_region=study_region,
    test_region=study_region,
    train_dates=('1983-01-01', '2010-12-31'),
    val_dates=('2011-01-01', '2012-12-31'),
    test_dates=('2013-01-01', '2014-12-31'),
    download=False,
    extract=False, 
    root=root
)

In [5]:
# In CDS, precipitation's variable name is "pr"
transform = transforms.ConvertPrecipitation(var_name='pr')
test_dataset = dataset.get_split('test', transform)

In [6]:
# Create an interpolator, which downsamples the hi-res target to the low-res input
metrics_dict = {'rmse': metrics.rmse, 'bias': metrics.bias, 'pearsons_r': metrics.pearsons_r}
interp = Interpolator(metrics_dict)

In [8]:
# Then we want to see how the downsampled hi-res targets compare against the low-res inputs
running = torch.tensor([0., 0., 0.])

for i, (x, y, info) in enumerate(test_dataset):
    yy, mask = interp(x, y, info)
    # channel 5 is the precipitation channel for this particular dataset
    results = interp.eval(x[5], yy, mask)
    running += torch.tensor(list(results.values()))
    
running /= (torch.ones_like(running) * i)
for metric_name, val in zip(metrics_dict.keys(), running.tolist()):
    print(f'Average {metric_name}: {val}')

Average rmse: 5.909879684448242
Average bias: -543.5753784179688
Average pearsons_r: 0.055460602045059204
