
# Inference CONUS

In this colab, we will describe the GenFocal debiased and downscaled forecasts for the continental United States (CONUS) for the months of June, July, and August for 2010-2019.  These datasets
contain the following variables
 - 10mW, wind speed at 10 meters (m/s)
 - 2mT, temperature at 2 meters (K)
 - MSL, mean sea-level pressure (Pa)
 - Q1000, near-surface specific humidity (kg/kg)

In [None]:
# @title PIP Installs
!pip install -q zarr xarray[complete] fsspec aiohttp requests gcsfs cartopy \
  cfgrib eccodes cf_xarray pint_xarray


In [None]:
# @title Imports
import h5py
import gcsfs
import matplotlib.pyplot as plt
from google.colab import auth
from google.cloud import storage
from datetime import datetime
import pandas as pd
import numpy as np
from cartopy import config
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import xarray as xr
import cf_xarray.units
import pint_xarray


In [None]:
# @title Plotting Functions

def plot_scalars(temp_data1, temp_data2, lat_min, lat_max, lon_min, lon_max):
    """
    Plots two scalar arrays on the same plot with a shared colorbar.

    Args:
        temp_data1: xarray DataArray of the first temperature data.
        temp_data2: xarray DataArray of the second temperature data.
        lat_min: Minimum latitude for the plot.
        lat_max: Maximum latitude for the plot.
        lon_min: Minimum longitude for the plot.
        lon_max: Maximum longitude for the plot.
    """

    fig, axs = plt.subplots(nrows=1, ncols=2,
                            subplot_kw={'projection': ccrs.PlateCarree()},
                            figsize=(12, 6))

    # Combine data for shared colorbar limits
    vmin = min(temp_data1.min(), temp_data2.min())
    vmax = max(temp_data1.max(), temp_data2.max())

    # Plot the first temperature data
    im1 = temp_data1.plot(
        ax=axs[0], transform=ccrs.PlateCarree(), add_colorbar=False,
           x='longitude', y='latitude',
           vmin=vmin, vmax=vmax,
           cmap='viridis'
    )

    # Overlay the second temperature data
    im2 = temp_data2.plot(
        ax=axs[1], transform=ccrs.PlateCarree(), add_colorbar=False,
           x='longitude', y='latitude',
           vmin=vmin, vmax=vmax,
           cmap='viridis'
    )


    # Add coastlines and gridlines
    for ax in axs:
      ax.coastlines()
      ax.add_feature(cfeature.BORDERS)
      ax.add_feature(cfeature.RIVERS)
      # Set plot extent
      ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
      cbar = plt.colorbar(im1, ax=ax, shrink=0.7)  # Use im1 for the colorbar
      cbar.set_label('Temperature (K)')

    plt.show()

def plot_scalar(temp_data1, lat_min, lat_max, lon_min, lon_max, title_string=""):
    """
    Plots two scalar arrays on the same plot with a shared colorbar.

    Args:
        temp_data1: xarray DataArray of the first temperature data.
        temp_data2: xarray DataArray of the second temperature data.
        lat_min: Minimum latitude for the plot.
        lat_max: Maximum latitude for the plot.
        lon_min: Minimum longitude for the plot.
        lon_max: Maximum longitude for the plot.
    """

    fig, axs = plt.subplots(nrows=1, ncols=1,
                            subplot_kw={'projection': ccrs.PlateCarree()},
                            figsize=(12, 6))

    vmin = temp_data1.min()
    vmax = temp_data1.max()

    # Plot the temperature data
    im1 = temp_data1.plot(
        ax=axs, transform=ccrs.PlateCarree(), add_colorbar=False,
           x='longitude', y='latitude',
           vmin=vmin, vmax=vmax,
           cmap='viridis'
    )

    # Add coastlines and gridlines
    axs.coastlines()
    axs.add_feature(cfeature.BORDERS)
    axs.add_feature(cfeature.STATES)
    # Set plot extent
    axs.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
    cbar = plt.colorbar(im1, ax=axs, shrink=0.7)  # Use im1 for the colorbar
    cbar.set_label('Temperature (K)')
    plt.title(f"{title_string}")

    plt.show()


We now need to authenticate with Google Cloud so we can access the GenFocal bucket

In [None]:
auth.authenticate_user()

## Chunked by pixels
First, we will look at the dataset chunked by pixel.  This chunking strategy
collects all of the forecast timesteps and members for a given model grid location into one file.  This makes it easy to visualize long-time series.

In [None]:
inference_conus = xr.open_zarr(
            "gs://genfocal/staging/inference/conus/debiased_100members_jja10s_8samples_xm153999662_pixel_chunks.zarr",
            consolidated=True,
)
inference_conus

## Chunked by member
First, we will look at the dataset chunked by pixel.  This chunking strategy
collects all of the grid points for a specific timestep and member
combination into one file.  This makes it easy to plot maps of the individual forecasts.  

In [None]:
# @markdown Latitude of point (-90 -> 90)
latitude_pt = 32.7079 # @param {type:"number"}

# @markdown Longitude of point (-180 -> 180)
longitude_pt = -96.9209 # @param {type:"number"}

# @markdown First Day To Plot
first_day = '2015-06-01' # @param {type:"date"}

# @markdown Last Day To Plot
last_day = '2015-08-31' # @param {type:"date"}


print(f"Latitude: {latitude_pt}")
print(f"Longitude: {longitude_pt}")
time_slice = slice(first_day, last_day)


inference_conus_pt = inference_conus.sel(time=time_slice).sel(latitude=latitude_pt, longitude=longitude_pt+360, method="nearest")
#era5_pt = era5_full.sel(time=time_slice).sel(latitude=latitude_pt, longitude=longitude_pt+360, method="nearest")

