# XArrayDataset

In [1]:
import autoroot
from oceanbench._src.datasets.base import XArrayDataset, XRConcatDataset
import numpy as np
import pandas as pd
from oceanbench._src.utils.custom_dtypes import (
    LongitudeAxis, LatitudeAxis, 
    TimeAxis, SSH2D, SSH2DT
)
from xarray_dataclasses import asdataarray, asdataset
%load_ext autoreload
%autoreload 2


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from oceanbench._src.utils.custom_dtypes import (
    LongitudeAxis, LatitudeAxis, CoordinateAxis,
    TimeAxis, SSH2D, SSH2DT, Bounds
)
from xarray_dataclasses import asdataarray, asdataset, Data, Name, Coordof, Coord
from oceanbench._src.geoprocessing.gridding import create_coord_grid

In [3]:
from dataclasses import dataclass
from typing import Literal, Tuple


X = Literal["x"]
Y = Literal["y"]
Z = Literal["z"]

@dataclass
class Variable1D:
    data: Data[X, np.ndarray]
    x: Coord[X, np.ndarray] = 0
    name: Name[str] = "var"


@dataclass
class Variable2D:
    data: Data[tuple[X,Y], np.ndarray]
    x: Coord[X, np.ndarray] = 0
    y: Coord[Y, np.ndarray] = 0
    name: Name[str] = "var"


@dataclass
class Variable3D:
    data: Data[tuple[X, Y, Z], np.ndarray]
    x: Coord[X, np.ndarray] = 0
    y: Coord[Y, np.ndarray] = 0
    z: Coord[Z, np.ndarray] = 0
    name: Name[str] = "var"



**Cases I**: 1D Time Series

* `[360] * [30p] * [30s] = [12]`
* Latent Vector - `[12, Z]`
* Ensemble Predictions - `[12, 20]`

**Case II**: Lat/Lon

* `[200, 200] * [10, 10] = [20, 20]`

**Case III**: Lat/Lon/Time

* `[200, 200, 360] * [10, 10, 30] = [20, 20, 12]`
* 360 Days to 30 Days
* Latent Vector
* Ensemble Predictions

## 1D Temporal Field

We have a 1D time series:



In [25]:
from oceanbench._src.datasets.base import XRDABatcher

In [26]:

x = CoordinateAxis.init_from_limits(x_min=1, x_max=360, dx = 1)
rng = np.random.RandomState(seed=123)
data = rng.randn(*x.data.shape)

@dataclass
class Variable1D:
    data: Data[X, np.ndarray]
    x: Coord[X, np.ndarray] = 0
    name: Name[str] = "var"
    
variable = Variable1D(data=data, x=x.data, name="ssh")
da = asdataarray(variable)
da.x

In [27]:
patches = {"x": 30}
strides = {"x": 5}
domain_limits = None#{"lat": slice(-10, 10)}
check_full_scan = True

xrda_batches = XRDABatcher(
    da=da,
    patches=patches,
    strides=strides,
    check_full_scan=check_full_scan
)

# assert ds[0].shape == (len(var.data),)
# np.testing.assert_array_equal(ds[0], np.arange(0, len(var.x), 1))
print(xrda_batches)
print(f"Torch Dataset(size): {len(xrda_batches)}")
# print(f"Batch (size): {ds[0].shape}")

XArray Patcher
DataArray size: OrderedDict([('x', 360)])
Patches:        OrderedDict([('x', 30)])
Strides:        OrderedDict([('x', 5)])
Num Batches:    OrderedDict([('x', 67)])
Torch Dataset(size): 67


In [28]:
import functools as ft
import itertools as it
from einops import repeat

all_batches = list(map(lambda x: x, xrda_batches))
all_batches_latent = list(map(lambda x: repeat(x, "... -> ... N", N=5), all_batches)) 

all_coords = xrda_batches.get_coords()

assert len(all_batches) == len(all_coords)

In [29]:
import itertools

items = list(itertools.chain(*[all_batches]))
items_latent = list(itertools.chain(*[all_batches_latent]))

