# XArrayDataset

In [78]:
import autoroot
import typing as tp
from dataclasses import dataclass
import numpy as np
import pandas as pd
import xarray_dataclasses as xrdataclass
from oceanbench._src.datasets.base import XRDABatcher

%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


This tutorial walks through some of the nice features of the custom `XRDABatcher` class.
This is a custom class that slices and dices through an `xr.DataArray` where a user can specify explicitly the patch dimensions and the strides.
We preallocated the *slices* and then we can arbitrarily call the slices at will.
This is very similar to the *torch.utils.data* object except we are only working with `xr.DataArray`'s directly.


There have been other previous attempts at this, e.g. `xBatcher`.
However, we found the API very cumbersome and non-intuitive.
This is our attempt to design an API that we are comfortable with and that we find easy to use.

Below, we have outlined a few use-cases that users may be interested in. 
These use cases are:

* Chunking a 1-Dimensional Time Series
* Patch-ify a 2D Grid
* Cube-ify a 3D Volume
* Cube-ify a 2D+T Spatio-Temporal Field
* Reconstructing Multiple Variables
* Choosing Specific Dimensions for Reconstructions

We will walk through each of these and highlight how this can be achieved with the custom `XRDABatcher` class.

## Case I: Chunking a 1D TS

In [88]:
T = tp.Literal["t"]


@dataclass
class Variable1D:
    data: xrdataclass.Data[T, np.ndarray]
    t: xrdataclass.Coord[T, np.ndarray] = 0
    name: xrdataclass.Name[str] = "var"

In [90]:

t = np.arange(1, 360+1, 1)
rng = np.random.RandomState(seed=123)
ts = np.sin(t)

ts = Variable1D(data=ts, t=t, name="var")

da = xrdataclass.asdataarray(ts)

da

In this first example, we are going to do a non-overlapping style.
We will take a 30 day window with a 30 day stride.
This will give us exactly 12 patches (like 12 months).

In [93]:
patches = {"t": 30}
strides = {"t": 30}
domain_limits = None#{"lat": slice(-10, 10)}
check_full_scan = True

xrda_batches = XRDABatcher(
    da=da,
    patches=patches,
    strides=strides,
    check_full_scan=check_full_scan
)

print(xrda_batches)
print(f"Dataset(size): {len(xrda_batches)}")


XArray Patcher
DataArray size: OrderedDict([('t', 360)])
Patches:        OrderedDict([('t', 30)])
Strides:        OrderedDict([('t', 30)])
Num Batches:    OrderedDict([('t', 12)])
Dataset(size): 12


In this example, we will incorporate overlapping windows.
We will do a 30 day window but we will have a 15 day stride.
So, we have a 15 day overlap when creating the patches.
We can do the mental calculation already because it's quite simple:

$$
\text{Patches} = \frac{360 \text{ days total } - 30 \text{ day patches }}{15 \text{ day stride }} + 1
$$

If this is nicely divisible, we wont have any problems. 
However, often times it's not so we might have to use the `floor` operator to ensure we get integers
Our method will give a warning (optional) which lets the user know there is an issue.

In [96]:
patches = {"t": 30}
strides = {"t": 15}
domain_limits = None#{"lat": slice(-10, 10)}
check_full_scan = True

xrda_batches = XRDABatcher(
    da=da,
    patches=patches,
    strides=strides,
    check_full_scan=check_full_scan
)

print(xrda_batches)
print(f"Dataset(size): {len(xrda_batches)}")

XArray Patcher
DataArray size: OrderedDict([('t', 360)])
Patches:        OrderedDict([('t', 30)])
Strides:        OrderedDict([('t', 15)])
Num Batches:    OrderedDict([('t', 23)])
Dataset(size): 23


## Case II: Patchify a 2D Grid

In [116]:
X = tp.Literal["x"]
Y = tp.Literal["y"]

@dataclass
class Variable2D:
    data: xrdataclass.Data[tuple[X, Y], np.ndarray]
    x: xrdataclass.Coord[X, np.ndarray] = 0
    y: xrdataclass.Coord[Y, np.ndarray] = 0
    name: xrdataclass.Name[str] = "var"

In [121]:

x = np.linspace(-1, 1, 128)
y = np.linspace(-2, 2, 128)
rng = np.random.RandomState(seed=123)

data = rng.randn(x.shape[0], y.shape[0])

grid = Variable2D(data=data, x=x, y=y, name="var")

da = xrdataclass.asdataarray(grid)

da

We will have a `[20,20]` patch with no overlap, `[20,20]`