First, we will examine a plot of the mean and standard deviation of the forecast temperature.



In [None]:
from math import inf
plt.figure(figsize=(12, 6))

mean_hourly_temp = inference_conus_pt.mean(dim='member')
std_hourly_temp = inference_conus_pt.std(dim='member')
inference_conus_pt['2mT'].sel(time=time_slice).mean(dim='member').plot(label='Mean')
plt.fill_between(mean_hourly_temp['2mT'].time.values,
                 (mean_hourly_temp['2mT'] - std_hourly_temp['2mT']).values,
                 (mean_hourly_temp['2mT'] + std_hourly_temp['2mT']).values,
                 color='blue', alpha=0.2, label='Std Deviation')

plt.title(f'2m Temperature at Lat: {latitude_pt}, Lon: {longitude_pt}')
plt.xlabel('Time')
plt.ylabel('2m Temperature (K)')
plt.legend()
plt.grid(True)
plt.show()

The diurnal cycle makes it hard to visualize the standard deviation/spread of the temperature forecasts, so we will reduce the data by examining the daily maximum/minimum temperatures.

In [None]:
inference_conus_pt_dailymin = inference_conus_pt.resample(time='1D').min()
inference_conus_pt_dailymax = inference_conus_pt.resample(time='1D').max()


In [None]:
# @title Daily Maximum Temperature
# Calculate the mean, std, 5th, and 95th percentiles of daily minimum
mean_daily_min = inference_conus_pt_dailymax.mean(dim='member')
std_daily_min = inference_conus_pt_dailymax.std(dim='member')
percentile_5_daily_min = inference_conus_pt_dailymax.quantile(0.05, dim='member')
percentile_95_daily_min = inference_conus_pt_dailymax.quantile(0.95, dim='member')

# Plotting
plt.figure(figsize=(12, 6))
mean_daily_min['2mT'].plot(label='Mean Daily Maximum 2m Temperature', color='blue')

# Shade the standard deviation
plt.fill_between(mean_daily_min['2mT'].time.values,
                 (mean_daily_min['2mT'] - std_daily_min['2mT']).values,
                 (mean_daily_min['2mT'] + std_daily_min['2mT']).values,
                 color='blue', alpha=0.2, label='Std Deviation')

# Plot the 5th and 95th percentiles
percentile_5_daily_min['2mT'].plot(label='5th Percentile', color='red', linestyle='--')
percentile_95_daily_min['2mT'].plot(label='95th Percentile', color='green', linestyle='--')

plt.title(f'Daily Maximum 2m Temperature at Lat: {latitude_pt}, Lon: {longitude_pt}')
plt.xlabel('Time')
plt.ylabel('2m Temperature (K)')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# @title Daily Minimum Temperature Forecast
# Calculate the mean, std, 5th, and 95th percentiles
mean_daily_min = inference_conus_pt_dailymin.mean(dim='member')
std_daily_min = inference_conus_pt_dailymin.std(dim='member')
percentile_5_daily_min = inference_conus_pt_dailymin.quantile(0.05, dim='member')
percentile_95_daily_min = inference_conus_pt_dailymin.quantile(0.95, dim='member')

# Plotting
plt.figure(figsize=(12, 6))
mean_daily_min['2mT'].plot(label='Mean Daily Minimum 2m Temperature', color='blue')

# Shade the standard deviation
plt.fill_between(mean_daily_min['2mT'].time.values,
                 (mean_daily_min['2mT'] - std_daily_min['2mT']).values,
                 (mean_daily_min['2mT'] + std_daily_min['2mT']).values,
                 color='blue', alpha=0.2, label='Std Deviation')

# Plot the 5th and 95th percentiles
percentile_5_daily_min['2mT'].plot(label='5th Percentile', color='red', linestyle='--')
percentile_95_daily_min['2mT'].plot(label='95th Percentile', color='green', linestyle='--')

plt.title(f'Daily Minimum 2m Temperature at Lat: {latitude_pt}, Lon: {longitude_pt}')
plt.xlabel('Time')
plt.ylabel('2m Temperature (K)')
plt.legend()
plt.grid(True)
plt.show()

## Inference CONUS chunked by member
Now we will work with the inference conus dataset chunked by member to plot a map of the daily maximum temperature.

In [None]:
inference_conus_member = xr.open_zarr(
            "gs://genfocal/data/inference/conus/debiased_100members_jja10s_8samples_xm153985229_member_chunks.zarr",
            consolidated=True
)
inference_conus_member

In [None]:
# @title Ensemble mean of daily maximum 2 meter temperature
surface_variable_name = "2mT"
date = "2015-08-01" # @param {type:"date"}
time_slice=slice(f"{date} T00", f"{date} T23")

scalar_array_daily = inference_conus_member[surface_variable_name].sel(time=time_slice).squeeze()
conus_lat = inference_conus_member.latitude
conus_lon = inference_conus_member.longitude
title_string = f"Ensemble Mean Daily Maximum Temperature on {date}"
scalar_array_dailymx_mean = scalar_array_daily.max(dim='time').mean(dim='member').compute()
plot_scalar(scalar_array_dailymx_mean, conus_lat.min(), conus_lat.max(),
             conus_lon.min(), conus_lon.max(), title_string)


In [None]:
# @title Ensemble standard deviation of daily maximum 2 meter temperature
scalar_array_daily_std = scalar_array_daily.max(dim='time').std(dim='member').compute()
title_string = f"Ensemble Mean Spread Maximum Temperature on {date}"
plot_scalar(scalar_array_daily_std, conus_lat.min(), conus_lat.max(),
             conus_lon.min(), conus_lon.max(), title_string)
