# Data Assimilation Windows

In [21]:
import autoroot
from pathlib import Path
import numpyro
import numpyro.distributions as dist
from jax import config
config.update("jax_enable_x64", True)
import einx
import jax
import jax.numpy as jnp
import numpy as np
import jax.random as jr
import xarray as xr
from jaxtyping import Float, Array
import cola
from oi_toolz._src.ops.kernels import kernel_rbf, gram
from oi_toolz._src.ops.linalg import create_psd_matrix
from oi_toolz._src.ops.varda import (
    linear_3dvar_model_space,
    linear_3dvar_model_space_incremental,
    linear_3dvar_obs_space_incremental
)
from cola.linalg import Auto
from oi_toolz._src.ops.enskf import analysis_etkf
from sklearn.datasets import make_spd_matrix, make_sparse_spd_matrix
from xrpatcher import XRDAPatcher

key = jr.key(123)

import matplotlib.pyplot as plt
import seaborn as sns
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.75)
import collections

%matplotlib inline

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Data

In [22]:
ds = xr.tutorial.load_dataset("air_temperature")

In [25]:
ds

In [31]:
# window
# create data item
TrainingItem = collections.namedtuple("TrainingItem", ("air",))
train_ds = ds[[*TrainingItem._fields]].to_array().transpose("time", "lon", "lat", "variable")

# create patcher 
patches = dict(time=5, lat=20, lon=20)
strides = patches
patcher = XRDAPatcher(
    da=train_ds, patches=patches, strides=strides
)

In [34]:
patcher

XArray Patcher
DataArray Size: OrderedDict([('time', 2920), ('lon', 53), ('lat', 25), ('variable', 1)])
Patches:        OrderedDict([('time', 5), ('lon', 20), ('lat', 20), ('variable', 1)])
Strides:        OrderedDict([('time', 5), ('lon', 20), ('lat', 20), ('variable', 1)])
Num Items:    OrderedDict([('time', 584), ('lon', 2), ('lat', 1), ('variable', 1)])

In [35]:
patcher[0]

In [75]:
# rescale time (days, hours)
# add time features (year, month, week, season, hour, minute)
# rescale space (

# get variable names
names = patcher[0].coords["variable"].values

# create dataframe
df = patcher[0].to_dataframe(*names)

# resale lat-lon


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,air
time,lon,lat,variable,Unnamed: 4_level_1
2013-01-01,200.0,75.0,air,241.199997
2013-01-01,200.0,72.5,air,243.799988
2013-01-01,200.0,70.0,air,250.000000
2013-01-01,200.0,67.5,air,266.500000
2013-01-01,200.0,65.0,air,274.500000
...,...,...,...,...
2013-01-02,247.5,37.5,air,267.199982
2013-01-02,247.5,35.0,air,273.790009
2013-01-02,247.5,32.5,air,282.290009
2013-01-02,247.5,30.0,air,288.790009


In [None]:
def rescale_lat_lon(ds):
    
    return ds

### Coordinates

**Time**

* Rescale Datetime -> Days, Minutes, Seconds
* Add Features -> Splines, Fourier, etc

**Space**

* Coordinate Transform
* Spatial Features -> Splines

In [17]:
from sklearn.kernel_approximation import RBFSampler
from sklearn.preprocessing import SplineTransformer, StandardScaler

In [18]:
ds["lon"].values

array([200. , 202.5, 205. , 207.5, 210. , 212.5, 215. , 217.5, 220. ,
       222.5, 225. , 227.5, 230. , 232.5, 235. , 237.5, 240. , 242.5,
       245. , 247.5, 250. , 252.5, 255. , 257.5, 260. , 262.5, 265. ,
       267.5, 270. , 272.5, 275. , 277.5, 280. , 282.5, 285. , 287.5,
       290. , 292.5, 295. , 297.5, 300. , 302.5, 305. , 307.5, 310. ,
       312.5, 315. , 317.5, 320. , 322.5, 325. , 327.5, 330. ],
      dtype=float32)

In [19]:
def xrgrid_to_coords(ds):
    # extract lat lon coorinates
    x, y = ds["lon"].values, ds["lat"].values
    
    # create meshgrid
    X, Y = np.meshgrid(X,Y, indexing="ij")
    return X, Y

def xrcoords_to_grid(x,y):

In [None]:
ds["x"

In [15]:
x_coords = ds[["lon", "lat"]].to_dataframe().reset_index()
x_coords

Unnamed: 0,lon,lat
0,200.0,75.0
1,200.0,72.5
2,200.0,70.0
3,200.0,67.5
4,200.0,65.0
...,...,...
1320,330.0,25.0
1321,330.0,22.5
1322,330.0,20.0
1323,330.0,17.5


In [None]:
rbf_feature = RBFSampler(gamma=1, random_state=1)

X_features = rbf_feature.fit_transform([ds.lat.values