In [122]:
patches = {"x": 8, "y": 8}
strides = {"x": 8, "y": 8}
domain_limits = None#{"lat": slice(-10, 10)}
check_full_scan = True

xrda_batches = XRDABatcher(
    da=da,
    patches=patches,
    strides=strides,
    check_full_scan=check_full_scan
)

print(xrda_batches)
print(f"Dataset(size): {len(xrda_batches)}")

XArray Patcher
DataArray size: OrderedDict([('x', 128), ('y', 128)])
Patches:        OrderedDict([('x', 8), ('y', 8)])
Strides:        OrderedDict([('x', 8), ('y', 8)])
Num Batches:    OrderedDict([('x', 16), ('y', 16)])
Dataset(size): 256


We will have a `[20,20]` patch with some overlap, like the boundaries of 2, `[2,2]`

In [123]:
patches = {"x": 8, "y": 8}
strides = {"x": 2, "y": 2}
domain_limits = None#{"lat": slice(-10, 10)}
check_full_scan = True

xrda_batches = XRDABatcher(
    da=da,
    patches=patches,
    strides=strides,
    check_full_scan=check_full_scan
)

print(xrda_batches)
print(f"Dataset(size): {len(xrda_batches)}")

XArray Patcher
DataArray size: OrderedDict([('x', 128), ('y', 128)])
Patches:        OrderedDict([('x', 8), ('y', 8)])
Strides:        OrderedDict([('x', 2), ('y', 2)])
Num Batches:    OrderedDict([('x', 61), ('y', 61)])
Dataset(size): 3721


## Case III: Cube-ify a 3D Volume

In [125]:
X = tp.Literal["x"]
Y = tp.Literal["y"]
Z = tp.Literal["z"]

@dataclass
class Variable3D:
    data: xrdataclass.Data[tuple[X, Y, Z], np.ndarray]
    x: xrdataclass.Coord[X, np.ndarray] = 0
    y: xrdataclass.Coord[Y, np.ndarray] = 0
    z: xrdataclass.Coord[Z, np.ndarray] = 0
    name: xrdataclass.Name[str] = "var"

In [126]:

x = np.linspace(-1, 1, 128)
y = np.linspace(-2, 2, 128)
z = np.linspace(-5, 5, 128)
rng = np.random.RandomState(seed=123)

data = rng.randn(x.shape[0], y.shape[0], z.shape[0])

grid = Variable3D(data=data, x=x, y=y, z=z, name="var")

da = xrdataclass.asdataarray(grid)

da

We will have a `[20,20]` patch with no overlap, `[20,20]`

In [127]:
patches = {"x": 8, "y": 8, "z": 8}
strides = {"x": 8, "y": 8, "z": 8}
domain_limits = None#{"lat": slice(-10, 10)}
check_full_scan = True

xrda_batches = XRDABatcher(
    da=da,
    patches=patches,
    strides=strides,
    check_full_scan=check_full_scan
)

print(xrda_batches)
print(f"Dataset(size): {len(xrda_batches)}")

XArray Patcher
DataArray size: OrderedDict([('x', 128), ('y', 128), ('z', 128)])
Patches:        OrderedDict([('x', 8), ('y', 8), ('z', 8)])
Strides:        OrderedDict([('x', 8), ('y', 8), ('z', 8)])
Num Batches:    OrderedDict([('x', 16), ('y', 16), ('z', 16)])
Dataset(size): 4096


We will have a `[20,20]` patch with some overlap, like the boundaries of 2, `[2,2]`

In [None]:
patches = {"x": 8, "y": 8, "z": 8}
strides = {"x": 2, "y": 2, "z": 2}
domain_limits = None#{"lat": slice(-10, 10)}
check_full_scan = True

xrda_batches = XRDABatcher(
    da=da,
    patches=patches,
    strides=strides,
    check_full_scan=check_full_scan
)

print(xrda_batches)
print(f"Dataset(size): {len(xrda_batches)}")

XArray Patcher
DataArray size: OrderedDict([('x', 128), ('y', 128)])
Patches:        OrderedDict([('x', 8), ('y', 8)])
Strides:        OrderedDict([('x', 2), ('y', 2)])
Num Batches:    OrderedDict([('x', 61), ('y', 61)])
Dataset(size): 3721


## Case IV: Cube-ify a 2D+T Spatio-Temporal Field

In [129]:
X = tp.Literal["x"]
Y = tp.Literal["y"]
T = tp.Literal["t"]

