Function description:

1. Select appropriate list of blobs from gcp fs, most recent hour's worth of data from GOES MCMIPC bucket. This should be 12 blobs because there are scans every 5 mins.
2. Download the set of blobs, pruning unneeded data such as data quality flag arrays and unused bands, return list of data sets
3. Concatenate the 12 datasets into one, effectively creating a dataset with a time dimension
4. Take median over the time dimension, so each pixel has median value of the last hour for each band
5. Feature engineer the median dataset, adding more informative bands that are ratios of the spectral channels
6. Reproject this dataset to epsg 4326
7. Download the preprocessed landfire layers. These have been reproject_matched to a GOES CONUS 'template' image, which has itself been reprojected to epsg 4326. This is intended to match the slightly convoluted preprocessing routine of training data the pytorch model was trained on.
8. Stack the GOES ds with the preprocessed landfire layers into a dataset.
9. Chunk the stacks to pytorch manageable size and upload to a bucket, this will be a large list of dataset files that have the stacked raster imagery with metadata that can be used to project pytorch container's inference.




In [88]:
from google.cloud import storage
from datetime import datetime, timedelta
import pandas as pd
import rioxarray
import xarray as xr
import numpy as np
import fsspec
import os
import tempfile
import zipfile


In [109]:
def select_blobs(bucket_name='gcp-public-data-goes-16'):
    """
    Selects the appropriate list of blobs from GCP fs, most recent hour's worth of data from GOES MCMIPC bucket.
    Returns: List of selected blobs.
    """
    # Get the current time
    attime = datetime.utcnow()

    # Set up Google Cloud Storage client
    client = storage.Client()
    bucket = client.get_bucket(bucket_name)

    # Create a range of directories to check. The GOES bucket is
    # organized by hour of day.
    selected_blobs = []
    for i in range(2):  # Get blobs from current hour and previous hour
        current_time = attime - timedelta(hours=i)
        prefix = f'ABI-L2-MCMIPC/{current_time.year}/{current_time.timetuple().tm_yday:03d}/{current_time.hour:02d}/'
        blobs = bucket.list_blobs(prefix=prefix)
        selected_blobs.extend([blob.name for blob in blobs])

    # Sort the blobs by their timestamp in descending order
    selected_blobs.sort(key=lambda name: name.split('_')[3][1:], reverse=True)

    # Check if there are at least 12 blobs
    if len(selected_blobs) < 12:
        raise Exception(f"Only {len(selected_blobs)} blobs found")

    return selected_blobs[:12]


def create_fs():
    """
    Creates a file system object for GCP. 
    Returns: File system object. fs can be interacted with as though it were a local file system.
    """
    fs = fsspec.filesystem('gcs', token=os.environ['GOOGLE_APPLICATION_CREDENTIALS'])
    return fs


def create_median_image(blob_list, fs, bucket_name='gcp-public-data-goes-16'):
    """
    This function creates a single-band image from a list of blob names.
    The pixel values in the image are the median values from the corresponding pixels in the input images.
    """
    # Initialize a list to store the Datasets
    datasets = []

    # Open each blob as a full dataset and load it into memory
    for blob in blob_list[4::]:  # Use every 6th blob for testing purposes (faster operation creating each median without pipeline-structure difference), change to 1 later
        f = fs.open(f'{bucket_name}/{blob}')
        print(f'Opening: {bucket_name}/{blob}')
        ds = xr.open_dataset(f).load()

        # Select a firey region for testing - large NW US region
        ds = ds.isel(x=slice(0, 1250), y=slice(0, 1250))

        datasets.append(ds)

    # Concatenate the datasets along a new 'band' dimension
    concated = xr.concat(datasets, dim='time')

    # Compute the median along the 'band' dimension
    median_ds = concated.median(dim='time', keep_attrs=True)

    # Close the files
    for ds in datasets:
        ds.close()

    # Return the median dataset
    return median_ds

