# Read test radar test file and define gridding and interpolation scheme

In [None]:
import xarray as xr
import xradar as xd
import matplotlib.pyplot as plt
import numpy as np

import sys
import os
sys.path.append(os.path.dirname(os.getcwd()))
import src.global_variables as global_vars
import src.local_variables as local_vars

**Read the radar data**

In [None]:
filename = "../data/01-06-2024/2024060112000700dBZ.vol"

In [None]:
def load_all_sweeps(filename):
    """Load all 12 sweeps for one parameter"""
    sweeps = []
    for i in range(12):
        ds = xr.open_dataset(filename, group=f"sweep_{i}", engine="rainbow")
        sweeps.append(ds)
    return sweeps

sweeps = load_all_sweeps(filename)

**Plot lowest elevation**

In [None]:
range_vals = sweeps[0].range.values      # shape (480,)
azimuth_vals = sweeps[0].azimuth.values  # shape (360,) 
dbzh_vals = sweeps[0].DBZH.values        # shape (360, 480)

# Colormap limits
vmin, vmax = -60, 60

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Manual plot with imshow
im1 = ax1.imshow(dbzh_vals, 
                 extent=[range_vals.min(), range_vals.max(), 
                        azimuth_vals.min(), azimuth_vals.max()],
                 aspect='auto', 
                 origin='lower',
                 cmap='seismic',
                 vmin=vmin,
                 vmax=vmax)
ax1.set_xlabel('Range (m)')
ax1.set_ylabel('Azimuth (degrees)')
ax1.set_title('Manual imshow plot')
plt.colorbar(im1, ax=ax1, label='DBZH (dBZ)')

# Built-in xarray plot for comparison
sweeps[0].DBZH.plot(ax=ax2)
ax2.set_title('Built-in xarray plot')

plt.tight_layout()
plt.show()

**Convert from polar to Cartesian coordinates**

In [None]:
def spherical_to_cartesian_3D(sweep):
    """Convert radar field from spherical to 3D Cartesian coordinates"""

    range_vals = sweep.range.values
    azimuth_vals = sweep.azimuth.values
    elevation_angle = sweep.sweep_fixed_angle.values

    R_mesh, Az_mesh = np.meshgrid(range_vals, azimuth_vals)

    phi = np.pi/2 - np.radians(elevation_angle) # Convert to zenith angle
    theta = np.radians(Az_mesh)

    x = R_mesh * np.sin(phi) * np.sin(theta)
    y = R_mesh * np.sin(phi) * np.cos(theta)
    z = R_mesh * np.cos(phi)

    return x, y, z

x, y, z = spherical_to_cartesian_3D(sweeps[0])

In [None]:
# Plot top view to verify correctness
fig, ax = plt.subplots(1, 2, figsize=(15, 6))  
# Top view (X-Y plane)
scatter1 = ax[0].scatter(x.flatten()/1000, y.flatten()/1000,
                     c=sweeps[0].DBZH.values.flatten(), s=0.5, cmap="seismic", vmin=vmin, vmax=vmax)

ax[0].set_xlabel('X (km)')
ax[0].set_ylabel('Y (km)')
ax[0].set_title('Manual view')
ax[0].set_aspect('equal')
# ax[0].set_xlim([-40,40])
# ax[0].set_ylim([-40, 40])

# Add colorbar
cbar = plt.colorbar(scatter1, ax=ax[0])  
cbar.set_label('DBZH (dBZ)')

# Verify with built-in plotting routine
rd = sweeps[0].xradar.georeference()
rd.DBZH.plot(ax=ax[1], x="x", y="y")

plt.show()

**Interpolate data**

In [None]:
def aggregate_all_elevations(sweeps, parameter='DBZH'):
    """Aggregate data from all sweeps (elevations) into one list and convert to Cartesian coordinates."""
    all_x, all_y, all_z, all_payload = [], [], [], []
    for sweep in sweeps:
        x, y, z = spherical_to_cartesian_3D(sweep)
        payload = sweep[parameter].values
        all_x.extend(x.flatten())
        all_y.extend(y.flatten())
        all_z.extend(z.flatten())
        all_payload.extend(payload.flatten())

    return np.array(all_x), np.array(all_y), np.array(all_z), np.array(all_payload)

all_x, all_y, all_z, all_dbzh = aggregate_all_elevations(sweeps, parameter='DBZH')


In [None]:
# Load the predefined grid
grid = np.load("../data/radar_hurum_grid_10x10_8km_spacing.npz")
x_m = grid['x_centers_m']
y_m = grid['y_centers_m']
z_m = grid['z_levels_m']

grid.files

In [None]:
# Interpolate with scipy griddata
def interpolate_scipy_griddata(x_m, y_m, z_m, method):


    from scipy.interpolate import griddata

    # Create 3D grid
    x_grid, y_grid, z_grid = np.meshgrid(x_m, y_m, z_m)

    # Only use valid data
    valid = ~np.isnan(all_dbzh)
    x_valid = all_x[valid]
    y_valid = all_y[valid]
    z_valid = all_z[valid]
    dbzh_valid = all_dbzh[valid]

    # Do the interpolation
    grid_values = griddata(points=np.column_stack([x_valid, y_valid, z_valid]),
                        values=dbzh_valid,
                        xi=np.column_stack([x_grid.flatten(), y_grid.flatten(), z_grid.flatten()]),
                        method=method)

    dbzh_interpolated = grid_values.reshape(x_grid.shape)
    return dbzh_interpolated

