In [None]:
import numpy as np
import xarray as xr
import xrft
import einops
from itertools import product

import matplotlib.pyplot as plt

%matplotlib inline

In [None]:
# downloading data from GDrive
from google_drive_downloader import GoogleDriveDownloader as gdd

file_ids = [
    "1KVDlXdaqPCi_gomHn6QK76NYmF7RONT5",
    "1MRUr0wAhVFIdrFjUg0YcDjaKaABKs09o",
    "1gw7BG7wlZfU_BGNDf3QH1UARWUjwsIjz",
    "1Cwl0gLUHFTwkvdSEm1KZiAsUU7jZyrAN",
]
paths = ["aviso.zip", "chla.zip", "sst.zip", "winds.zip"]

for f, p in zip(file_ids, paths):
    gdd.download_file_from_google_drive(
        file_id=f, dest_path="./data/" + p, showsize=True, unzip=True
    )

In [None]:
# n_
# image = np.random.randn(

# offsets = product(range(0, n_columns, width), range(0, n_rows, height))

## SST

In [None]:
ds = xr.tutorial.open_dataset("ersstv5")
da = ds.sst

### SSH

In [None]:
url = "/Users/eman/.CMVolumes/cal1_data/dc_2021/results/OSE_ssh_mapping_BASELINE.nc"
ds = xr.open_dataset(url)
da = ds.ssh

In [None]:
da

In [None]:
# ds = xr.tutorial.open_dataset("air_temperature")
# da = ds.air

### FFT (Flatten)

In [None]:
import scipy

In [None]:
# Power spectrum density reference field
# C2 parameter
delta_t = 0.9434  # s
velocity = 6.77  # km/s
delta_x = velocity * delta_t
length_scale = 10000  # sehment length scale in km

npt = int(length_scale / delta_x)

wavenumber, psd_signal = scipy.signal.welch(
    da.data.flatten(), fs=1.0 / delta_x, nperseg=npt, scaling="density", noverlap=0
)

wavenumber_inv = 1.0 / wavenumber
wavenumber_inv

In [None]:
fig, ax = plt.subplots()

ax.plot(1.0 / wavenumber[1:], psd_signal[1:], color="black")

ax.set(xlabel="Wavelength", ylabel="PSD", xscale="log", yscale="log")

ax.invert_yaxis()

plt.grid(which="both", alpha=0.5)
plt.tight_layout()
plt.show()

### FFT (Average)

In [None]:
# chunk array
signal = da.chunk({"lat": 1, "time": da["time"].size, "lon": da["lon"].size})


# renormalize the time
signal["time"] = (signal.time - signal.time[0]) / np.timedelta64(1, "D")

# compute power spectrum
signal_psd = xrft.power_spectrum(
    signal, dim=["time", "lon"], detrend="linear", window=True
).compute()

# calculate mean signal
signal_psd_mean = signal_psd.mean(dim=["lat"]).where(
    (signal_psd.freq_time > 0) & (signal_psd.freq_lon > 0), drop=True
)

In [None]:
signal_psd_mean

In [None]:
fig, ax = plt.subplots()

pts = ax.contourf(
    1.0 / signal_psd_mean["freq_lon"] * 111,
    1.0 / signal_psd_mean["freq_time"],
    signal_psd_mean,
    levels=np.arange(0, 1.1, 0.1),
    cmap="RdYlGn",
    extend="both",
)

# colorbar
cbar = plt.colorbar(pts, pad=0.01)

ax.set(
    xlabel="Spatial Wavelength (km)",
    ylabel="Temporal Wavelength (Days)",
    # xscale="log", yscale="log",
    # zscale="log",
    # fontweight="bold"
    title="PSD-Score",
)

plt.grid(linestyle="--", lw=1, color="white")
plt.tight_layout()
plt.show()

In [None]:
# take the mean with frequency
signal_psd_mean_lon = signal_psd_mean.mean("freq_time")

# psd_signal_L4 = xrft.power_spectrum(signal_L4, dim=['time','lon'], detrend='linear', window=True).compute()
# mean_psd_signal_L4 = psd_signal_L4.mean(dim=['lat']).where( (psd_signal_L4.freq_time > 0) &(psd_signal_L4.freq_lon > 0), drop=True)

In [None]:
fig, ax = plt.subplots()

ax.plot(
    1.0 / signal_psd_mean.freq_lon * 111,
    signal_psd_mean_lon,
    label="Signal",
    color="black",
)

ax.set(
    yscale="log",
    xscale="log",
    xlabel="Wavelength (km)",
    ylabel="PSD (m$^{2}$/cycles/km)",
    xlim=(1e0, 1e4),
    # ylim=(10e-6, 10e4)
)
ax.invert_xaxis()
ax.legend()
ax.grid(which="both", alpha=0.5)

plt.tight_layout()
plt.show()

### FFT (Spatial-Temporal)

### Isotropic

In [None]:
url = "/Users/eman/.CMVolumes/cal1_data/dc_2021/results/OSE_ssh_mapping_BASELINE.nc"
ds = xr.open_dataset(url)
signal = ds.ssh

# renormalize the time
signal["time"] = (signal.time - signal.time[0]) / np.timedelta64(1, "D")

