In [1]:
import numpy as np
import xarray as xr
import os
import sys
sys.path.append("../")

import torch
from torch import nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

from src.data_utils import *        
from src.models import DeepONet
from src.pinn_train_utils import *

import matplotlib.pyplot as plt
%matplotlib inline

%load_ext autoreload
%autoreload 2

In [2]:
data_dir = "../data/"
zarr_ds = xr.open_zarr(store=data_dir + "IO.zarr", consolidated=True)
zarr_ds = zarr_ds.sel(lat=slice(32, -11.75), lon=slice(42, 101.75))
zarr_ds = zarr_ds.sortby("time")

In [3]:
vars = [
    "CHL_cmes-level3",
    "air_temp",
    "sst",
    "adt",
    "curr_dir",
    "u_wind",
    "v_wind",
    "u_curr",
    "v_curr",
    "CHL_cmes-gapfree",
]

In [4]:
x = zarr_ds.sel(time=slice("2000-01-01", "2022-12-31"))[vars]

In [29]:
vars_to_log = ["CHL_cmes-level3", "CHL_cmes-gapfree"]

# Apply log transformation to selected variables
x_logged = x.map(lambda da: np.log(da) if da.name in vars_to_log else da)

In [36]:
stacked = x_logged[vars].to_stacked_array("meow", sample_dims=["time"]).fillna(0.0)

In [37]:
data = torch.from_numpy(stacked.values)

In [53]:
img_data = data.reshape(-1, 10, 176, 240)

In [54]:
dataset = torch.utils.data.TensorDataset(img_data)

In [55]:
frac_train = 0.8
n_train = int(np.ceil(frac_train * len(data)))
n_test = len(data) - n_train
train_set, test_set = random_split(data, [n_train, n_test])

In [62]:
torch.save(dataset, "../data/chl_train_channels_2000_2022.pt")

In [50]:
train_loader = DataLoader(
    train_set,
    batch_size=32,
    shuffle=True,
)

test_loader = DataLoader(
    test_set,
    batch_size=32,
    shuffle=True,
)

In [52]:
next(iter(train_loader)).shape

torch.Size([32, 422400])