In [1]:
import pandas as pd
import numpy as np
from pyproj import Proj
from scipy.spatial import cKDTree
from netCDF4 import Dataset
from datetime import datetime, timedelta

class RealWind1:
    def __init__(self, start_date="2019-10-01", num_days=9):
        """
        Precompute and store wind data in a pandas DataFrame for fast lookup.
        The table includes time (in minutes), northing, easting, U (east wind), and V (north wind).
        """
        self.proj = Proj(proj='utm', zone=56, south=True, ellps='WGS84')
        self.wind_table = []  # List to accumulate wind data

        start_date = datetime.strptime(start_date, "%Y-%m-%d")
        layer_range = (56, 68)  # Height range for averaging wind speeds
        bounding_box = (140.5, -39, 150, -34)

        for day_offset in range(num_days):
            current_date = start_date + timedelta(days=day_offset)
            date_str = current_date.strftime('%Y%m%d')
            file_path = rf"C:\\Users\\Nur Izfarwiza\\Documents\\Dissertation\\Wind\\MERRA2_400.tavg3_3d_asm_Nv.{date_str}.nc4"

            try:
                dataset = Dataset(file_path, 'r')
                lats = dataset.variables['lat'][:]
                lons = dataset.variables['lon'][:]
                
                # Apply the bounding box filter
                lat_indices = np.where((lats >= bounding_box[1]) & (lats <= bounding_box[3]))[0]
                lon_indices = np.where((lons >= bounding_box[0]) & (lons <= bounding_box[2]))[0]

                lats = lats[lat_indices]
                lons = lons[lon_indices]

                # Load and filter wind data within the bounding box
                eastward_wind = dataset.variables['U'][:, :, lat_indices, :][:, :, :, lon_indices]
                northward_wind = dataset.variables['V'][:, :, lat_indices, :][:, :, :, lon_indices]

                # Average over the specified vertical layers
                layer_range_slice = slice(layer_range[0], layer_range[1])
                eastward_wind_avg = np.mean(eastward_wind[:, layer_range_slice, :, :], axis=1)
                northward_wind_avg = np.mean(northward_wind[:, layer_range_slice, :, :], axis=1)

                # Read and store time directly in minutes
                print("NASA Time Units:", dataset.variables['time'].units)
                time_var = np.array(dataset.variables['time'][:], dtype=np.float64)  # Already in minutes!

                for t_idx, t_val in enumerate(time_var):
                    if not np.isfinite(t_val):
                        continue  # Skip invalid time values
                    timestamp = t_val *60 #convert minutes to seconds

                    # Vectorized coordinate transformation
                    lon_grid, lat_grid = np.meshgrid(lons, lats)
                    easting, northing = self.proj(lon_grid, lat_grid)

                    # Flatten arrays to create a table of points
                    easting = easting.flatten()
                    northing = northing.flatten()
                    u_wind = eastward_wind_avg[t_idx].flatten()
                    v_wind = northward_wind_avg[t_idx].flatten()

                    # Filter out invalid projections
                    valid_mask = np.isfinite(easting) & np.isfinite(northing) & np.isfinite(u_wind) & np.isfinite(v_wind)

                    # Append valid data to the wind table
                    self.wind_table.extend(
                        zip([timestamp] * np.sum(valid_mask),
                            northing[valid_mask],
                            easting[valid_mask],
                            u_wind[valid_mask],
                            v_wind[valid_mask])
                    )

                dataset.close()

            except FileNotFoundError:
                print(f"File for {date_str} not found.")
            except Exception as e:
                print(f"Error loading {date_str}: {e}")

        # Convert to pandas DataFrame for fast lookups
        self.wind_table = pd.DataFrame(self.wind_table, columns=["Timestamp", "Northing", "Easting", "East Wind", "North Wind"])

        # Clean data: Drop NaN or Inf values
        self.wind_table = self.wind_table.replace([np.inf, -np.inf], np.nan)
        self.wind_table = self.wind_table.dropna()

        # Sort by timestamp for faster temporal queries
        self.wind_table = self.wind_table.sort_values(by="Timestamp").reset_index(drop=True)

        # Build a spatial KD-Tree for efficient spatial querying
        self.wind_tree = cKDTree(self.wind_table[["Easting", "Northing"]].values)

        print(f"Wind data precomputed for {len(self.wind_table)} points across {num_days} days.")
    
    def getwind(self, coords):
        """
        Get the nearest-neighbor wind speed for particles at given positions and times.

        Parameters:
        - coords: A tensor of shape (num_particles, num_observations, 3)
                  where each entry is [time, easting, northing].

        Returns:
        - wind_data: A tensor of shape (num_particles, num_observations, 2)
                     where each entry is [east_wind_speed, north_wind_speed].
        """
        num_particles, num_observations, _ = coords.shape
        wind_data = np.full((num_particles, num_observations, 2), np.nan)  # Initialize with NaN

        timestamps = coords[:, :, 0].flatten()  # Keep time in minutes

        for i in range(num_particles):
            for j in range(num_observations):
                easting, northing = coords[i, j, 1], coords[i, j, 2]
                timestamp = timestamps[i * num_observations + j]

                closest_time_idx = (np.abs(self.wind_table["Timestamp"] - timestamp)).idxmin()
                dist, closest_idx = self.wind_tree.query([easting, northing], k=1)

                wind_data[i, j] = self.wind_table.loc[closest_idx, ["East Wind", "North Wind"]]

        return wind_data