@dataclass
class Variable2DT:
    data: xrdataclass.Data[tuple[T, X, Y], np.ndarray]
    t: xrdataclass.Coord[T, np.ndarray] = 0
    x: xrdataclass.Coord[X, np.ndarray] = 0
    y: xrdataclass.Coord[Y, np.ndarray] = 0
    name: xrdataclass.Name[str] = "var"

In [134]:

x = np.linspace(-1, 1, 200)
y = np.linspace(-2, 2, 200)
t = np.arange(1, 360+1, 1)
rng = np.random.RandomState(seed=123)

data = rng.randn(t.shape[0], x.shape[0], y.shape[0])

grid = Variable2DT(data=data, x=x, y=y, t=t, name="var")

da = xrdataclass.asdataarray(grid)

da

Now, this is a rather big field.
Let's say we want to use some ML method with a CNN to learn how to predict ...
However, ingesting this large patch would be very difficult.
So instead, we will use the standard size for many CNNs, which is `[64,64]`.
In addition, we will use a temporal window of 15 days. 
So the patch will be `[15,64,64]`. 

As will the above examples, we will also account for the overlap in the spatial borders with `[4,4]` strides.
And lastly, we will have a 5 day overlap for the time steps.
So the final strides will be `[5,4,4]`

In [135]:
patches = {"x": 64, "y": 64, "t": 15}
strides = {"x": 4, "y": 4, "t": 5}
domain_limits = None#{"lat": slice(-10, 10)}
check_full_scan = True

xrda_batches = XRDABatcher(
    da=da,
    patches=patches,
    strides=strides,
    check_full_scan=check_full_scan
)

print(xrda_batches)
print(f"Dataset(size): {len(xrda_batches)}")

XArray Patcher
DataArray size: OrderedDict([('t', 360), ('x', 200), ('y', 200)])
Patches:        OrderedDict([('t', 15), ('x', 64), ('y', 64)])
Strides:        OrderedDict([('t', 5), ('x', 4), ('y', 4)])
Num Batches:    OrderedDict([('t', 70), ('x', 35), ('y', 35)])
Dataset(size): 85750


All of the sudden, we have a LOT of data if we do things in a patch-wise manner, more than 85K samples!
However, we know from statistics that perhaps this isn't the greatest idea because there are a lot of overlap.
So we can be clever and use a training dataset with less overlap. 
However, we can create a different dataset for predictions where we reduce the strides considerably so that we take a weighted average over the predictions!

In [136]:
patches = {"x": 64, "y": 64, "t": 15}
strides = {"x": 1, "y": 1, "t": 1}
domain_limits = None#{"lat": slice(-10, 10)}
check_full_scan = True

xrda_batches = XRDABatcher(
    da=da,
    patches=patches,
    strides=strides,
    check_full_scan=check_full_scan
)

print(xrda_batches)
print(f"Dataset(size): {len(xrda_batches)}")

XArray Patcher
DataArray size: OrderedDict([('t', 360), ('x', 200), ('y', 200)])
Patches:        OrderedDict([('t', 15), ('x', 64), ('y', 64)])
Strides:        OrderedDict([('t', 1), ('x', 1), ('y', 1)])
Num Batches:    OrderedDict([('t', 346), ('x', 137), ('y', 137)])
Dataset(size): 6494074


So this will cover use because we can take a weighted average of all of the predictions!

## Case V: Reconstructing with multiple variables

In this example, we look at how we can do reconstructions with multiple variables.
This may occur when we have used different methods to make predictions and we want to reconstruct all of them.

Another example is when we have some sort of latent variable representation and we would like to reconstruct each of the latent variable representations.


In [None]:
import itertools

items = list(itertools.chain(*[all_batches]))
items_latent = list(itertools.chain(*[all_batches_latent]))

In [None]:
# create items
all_batches = list(map(lambda x: x.data, xrda_batches))
items = list(itertools.chain(*[all_batches]))
items[0]

array([-1.0856306033005612, 0.9973454465835858, 0.28297849805199204,
       -1.5062947139180922, -0.5786002519685364, 1.651436537097151,
       -2.426679243393074, -0.42891262885617726, 1.2659362587055338,
       -0.8667404022651016, -0.6788861516220543, -0.09470896893689112,
       1.4913896261242878, -0.638901996684651, -0.44398195964606546,
       -0.43435127561851733, 2.205930082725455, 2.1867860889737862,
       1.004053897878877, 0.386186399174856, 0.7373685758962422,
       1.490732028150799, -0.9358338684023914, 1.1758290447821034,
       -1.2538806677490124, -0.6377515024534103, 0.9071051958003014,
       -1.4286807002259692, -0.1400687201886661, -0.8617548958596855],
      dtype=object)