Run the cell below to create the UI to browse firebench datasets. Creating the UI will take about a minute. Then you can interact with the UI without rerunning the cell.

In [None]:
# @title UI
try:
  import zarr
except:
  !pip3 install zarr

from IPython import display
import functools
import dask
import gcsfs
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import axes_divider
import numpy as np
import re
import types
import xarray
import ipywidgets as widgets


def add_colorbar(ax, im, **kwargs):
  divider = axes_divider.make_axes_locatable(ax)
  cax = divider.append_axes('right', size='2%', pad=0.1,
                            axes_class=plt.Axes)
  return ax.figure.colorbar(im, cax=cax, **kwargs)


def find_fire_zarrs(fs, path, max_depth=2):
  if path.endswith('fire.zarr'):
    return [path]
  if fs.info(path)['type'] == 'directory':
    if max_depth >= 0:
      out = []
      for sub_path in fs.ls(path):
        out.extend(find_fire_zarrs(fs, sub_path, max_depth - 1))
      return out
  return []


def extract_wind_and_slope(path):
  m = re.search(r'u([\d.]+)/ramp([\d.]+)/', path)
  if m:
    return m.group(1), m.group(2)


def build_slope_by_wind_map(paths):
  out = {}
  for wind, slope in [extract_wind_and_slope(p) for p in paths]:
    out.setdefault(wind, []).append(slope)
  return out


def create_widgets_for_dim(dim):
  slice_pos_widget = widgets.FloatSlider(
    value=.5,
    min=0,
    max=1,
    step=0.01,
    description=f'{dim} plane',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.01f',
  )
  zoom_range_widget = widgets.FloatRangeSlider(
    value=[0, 1],
    min=0,
    max=1,
    step=0.01,
    description=f'{dim} zoom',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.01f',
  )
  box_widget = widgets.HBox([slice_pos_widget, zoom_range_widget])
  return slice_pos_widget, zoom_range_widget, box_widget


def create_spatial_widgets(ui, sample_ds):
  ui.slice_dim_widget = widgets.Dropdown(
    options=['x', 'y', 'z'],
    value='z',
    description='Slice through:',
    disabled=False,
  )
  ui.xp, ui.xr, xb = create_widgets_for_dim('x')
  ui.yp, ui.yr, yb = create_widgets_for_dim('y')
  ui.zp, ui.zr, zb = create_widgets_for_dim('z')
  ui.spatial = widgets.VBox([ui.slice_dim_widget, xb, yb, zb])

  def set_visibility(w, is_hidden):
    w.layout.visibility = 'hidden' if is_hidden else 'visible'

  def update_xyz_controls(_):
    da = sample_ds[ui.var_widget.value]
    for dim in ['x', 'y', 'z']:
      if dim not in da.dims:
        ui.slice_dim_widget.value = dim
    set_visibility(ui.xr, ui.slice_dim_widget.value == 'x')
    set_visibility(ui.yr, ui.slice_dim_widget.value == 'y')
    set_visibility(ui.zr, ui.slice_dim_widget.value == 'z')
    set_visibility(ui.xp, ui.slice_dim_widget.value != 'x')
    set_visibility(ui.yp, ui.slice_dim_widget.value != 'y')
    set_visibility(ui.zp, ui.slice_dim_widget.value != 'z')

  update_xyz_controls(None)
  ui.slice_dim_widget.observe(update_xyz_controls)


