In [None]:
import os
import sys

from pyprojroot import here

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

# spyder up to find the root
root = here(project_files=[".root"])
exp = here(
    relative_project_path=root.joinpath("experiments/dc21a"), project_files=[".local"]
)


# append to path
sys.path.append(str(root))
sys.path.append(str(exp))

In [None]:
import time

import pytorch_lightning as pl
import torch
import torch.nn as nn
from torchvision.transforms import Compose
import xarray as xr
from inr4ssh._src.datasets import AlongTrackDataset
from inr4ssh._src.datasets.utils import get_num_training

from inr4ssh._src.transforms.dataset import (
    TimeJulianMinMax,
    TimeJulian,
    TimeMinMax,
    ToTensor,
)

%matplotlib inline
%load_ext autoreload
%autoreload 2

```python
spatial_columns = ["lon", "lat"], ["x", "y", "z"], ["lon_rad", "lat_rad"]
temporal_columns = ["time"], ["vtime"]
```

**Transformations**

* Spherical 2 Cartesian
* Spherical Degrees to Radians
* Temporal to Julian
* TimeStamps 2 Days of the Year
* TimeStamps 2 Cycles
* Temporal Scaling

In [None]:
ds_link = "/Volumes/EMANS_HDD/data/dc20a_osse/test/ml/nadir1.nc"

ds = xr.open_dataset(ds_link)

ds

### Dataset Transforms

So there are a few transformations we should do within the dataset: 1) timestamps and 2) torch tensors. In general, the xarray datasets will almost always have numpy arrays for the spatial and output values. So we need to change them into torch tensors. We also have datetime64 data structures for the time values. So we need to transform those into numerical values and additionally into torch tensors.

**Note**: There are other additional transformations we can do, e.g. spherical, cartesian, etc, but I decided to offload them to the `trainer` (which will be discussed later). In general, the dataset transformations should only have transformations that change numpy arrays to torch tensors 

There are some available transformations within the library:
* Julian Time (Temporal Transform)
* TimeMinMax (Temporal Transform)
* ToTensor (Spatial, Temporal, Output Transform)

In [None]:
# define column names
spatial_columns = ["lon", "lat"]
temporal_columns = ["time"]
output_columns = ["ssh_model"]

# initialize dataset
torch_ds = AlongTrackDataset(
    ds, spatial_columns, temporal_columns, output_columns, transform=None
)

batchsize = 32

ibatch = torch_ds[0:batchsize]

ibatch["spatial"].shape, ibatch["temporal"].shape, ibatch["output"].shape

In [None]:
# define column names
spatial_columns = ["lon", "lat"]
temporal_columns = ["time"]
output_columns = ["ssh_model"]

# initialize dataset
torch_ds = AlongTrackDataset(
    ds, spatial_columns, temporal_columns, output_columns=None, transform=None
)

batchsize = 32

ibatch = torch_ds[0:batchsize]

ibatch["spatial"].shape, ibatch["temporal"].shape

In [None]:
transform = Compose(
    [
        # TimeMinMax(),
        # TimeJulian(),
        TimeJulianMinMax(),
        ToTensor(),
    ]
)

In [None]:
# initialize dataset
torch_ds = AlongTrackDataset(
    ds, spatial_columns, temporal_columns, output_columns, transform=transform
)

batchsize = 32

ibatch = torch_ds[0:batchsize]

ibatch["spatial"].shape, ibatch["temporal"].shape, ibatch["output"].shape

In [None]:
# initialize dataset
torch_ds = AlongTrackDataset(
    ds, spatial_columns, temporal_columns, output_columns=None, transform=transform
)

batchsize = 32

ibatch = torch_ds[0:batchsize]

ibatch["spatial"].shape, ibatch["temporal"].shape

In [None]:
ibatch["temporal"].min(), ibatch["temporal"].max()

### Utility: `pd.DataFrame`, `xr.Dataset`

For inference/predictions, we will often have a dataset we want for predictions, and then we want to extract a dataframe and/or dataset with all of the coordinates. These utility functions will help do that using the attributes within the dataset.

In [None]:
# initialize dataset
torch_ds = AlongTrackDataset(
    ds, spatial_columns, temporal_columns, output_columns, transform=transform
)

outputs = torch_ds[:]["output"]

ds_ = torch_ds.create_predict_df(outputs)

ds_.head()

In [None]:
# initialize dataset
torch_ds = AlongTrackDataset(
    ds, spatial_columns, temporal_columns, output_columns=None, transform=transform
)


ds_ = torch_ds.create_predict_df(outputs)

ds_.head()

In [None]:
ds_ = torch_ds.create_predict_ds(outputs)
ds_

## Train/Test Split

In [None]:
import numpy as np

train_prct = 0.9
num_train, num_valid = get_num_training(len(torch_ds), train_prct=train_prct)
train_split_seed = 42

train_set, valid_set = torch.utils.data.random_split(
    torch_ds,
    [num_train, num_valid],
    generator=torch.Generator().manual_seed(train_split_seed),
)

## DataLoader

So finally, we can easily put this in a dataloader. This makes things really easy in terms for generating batches.

In [None]:
train_dl = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True)
valid_dl = torch.utils.data.DataLoader(valid_set, batch_size=64, shuffle=False)

In [None]:
# ibatch = next(iter(torch_dl))
for ibatch in train_dl:
    break

for ivbatch in valid_dl:
    break

In [None]:
ibatch["spatial"].shape, ibatch["temporal"].shape  # , ibatch["output"].shape

In [None]:
ivbatch["spatial"].shape, ivbatch["temporal"].shape  # , ivbatch["output"].shape