# Sparse Sheaf Signal Processing

In [12]:
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
from sklearn.neighbors import NearestNeighbors
from vdm import VDM
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils import data as Data
import torch.nn.functional as F
import cvxpy as cp
from sklearn.linear_model import OrthogonalMatchingPursuit
from wavelet import Wavelet
import xarray as xr

SEED = 6111983
torch.manual_seed(SEED)
np.random.seed(SEED)

## Dataset

API for the data download:

```
dataset = "reanalysis-era5-single-levels"
request = {
    "product_type": ["reanalysis"],
    "variable": [
        "10m_u_component_of_wind",
        "10m_v_component_of_wind"
    ],
    "year": ["2026"],
    "month": ["02"],
    "day": ["01"],
    "time": ["00:00"],
    "data_format": "grib",
    "download_format": "zip"
}

client = cdsapi.Client()
client.retrieve(dataset, request).download()
```

In [4]:
data = xr.open_dataset("data.grib", engine="cfgrib")
print(data)

<xarray.Dataset> Size: 8MB
Dimensions:     (latitude: 721, longitude: 1440)
Coordinates:
  * latitude    (latitude) float64 6kB 90.0 89.75 89.5 ... -89.5 -89.75 -90.0
  * longitude   (longitude) float64 12kB 0.0 0.25 0.5 0.75 ... 359.2 359.5 359.8
    number      int64 8B ...
    time        datetime64[ns] 8B ...
    step        timedelta64[ns] 8B ...
    surface     float64 8B ...
    valid_time  datetime64[ns] 8B ...
Data variables:
    u10         (latitude, longitude) float32 4MB ...
    v10         (latitude, longitude) float32 4MB ...
Attributes:
    GRIB_edition:            1
    GRIB_centre:             ecmf
    GRIB_centreDescription:  European Centre for Medium-Range Weather Forecasts
    GRIB_subCentre:          0
    Conventions:             CF-1.7
    institution:             European Centre for Medium-Range Weather Forecasts
    history:                 2026-02-16T19:52 GRIB to CDM+CF via cfgrib-0.9.1...


In [5]:
# Convert latitude and longitude to radians
lat_rad = np.deg2rad(data.latitude.values)
lon_rad = np.deg2rad(data.longitude.values)