def create_widgets(ui):
  fs = gcsfs.GCSFileSystem(token="anon")
  paths = find_fire_zarrs(fs, 'gs://firebench/v2024.04')
  sample_ds = xarray.open_zarr(store=fs.get_mapper(paths[0]))
  all_vars = [v for v in sample_ds.data_vars
              if set(sample_ds[v].dims) == {'x', 'y', 'z', 't'}]
  top_vars = ['u', 'v', 'w', 'theta', 'T_s', 'rho_f']
  ui_vars = top_vars + sorted(set(all_vars) - set(top_vars),
                              key=lambda x: x.lower())
  slope_by_wind = build_slope_by_wind_map(paths)
  all_slopes = set.union(*[set(v) for v in slope_by_wind.values()])
  ui.wind_speed_widget = widgets.Dropdown(
      options=sorted(slope_by_wind.keys(), key=float),
      value=None,
      description='Wind speed:',
      disabled=False,
  )
  ui.slope_widget = widgets.Dropdown(
      options=sorted(all_slopes, key=float),
      value=None,
      description='Slope:',
      disabled=False,
  )
  ui.var_widget = widgets.Dropdown(
      options=ui_vars,
      value=ui_vars[0],
      description='Variables:',
      disabled=False,
  )
  ui.time_widget = widgets.FloatSlider(
      value=1,
      min=0,
      max=1,
      step=0.01,
      description='Time fraction:',
      disabled=False,
      continuous_update=False,
      orientation='horizontal',
      readout=True,
      readout_format='.01f',
  )
  ui.plot_button = widgets.Button(
      description='Plot',
      disabled=False,
      button_style='', # 'success', 'info', 'warning', 'danger' or ''
  )

  def limit_wind_slope_to_available(_):
    first_wind = next(iter(slope_by_wind))
    if ui.wind_speed_widget.value not in slope_by_wind:
      ui.wind_speed_widget.value = first_wind
    if ui.slope_widget.value not in slope_by_wind[ui.wind_speed_widget.value]:
      ui.slope_widget.value = slope_by_wind[ui.wind_speed_widget.value][0]

  limit_wind_slope_to_available(None)

  create_spatial_widgets(ui, sample_ds)
  ui.main = widgets.VBox(
      [widgets.HBox([ui.wind_speed_widget, ui.slope_widget]),
       widgets.HBox([ui.var_widget, ui.time_widget]),
       ui.spatial, ui.plot_button])


def get_slicer_for_dim(da, pos, value_range, dim, slice_dim):
  if dim not in da.dims:
    return {}
  if dim == slice_dim:
    return {dim: int(len(da[dim] -1) * pos)}
  else:
    return {dim: slice(int(len(da[dim] - 1) * value_range[0]),
                       int(len(da[dim]) * value_range[1]))}


def get_slicer(da, ui):
  out = {}
  out.update(get_slicer_for_dim(da, ui.xp.value,
                                ui.xr.value, 'x',
                                ui.slice_dim_widget.value))
  out.update(get_slicer_for_dim(da, ui.yp.value,
                                ui.yr.value, 'y',
                                ui.slice_dim_widget.value))
  out.update(get_slicer_for_dim(da, ui.zp.value,
                                ui.zr.value, 'z',
                                ui.slice_dim_widget.value))
  return out


def get_slice(ui):
  fs = gcsfs.GCSFileSystem(token="anon")
  store = fs.get_mapper(
      f'gs://firebench/v2024.04/u{ui.wind_speed_widget.value}/'
      f'ramp{ui.slope_widget.value}/fire.zarr'
  )
  da = xarray.open_zarr(store=store)[ui.var_widget.value]
  da = da.isel(t=int((len(da.t) - 1) * ui.time_widget.value))
  slicer = get_slicer(da, ui)
  return da.isel(slicer)


def time_to_str(t):
  x = (t // np.timedelta64(1, 's')).item()
  return f'{x // 60}m {x % 60}s'


def plot_slice(ui, da, placeholder):
  fig, ax = plt.subplots(1, 1, figsize=(7, 5))
  if ui.slice_dim_widget.value == 'y':
    xdim, ydim = 'x', 'z'
  if ui.slice_dim_widget.value == 'z':
    xdim, ydim = 'x', 'y'
  if ui.slice_dim_widget.value == 'x':
    xdim, ydim = 'y', 'z'
  if placeholder is None:
    pcm = ax.pcolormesh(da[xdim], da[ydim], da.transpose(ydim, xdim),
                        cmap='PuOr', vmin=-20, vmax=20)
    add_colorbar(ax, pcm)
  else:
    ax.text(0.05, 0.95, placeholder, transform=ax.transAxes, fontsize=14,
            verticalalignment='top')
  slice_coord = (f'{ui.slice_dim_widget.value} = '
                 f'{da[ui.slice_dim_widget.value].item()} m')
  ax.set_title(f'{da.name}, t = {time_to_str(da.t)}, {slice_coord}')
  ax.set_xlabel(f'{xdim} (m)')
  ax.set_ylabel(f'{ydim} (m)')
  plt.show()


def run_ui(ui):
  output = widgets.Output()
  output.layout.height = '500px'
  def on_plot_button_clicked(ui, button):
    with output:
      display.clear_output()
      da = get_slice(ui)
      plot_slice(ui, da,
                 'Loading - vertical slices might take up to a minute...')
      da = da.compute()
      display.clear_output()
      plot_slice(ui, da, None)

  display.clear_output()
  ui.plot_button.on_click(functools.partial(on_plot_button_clicked, ui))
  display.display(ui.main, output)

ui_objects = types.SimpleNamespace()
create_widgets(ui_objects)
run_ui(ui_objects)