In [1]:
import sys, os
from pyprojroot import here


# spyder up to find the root

local = here(project_files=[".local"])

# append to path
sys.path.append(str(local))

In [16]:
import jax
import jax.numpy as jnp
import jax.random as jrandom
import optax
import equinox as eqx
from einops import rearrange
import numpy as np
import xarray as xr
import pandas as pd

import tqdm.notebook as tqdm

from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
import hvplot.xarray
from xmovie import Movie

from src.utils import get_meshgrid, calculate_gradient, calculate_laplacian
from src.mlp import MLPNet
from src.siren import SirenNet

from src.data import make_mini_batcher
from src.viz import create_movie

import wandb

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
# wandb.init(
#     tags=["eda"],
#     project="nerf4ssh",
#     entity="ige",
#     dir=local,
# )

# smoke_test = False
# wandb.config.update({"smoke_test": smoke_test})

## Train Data

In [4]:
model = "ssh"

img = xr.open_dataset("/mnt/meom/workdir/johnsonj/data/ssh_data/interim_gulf/train.nc")
img = img.sortby("time")
img

In [5]:
print(f"# Datapoints: {img.sla_unfiltered.shape[0]:_}")

# Datapoints: 1_338_530


#### Rolling Mean

In [8]:
df = img.to_dataframe().reset_index()
df
#.to_xarray()

Unnamed: 0,time,longitude,latitude,sla_unfiltered
0,2015-06-19 21:25:34.252000000,305.982123,35.871405,-0.024
1,2015-06-19 21:25:35.276000000,305.963770,35.927864,-0.089
2,2015-06-19 21:25:36.300000000,305.945397,35.984320,-0.127
3,2015-06-19 21:25:37.324000000,305.927004,36.040774,-0.179
4,2015-06-19 21:25:38.348000000,305.908590,36.097226,-0.244
...,...,...,...,...
1338525,2016-12-20 17:09:33.261716736,305.866657,37.530990,0.192
1338526,2016-12-20 17:09:34.340296704,305.899122,37.481461,0.207
1338527,2016-12-20 17:09:35.418876672,305.931539,37.431921,0.200
1338528,2016-12-20 17:09:36.497456640,305.963907,37.382369,0.193


In [12]:
print(f"# Datapoints: {df.shape[0]:_}")

# Datapoints: 1_338_530


In [18]:
# %%time

# ds = df.set_index(["longitude", "latitude", "time"]).to_xarray()

## Model

In [6]:
model = "ssh"

ds = xr.open_dataset("/mnt/meom/workdir/johnsonj/data/ssh_data/interim_gulf/model.nc")
ds = ds.sortby("time")
ds

In [None]:
create_movie(ds.sla, framedim="time", name="ssh_gulf_duacs")

In [None]:
![](plots/movie_ssh_gulf_duacs_grad.gif)

In [17]:
create_movie(ds.sla_grad, framedim="time", name="ssh_gulf_duacs_grad")



  0%|          | 0/1826 [00:00<?, ?it/s]

Movie created at movie_ssh_gulf_duacs_grad.mp4
GIF created at plots/movie_ssh_gulf_duacs_grad.gif


![](plots/movie_ssh_gulf_duacs_grad.gif)

In [7]:
print(f"# Datapoints: {ds.sla.shape[0]:_}")

# Datapoints: 1_826


In [9]:
ds.sla.hvplot.image(x="longitude", y="latitude", width=500, height=400, cmap="RdBu_r", title="DUACs Model")

### Gradients

In [11]:
%%time
ds["sla_grad"] = calculate_gradient(ds.sla, "longitude", "latitude")

CPU times: user 104 ms, sys: 28 ms, total: 132 ms
Wall time: 150 ms


In [None]:
create_movie(ds.sla_grad, framedim="time", name="ssh_gulf_duacs_grad")

In [13]:
ds.sla_grad.hvplot.image(x="longitude", y="latitude", width=500, height=400, cmap="RdBu_r", title="DUACs Model")

### Laplacian

In [14]:
%%time
ds["sla_lap"] = calculate_laplacian(ds.sla, "longitude", "latitude")

CPU times: user 156 ms, sys: 32 ms, total: 188 ms
Wall time: 227 ms


In [None]:
create_movie(ds.sla_lap, framedim="time", name="ssh_gulf_duacs_lap")

In [15]:
ds.sla_lap.hvplot.image(x="longitude", y="latitude", width=500, height=400, cmap="RdBu_r", title="DUACs Model")