In [1]:
import pyrootutils
path = pyrootutils.find_root(search_from="./", indicator=".home")
pyrootutils.set_root(
    path=path, # path to the root directory
    project_root_env_var=True, # set the PROJECT_ROOT environment variable to root directory
    dotenv=True, # load environment variables from .env if exists in root directory
    pythonpath=True, # add root directory to the PYTHONPATH (helps with imports)
    cwd=True, # change current working directory to the root directory (helps with filepaths)
)


In [2]:
from oceanbench._src.datasets.base import XArrayDataset, XRConcatDataset
import numpy as np
import pandas as pd
from oceanbench._src.utils.custom_dtypes import (
    LongitudeAxis, LatitudeAxis, 
    TimeAxis, SSH2D, SSH2DT
)
from xarray_dataclasses import asdataarray, asdataset
%load_ext autoreload
%autoreload 2


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from oceanbench._src.utils.custom_dtypes import (
    LongitudeAxis, LatitudeAxis, 
    TimeAxis, SSH2D, SSH2DT
)
from xarray_dataclasses import asdataarray, asdataset

## Test Case I - 2D Field

In [4]:
lon_axis = LongitudeAxis.init_from_limits(lon_min=40, lon_max=50, dlon=0.05)
lat_axis = LatitudeAxis.init_from_limits(lat_min=-60, lat_max=-50, dlat=0.05)
time_axis = TimeAxis.init_from_limits("2012-01-01", "2012-01-30", "1D")

ssh_xrds = SSH2D.init_from_axis(lon=lon_axis, lat=lat_axis)
ssh_xrds = asdataarray(ssh_xrds).to_dataset()
ssh_xrds

In [5]:
ssh_xrds.lon.values.min(), ssh_xrds.lon.values.max(), ssh_xrds.lat.values.min(), ssh_xrds.lat.values.max()

(40.0, 49.99999999999943, -60.0, -50.00000000000057)

In [6]:
from oceanbench._src.geoprocessing.select import select_bounds, select_bounds_multiple
from oceanbench._src.utils.custom_dtypes import Bounds

In [7]:
# create bounds object
lon_bnds = Bounds(val_min=42, val_max=48, name="lon")
lat_bnds = Bounds(val_min=-58, val_max=-52, name="lat")

# subset dataarayy
d_sub = select_bounds(ssh_xrds, lon_bnds)
d_sub

In [8]:
# create bounds object
lon_bnds = Bounds(val_min=42, val_max=48, name="lon")
lat_bnds = Bounds(val_min=-58, val_max=-52, name="lat")
bounds = [lon_bnds, lat_bnds]

d_sub = select_bounds_multiple(ssh_xrds, bounds)
d_sub

In [9]:
# create bounds object
lon_bnds = Bounds(val_min=42, val_max=48, name="lon")
lat_bnds = Bounds(val_min=-58, val_max=-52, name="lat")

# subset dataarayy
d_sub = select_bounds(ssh_xrds, lon_bnds)
d_sub

In [10]:
ssh_xrds.dims

Frozen({'lon': 201, 'lat': 201})

In [37]:
lon_axis = LongitudeAxis.init_from_limits(lon_min=40, lon_max=50, dlon=0.05)
lat_axis = LatitudeAxis.init_from_limits(lat_min=-60, lat_max=-50, dlat=0.05)
time_axis = TimeAxis.init_from_limits("2012-01-01", "2012-01-30", "1D")

ssh_xrds = SSH2D.init_from_axis(lon=lon_axis, lat=lat_axis)
ssh_xrds = asdataarray(ssh_xrds).to_dataset()
ssh_xrds

In [60]:
# patch dims
# create bounds object
lon_bnds = Bounds(val_min=42, val_max=48, name="lon")
lat_bnds = Bounds(val_min=-58, val_max=-52, name="lat")
bounds = [lon_bnds, lat_bnds]

patch_dims = {"lat": 10, "lon": 10}
strides = {"lat": 5, "lon": 5}
check_full_scan = False
check_dim_order = False
transforms = None