def download_landfire_layers(fs, bucket_name='firenet_reference', blob_name='combined_landfire.nc'):
    """
    Downloads the preprocessed landfire layers. These have been reproject_matched to a GOES CONUS 'template' image, 
    which has itself been reprojected to epsg 5070. Properly loading and accessing the spatial metadata uses a
    trick, openning with xarray, saving to nc, then opening the nc tempfile with rioxarray. This is not a "good"
    approach but for whatever reason the spatial metadata couldn't be accessed otherwise. Trying to open directly with
    rioxarray runs into some interfacing problem with google buckets.
    Returns: Preprocessed landfire layers.
    """

    # Open the blob as a full dataset and load it into memory
    f = fs.open(f'{bucket_name}/{blob_name}')
    print(f'Opening: {bucket_name}/{blob_name}')
    ds = xr.open_dataset(f).load()

    # Create a temporary file
    with tempfile.NamedTemporaryFile(suffix='.nc') as tmpfile:
        # Save the dataset to the temporary file
        ds.to_netcdf(tmpfile.name)

        # Open the temporary file with rioxarray
        landfire_layers = rioxarray.open_rasterio(tmpfile.name)

    return landfire_layers

def reproject_dataset(dataset, landfire_layers):
    """
    Reprojects the dataset to the static layers.
    Note that the technique used here is again creating a tempfile and then opening it with rioxarray.
    This is not a "good" approach but for whatever reason the spatial metadata couldn't be accessed otherwise.
    Trying to open directly with rioxarray runs into some interfacing problem with google buckets.
    Returns: Reprojected dataset.
    """
    # Create a temporary file
    with tempfile.NamedTemporaryFile(suffix='.nc') as tmpfile:
        # Save the dataset to the temporary file
        dataset.to_netcdf(tmpfile.name)

        # Open the temporary file with rioxarray
        ds_rio = rioxarray.open_rasterio(tmpfile.name)
        
        # Reproject the dataset to the template dataset with landfire layers, landfire layers was generated by
        #  `reprojectmatch_and_stack_landfire_for_bucket.ipynb`
        reprojected_dataset = ds_rio.rio.reproject_match(landfire_layers)

    return reprojected_dataset

def engineer_features(dataset):
    """
    Feature engineers the median dataset, adding more informative bands that are ratios of the spectral channels.
    Returns: Feature engineered dataset.
    """
    # Ensure the CRS is preserved by extracting it from the original dataset
    original_crs = dataset.rio.crs

    # Compute the new features
    feat1 = dataset['CMI_C06'] / dataset['CMI_C05']
    feat2 = dataset['CMI_C07'] / dataset['CMI_C05']
    feat3 = dataset['CMI_C07'] / dataset['CMI_C06']
    feat4 = dataset['CMI_C14'] / dataset['CMI_C07']

    # Create a dictionary of the new features
    data_dict = {'feat_6_5': feat1, 'feat_7_5': feat2, 'feat_7_6': feat3, 'feat_14_7': feat4}

    # Add the new features to the dataset
    engineered_dataset = dataset.assign(data_dict)

    # Write the CRS of original_dataset to engineered_dataset, as a global attribute
    engineered_dataset.rio.write_crs(original_crs, inplace=True)

    # Write the CRS to every variable in engineered_dataset, making all var attrs match
    for var in engineered_dataset.data_vars:
        engineered_dataset[var].rio.write_crs(original_crs, inplace=True)

    return engineered_dataset

def stack_datasets(goes_ds, landfire_layers):
    """
    Stacks the GOES ds with the preprocessed landfire layers into a dataset.
    Sets 'grid_mapping' to 'spatial_ref' in the encoding for every data variable in the process.
    Returns: Stacked dataset.
    """
    # Merge the two datasets
    stacked_dataset = xr.merge([goes_ds, landfire_layers])

    # Set 'grid_mapping' to 'spatial_ref' in the encoding for every data variable
    for var in stacked_dataset.data_vars:
        stacked_dataset[var].encoding['grid_mapping'] = 'spatial_ref'

    # Optionally, delete 'goes_imager_projection' if it's no longer needed
    if 'goes_imager_projection' in stacked_dataset:
        del stacked_dataset['goes_imager_projection']

    return stacked_dataset


