# Introduction to xarray: Why and How

## Why use xarray?

- **N-dimensional labeled arrays:** Unlike plain NumPy arrays, xarray supports labels (names) for dimensions, coordinates, and metadata.  
- **Easy handling of multi-dimensional scientific data:** Perfect for datasets like climate model outputs, satellite data, and geospatial grids.  
- **Powerful indexing and slicing:** Access data by coordinate labels instead of integer indices, making code more readable and less error-prone.  
- **Integration with other libraries:** Works well with pandas, NumPy, matplotlib, and Dask for parallel computing.  
- **Built-in support for NetCDF:** A common format for climate and oceanographic data.

---

## How does xarray work?

- The core data structure is the **`xarray.DataArray`**, which holds multi-dimensional data with dimension names and coordinates.  
- Larger collections of variables and coordinates are managed via **`xarray.Dataset`**, like a dict of DataArrays.  
- Coordinates provide meaningful labels for axes (e.g., time, latitude, longitude).  
- You can perform arithmetic, group operations, resampling, and more with labeled data.

---


In [None]:
# Install packages from requirements.txt (needed for this session)
! pip install -r ./../../requirements.txt

### 1) Setup: Create a Synthetic Dataset

This code block creates a synthetic xarray `Dataset` simulating daily temperature data over Cyprus for one year (2022). It builds:

- A spatial grid of latitudes and longitudes roughly covering Cyprus  
- A base temperature pattern varying with latitude and longitude  
- A seasonal cycle to simulate yearly temperature changes  
- Random noise for realism  
- The resulting 3D data array has dimensions `[time, latitude, longitude]`

This dataset provides a realistic basis to demonstrate xarray’s data handling features.


In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt

# Define coords
times = pd.date_range('2022-01-01', '2022-12-31', freq='D')
lats = np.linspace(34, 36, 50)  # ~Cyprus latitude range
lons = np.linspace(32, 35, 60)  # ~Cyprus longitude range

# Create meshgrid for lat/lon to simulate spatial pattern
lon2d, lat2d = np.meshgrid(lons, lats)

# Base temperature pattern: warmer in south, cooler in north, some random noise
base_temp = 20 - (lat2d - 34) * 5 + (lon2d - 32) * 2

# Time variation: seasonal sine wave + noise
days_in_year = len(times)
seasonal_cycle = 10 * np.sin(2 * np.pi * (np.arange(days_in_year) / days_in_year))

# Combine spatial and temporal patterns into 3D array: time x lat x lon
data = np.empty((days_in_year, len(lats), len(lons)))
for i, day_val in enumerate(seasonal_cycle):
    noise = np.random.normal(0, 0.5, size=base_temp.shape)
    data[i, :, :] = base_temp + day_val + noise

# Create xarray Dataset
ds = xr.Dataset(
    {
        "temperature": (["time", "latitude", "longitude"], data)
    },
    coords={
        "time": times,
        "latitude": lats,
        "longitude": lons
    }
)

print(ds)


---


### 2) Slicing and Selecting

Here we demonstrate how to select subsets of data using labeled coordinates:

- Select temperature values for a single day (`2022-06-01`) using `.sel()`  
- Extract a smaller geographic subset by slicing latitude and longitude ranges

This highlights xarray’s powerful and readable indexing by dimension names and coordinate values instead of numeric indices.


In [None]:
# Select temperature on June 1, 2022
temp_june1 = ds.sel(time="2022-06-01")
print(temp_june1)

# Select data in latitude range 34.5 to 35.5 and longitude 32.5 to 34.0
subset = ds.sel(latitude=slice(34.5, 35.5), longitude=slice(32.5, 34.0))
print(subset)

---


### 3) Aggregation Over Space

This example computes the spatial average temperature for each day by taking the mean over the latitude and longitude dimensions.

It shows how easy it is to reduce data dimensions with aggregation functions like `.mean()` over named dimensions.


In [None]:
# Compute daily mean temperature over whole spatial domain
daily_mean_temp = ds.temperature.mean(dim=["latitude", "longitude"])
print(daily_mean_temp)


---


### 4) Aggregation Over Time and Anomaly Computation

Here we calculate:

- A **daily climatology** by averaging temperature values for each calendar day (day of year) across the time dimension  
- An **anomaly** for a specific date (`2022-06-01`), which is the difference between the temperature that day and the climatological average for that day of year

This demonstrates grouping, aggregation over time, and arithmetic operations with labeled data.


In [None]:
# Compute climatological mean for each day of the year (using all data, here just 1 year)
climatology = ds.temperature.groupby('time.dayofyear').mean('time')

# Compute anomaly for June 1, 2022 (difference from climatology)
day_of_year = pd.Timestamp("2022-06-01").dayofyear
anomaly_june1 = ds.sel(time="2022-06-01").temperature - climatology.sel(dayofyear=day_of_year)

print(anomaly_june1)


---


### 5) Filtering (Rolling Mean)

This section applies a **7-day rolling mean** filter in the time dimension to smooth the temperature time series.

We plot the original and smoothed temperature for a single geographic location to visualize how rolling averages reduce day-to-day noise.

Rolling window operations are common in time series analysis and are easy to implement with xarray.