In [2]:
import pandas as pd
import numpy as np
from pyproj import Proj
from scipy.spatial import cKDTree
from netCDF4 import Dataset
from datetime import datetime, timedelta

class RealWind2:
    def __init__(self, start_date="2019-10-01", num_days=9):
        self.proj = Proj(proj='utm', zone=56, south=True, ellps='WGS84')
        self.wind_by_time = {}
        self.kdtrees = {}

        start_date = datetime.strptime(start_date, "%Y-%m-%d")
        layer_range = (56, 68)
        bounding_box = (110, -45, 155, -10)

        for day_offset in range(num_days):
            current_date = start_date + timedelta(days=day_offset)
            date_str = current_date.strftime('%Y%m%d')
            file_path = rf"C:\\Users\\Nur Izfarwiza\\Documents\\Dissertation\\Wind\\MERRA2_400.tavg3_3d_asm_Nv.{date_str}.nc4"

            try:
                dataset = Dataset(file_path, 'r')
                lats = dataset.variables['lat'][:]
                lons = dataset.variables['lon'][:]

                lat_indices = np.where((lats >= bounding_box[1]) & (lats <= bounding_box[3]))[0]
                lon_indices = np.where((lons >= bounding_box[0]) & (lons <= bounding_box[2]))[0]

                lats = lats[lat_indices]
                lons = lons[lon_indices]

                eastward_wind = dataset.variables['U'][:, :, lat_indices, :][:, :, :, lon_indices]
                northward_wind = dataset.variables['V'][:, :, lat_indices, :][:, :, :, lon_indices]

                layer_range_slice = slice(layer_range[0], layer_range[1])
                eastward_wind_avg = np.mean(eastward_wind[:, layer_range_slice, :, :], axis=1)
                northward_wind_avg = np.mean(northward_wind[:, layer_range_slice, :, :], axis=1)

                time_var = np.array(dataset.variables['time'][:], dtype=np.float64)

                for t_idx, t_val in enumerate(time_var):
                    if not np.isfinite(t_val):
                        continue
                    timestamp = int((t_val * 60) // 10800 * 10800)  # Round to nearest 3-hour (in seconds)

                    lon_grid, lat_grid = np.meshgrid(lons, lats)
                    easting, northing = self.proj(lon_grid, lat_grid)

                    easting = easting.flatten()
                    northing = northing.flatten()
                    u_wind = eastward_wind_avg[t_idx].flatten()
                    v_wind = northward_wind_avg[t_idx].flatten()

                    valid_mask = np.isfinite(easting) & np.isfinite(northing) & np.isfinite(u_wind) & np.isfinite(v_wind)
                    data = np.stack([easting[valid_mask], northing[valid_mask], u_wind[valid_mask], v_wind[valid_mask]], axis=1)

                    if timestamp not in self.wind_by_time:
                        self.wind_by_time[timestamp] = data
                    else:
                        self.wind_by_time[timestamp] = np.vstack([self.wind_by_time[timestamp], data])

                dataset.close()

            except Exception as e:
                print(f"Error loading {date_str}: {e}")

        for ts_key, data in self.wind_by_time.items():
            self.kdtrees[ts_key] = cKDTree(data[:, :2])

        print(f"Loaded wind data for {len(self.wind_by_time)} timestamps.")

    def getwind(self, coords):
        num_particles, num_observations, _ = coords.shape
        wind_data = np.full((num_particles, num_observations, 2), np.nan)

        flat_coords = coords.reshape(-1, 3)
        coords_flat = coords.reshape(-1, 3)
        times = coords_flat[:, 0]
        eastings = coords_flat[:, 1]
        northings = coords_flat[:, 2]
    
        # Define bounding box in UTM (convert once at init for faster runtime)
        # Example: UTM for (110E, -45) to (155E, -10)
        x_min, y_min = self.proj(110, -45)
        x_max, y_max = self.proj(155, -10)
    
        for idx, (t, x, y) in enumerate(zip(times, eastings, northings)):
            if not np.isfinite(x) or not np.isfinite(y):
                print(f"Particle {idx} has NaN or Inf: x={x}, y={y}")
                continue
    
            if x < x_min or x > x_max or y < y_min or y > y_max:
                print(f"⚠️ Particle {idx} out of bounds:")
                print(f"  Time: {t:.2f} s | Easting: {x:.2f} | Northing: {y:.2f}")
                continue
        timestamps = (flat_coords[:, 0] // 10800 * 10800).astype(int)
        unique_ts = np.unique(timestamps)

        for ts in unique_ts:
            if ts not in self.kdtrees:
                continue

            mask = timestamps == ts
            coords_subset = flat_coords[mask]
            spatial_coords = coords_subset[:, 1:3]

            tree = self.kdtrees[ts]
            data = self.wind_by_time[ts]

            _, idx = tree.query(spatial_coords)
            wind_vals = data[idx, 2:]

            wind_data.reshape(-1, 2)[mask] = wind_vals

        return wind_data


In [7]:
import time
from advectionGP.sensors import RemoteSensingModel
# Generate test particles (example: shape [N_particles, N_obs, 3])
# Let's say you have a sensors class or random dummy input
sensors = RemoteSensingModel()
particles = sensors.genParticles(30)  # or np.random.randn(5, 120, 3) with proper values

# Test RealWind version 1
wind_model_1 = RealWind1(start_date="2019-10-01", num_days=3)
start_time = time.time()
wind_1 = wind_model_1.getwind(particles)
elapsed_1 = time.time() - start_time
print(f"⏱️ Version 1 getwind() took {elapsed_1:.4f} seconds")

if np.isnan(wind_1).any():
    print("❌ NaNs found in wind_1!")
    nan_indices = np.argwhere(np.isnan(wind_1))
    print(f"NaN at indices:\n{nan_indices[:10]} ...")  # show only first 10
else:
    print("✅ No NaNs in wind_1!")

if np.isinf(wind_1).any():
    print("❌ Inf values found in wind_1!")

# Test RealWind version 2
wind_model_2 = RealWind2(start_date="2019-10-01", num_days=3)  # assuming you've split them
start_time = time.time()
wind_2 = wind_model_2.getwind(particles)
elapsed_2 = time.time() - start_time
print(f"⏱️ Version 2 getwind() took {elapsed_2:.4f} seconds")

if np.isnan(wind_2).any():
    print("❌ NaNs found in wind_2!")
    nan_indices = np.argwhere(np.isnan(wind_2))
    print(f"NaN at indices:\n{nan_indices[:10]} ...")  # show only first 10
else:
    print("✅ No NaNs in wind_2!")

if np.isinf(wind_2).any():
    print("❌ Inf values found in wind_2!")


NASA Time Units: minutes since 2019-10-01 01:30:00


  time_var = np.array(dataset.variables['time'][:], dtype=np.float64)  # Already in minutes!
cannot be safely cast to variable data type
  time_var = np.array(dataset.variables['time'][:], dtype=np.float64)  # Already in minutes!


NASA Time Units: minutes since 2019-10-02 01:30:00
NASA Time Units: minutes since 2019-10-03 01:30:00
Wind data precomputed for 4224 points across 3 days.
⏱️ Version 1 getwind() took 2.4771 seconds
✅ No NaNs in wind_2!


  time_var = np.array(dataset.variables['time'][:], dtype=np.float64)
cannot be safely cast to variable data type
  time_var = np.array(dataset.variables['time'][:], dtype=np.float64)


Loaded wind data for 8 timestamps.
⏱️ Version 2 getwind() took 0.0141 seconds
✅ No NaNs in wind_2!