In [None]:
dbzh_interpolated_nearest = interpolate_scipy_griddata(x_m, y_m, z_m, "nearest")

In [None]:
dbzh_interpolated_linear = interpolate_scipy_griddata(x_m, y_m, z_m, "linear")

In [None]:
# Visualize altitude slices
def visualize_altitude_slices(dbzh):
    z_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]  
    fig, axes = plt.subplots(5, 4, figsize=(15, 12))
    axes = axes.flatten()

    for i, z in enumerate(z_index):
        axes[i].imshow(dbzh[:,:,z], 
                    cmap="seismic", vmin=-60, vmax=60, 
                    origin="lower",
                    extent=[x_m.min()/1000, x_m.max()/1000,  # x-range (East-West)
                                y_m.min()/1000, y_m.max()/1000]) 
        axes[i].set_title(f"Altitude: {z_m[z]}m")
        axes[i].set_xticks([])  # Remove x-axis ticks
        axes[i].set_yticks([])  
    plt.tight_layout(pad=0.1)  # Minimal padding
    plt.subplots_adjust(wspace=0.05, hspace=0.2)  # Reduce spacing between plots
    plt.show()

In [None]:
visualize_altitude_slices(dbzh_interpolated_nearest)

In [None]:
visualize_altitude_slices(dbzh_interpolated_linear)

In [None]:
def interpolate_k_nearest(x_m, y_m, z_m, k=4):
    """K-nearest neighbor interpolation with distance weighting"""
    from scipy.spatial import cKDTree
    
    # Create 3D grid
    x_grid, y_grid, z_grid = np.meshgrid(x_m, y_m, z_m)
    
    # Only use valid data
    valid = ~np.isnan(all_dbzh)
    x_valid = all_x[valid]
    y_valid = all_y[valid]
    z_valid = all_z[valid]
    dbzh_valid = all_dbzh[valid]
    
    # Build KDTree for fast neighbor search
    points = np.column_stack([x_valid, y_valid, z_valid])
    tree = cKDTree(points)
    
    # Grid points to interpolate to
    grid_points = np.column_stack([x_grid.flatten(), y_grid.flatten(), z_grid.flatten()])
    
    # Find k nearest neighbors for each grid point
    distances, indices = tree.query(grid_points, k=k)
    
    # Calculate weights (inverse distance weighting)
    weights = 1.0 / (distances + 1e-6)  # Closer points get higher weight
    weights = weights / np.sum(weights, axis=1, keepdims=True)  # Normalize weights
    
    # Weighted average of k nearest neighbors
    interpolated = np.sum(dbzh_valid[indices] * weights, axis=1)
    
    return interpolated.reshape(x_grid.shape)

In [None]:
dbzh_interpolated_k4 = interpolate_k_nearest(x_m, y_m, z_m, k=4)

In [None]:
visualize_altitude_slices(dbzh_interpolated_k4)

In [None]:
def visualize_all_radar_cones(sweeps, x_m, y_m, z_m):
    """Show all 12 radar elevation angles"""
    
    fig, ax = plt.subplots(figsize=(15, 8))
    
    # Use a colormap for all 12 elevations
    colors = plt.cm.tab20(np.linspace(0, 1, 12))
    
    for idx, sweep in enumerate(sweeps):  # All 12 sweeps
        x, y, z = spherical_to_cartesian_3D(sweep)
        
        # Take slice near y=0 (2km tolerance)
        y_slice_mask = np.abs(y) < 2000
        
        if np.any(y_slice_mask):
            elevation = sweep.sweep_fixed_angle.values
            ax.scatter(x[y_slice_mask]/1000, z[y_slice_mask]/1000, 
                      c=[colors[idx]], s=1, alpha=0.6, 
                      label=f'{elevation:.1f}°')
    
    # Plot your grid points
    grid_x_centers = x_m / 1000
    grid_z_levels = z_m / 1000
    grid_x_mesh, grid_z_mesh = np.meshgrid(grid_x_centers, grid_z_levels)
    ax.scatter(grid_x_mesh.flatten(), grid_z_mesh.flatten(), 
              c='k', s=30, marker='s', alpha=0.8, label='Grid points')
    
    # Radar location
    ax.scatter(0, 0, c='black', s=100, marker='^', label='Radar')
    
    ax.set_xlabel('Distance (km)')
    ax.set_ylabel('Height (km)')
    ax.set_title('All 12 Radar Elevations vs Grid')
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    ax.grid(True, alpha=0.3)

    ax.set_xlim([-45, 45])  # Slightly wider than 40km grid
    ax.set_ylim([0, 10])  
    
    plt.tight_layout()
    plt.show()

visualize_all_radar_cones(sweeps, x_m, y_m, z_m)

