# demo of the xarray interface to foscat

## set up example data

In [None]:
import healpy as hp
import numpy as np
import xarray as xr

In [None]:
nside = 64

In [None]:
from scipy.interpolate import RegularGridInterpolator

# convert the input data in a nside=128 healpix map
l_nside = 128

im = np.load("target_map_lss.npy")
xsize, ysize = im.shape

# Define the new row and column to be added to prepare the interpolation
new_row = im[0:1, :]  # A new row with N elements (the other longitude)
new_column = np.concatenate(
    [im[:, 0:1], im[-2:-1, 0:1]], 0
)  # A new column with N+1 elements to add previous latitude

# Add the new row to the array
im = np.vstack([im, new_row])

# Add the new column to the array with the new row

im = np.hstack([im, new_column])

# Create a grid of coordinates corresponding to the array indices
x = np.linspace(0, im.shape[0] - 1, im.shape[0])
y = np.linspace(0, im.shape[1] - 1, im.shape[1])

# Create an interpolator
interpolator = RegularGridInterpolator((x, y), im)

# List of healpix coordinate to interpol
colatitude, longitude = hp.pix2ang(l_nside, np.arange(12 * l_nside**2), nest=True)
coords = (
    np.concatenate([colatitude / np.pi * xsize, longitude / (2 * np.pi) * ysize], 0)
    .reshape(2, colatitude.shape[0])
    .T
)

# Perform the interpolation
heal_im = interpolator(coords)

# reduce the final map to the expected resolution
if nside > 128:
    th, ph = hp.pix2ang(nside, np.arange(12 * nside**2), nest=True)
    heal_im = hp.get_interp_val(heal_im, th, ph, nest=True)
else:
    heal_im = np.mean(heal_im.reshape(12 * nside**2, (l_nside // nside) ** 2), 1)
hp.mollview(heal_im, cmap="plasma", nest=True, title="LSS", min=-3, max=3)

# free memory
del coords
del interpolator
del colatitude
del longitude

In [None]:
target_lss = xr.Dataset(
    {"lss": ("cells", heal_im)}, coords={"cell_ids": ("cells", np.arange(heal_im.size))}
)
target_lss

In [None]:
rng = np.random.default_rng()

In [None]:
n_timesteps = 5
random_noise = xr.DataArray(
    rng.normal(scale=0.1, size=(n_timesteps, target_lss.sizes["cells"])),
    dims=["time", "cells"],
    coords={
        "cell_ids": ("cells", np.arange(heal_im.size)),
        "time": xr.date_range("2020-01-01", freq="MS", periods=n_timesteps),
    },
)
random_noise

In [None]:
target_lss2 = target_lss + random_noise
target_lss2

## compute reference statistics

In [None]:
import foscat.xarray as foscat

In [None]:
params = foscat.Parameters(
    n_orientations=4, kernel_size=5, jmax_delta=0, dtype="float64", backend="numpy"
)

In [None]:
stats = foscat.reference_statistics(target_lss2["lss"], parameters=params, variances=True)
stats

In [None]:
stats.foscat.plot()

## compute cross statistics

In [None]:
params = foscat.Parameters(
    n_orientations=4, kernel_size=5, jmax_delta=0, dtype="float64", backend="numpy"
)

In [None]:
stats = foscat.cross_statistics(target_lss["lss"], target_lss2["lss"], parameters=params, variances=True)
stats

In [None]:
stats.foscat.plot()