In [36]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt

import os
os.environ["DDEBACKEND"] = "pytorch"
import torch
from torch import nn, optim
import deepxde as dde
from shapely.geometry import Point
import cartopy.feature as cfeature

from models import ChlorophyllDeepONet
from boundary_conds import boundary_condition, get_xt_geom
from pdes import pde
from data_utils import *


%matplotlib inline

In [33]:
# Load and preprocess data
def load_and_preprocess_data():
    "TODO: Time slice variable?"
    print("Starting data load and preprocessing...")
    zarr_ds = xr.open_zarr(
        store="../shared-public/mind_the_chl_gap/IO.zarr", consolidated=True
    )
    zarr_ds = zarr_ds.sel(lat=slice(32, -11.75), lon=slice(42, 101.75))

    all_nan_dates = (
        np.isnan(zarr_ds["CHL_cmes-level3"]).all(dim=["lon", "lat"]).compute()
    )
    zarr_ds = zarr_ds.sel(time=~all_nan_dates)
    zarr_ds = zarr_ds.sortby("time")
    zarr_ds = zarr_ds.sel(time=slice("2019-01-01", "2022-12-31"))
    return zarr_ds

In [34]:
zarr_ds = load_and_preprocess_data()
data, time, lat, lon, water_mask = prepare_data_for_pinn(zarr_ds)
chl_data = data["CHL_cmes-level3"]

Starting data load and preprocessing...
Starting data preparation for PINN...


In [35]:
geomtime, coastline = get_xt_geom(lat, lon, time)

# Convert numpy arrays to PyTorch tensors
air_temp = torch.tensor(data["air_temp"], dtype=torch.float32)
sst = torch.tensor(data["sst"], dtype=torch.float32)
curr_dir = torch.tensor(data["curr_dir"], dtype=torch.float32)
ug_curr = torch.tensor(data["ug_curr"], dtype=torch.float32)
u_wind = torch.tensor(data["u_wind"], dtype=torch.float32)
v_wind = torch.tensor(data["v_wind"], dtype=torch.float32)
v_curr = torch.tensor(data["v_curr"], dtype=torch.float32)

In [37]:
def pde(x, y):
    lat, lon, t = x[:, 0:1], x[:, 1:2], x[:, 2:3]
    d2U_dlat2 = dde.grad.hessian(y, x, component=0, i=0, j=0)
    d2U_dlon2 = dde.grad.hessian(y, x, component=0, i=1, j=1)
    d2U_dt2 = dde.grad.hessian(y, x, component=0, i=2, j=2)

    t_ = t.to(int)

    rho = (
        0.1 * torch.sin(lat) * torch.cos(lon) * torch.exp(-0.1 * t)
        + 0.05 * torch.sin(2 * torch.pi * t / 365)
        + (
            0.5 * air_temp[t_, :]
            + -1.0 * sst[t_, :]
            + 0.05 * curr_dir[t_, :]
            + 0.15 * ug_curr[t_, :]
            + 0.4 * u_wind[t_, :]
            + -0.2 * v_wind[t_, :]
            + 0.3 * v_curr[t_, :]
        )
    )

    residual = d2U_dlat2 + d2U_dlon2 + d2U_dt2 - rho
    return residual

In [39]:
bc_robin = dde.icbc.RobinBC(geomtime, lambda X, y: y, boundary_condition)

data_pinn = dde.data.TimePDE(
    geomtime,
    pde,
    [bc_robin],
    num_domain=10000,
    num_boundary=2000,
    num_initial=2000,
)

TypeError: is_in_ocean() missing 1 required positional argument: 'coastline'