In [None]:
def k_nearest_with_masking(x_m, y_m, z_m, k=4, max_distance=5000):
    """K-nearest interpolation with distance-based masking"""
    from scipy.spatial import cKDTree
    
    # Create 3D grid
    x_grid, y_grid, z_grid = np.meshgrid(x_m, y_m, z_m)
    
    # Only use valid data
    valid = ~np.isnan(all_dbzh)
    x_valid = all_x[valid]
    y_valid = all_y[valid]
    z_valid = all_z[valid]
    dbzh_valid = all_dbzh[valid]
    
    # Build KDTree for fast neighbor search
    points = np.column_stack([x_valid, y_valid, z_valid])
    tree = cKDTree(points)

    # Grid points to interpolate to
    grid_points = np.column_stack([x_grid.flatten(), y_grid.flatten(), z_grid.flatten()])
    
    # Find k nearest neighbors
    distances, indices = tree.query(grid_points, k=k)
    
    # Calculate weights (inverse distance weighting)
    weights = 1.0 / (distances + 1e-6)
    weights = weights / np.sum(weights, axis=1, keepdims=True)
    
    # Weighted average of k nearest neighbors
    interpolated = np.sum(dbzh_valid[indices] * weights, axis=1)
    
    # Create mask: only trust interpolation within max_distance of real data
    distances_to_nearest, _ = tree.query(grid_points, k=1)
    mask = distances_to_nearest < max_distance
    
    # Apply mask
    interpolated_masked = interpolated.copy()
    interpolated_masked[~mask] = 0.0  # Set distant interpolations to 0
    
    # Reshape outputs
    interpolated_final = interpolated_masked.reshape(x_grid.shape)
    mask_final = mask.reshape(x_grid.shape).astype(float)
    
    return interpolated_final, mask_final

In [None]:
dbzh_interpolated_masked, mask = k_nearest_with_masking(x_m, y_m, z_m, k=100, 
                                                        max_distance=np.sqrt(4000**2 + 4000**2 + (4000)**2))




# Visualize both the data and mask
visualize_altitude_slices(dbzh_interpolated_masked)
visualize_altitude_slices(mask)  

In [None]:
def k_nearest_with_masking_anisotropic(x, y, z, values,
                                       grid,
                                       k=4, max_distance=4000,
                                       vertical_scale=16.0):
    """
    Interpolate radar data to regular 3D grid using anisotropic k-nearest neighbors.
    
    Uses vertical scaling to account for different grid spacing in horizontal vs 
    vertical directions, preventing bias toward vertical neighbors.
    
    Parameters
    ----------
    x, y, z : array_like
        Radar measurement coordinates in meters  
    values : array_like
        Radar values to interpolate
    grid : dict
        Grid definition containing 'x_centers_m', 'y_centers_m', 'z_levels_m'
    k : int, default 4
        Number of nearest neighbors to find
    max_distance : float, default 4000
        Maximum search distance in scaled coordinates (meters)
    vertical_scale : float, default 16.0
        Scaling factor for vertical coordinates to create isotropic search space
        
    Returns
    -------
    ndarray
        Interpolated values with shape (n_y, n_x, n_z), NaN where insufficient neighbors
    """
    from scipy.spatial import cKDTree
    import numpy as np

    x_m = grid['x_centers_m']
    y_m = grid['y_centers_m']
    z_m = grid['z_levels_m']

    # Create 3D grid to interpolate to
    x_grid, y_grid, z_grid = np.meshgrid(x_m, y_m, z_m)

    # Flatten grid
    grid_points = np.column_stack([
        x_grid.flatten(),
        y_grid.flatten(),
        z_grid.flatten() * vertical_scale  # Scale vertical coordinate
    ])

    # Only use valid data
    valid = ~np.isnan(values)
    x_valid = x[valid]
    y_valid = y[valid]
    z_valid = z[valid]
    dbzh_valid = values[valid]

    # Rescale z before building the KDTree
    points = np.column_stack([
        x_valid,
        y_valid,
        z_valid * vertical_scale
    ])
    tree = cKDTree(points)

    # Find k nearest neighbors for each voxel
    distances, indices = tree.query(grid_points, k=k)
    print(np.shape(distances), np.shape(indices))

    # Initialize result array with NaNs
    interpolated = np.full(len(grid_points), np.nan)

    # Loop through each grid point (voxel)
    for i in range(len(grid_points)):
        neighbor_idxs = indices[i]
        d = distances[i]

        # Select only neighbors within max_distance
        within_mask = d < max_distance

        if np.any(within_mask):
            valid_values = dbzh_valid[neighbor_idxs[within_mask]]
            interpolated[i] = np.mean(valid_values)  # Uniform average

    # Reshape to grid
    interpolated_final = interpolated.reshape(x_grid.shape)

    return interpolated_final

In [None]:
dbzh_interpolated_anisotropic = k_nearest_with_masking_anisotropic(all_x, all_y, all_z, all_dbzh,
                                       grid,
                                       k=100, max_distance=np.sqrt(4000**2 + 4000**2 + (4000)**2),
                                       vertical_scale=16)




# Visualize both the data and mask
visualize_altitude_slices(dbzh_interpolated_anisotropic)