In [None]:
import xarray as xr
import numpy as np

# Load .nc files
ds4 = xr.open_dataset("predictions_m4_g4_l32.nc")
ds5 = xr.open_dataset("predictions_m5_g4_l32.nc")
ds6 = xr.open_dataset("predictions_m6_g4_l32.nc")

ds4, ds5, ds6

In [None]:
def process_nc(ds):
  # pick only 1 time entry
  ds = ds.isel(time=0)

  # pick only level 1000
  ds = ds.isel(level=-1)

  # drop time and level coords
  ds = ds.drop_vars(["time", "level"])

  return ds

In [None]:
ds4_proc = process_nc(ds4)
ds5_proc = process_nc(ds5)
ds6_proc = process_nc(ds6)

ds4_proc, ds5_proc, ds6_proc

## Generate data.npz (target) from higher mesh

In [None]:
def gen_data_npz(ds):
  # Define the target variables
  target_vars = [
      #'geopotential_at_surface',
      #'land_sea_mask',
      '2m_temperature',
      'mean_sea_level_pressure',
      '10m_v_component_of_wind',
      '10m_u_component_of_wind',
      'total_precipitation_6hr',
      #'toa_incident_solar_radiation',
      'temperature',
      'geopotential',
      'u_component_of_wind',
      'v_component_of_wind',
      'vertical_velocity',
      'specific_humidity'
  ]

  # Get the shape of lat and lon
  lat_len = len(ds['lat'])  # e.g., 181
  lon_len = len(ds['lon'])  # e.g., 360

  # Calculate total number of nodes and targets
  n_nodes = lat_len * lon_len  # Total number of nodes (lat * lon)
  n_targets = len(target_vars)  # Number of target variables (e.g., 14)

  # Prepare an empty list to store each flattened variable's data
  flattened_vars = []

  # Iterate through each target variable and reshape it
  for var in target_vars:
      data = ds[var].values  # Extract the variable's data as a numpy array

      # Check the shape of the data
      data_shape = data.shape
      print(f"Shape of {var}: {data_shape}")  # For debugging purposes

      # Flatten the data based on its dimensionality
      if len(data_shape) == 2:  # (lat, lon) format
          flattened_data = data.reshape(-1)  # Flatten to 1D array
      elif len(data_shape) == 3:  # (batch, lat, lon) format
          flattened_data = data.reshape(data_shape[0], -1)  # Flatten lat and lon, keep batch
      elif len(data_shape) == 4:  # (batch, level, lat, lon) format
          flattened_data = data.reshape(data_shape[0] * data_shape[1], -1)  # Flatten batch and level
      else:
          raise ValueError(f"Unexpected shape for variable {var}: {data_shape}")

      # Append the flattened data to the list
      flattened_vars.append(flattened_data)

  # Stack all flattened variables vertically and transpose to get (n_nodes, n_targets)
  target_signal = np.vstack(flattened_vars).T  # Shape: (n_nodes, n_targets)

  # Display the shape of the target_signal for verification
  print(f"\nShape of target_signal: {target_signal.shape}")  # Expected: (n_nodes, n_targets)

  # Save the target_signal to an NPZ file with the key 'target'
  np.savez("data.npz", target=target_signal)
  # Alternatively, to save space, you can use savez_compressed:
  # np.savez_compressed("data.npz", target=target_signal)

  print("\nSaved target_signal to data.npz with key 'target'\n")


## Generate points.npy from lower mesh

In [None]:
def gen_points_npy(ds):
  # Flatten lat and lon dimensions and create the coordinates matrix
  lat = ds['lat'].values  # Get the latitude values as a numpy array
  lon = ds['lon'].values  # Get the longitude values as a numpy array

  # Create a meshgrid of latitudes and longitudes
  lon_grid, lat_grid = np.meshgrid(lon, lat)

  # Reshape the grids into 1D arrays
  lat_flat = lat_grid.flatten()  # Flatten latitude grid into 1D
  lon_flat = lon_grid.flatten()  # Flatten longitude grid into 1D

  # Combine the flattened lat and lon into a single array of shape (n_nodes, 2)
  points = np.vstack((lat_flat, lon_flat)).T  # Stack lat and lon into (n_nodes, 2)

  # points now contains coordinates of shape (n_nodes, 2)
  print(f"Shape of points: {points.shape}")  # Should print (n_nodes, 2)

  np.save("points.npy", points)
  print("\nSaved points to points.npy\n")

