In [None]:
import xarray as xr

```python
spatial_columns = ["lon", "lat"], ["x", "y", "z"], ["lon_rad", "lat_rad"]
temporal_columns = ["time"], ["vtime"]
```

**Transformations**

* Spherical 2 Cartesian
* Spherical Degrees to Radians
* Temporal to Julian
* TimeStamps 2 Days of the Year
* TimeStamps 2 Cycles
* Temporal Scaling

In [None]:
ds_link = "/Volumes/EMANS_HDD/data/dc20a_osse/test/ml/swot1nadir5.nc"

ds = xr.open_dataset(ds_link)

ds

In [None]:
import torch
import pandas as pd


class AlongTrackDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        ds,
        spatial_columns,
        temporal_columns,
        output_columns,
        spatial_transforms=None,
        temporal_transforms=None,
        output_transforms=None,
    ):
        df = ds.to_dataframe().reset_index().dropna()
        self.x = df[spatial_columns].values
        self.t = df[temporal_columns].values
        self.y = df[output_columns].values
        self.spatial_transforms = spatial_transforms
        self.temporal_transforms = temporal_transforms
        self.output_transforms = output_transforms
        self.output_columns = output_columns
        self.spatial_columns = spatial_columns
        self.temporal_columns = temporal_columns

    def __len__(self):
        return len(self.x)

    def __getitem__(self, item):

        # get spatial columns
        x = self.x[item]

        # spatial transform
        if self.spatial_transforms is not None:
            for itransform in self.spatial_transforms:
                # print(x, x.shape)
                x = itransform(x)
                # print(x, x.shape)

        # temporal transform
        t = self.t[item]

        if self.temporal_transforms is not None:
            t = self.temporal_transforms(t)

        # output transform
        # temporal transform
        y = self.y[item]

        if self.output_transforms is not None:
            for itransform in self.output_transforms:
                # print(y)
                y = itransform(y)
                # print(y)11

        return {"spatial": x, "temporal": t, "output": y}

    def create_predict_df(self, outputs):
        assert outputs.shape == self.y.shape
        df = pd.DataFrame()
        df[self.spatial_columns] = self.x
        df[self.temporal_columns] = self.t
        df[self.output_columns] = self.y
        names = list(map(lambda x: x + "_predict", self.output_columns))
        df[names] = outputs
        return df

    def create_predict_ds(self, outputs):
        ds = (
            self.create_predict_df(outputs).set_index(self.temporal_columns).to_xarray()
        )
        return ds

In [None]:
from torchvision.transforms import ToTensor, Lambda, Compose
import numpy as np
import math

import pandas as pd


class JulianTime:
    def __init__(self):
        pass

    def __call__(self, x):
        # return x.toordinal()[:, None]
        print(x.shape)
        shape = x.shape
        x = pd.to_datetime(x.flatten()).to_julian_date()
        print(x.shape)
        x = np.asarray(x)
        print(x.shape)
        x = x.reshape(*shape)
        print(x.shape)
        return x


class TimeMinMax:
    def __init__(self, time_min: str = "2005-01-10", time_max: str = "2022-01-01"):
        self.time_min = time_min
        self.time_max = time_max

    def __call__(self, x):
        time_min, time_max = np.datetime64(self.time_min), np.datetime64(self.time_max)
        x = (x - time_min) / (time_max - time_min)
        return x


# class TimeSinCos:
#     def __init__(self):
#         pass

#     def __call__(self, x):
#         print(x.shape)
#         x1 = torch.sin(2*math.pi * x/24)
#         x2 = torch.cos(2*math.pi * x/24)
#         print(x1.shape, x2.shape)
#         return torch.cat([x1, x2], dim=1)

In [None]:
to_tensor = Lambda(torch.tensor)
# temporal_transforms = [JulianTime(), torch.FloatTensor]
# temporal_transforms = Compose(
#     [JulianTime(), to_tensor]
# )

temporal_transforms = Compose([TimeMinMax(), to_tensor])

In [None]:
spatial_columns = ["lon", "lat"]
temporal_columns = ["time"]
output_columns = ["ssh_model"]

In [None]:
torch_ds = AlongTrackDataset(
    ds,
    spatial_columns,
    temporal_columns,
    output_columns,
    temporal_transforms=temporal_transforms,
    spatial_transforms=[to_tensor],
    output_transforms=[to_tensor],
)

ibatch = torch_ds[0:3]

ibatch["spatial"].shape, ibatch["temporal"].shape, ibatch["output"].shape

In [None]:
outputs = torch_ds[:]["output"]

ds_ = torch_ds.create_predict_df(outputs)
ds_.head()

In [None]:
ds_ = torch_ds.create_predict_ds(outputs)
ds_

In [None]:
torch_dl = torch.utils.data.DataLoader(torch_ds, batch_size=32)

In [None]:
# torch_ds[0:10]

In [None]:
# ibatch = next(iter(torch_dl))
for ibatch in torch_dl:
    break

In [None]:
ibatch["spatial"].shape, ibatch["temporal"].shape, ibatch["output"].shape