# Data prep

ERA5 data from:
https://planetarycomputer.microsoft.com/dataset/era5-pds#Example-Notebook

## Imports

In [None]:
import numpy as np
import planetary_computer as pc
import pystac_client
import xarray as xr
import fsspec
from ndpyramid import pyramid_reproject

## Load datasets

In [None]:
catalog = pystac_client.Client.open(
    "https://planetarycomputer.microsoft.com/api/stac/v1/"
)
search = catalog.search(
    collections=["era5-pds"],
    # query={"era5:kind": {"eq": "an"}},
    datetime="2020-01-01",
    # datetime="2020",
)
items = search.item_collection()
print(len(items))

In [None]:
datasets = [
    xr.open_dataset(asset.href, **asset.extra_fields["xarray:open_kwargs"])
    for it in items
    for asset in pc.sign(it).assets.values()
]

## Extract desired based data

In [None]:
ds = xr.combine_by_coords(datasets, join="exact")
ds = ds.rio.write_crs("EPSG:4326")
ds = ds.assign_coords(lon=((ds["lon"] + 180) % 360) - 180)
ds = ds.rename({"lon": "x", "lat": "y"})
ds = ds.sortby(["x", "y"])

In [None]:
# Uncomment to slice to a specific geograph
# ds = ds.sel(y=slice(45, 60)).sel(x=slice(-12, 5))

In [None]:
# Uncomment to slice out only a single time stamp
# ds = ds.isel(time=0)

## Create arrays of weather "ok"ness

In [None]:
ds["wind"] = (
    np.sqrt(ds.eastward_wind_at_10_metres**2 + ds.northward_wind_at_10_metres**2)
    * 1.94384
)

In [None]:
dsr = ds.wind.resample(time="1D").max(dim="time").to_dataset()
dsr["rain"] = ds.precipitation_amount_1hour_Accumulation.resample(time="1D").sum(
    dim="time"
)
dsr["temp"] = ds.air_temperature_at_2_metres.resample(time="1D").max(dim="time")

dsr["wind_ok"] = dsr.wind < 10
dsr["rain_ok"] = dsr.rain < (1 / 1000)
dsr["temp_ok"] = (dsr.temp > (16 + 273.15)) & (dsr.temp < (27 + 273.15))
dsr["all_ok"] = dsr.wind_ok & dsr.rain_ok & dsr.temp_ok

In [None]:
res = dsr.sum(dim="time").astype("float32")[
    [
        "wind_ok",
        "rain_ok",
        "temp_ok",
        "all_ok",
    ]
]

## Create pyramids

In [None]:
levels = 5
pyr = pyramid_reproject(
    res,
    levels=levels,
    resampling="nearest",
)

In [None]:
pyr = pyr.chunk(
    {
        # "year": 1,
        "y": 128,
        "x": 128,
    }
)

## Save pyramids to local/S3

In [None]:
output = "s3://BUCKET/viz/name.zarr"
pyr.to_zarr(output, consolidated=True, mode="w")

# Scratch
Everything below here is just messing around, testing things out.

## Save raw as Zarr

In [None]:
res_chunked = res.chunk()
path = "s3://BUCKET/raw/nums_chunked.zarr"
res_chunked.to_zarr(fsspec.get_mapper(path))

## Load zarr

In [None]:
path = "s3://BUCKET/viz/nums.zarr"
check = xr.open_zarr(fsspec.get_mapper(path), consolidated=True)

## Save as GeoTIFF

In [None]:
import rasterio
from rasterio.transform import from_bounds

In [None]:
def save_raster(ds, path):
    lat_min, lat_max = float(ds.lat.min()), float(ds.lat.max())
    lon_min, lon_max = float(ds.lon.min()), float(ds.lon.max())
    width = ds.lon.shape[0]
    height = ds.lat.shape[0]
    transform = from_bounds(lon_min, lat_min, lon_max, lat_max, width, height)
    dtype = "int32"
    flipped = np.flipud(ds.astype(dtype).values)
    with rasterio.open(
        path,
        "w",
        driver="GTiff",
        height=height,
        width=width,
        count=1,
        dtype=dtype,
        crs="EPSG:4326",
        transform=transform,
    ) as dst:
        dst.write(flipped, 1)

In [None]:
save_raster(num_wind, "wind.tif")

## Check saved TIFF

In [None]:
import matplotlib.pyplot as plt

In [None]:
tif_path = "rain.tif"
with rasterio.open(tif_path) as dataset:
    data = dataset.read(1)  # Reading the first band (assuming single-band GeoTIFF)
    plt.imshow(data, cmap="viridis")
    plt.colorbar(label="Value")
    plt.title("GeoTIFF Data Visualization")
    plt.show()

## Merge raw Zarrs

In [None]:
datasets = []
for year in range(2015, 2021):
    path = f"./s3data/raw/{year}.zarr"
    ds = xr.open_zarr(path)
    ds = ds.expand_dims(year=[year])  # Add a new dimension `year` with the value
    datasets.append(ds)
# Concatenate along the `year` dimension
combined_ds = xr.concat(datasets, dim="year")

In [None]:
# mean_ds = combined_ds.mean(dim="year")
# mean_ds = mean_ds.assign_coords(year="mean")  # Set the year coordinate to "average"

# Concatenate the mean dataset along the `year` dimension
# ds_mean = xr.concat([combined_ds, mean_ds], dim="year")

In [None]:
# dt = xr.open_datatree("./s3data/pyr/era5_2020_num_l6.zarr/", engine="zarr")