In [30]:
xrda_batches.coord_names

['x']

In [31]:
from oceanbench._src.datasets.utils import reconstruct_from_items

# reconstruct_from_items(items_latent, ["x","z"], coords)
w = np.ones([30,])
dims_label = ["x"]
items = list(itertools.chain(*[all_batches]))
# rec_da = xrda_batches.reconstruct(all_batches, dims_label, weight=w)
rec_da = reconstruct_from_items(items, dims_label, xrda_batches, weight=w)
# rec_da = reconstruct_from_items(items_latent, dims_label_latent, ds, weight=w)

np.testing.assert_array_almost_equal(rec_da.data, xrda_batches.da)

100%|██████████████████████████████████████████| 67/67 [00:00<00:00, 475.37it/s]


## Test Case II - 2D Field

In [49]:
from oceanbench._src.utils.custom_dtypes import (
    LongitudeAxis, LatitudeAxis, 
    TimeAxis, SSH2D, SSH2DT
)

In [62]:
lon_axis = LongitudeAxis.init_from_limits(lon_min=40, lon_max=50-0.05, dlon=0.05)
lat_axis = LatitudeAxis.init_from_limits(lat_min=-60, lat_max=-50-0.05, dlat=0.05)

da = SSH2D.init_from_axis(lon=lon_axis, lat=lat_axis)
da = asdataarray(da).to_dataset()
da

In [63]:
patches = {"lat": 40,}
strides = {"lat": 40,}
domain_limits = None#{"lat": slice(-10, 10)}
check_full_scan = True

xrda_batches = XRDABatcher(
    da=da.ssh,
    patches=patches,
    strides=strides,
    check_full_scan=check_full_scan
)

# assert ds[0].shape == (len(var.data),)
# np.testing.assert_array_equal(ds[0], np.arange(0, len(var.x), 1))
print(xrda_batches)
print(f"Torch Dataset(size): {len(xrda_batches)}")
# print(f"Batch (size): {ds[0].shape}")

XArray Patcher
DataArray size: OrderedDict([('lat', 200), ('lon', 200)])
Patches:        OrderedDict([('lat', 40), ('lon', 200)])
Strides:        OrderedDict([('lat', 40), ('lon', 1)])
Num Batches:    OrderedDict([('lat', 5), ('lon', 1)])
Torch Dataset(size): 5


In [64]:
import functools as ft
import itertools as it
from einops import repeat

all_batches = list(map(lambda x: x, xrda_batches))


all_coords = xrda_batches.get_coords()

assert len(all_batches) == len(all_coords)
all_batches[0].shape

(40, 200)

In [68]:
w = np.ones((40,200))
dims_label = ["lat", "lon"]

rec_da = xrda_batches.reconstruct([all_batches], dims_label, weight=w)


100%|████████████████████████████████████████████| 5/5 [00:00<00:00, 278.57it/s]


In [66]:
from oceanbench._src.datasets.utils import reconstruct_from_items

# reconstruct_from_items(items_latent, ["x","z"], coords)
w = np.ones((40,200))
dims_label = ["lat", "lon"]
items = list(itertools.chain(*[all_batches]))
# rec_da = xrda_batches.reconstruct(all_batches, dims_label, weight=w)
rec_da = reconstruct_from_items(items, dims_label, xrda_batches, weight=w)
# rec_da = reconstruct_from_items(items_latent, dims_label_latent, ds, weight=w)

np.testing.assert_array_almost_equal(rec_da.data, xrda_batches.da)

100%|████████████████████████████████████████████| 5/5 [00:00<00:00, 258.12it/s]


In [55]:
from oceanbench._src.datasets.utils import reconstruct_from_items

# reconstruct_from_items(items_latent, ["x","z"], coords)
w = np.ones((40,200))
dims_label_latent = ["lat", "lon", "z"]

all_batches_latent = list(map(lambda x: repeat(x, "... -> ... N", N=5), all_batches)) 
items_latent = list(itertools.chain(*[all_batches_latent]))