# compute power spectrum (window=True
signal_psd = xrft.isotropic_power_spectrum(signal, dim=["lat", "lon"], detrend="linear")
signal_psd.coords["wavenumber"] = ("freq_r", signal_psd["freq_r"].data)

# calculate mean signal
signal_psd_mean = signal_psd.mean(dim=["time"]).compute()

In [None]:
signal_psd_mean

In [None]:
fig, ax = plt.subplots()

ax.plot(
    1.0 / signal_psd_mean * 111,
    signal_psd["freq_r"].data,
    label="Signal",
    color="black",
)

ax.set(
    yscale="log",
    xscale="log",
    xlabel="Wavelength (km)",
    ylabel="PSD (m$^{2}$/cycles/km)",
    # xlim=(1e0, 1e4),
)
ax.invert_xaxis()
ax.legend()
ax.grid(which="both", alpha=0.5)

plt.tight_layout()
plt.show()

In [None]:
psd_signal_eNATL["freq_r"]

In [None]:
signal_L4 = L4[
    ssh
]  # .chunk({"lat":1, 'time': eNATL['time'].size, 'lon': eNATL['lon'].size})
signal_L4["time"] = (signal_L4.time - signal_L4.time[0]) / numpy.timedelta64(1, "D")
psd_signal_L4 = xrft.isotropic_power_spectrum(
    signal_L4, dim=["lat", "lon"], detrend="linear", nfactor=4, truncate=True
)
psd_signal_L4.coords["wavenumber"] = ("freq_r", psd_signal_L4["freq_r"].data)
psd_signal_L4.rename("Spatial spectrum")
mean_psd_signal_L4 = psd_signal_L4.mean(dim="time").compute()

In [None]:
plt.rcParams["figure.figsize"] = [20.50, 10.50]


f, (ax1) = plt.subplots(1, 1, sharey=True)
mini1 = np.where(mean_psd_signal_L4_lon == mean_psd_signal_L4_lon.values.min())[0][0]
ax1.plot(
    1 / mean_psd_signal_L4_lon.freq_lon[0:mini1] * 111,
    mean_psd_signal_L4_lon[0:mini1],
    "k",
    label="SSH_l4",
)
mini2 = np.where(mean_psd_signal_seNATL_lon == mean_psd_signal_seNATL_lon.values.min())[
    0
][0]
ax1.plot(
    1 / mean_psd_signal_seNATL_lon.freq_lon[0:mini2] * 111,
    mean_psd_signal_seNATL_lon[0:mini2],
    label="seNATL",
)
mini3 = np.where(mean_psd_signal_eNATL_lon == mean_psd_signal_eNATL_lon.values.min())[
    0
][0]
ax1.plot(
    1 / mean_psd_signal_eNATL_lon.freq_lon[0:mini3] * 111,
    mean_psd_signal_eNATL_lon[0:mini3],
    "--",
    label="eNATL",
)
ax1.set_ylabel("PSD [$m^2$/cyc/km]")
ax1.set_xlabel("Wavelength [km]")
ax1.grid(which="both", alpha=0.5)
ax1.legend()
ax1.set_yscale("log")
ax1.set_xlim([0, 100])
ax1.invert_xaxis()
ymin, ymax = ax1.get_ylim()
# ax1.vlines(x=[1/mean_psd_signal_L4_lon.freq_lon[mini1-1]*111,1/mean_psd_signal_eNATL_lon.freq_lon[mini2]*111, 1/mean_psd_signal_eNATL_lon.freq_lon[mini3]*111], ymin=ymin, ymax=ymax, colors=['k','k', 'k'], ls='-', lw=2, alpha=0.5)
ax1.text(
    1 / mean_psd_signal_L4_lon.freq_lon[mini1] * 111,
    mean_psd_signal_L4_lon[mini1],
    str(round(1 / mean_psd_signal_L4_lon.freq_lon[mini1].values * 111, 2)),
)
ax1.text(
    1 / mean_psd_signal_eNATL_lon.freq_lon[mini3] * 111,
    mean_psd_signal_eNATL_lon[mini3],
    str(round(1 / mean_psd_signal_eNATL_lon.freq_lon[mini3].values * 111, 2)),
)
ax1.text(
    1 / mean_psd_signal_seNATL_lon.freq_lon[mini2] * 111,
    mean_psd_signal_seNATL_lon[mini2],
    str(round(1 / mean_psd_signal_seNATL_lon.freq_lon[mini2].values * 111, 2)),
)
print(mini1, mini2, mini3)

plt.show()

In [None]:
signal_psd_mean

### Fourier

In [None]:
Fda = xrft.fft(da.isel(time=0), dim="lat", true_phase=True, true_amplitude=True)
Fda

## Sea Surface Height (SSH)

In [None]:
# from intake import open_catalog
# cat = open_catalog("https://raw.githubusercontent.com/pangeo-data/pangeo-datastore/master/intake-catalogs/ocean.yaml")
# ds  = cat["sea_surface_height"].to_dask()

In [None]:
ds = xr.open_zarr("data/aviso.zarr")
ds

In [None]:
ds

## Air Temperature

In [None]:
ds = xr.tutorial.open_dataset("air_temperature")
da = ds.air

In [None]:
da

In [None]:
Fda = xrft.fft(da.isel(time=0), dim="lat", true_phase=True, true_amplitude=True)
Fda