# def chunk_image(multiband_image, chunk_size=64):
#     """
#     Splits the multiband image into chunks.
#     Args:
#         multiband_image (xarray.DataArray): The multiband image to be chunked.
#         chunk_size (int): The size of the chunks. Default is 64.
#     Returns:
#         chunks (list): A list of xarray Datasets representing the chunks.
#         spatial_info (list): A list of tuples representing the spatial information of each chunk.
#     """
#     # Get the width and height of the image
#     width = multiband_image.dims['x']
#     height = multiband_image.dims['y']

#     # Calculate the number of chunks in x and y direction
#     nx, ny = width // chunk_size, height // chunk_size

#     # Initialize a list to store the chunks
#     chunks = []

#     # Loop over the image
#     for i in range(ny):
#         for j in range(nx):
#             # Define the slice
#             y_slice = slice(i * chunk_size, (i + 1) * chunk_size)
#             x_slice = slice(j * chunk_size, (j + 1) * chunk_size)

#             # Extract the chunk across all bands
#             chunk = multiband_image.isel(y=y_slice, x=x_slice)

#             # Store the chunk and its spatial information
#             chunks.append(chunk)
#     return chunks

def create_spatial_template(dataset):
    """
    Creates a spatial template from the original dataset by keeping only one data variable
    and setting its values to 0, while preserving spatial metadata. This step allows us to
    Take the datavars out as a 21x64x64 numpy array, then add the model output back as 1x64x64
    so that the spatial projection of the output is untouched.
    
    Parameters:
    - dataset: xarray.Dataset or rioxarray object with spatial dimensions and CRS information.
    
    Returns:
    - template: xarray.Dataset with a single data variable filled with NaNs and original spatial metadata.
    """
    # Clone the dataset to avoid modifying the original
    template = dataset.copy()
    
    # Select the first data variable (assuming there's at least one)
    first_var_name = list(template.data_vars)[0]
    first_var = template[first_var_name]
    
    # Create a 0-filled template of the first variable
    nan_template = xr.full_like(first_var, fill_value=np.nan)
    
    # Remove all data variables from the template
    for var_name in list(template.data_vars):
        del template[var_name]
    
    # Add the NaN-filled template variable back
    template[first_var_name] = nan_template
    
    # Ensure the spatial metadata is preserved
    # Note: This step might be redundant if the metadata is already attached to the coordinates
    # and not the data variables themselves. However, it's a safeguard for maintaining CRS.
    if hasattr(dataset, 'rio') and hasattr(dataset.rio, 'crs'):
        template.rio.write_crs(dataset.rio.crs, inplace=True)
    
    return template

def extract_data_as_array(dataset):
    """
    This function pulls the data variable value arrays out of the xarray dataset.
    """
    # Stack the data variables, then use np.squeeze() to remove the singleton dimension that was a placeholder for time.
    # The resultant array should be the shape 42, 3506, 2266. The second two dimensions may change if region of interest changes.

    stacked_array = np.stack([dataset[var].values for var in dataset.data_vars], axis=0)
    squeezed_array = np.squeeze(stacked_array)
    return squeezed_array

def chunk_ndarray(arr, chunk_size=64):
    """
    Breaks down an N-dimensional array into chunks along the last two dimensions,
    keeping the first dimension intact in each chunk.
    
    Parameters:
    - arr: Input N-dimensional NumPy array with shape (Variables, Height, Width).
    - chunk_size: Size of the chunks along each of the last two dimensions.
    
    Returns:
    - A list of chunks, where each chunk is an N-dimensional NumPy array with shape (Variables, chunk_size, chunk_size).
    """
    chunks = []
    # Iterate over the last two dimensions in steps of `chunk_size`
    for i in range(0, arr.shape[1], chunk_size):  # Height dimension
        for j in range(0, arr.shape[2], chunk_size):  # Width dimension
            # Calculate the end indices while ensuring they do not exceed the array's dimensions
            end_i = min(i + chunk_size, arr.shape[1])
            end_j = min(j + chunk_size, arr.shape[2])
            # Extract the chunk
            chunk = arr[:, i:end_i, j:end_j]
            # Only add chunks that meet the full size requirement (i.e., 42x64x64)
            if chunk.shape[1] == chunk_size and chunk.shape[2] == chunk_size:
                chunks.append(chunk)
    return chunks

