In [11]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
import torch
from bsd_dataset import get_dataset, regions, DatasetRequest

In [13]:
if torch.cuda.is_available():
    print(torch.cuda.device_count())

1


In [14]:
input_datasets = [
    DatasetRequest(
        dataset='projections-cmip6',
        model='gfdl_esm4',
        variable='precipitation',
    )
]

target_dataset = DatasetRequest(dataset='chirps', resolution=0.25)

In [15]:
dataset = get_dataset(
    input_datasets,
    target_dataset,
    train_region=regions.SouthAmerica,
    val_region=regions.SouthAmerica,
    test_region=regions.SouthAmerica,
    train_dates=('1990-01-01', '1990-12-31'),
    val_dates=('1991-01-01', '1991-12-31'),
    test_dates=('1992-01-01', '1992-12-31'),
    download=False,  # CHANGE ME (as needed)
    extract=True,   # CHANGE ME (as needed)
    device='cuda:0'
)

In [16]:
train_dataset = dataset.get_split('train')
val_dataset = dataset.get_split('val')
test_dataset = dataset.get_split('test')

All tensors are latitude by longitude.

In [17]:
x, y, info = train_dataset[0]

print(f'INPUT SHAPE: {x.shape} ({x.device})')
print(f'TARGET SHAPE: {y.shape} ({x.device})')

INPUT SHAPE: torch.Size([1, 75, 48]) (cuda:0)
TARGET SHAPE: torch.Size([280, 240]) (cuda:0)


In [18]:
print('INFO SUMMARY')
for k, v in info.items():
    print(f' - {k} shape: {v.shape} ({v.device})')

INFO SUMMARY
 - x_lat shape: torch.Size([75, 48]) (cuda:0)
 - x_lon shape: torch.Size([75, 48]) (cuda:0)
 - y_lat shape: torch.Size([75, 48]) (cuda:0)
 - y_lon shape: torch.Size([75, 48]) (cuda:0)
 - y_mask shape: torch.Size([280, 240]) (cuda:0)


Latitudes and longitudes are provided as unnormalized. Latitudes are in the range \[-90, 90\], and longitudes are in the range \[0, 360\]. At the end are functions to perform normalization (I will eventually migrate this into the dataset itself).

In [20]:
info['x_lat']

tensor([[-54.5000, -54.5000, -54.5000,  ..., -54.5000, -54.5000, -54.5000],
        [-53.5000, -53.5000, -53.5000,  ..., -53.5000, -53.5000, -53.5000],
        [-52.5000, -52.5000, -52.5000,  ..., -52.5000, -52.5000, -52.5000],
        ...,
        [ 17.5000,  17.5000,  17.5000,  ...,  17.5000,  17.5000,  17.5000],
        [ 18.5000,  18.5000,  18.5000,  ...,  18.5000,  18.5000,  18.5000],
        [ 19.5000,  19.5000,  19.5000,  ...,  19.5000,  19.5000,  19.5000]],
       device='cuda:0', dtype=torch.float64)

In [21]:
info['x_lon']

tensor([[270.6250, 271.8750, 273.1250,  ..., 326.8750, 328.1250, 329.3750],
        [270.6250, 271.8750, 273.1250,  ..., 326.8750, 328.1250, 329.3750],
        [270.6250, 271.8750, 273.1250,  ..., 326.8750, 328.1250, 329.3750],
        ...,
        [270.6250, 271.8750, 273.1250,  ..., 326.8750, 328.1250, 329.3750],
        [270.6250, 271.8750, 273.1250,  ..., 326.8750, 328.1250, 329.3750],
        [270.6250, 271.8750, 273.1250,  ..., 326.8750, 328.1250, 329.3750]],
       device='cuda:0', dtype=torch.float64)

In [22]:
def normalize_latitudes(lats):
    # Converts from the range [-90, 90] to [0, 1]
    return (lats + 90) / 180

def normalize_longitudes(lons):
    # Converts from the range [0, 360] to [0, 1]
    return lons / 360

In [23]:
normalize_latitudes(info['x_lat'])

tensor([[0.1972, 0.1972, 0.1972,  ..., 0.1972, 0.1972, 0.1972],
        [0.2028, 0.2028, 0.2028,  ..., 0.2028, 0.2028, 0.2028],
        [0.2083, 0.2083, 0.2083,  ..., 0.2083, 0.2083, 0.2083],
        ...,
        [0.5972, 0.5972, 0.5972,  ..., 0.5972, 0.5972, 0.5972],
        [0.6028, 0.6028, 0.6028,  ..., 0.6028, 0.6028, 0.6028],
        [0.6083, 0.6083, 0.6083,  ..., 0.6083, 0.6083, 0.6083]],
       device='cuda:0', dtype=torch.float64)

In [24]:
normalize_longitudes(info['x_lon'])

tensor([[0.7517, 0.7552, 0.7587,  ..., 0.9080, 0.9115, 0.9149],
        [0.7517, 0.7552, 0.7587,  ..., 0.9080, 0.9115, 0.9149],
        [0.7517, 0.7552, 0.7587,  ..., 0.9080, 0.9115, 0.9149],
        ...,
        [0.7517, 0.7552, 0.7587,  ..., 0.9080, 0.9115, 0.9149],
        [0.7517, 0.7552, 0.7587,  ..., 0.9080, 0.9115, 0.9149],
        [0.7517, 0.7552, 0.7587,  ..., 0.9080, 0.9115, 0.9149]],
       device='cuda:0', dtype=torch.float64)