torch_ds = XArrayDataset(
    da=ssh_xrds.ssh, 
    patch_dims=patch_dims,
    strides=strides,
    domain_limits=bounds,
    check_full_scan=check_full_scan,
    check_dim_order=check_dim_order,
    transforms=transforms
)

print(f"Patch Dims: {torch_ds.patch_dims}")
print(f"Size: {torch_ds.da_size}")
print(f"Strides: {torch_ds.strides}")
torch_ds.da

Patch Dims: {'lat': 10, 'lon': 10}
Size: {'lat': 23, 'lon': 23}
Strides: {'lat': 5, 'lon': 5}


In [74]:
torch_ds[0].shape

(2, 10, 10)

In [75]:
lon_axis = LongitudeAxis.init_from_limits(lon_min=40, lon_max=50, dlon=0.05)
lat_axis = LatitudeAxis.init_from_limits(lat_min=-60, lat_max=-50, dlat=0.05)
time_axis = TimeAxis.init_from_limits("2012-01-01", "2012-03-30", "1D")

ssh_xrds = SSH2DT.init_from_axis(lon=lon_axis, lat=lat_axis, time=time_axis)
ssh_xrds = asdataarray(ssh_xrds).to_dataset()
ssh_xrds

In [84]:
xr_dims = ssh_xrds.ssh.dims
xr_dims
# null_dims = dict(ikey==1 for (ikey, idata) in xr_dims.items())
# null_dims

('time', 'lat', 'lon')

In [92]:
# patch dims
# create bounds object
lon_bnds = Bounds(val_min=42, val_max=48, name="lon")
lat_bnds = Bounds(val_min=-58, val_max=-52, name="lat")
time_bnds = Bounds(
    val_min=pd.to_datetime("2012-01-01"), 
    val_max=pd.to_datetime("2012-01-30"), name="time"
)
bounds = [lon_bnds, lat_bnds, time_bnds]

patch_dims = {"time": 2, "lat": 10, "lon": 10}
strides = {"time":2, "lat": 5, "lon": 5}
check_full_scan = True
check_dim_order = True
transforms = None


# NULL Case!
torch_ds = XArrayDataset(
    da=ssh_xrds.ssh, 
    patch_dims=None,
    strides=None,
    domain_limits=None,
    check_full_scan=check_full_scan,
    check_dim_order=check_dim_order,
    transforms=transforms
)

print(f"Patch Dims: {torch_ds.patch_dims}")
print(f"Size: {torch_ds.da_size}")
print(f"Strides: {torch_ds.strides}")
print(f"Batch: {torch_ds[0].shape}")
torch_ds.da

Patch Dims: {'time': 1, 'lat': 1, 'lon': 1}
Size: {'time': 90, 'lat': 201, 'lon': 201}
Strides: {}
Batch: (1, 1, 1)


In [90]:
# patch dims
# create bounds object
lon_bnds = Bounds(val_min=42, val_max=48, name="lon")
lat_bnds = Bounds(val_min=-58, val_max=-52, name="lat")
time_bnds = Bounds(
    val_min=pd.to_datetime("2012-01-01"), 
    val_max=pd.to_datetime("2012-01-30"), name="time"
)
bounds = [lon_bnds, lat_bnds, time_bnds]

patch_dims = {"time": 2, "lat": 10, "lon": 10}
strides = {"time":2, "lat": 5, "lon": 5}
check_full_scan = True
check_dim_order = True
transforms = None


torch_ds = XArrayDataset(
    da=ssh_xrds.ssh, 
    patch_dims=patch_dims,
    strides=strides,
    domain_limits=bounds,
    check_full_scan=check_full_scan,
    check_dim_order=check_dim_order,
    transforms=transforms
)

print(f"Patch Dims: {torch_ds.patch_dims}")
print(f"Size: {torch_ds.da_size}")
print(f"Strides: {torch_ds.strides}")
print(f"Batch: {torch_ds[0].shape}")
torch_ds.da

Patch Dims: {'time': 2, 'lat': 10, 'lon': 10}
Size: {'time': 15, 'lat': 23, 'lon': 23}
Strides: {'time': 2, 'lat': 5, 'lon': 5}
Batch: (2, 10, 10)


In [64]:
torch_ds[0].shape

(2, 10, 10)