rec_da = reconstruct_from_items(items_latent, dims_label_latent, xrda_batches, weight=w)

np.testing.assert_array_almost_equal(rec_da.isel(z=0).data, xrda_batches.da.data)

100%|████████████████████████████████████████████| 5/5 [00:00<00:00, 185.41it/s]


## Test Case III - 2D+T Field

In [38]:
from oceanbench._src.utils.custom_dtypes import (
    LongitudeAxis, LatitudeAxis, 
    TimeAxis, SSH2D, SSH2DT
)

In [39]:
lon_axis = LongitudeAxis.init_from_limits(lon_min=40, lon_max=50-0.05, dlon=0.05)
lat_axis = LatitudeAxis.init_from_limits(lat_min=-60, lat_max=-50-0.05, dlat=0.05)
time_axis = TimeAxis.init_from_limits("2012-01-01", "2012-01-30", "1D")

da = SSH2DT.init_from_axis(lon=lon_axis, lat=lat_axis, time=time_axis)
da = asdataarray(da).to_dataset()
da

In [40]:
# patches = {"time": 10, "lat": 40, "lon": 20, }
patches = {"lon": 20, "time": 10, "lat": 40, }
strides = {"time": 10, "lat": 40, "lon": 20, }
domain_limits = None#{"lat": slice(-10, 10)}
check_full_scan = True

xrda_batches = XRDABatcher(
    da=da.ssh,
    patches=patches,
    strides=strides,
    check_full_scan=check_full_scan
)

# assert ds[0].shape == (len(var.data),)
# np.testing.assert_array_equal(ds[0], np.arange(0, len(var.x), 1))
print(xrda_batches)
print(f"Torch Dataset(size): {len(xrda_batches)}")
# print(f"Batch (size): {ds[0].shape}")

XArray Patcher
DataArray size: OrderedDict([('time', 30), ('lat', 200), ('lon', 200)])
Patches:        OrderedDict([('time', 10), ('lat', 40), ('lon', 20)])
Strides:        OrderedDict([('time', 10), ('lat', 40), ('lon', 20)])
Num Batches:    OrderedDict([('time', 3), ('lat', 5), ('lon', 10)])
Torch Dataset(size): 150


In [41]:
import functools as ft
import itertools as it
from einops import repeat

all_batches = list(map(lambda x: x, xrda_batches))


all_coords = xrda_batches.get_coords()

assert len(all_batches) == len(all_coords)
all_batches[0].shape

(10, 40, 20)

In [42]:
from oceanbench._src.datasets.utils import reconstruct_from_items

# reconstruct_from_items(items_latent, ["x","z"], coords)
w = np.ones((10,40,20))
dims_label = ["time", "lat", "lon"]
items = list(itertools.chain(*[all_batches]))
# rec_da = xrda_batches.reconstruct(all_batches, dims_label, weight=w)
rec_da = reconstruct_from_items(items, dims_label, xrda_batches, weight=w)
# rec_da = reconstruct_from_items(items_latent, dims_label_latent, ds, weight=w)

np.testing.assert_array_almost_equal(rec_da.data, xrda_batches.da)

100%|████████████████████████████████████████| 150/150 [00:00<00:00, 220.63it/s]


In [43]:
from oceanbench._src.datasets.utils import reconstruct_from_items

w = np.ones((10,40,20))
dims_label_latent = ["time", "lat", "lon", "z", "d"]

all_batches_latent = list(map(lambda x: repeat(x, "... -> ... N M", N=5, M=10), all_batches)) 
items_latent = list(itertools.chain(*[all_batches_latent]))

rec_da = reconstruct_from_items(items_latent, dims_label_latent, xrda_batches, weight=w)

100%|█████████████████████████████████████████| 150/150 [00:03<00:00, 43.59it/s]


In [44]:
np.testing.assert_array_almost_equal(rec_da.isel(z=0, d=0).data, xrda_batches.da.data)

In [45]:
rec_da