def process_chunks(chunks):
    # Example operation: sum along the first axis (simulating adding arrays together
    # Once we get neural net into this spot, returning the 1x64x64 arrays we should be good
    return np.sum(chunks, axis=1)

def stitch_chunks(processed_chunks, original_shape):
    # Assuming processed_chunks is a list of 2D arrays (Height, Width)
    # and original_shape is the shape of the 2D plane of the original array (Height, Width)
    stitched = np.zeros(original_shape)
    chunk_size = processed_chunks[0].shape[0]
    for k, chunk in enumerate(processed_chunks):
        i, j = divmod(k, original_shape[1] // chunk_size)
        stitched[i*chunk_size:(i+1)*chunk_size, j*chunk_size:(j+1)*chunk_size] = chunk
    return stitched

In [3]:
selected_blobs = select_blobs()


In [4]:

fs = create_fs()
median_ds = create_median_image(selected_blobs, fs)

Opening: gcp-public-data-goes-16/ABI-L2-MCMIPC/2024/035/20/OR_ABI-L2-MCMIPC-M6_G16_s20240352036176_e20240352038555_c20240352039070.nc
Opening: gcp-public-data-goes-16/ABI-L2-MCMIPC/2024/035/20/OR_ABI-L2-MCMIPC-M6_G16_s20240352031176_e20240352033555_c20240352034074.nc
Opening: gcp-public-data-goes-16/ABI-L2-MCMIPC/2024/035/20/OR_ABI-L2-MCMIPC-M6_G16_s20240352026176_e20240352028549_c20240352029070.nc
Opening: gcp-public-data-goes-16/ABI-L2-MCMIPC/2024/035/20/OR_ABI-L2-MCMIPC-M6_G16_s20240352021176_e20240352023555_c20240352024065.nc
Opening: gcp-public-data-goes-16/ABI-L2-MCMIPC/2024/035/20/OR_ABI-L2-MCMIPC-M6_G16_s20240352016176_e20240352018555_c20240352019068.nc
Opening: gcp-public-data-goes-16/ABI-L2-MCMIPC/2024/035/20/OR_ABI-L2-MCMIPC-M6_G16_s20240352011176_e20240352013549_c20240352014072.nc
Opening: gcp-public-data-goes-16/ABI-L2-MCMIPC/2024/035/20/OR_ABI-L2-MCMIPC-M6_G16_s20240352006176_e20240352008556_c20240352009073.nc
Opening: gcp-public-data-goes-16/ABI-L2-MCMIPC/2024/035/20/OR_

In [98]:
landfire_layers = download_landfire_layers(fs)
reprojected_median_ds = reproject_dataset(median_ds, landfire_layers)
reprojected_median_ds = engineer_features(reprojected_median_ds)



Opening: firenet_reference/combined_landfire.nc


  dataset.to_netcdf(tmpfile.name)
  dataset.to_netcdf(tmpfile.name)


In [99]:
stacked_ds = stack_datasets(reprojected_median_ds, landfire_layers)

In [100]:
stacked_ds

In [108]:
# Assuming 'ds' is your original xarray.Dataset
template = create_spatial_template(stacked_ds)
npy_array = extract_data_as_array(stacked_ds)
npy_array.shape

(42, 3506, 2266)

In [111]:

chunks = chunk_ndarray(npy_array)
processed_chunks = process_chunks(chunks)

In [116]:

stitched = stitch_chunks(processed_chunks, npy_array.shape[-2:])
stitched.shape

(3506, 2266)

In [117]:

template[list(stacked_ds.data_vars)[0]].values = np.expand_dims(stitched, axis=0)

In [115]:
template.to_netcdf("processed_dataset.nc")
