In [None]:
import math
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import scipy.io
import scipy.spatial
from shapely.geometry import LineString, Point, box
from shapely.ops import nearest_points

import src.utils as utils
from src.parcels_utils import xr_dataset_to_fieldset, HFRGrid

In [None]:
velocity_field_nc = utils.CURRENT_NETCDF_DIR / "hunington_latest_ThreddsCode.USWC_6KM_HOURLY.nc"
coastline_mat = utils.MATLAB_DIR / "socal_boundary.mat"
lat_diff = None
lon_diff = None

In [None]:
vel_ds = xr.open_dataset(velocity_field_nc)
lats, lons = utils.load_pts_mat(coastline_mat, "yb", "xb")
coastline = LineString(np.array([lats, lons]).T)

In [None]:
if lat_diff is None:
    lat_diff = abs(np.diff(lats).min()) / 2
if lon_diff is None:
    lon_diff = abs(np.diff(lons).min()) / 2

In [None]:
lat_start = lats.min()
lat_end = lats.max()
lon_start = lons.min()
lon_end = lons.max()

In [None]:
lat_steps = math.ceil((lat_end - lat_start) / lat_diff) + 1
lon_steps = math.ceil((lon_end - lon_start) / lon_diff) + 1

In [None]:
lats_grid = lat_start + np.arange(lat_steps) * lat_diff
lons_grid = lon_start + np.arange(lon_steps) * lon_diff

In [None]:
grid = np.empty((len(lats_grid), len(lons_grid)), dtype=object)
info = np.empty((len(lats_grid), len(lons_grid)), dtype=object)

In [None]:
# I tried to do something like ray tracing, realized I couldn't be bothered to think about it
# here, have this brute force approach instead
# lol
for i in range(len(lats_grid)):
    for j in range(len(lons_grid)):
        latmin = lats_grid[i] - lat_diff / 2
        latmax = lats_grid[i] + lat_diff / 2
        lonmin = lons_grid[j] - lon_diff / 2
        lonmax = lons_grid[j] + lon_diff / 2
        grid[i, j] = box(latmin, lonmin, latmax, lonmax)

In [None]:
for i in range(len(lats) - 1):
    segment = LineString([[lats[i], lons[i]], [lats[i + 1], lons[i + 1]]])
    vec = (lons[i + 1] - lons[i], lats[i + 1] - lats[i])
    start_j = int((lats[i] - lats_grid[0]) // lat_diff)
    end_j = int((lats[i + 1] - lats_grid[0]) // lat_diff)
    start_k = int((lons[i] - lons_grid[0]) // lon_diff)
    end_k = int((lons[i + 1] - lons_grid[0]) // lon_diff)
    for j in range(min(start_j, end_j), max(start_j, end_j) + 1):
        for k in range(min(start_k, end_k), max(start_k, end_k) + 1):
            if info[j, k] is None and segment.intersects(grid[j, k]):
                info[j, k] = vec

In [None]:
import matplotlib.pyplot as plt
plt.scatter(np.where(info != None)[1], np.where(info != None)[0], s=0.1)

In [None]:
plt.scatter(lons, lats, s=0.1)

In [None]:
coast_u = np.tile((np.zeros(grid.shape)), (len(vel_ds["time"]), 1, 1))
coast_v = np.tile((np.zeros(grid.shape)), (len(vel_ds["time"]), 1, 1))
mgrid = np.meshgrid(vel_ds["lat"], vel_ds["lon"], indexing="ij")

In [None]:
# this is gonna take a while to run
for i in range(len(vel_ds["time"])):
    nonzero = np.where(~np.isnan(vel_ds["u"])[i])
    positions = np.array([mgrid[0][nonzero], mgrid[1][nonzero]]).T
    kdtree = scipy.spatial.KDTree(positions)
    u = vel_ds["u"][i].values
    v = vel_ds["v"][i].values
    for j in range(grid.shape[0]):
        for k in range(grid.shape[1]):
            if info[j, k] is not None:
                closest_idx = kdtree.query([lats_grid[j], lons_grid[k]])[1]
                closest_u = u[nonzero[0][closest_idx], nonzero[1][closest_idx]]
                closest_v = v[nonzero[0][closest_idx], nonzero[1][closest_idx]]
                vec = info[j, k]  # this order should be in u, v (it was diff lon, diff lat)
                # project u, v pair onto coastline vector
                scalar = (closest_u * vec[0] + closest_v * vec[1]) / (vec[0] ** 2 + vec[1] ** 2)
                coast_u[i, j, k] = scalar * vec[0]
                coast_v[i, j, k] = scalar * vec[1]

In [None]:
ds = xr.Dataset(
    data_vars={
        "u": (["time", "lat", "lon"], coast_u),
        "v": (["time", "lat", "lon"], coast_v)
    },
    coords={
        "time": vel_ds["time"],
        "lat": lats_grid,
        "lon": lons_grid
    }
)

In [None]:
from src.parcels_utils import rename_dataset_vars, xr_dataset_to_fieldset
fs = xr_dataset_to_fieldset(rename_dataset_vars(ds))