In [None]:
# Apply 7-day rolling mean in time dimension to smooth temperature
smoothed_temp = ds.temperature.rolling(time=7, center=True).mean()

# Plot original vs smoothed at a single location (lat=35, lon=33)
lat_sel = 35
lon_sel = 33

orig_series = ds.temperature.sel(latitude=lat_sel, longitude=lon_sel, method='nearest')
smooth_series = smoothed_temp.sel(latitude=lat_sel, longitude=lon_sel, method='nearest')

plt.figure(figsize=(10, 4))
plt.plot(orig_series.time, orig_series, label='Original')
plt.plot(smooth_series.time, smooth_series, label='7-day Rolling Mean', linewidth=2)
plt.title(f"Temperature Time Series at lat={lat_sel}, lon={lon_sel}")
plt.ylabel("Temperature (°C)")
plt.legend()
plt.show()


### High-Resolution Temperature Time Series with Multi-Scale Rolling Means

In this example, we simulate a high-frequency temperature time series over **3 weeks** with measurements every **5 minutes**. The temperature follows a sinusoidal daily cycle with added random noise to mimic natural variability.

We then apply **rolling mean filters** with window sizes corresponding to:

- **1 hour** (12 data points)  
- **3 hours** (36 data points)  
- **6 hours** (72 data points)  
- **12 hours** (144 data points)  

These rolling means progressively smooth the data by averaging over longer time intervals.

The plot zooms in on a **2-day window (June 7 to June 9, 2022)** to clearly visualize how increasing the rolling window size reduces short-term fluctuations while preserving longer-term trends.

This demonstrates the trade-off between noise reduction and temporal resolution in time series analysis using rolling averages.


In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt

# Create high-frequency time axis: 3 weeks at 5-minute intervals
times_highres = pd.date_range('2022-06-01', periods=3*7*24*12, freq='5min')

# Parameters for sinusoidal daily temperature cycle
base_temp = 20
amplitude = 10
period = 24 * 60  # minutes in a day

minutes = np.arange(len(times_highres)) * 5

# Simulate temperature: sinusoidal daily cycle + low noise
temp_values = base_temp + amplitude * np.sin(2 * np.pi * minutes / period) + np.random.normal(0, 0.3, len(times_highres))

# Create xarray Dataset
ds_highres = xr.Dataset(
    {
        "temperature": (["time"], temp_values)
    },
    coords={
        "time": times_highres
    }
)

# Rolling window sizes in number of 5-min points
# 1 hour = 12 points, 3 hours = 36, 6 hours = 72, 12 hours = 144 points
window_sizes_points = [12, 36, 72, 144]

# Calculate rolling means (centered)
rolled_dict = {}
for w in window_sizes_points:
    rolled_dict[w] = ds_highres.temperature.rolling(time=w, center=True).mean()

# Plot original full data and zoom to a few days to see detail
plt.figure(figsize=(15, 7))
plt.plot(ds_highres.time, ds_highres.temperature, label='Original', color='black', lw=1.5)

# Zoom interval: June 7 00:00 to June 9 00:00 (2 days)
zoom_start = pd.Timestamp("2022-06-07 00:00")
zoom_end = pd.Timestamp("2022-06-09 00:00")

colors = ['red', 'orange', 'green', 'blue']
for w, c in zip(window_sizes_points, colors):
    zoomed = rolled_dict[w].sel(time=slice(zoom_start, zoom_end))
    plt.plot(zoomed.time, zoomed, label=f'{w*5} min Rolling Mean', lw=2, color=c)

plt.xlim(zoom_start, zoom_end)
plt.title("3 Weeks High-Resolution Temperature with Various Rolling Means (5-min intervals)")
plt.ylabel("Temperature (°C)")
plt.xlabel("Time")
plt.legend()
plt.grid(True, linestyle='--', alpha=0.5)
plt.show()


---

### Plotting Local vs Regional Temperature Time Series

This plot compares temperature variations at two scales over Cyprus for the year 2022:

- **Local Temperature:** Time series extracted from the grid point nearest to latitude 35° and longitude 33°. This shows the detailed daily temperature fluctuations at that specific location.

- **Regional Mean Temperature:** The daily spatial average of temperature over the entire Cyprus grid. This smooths out local variability and highlights broader regional trends.

Comparing these two time series helps understand how local weather patterns relate to the overall regional climate signal.


In [None]:
import matplotlib.pyplot as plt

# Select location time series (nearest grid point to lat=35, lon=33)
lat_sel = 35
lon_sel = 33
local_ts = ds.temperature.sel(latitude=lat_sel, longitude=lon_sel, method='nearest')

# Compute daily spatial mean time series
regional_ts = ds.temperature.mean(dim=["latitude", "longitude"])

plt.figure(figsize=(12, 5))
plt.plot(local_ts.time, local_ts, label="Local Temperature (lat=35°, lon=33°)")
plt.plot(regional_ts.time, regional_ts, label="Regional Mean Temperature")
plt.title("Temperature Time Series: Local vs Regional (Cyprus, 2022)")
plt.xlabel("Date")
plt.ylabel("Temperature (°C)")
plt.legend()
plt.grid(True)
plt.show()