## Generate fourier.npy from lower mesh

In [None]:
def gen_fourier_npy(ds):
  # Define the target variables
  target_vars = [
      #'geopotential_at_surface',
      #'land_sea_mask',
      '2m_temperature',
      'mean_sea_level_pressure',
      '10m_v_component_of_wind',
      '10m_u_component_of_wind',
      'total_precipitation_6hr',
      #'toa_incident_solar_radiation',
      'temperature',
      'geopotential',
      'u_component_of_wind',
      'v_component_of_wind',
      'vertical_velocity',
      'specific_humidity'
  ]

  # Fourier transform results will be stored in this list
  fourier_embeddings = []

  # Iterate through each target variable
  for var in target_vars:
      data = ds[var].values  # Get the variable's data as a numpy array
      data_shape = data.shape
      print(f"Shape of {var}: {data_shape}")  # To check if the shape is as expected

      if len(data_shape) == 2:  # (lat, lon) format
          # Flatten the lat and lon dimensions into a 1D array
          flattened_data = data.flatten()  # (n_nodes,)
          flattened_data = flattened_data.reshape(-1, 1)  # (n_nodes, 1) to ensure it's a 2D column vector
      elif len(data_shape) == 3 and data_shape[0] == 1:  # (1, lat, lon) format
          # Remove the batch dimension (squeeze) and flatten lat and lon
          data = np.squeeze(data)  # (lat, lon)
          flattened_data = data.flatten()  # (n_nodes,)
          flattened_data = flattened_data.reshape(-1, 1)  # (n_nodes, 1)
      elif len(data_shape) == 3:  # (batch, lat, lon) format
          # Flatten lat and lon dimensions, leaving the batch dimension intact
          flattened_data = data.reshape(data_shape[0], -1)  # (batch, n_nodes)
      else:
          raise ValueError(f"Unexpected shape for variable {var}: {data_shape}")

      # Apply Fourier Transform to each variable (FFT over spatial dimensions)
      fourier_transformed = np.fft.fft(flattened_data, axis=1)  # FFT along the spatial (lat, lon) dimension

      # Compute magnitude of the Fourier components (or keep real/imaginary parts)
      magnitude = np.abs(fourier_transformed)  # Magnitude of the FFT result

      # You can choose the number of frequency components (n_fourier)
      n_fourier = magnitude.shape[1]  # Number of Fourier coefficients (e.g., for FFT along the lat-lon axis)
      print(f"Number of Fourier components: {n_fourier}")

      # Append the magnitude as a feature in the spectral embeddings
      fourier_embeddings.append(magnitude)

  # Stack all the Fourier embeddings into a single array
  # Ensure that all arrays are reshaped to have the same number of nodes (n_nodes)
  fourier_embeddings = np.hstack(fourier_embeddings)  # (n_nodes, n_fourier)

  # Now fourier_embeddings contains the spectral embeddings (n_nodes, n_fourier)
  print(f"\nShape of Fourier embeddings: {fourier_embeddings.shape}")

  np.save("fourier.npy", fourier_embeddings)
  print("\nSaved spectral embeddings to fourier.npy\n")


## Generate files

In [None]:
_M4 = True
_M5 = False
_M6 = False
if _M4:
  gen_data_npz(ds4_proc)
  gen_points_npy(ds4_proc)
  gen_fourier_npy(ds4_proc)
elif _M5:
  gen_data_npz(ds5_proc)
  gen_points_npy(ds5_proc)
  gen_fourier_npy(ds5_proc)
elif _M6:
  gen_data_npz(ds6_proc)
  gen_points_npy(ds6_proc)
  gen_fourier_npy(ds6_proc)