# XArrayDataset

In [1]:
import pyrootutils
path = pyrootutils.find_root(search_from="./", indicator=".home")
pyrootutils.set_root(
    path=path, # path to the root directory
    project_root_env_var=True, # set the PROJECT_ROOT environment variable to root directory
    dotenv=True, # load environment variables from .env if exists in root directory
    pythonpath=True, # add root directory to the PYTHONPATH (helps with imports)
    cwd=True, # change current working directory to the root directory (helps with filepaths)
)


In [84]:
from oceanbench._src.datasets.base import XArrayDataset, XRConcatDataset
import numpy as np
import pandas as pd
from oceanbench._src.utils.custom_dtypes import (
    LongitudeAxis, LatitudeAxis, 
    TimeAxis, SSH2D, SSH2DT
)
from xarray_dataclasses import asdataarray, asdataset
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [104]:
ds.dims[0] not in list(patch_dims.keys())

False

In [168]:
from typing import Dict, List

def get_xrda_dims(da: xr.DataArray) -> Dict[str, int]:
    return dict(zip(da.dims, da.shape))


def get_xrda_size(da: xr.DataArray, patches: Dict[str, int], strides: Dict[str, int]) -> Dict[str, int]:
    
    da_dims = get_xrda_dims(da)
    
    check_lists_equal(list(da_dims.keys()), list(patches.keys()))
    check_lists_equal(list(da_dims.keys()), list(strides.keys()))
    
    
    dim_size = {}
    for dim in patch_dims:
        dim_size[dim] = max((da_dims[dim] - patch_dims[dim]) // strides[dim] + 1, 0)
    
    return dim_size

def check_lists_equal(list_1: List, list_2: List):
    msg = f"Lists not equal...: \n{list_1}\n{list_2}"
    assert sorted(list_1) == sorted(list_2), msg
    
def check_lists_subset(list_1: List, list_2: List):
    msg = f"Lists not subset...: \n{list_1}\n{list_2}"
    assert set(list_1) <= set(list_2)
    

def update_dict_xdims(da: xr.DataArray, dims: Dict) -> Dict:
    
    update_dims = {f"{idim}":1 for idim in da.dims if idim not in list(dims.keys())}
    
    dims = {**dims, **update_dims}
    
    check_lists_equal(list(da.dims), list(dims.keys()))
    
    return dims

In [192]:
patches = {"lon": 5}
strides = {"lat": 10}

patches = update_dict_xdims(ssh_da, patches)
strides = update_dict_xdims(ssh_da, strides)

da_size = get_xrda_size(ssh_da, patches, strides)

patches, strides, da_size

({'lon': 5, 'lat': 1}, {'lat': 10, 'lon': 1}, {})

In [199]:
da.dims

('time', 'space')

In [87]:
slices = dict(lon=slice(-20, 20), lat=slice(-10,10))
ds = ds.sel(**slices)
ds

In [3]:
from oceanbench._src.utils.custom_dtypes import (
    LongitudeAxis, LatitudeAxis, CoordinateAxis,
    TimeAxis, SSH2D, SSH2DT, Bounds
)
from xarray_dataclasses import asdataarray, asdataset, Data, Name, Coordof, Coord
from oceanbench._src.geoprocessing.gridding import create_coord_grid

In [8]:
from dataclasses import dataclass
from typing import Literal, Tuple


X = Literal["x"]
Y = Literal["y"]
Z = Literal["z"]

@dataclass
class Variable1D:
    data: Data[X, np.ndarray]
    x: Coord[X, np.ndarray] = 0
    name: Name[str] = "var"


@dataclass
class Variable2D:
    data: Data[tuple[X,Y], np.ndarray]
    x: Coord[X, np.ndarray] = 0
    y: Coord[Y, np.ndarray] = 0
    name: Name[str] = "var"


@dataclass
class Variable3D:
    data: Data[tuple[X, Y, Z], np.ndarray]
    x: Coord[X, np.ndarray] = 0
    y: Coord[Y, np.ndarray] = 0
    z: Coord[Z, np.ndarray] = 0
    name: Name[str] = "var"



In [26]:
ones.shape, axis1.shape

((1,), (21,))

In [29]:
axis1 = np.arange(-10, 10+1, 1)

ones = np.ones((len(axis1),))
var = Variable1D(ones, axis1)
var = asdataarray(var)

### Null Case

In [173]:
lon_axis = LongitudeAxis.init_from_limits(-50, 50, 0.05)
lat_axis = LatitudeAxis.init_from_limits(-50, 50, 0.05)
ssh_da = SSH2D.init_from_axis(lon=lon_axis, lat=lat_axis)
ssh_da = asdataarray(ssh_da)
# ds

In [213]:
patches = {"lon": 10, "lat": 10}
strides = {"lon": 2, "lat": 2}
domain_limits = None
check_full_scan = False
check_dim_order = False
transforms = None

ds = XArrayDataset(
    da=ssh_da,
    patches=patches,
    strides=strides,
    domain_limits=domain_limits,
    check_full_scan=check_full_scan,
    check_dim_order=check_dim_order,
    transforms=transforms
)

# assert ds[0].shape == (len(var.data),)
# np.testing.assert_array_equal(ds[0], np.arange(0, len(var.x), 1))
print(f"Patch Dims: {ds.patches}")
print(f"Size: {ds.da_size}")
print(f"Strides: {ds.strides}")
print(f"Batch (size): {ds[0].shape}")
# print(f"Batch: {ds[0]}")

Patch Dims: {'lon': 10, 'lat': 10}
Size: {'lon': 996, 'lat': 996}
Strides: {'lon': 2, 'lat': 2}
Batch (size): (10, 10)


In [181]:
ds[0].shape

(1, 1)

### Patch Dims

In [33]:
axis1 = np.arange(-10, 10+1, 1)
axis2 = np.arange(-20, 20+1, 1)

ones = np.ones((len(axis1),len(axis2)))
var = Variable2D(ones, axis1, axis2)
var = asdataarray(var)

In [39]:
# var = Variable2D(data=ones, x=axis1, y=axis2)
# var = asdataarray(var)

patch = (1,1)
stride = 1#None
# size = 
batch = 1
xlims = (-5,5)

In [41]:
patch_dims = {"x": patch[0], "y": patch[1]} if patch is not None else {}
strides = {"x": stride[0]} if stride is not None else {}
# create bounds object
if xlims is not None:
    domain_limits = Bounds(val_min=xlims[0], val_max=xlims[1], name="x")
else:
    domain_limits = None
check_full_scan = True
check_dim_order = True
transforms = None

ds = XArrayDataset(
    da=var,
    patch_dims=patch_dims,
    strides=strides,
    domain_limits=domain_limits,
    check_full_scan=check_full_scan,
    check_dim_order=check_dim_order,
    transforms=transforms
)

print(f"Patch Dims: {ds.patch_dims}")
print(f"Size: {ds.da_size}")
print(f"Strides: {ds.strides}")
print(f"Batch (size): {ds[0].shape}")
print(f"Batch: {ds[0]}")
ds.da

Patch Dims: {'x': 1, 'y': 1}
Size: {'x': 11, 'y': 41}
Strides: {}
Batch (size): (1, 1)
Batch: [[1.]]


In [85]:
isinstance(domain_limits, Bounds)

True

### Domain Limits

In [None]:
patch_dims = {}
strides = {}
domain_limits = None
check_full_scan = True
check_dim_order = True
transforms = None

ds = XArrayDataset(
    da=var,
    patch_dims=patch_dims,
    strides=strides,
    domain_limits=domain_limits,
    check_full_scan=check_full_scan,
    check_dim_order=check_dim_order,
    transforms=transforms
)

## Test Case I - 2D Field

In [4]:
lon_axis = LongitudeAxis.init_from_limits(lon_min=40, lon_max=50, dlon=0.05)
lat_axis = LatitudeAxis.init_from_limits(lat_min=-60, lat_max=-50, dlat=0.05)
time_axis = TimeAxis.init_from_limits("2012-01-01", "2012-01-30", "1D")

ssh_xrds = SSH2D.init_from_axis(lon=lon_axis, lat=lat_axis)
ssh_xrds = asdataarray(ssh_xrds).to_dataset()
ssh_xrds

In [5]:
ssh_xrds.lon.values.min(), ssh_xrds.lon.values.max(), ssh_xrds.lat.values.min(), ssh_xrds.lat.values.max()

(40.0, 49.99999999999943, -60.0, -50.00000000000057)

In [6]:
from oceanbench._src.geoprocessing.select import select_bounds, select_bounds_multiple
from oceanbench._src.utils.custom_dtypes import Bounds

In [7]:
# create bounds object
lon_bnds = Bounds(val_min=42, val_max=48, name="lon")
lat_bnds = Bounds(val_min=-58, val_max=-52, name="lat")

# subset dataarayy
d_sub = select_bounds(ssh_xrds, lon_bnds)
d_sub

In [8]:
# create bounds object
lon_bnds = Bounds(val_min=42, val_max=48, name="lon")
lat_bnds = Bounds(val_min=-58, val_max=-52, name="lat")
bounds = [lon_bnds, lat_bnds]

d_sub = select_bounds_multiple(ssh_xrds, bounds)
d_sub

In [9]:
# create bounds object
lon_bnds = Bounds(val_min=42, val_max=48, name="lon")
lat_bnds = Bounds(val_min=-58, val_max=-52, name="lat")

# subset dataarayy
d_sub = select_bounds(ssh_xrds, lon_bnds)
d_sub

In [10]:
ssh_xrds.dims

Frozen({'lat': 201, 'lon': 201})

In [11]:
lon_axis = LongitudeAxis.init_from_limits(lon_min=40, lon_max=50, dlon=0.05)
lat_axis = LatitudeAxis.init_from_limits(lat_min=-60, lat_max=-50, dlat=0.05)
time_axis = TimeAxis.init_from_limits("2012-01-01", "2012-01-30", "1D")

ssh_xrds = SSH2D.init_from_axis(lon=lon_axis, lat=lat_axis)
ssh_xrds = asdataarray(ssh_xrds).to_dataset()
ssh_xrds

In [12]:
# patch dims
# create bounds object
lon_bnds = Bounds(val_min=42, val_max=48, name="lon")
lat_bnds = Bounds(val_min=-58, val_max=-52, name="lat")
bounds = [lon_bnds, lat_bnds]

patch_dims = {"lat": 10, "lon": 10}
strides = {"lat": 5, "lon": 5}
check_full_scan = False
check_dim_order = False
transforms = None


torch_ds = XArrayDataset(
    da=ssh_xrds.ssh, 
    patch_dims=patch_dims,
    strides=strides,
    domain_limits=bounds,
    check_full_scan=check_full_scan,
    check_dim_order=check_dim_order,
    transforms=transforms
)

print(f"Patch Dims: {torch_ds.patch_dims}")
print(f"Size: {torch_ds.da_size}")
print(f"Strides: {torch_ds.strides}")
torch_ds.da

Patch Dims: {'lat': 10, 'lon': 10}
Size: {'lat': 23, 'lon': 23}
Strides: {'lat': 5, 'lon': 5}


In [13]:
torch_ds[0].shape

(10, 10)

In [14]:
lon_axis = LongitudeAxis.init_from_limits(lon_min=40, lon_max=50, dlon=0.05)
lat_axis = LatitudeAxis.init_from_limits(lat_min=-60, lat_max=-50, dlat=0.05)
time_axis = TimeAxis.init_from_limits("2012-01-01", "2012-03-30", "1D")

ssh_xrds = SSH2DT.init_from_axis(lon=lon_axis, lat=lat_axis, time=time_axis)
ssh_xrds = asdataarray(ssh_xrds).to_dataset()
ssh_xrds

In [15]:
xr_dims = ssh_xrds.ssh.dims
xr_dims
# null_dims = dict(ikey==1 for (ikey, idata) in xr_dims.items())
# null_dims

('time', 'lat', 'lon')

In [25]:
from xarray_dataclasses import Data, Coordof, Name, asdataarray
from oceanbench._src.utils.custom_dtypes import CoordinateAxis, X
from dataclasses import dataclass


axis = CoordinateAxis.init_from_limits(-10, 10, 1)

@dataclass
class Variable:
    data: Data[X, np.ndarray]
    x: Coordof[CoordinateAxis] = 0
    name: Name[str] = "var"
    
ones = np.ones((axis.ndim,))
var = Variable(data=ones, x=axis)
var = asdataarray(var)
var

In [16]:
# patch dims
# create bounds object
lon_bnds = Bounds(val_min=42, val_max=48, name="lon")
lat_bnds = Bounds(val_min=-58, val_max=-52, name="lat")
time_bnds = Bounds(
    val_min=pd.to_datetime("2012-01-01"), 
    val_max=pd.to_datetime("2012-01-30"), name="time"
)
bounds = [lon_bnds, lat_bnds, time_bnds]

patch_dims = {"time": 2, "lat": 10, "lon": 10}
strides = {"time":2, "lat": 5, "lon": 5}
check_full_scan = True
check_dim_order = True
transforms = None


# NULL Case!
torch_ds = XArrayDataset(
    da=ssh_xrds.ssh, 
    patch_dims=None,
    strides=None,
    domain_limits=None,
    check_full_scan=check_full_scan,
    check_dim_order=check_dim_order,
    transforms=transforms
)

print(f"Patch Dims: {torch_ds.patch_dims}")
print(f"Size: {torch_ds.da_size}")
print(f"Strides: {torch_ds.strides}")
print(f"Batch: {torch_ds[0].shape}")
torch_ds.da

Patch Dims: {'time': 1, 'lat': 1, 'lon': 1}
Size: {'time': 90, 'lat': 201, 'lon': 201}
Strides: {}
Batch: (1, 1, 1)


In [17]:
# patch dims
# create bounds object
lon_bnds = Bounds(val_min=42, val_max=48, name="lon")
lat_bnds = Bounds(val_min=-58, val_max=-52, name="lat")
time_bnds = Bounds(
    val_min=pd.to_datetime("2012-01-01"), 
    val_max=pd.to_datetime("2012-01-30"), name="time"
)
bounds = [lon_bnds, lat_bnds, time_bnds]

patch_dims = {"time": 2, "lat": 10, "lon": 10}
strides = {"time":2, "lat": 5, "lon": 5}
check_full_scan = True
check_dim_order = True
transforms = None


torch_ds = XArrayDataset(
    da=ssh_xrds.ssh, 
    patch_dims=patch_dims,
    strides=strides,
    domain_limits=bounds,
    check_full_scan=check_full_scan,
    check_dim_order=check_dim_order,
    transforms=transforms
)

print(f"Patch Dims: {torch_ds.patch_dims}")
print(f"Size: {torch_ds.da_size}")
print(f"Strides: {torch_ds.strides}")
print(f"Batch: {torch_ds[0].shape}")
torch_ds.da

Patch Dims: {'time': 2, 'lat': 10, 'lon': 10}
Size: {'time': 15, 'lat': 23, 'lon': 23}
Strides: {'time': 2, 'lat': 5, 'lon': 5}
Batch: (2, 10, 10)


In [64]:
torch_ds[0].shape

(2, 10, 10)