In [5]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt

import os
os.environ["DDEBACKEND"] = "pytorch"
import torch
from torch import nn, optim
import deepxde as dde
from shapely.geometry import Point
import cartopy.feature as cfeature

import sys
sys.path.append("../")

from src.models import ChlorophyllDeepONet
from src.boundary_conds import get_xt_geom, is_in_ocean
from src.pdes import pde
from src.data_utils import *

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [5]:
zarr_ds = load_and_preprocess_data()
data, time, lat, lon, water_mask = prepare_data_for_pinn(zarr_ds)
chl_data = data["CHL_cmes-level3"]

Starting data load and preprocessing...
Starting data preparation for PINN...


In [6]:
geomtime, coastline = get_xt_geom(lat, lon, time)

# Convert numpy arrays to PyTorch tensors
air_temp = torch.tensor(data["air_temp"], dtype=torch.float32)
sst = torch.tensor(data["sst"], dtype=torch.float32)
curr_dir = torch.tensor(data["curr_dir"], dtype=torch.float32)
ug_curr = torch.tensor(data["ug_curr"], dtype=torch.float32)
u_wind = torch.tensor(data["u_wind"], dtype=torch.float32)
v_wind = torch.tensor(data["v_wind"], dtype=torch.float32)
v_curr = torch.tensor(data["v_curr"], dtype=torch.float32)

In [9]:
def boundary_condition(x, on_boundary):
    lat = x[0]
    lon = x[1]
    ocean_boundary = is_in_ocean(lat, lon, coastline)
    return on_boundary and ocean_boundary

bc_robin = dde.icbc.RobinBC(geomtime, lambda X, y: y, boundary_condition)

data_pinn = dde.data.TimePDE(
    geomtime,
    pde,
    [bc_robin],
    num_domain=10000,
    num_boundary=2000,
    num_initial=2000,
)

KeyboardInterrupt: 