# ML accelerated CFD data analysis

This notebook reproduces key figures in our [PNAS paper](https://www.pnas.org/content/118/21/e2101784118) based on saved datasets. The data is stored in netCDF files in Google Cloud Storage, and the analysis uses xarray and JAX-CFD.

> Indented block



In [1]:
! pip install -U xarray jax-cfd[data]==0.1.0



## Figure 1

Replication of the Figure 1 from the PNAS paper, except with a [bootstrap](https://en.wikipedia.org/wiki/Bootstrapping_(statistics)) based estimation of uncertainty, given the sample size of 16 trajectories.

In [2]:
# @title Utility functions
import xarray
import seaborn
import numpy as np
import pandas as pd
import jax_cfd.data.xarray_utils as xru
from jax_cfd.data import evaluation
import matplotlib.pyplot as plt


def correlation(x, y):
  state_dims = ['x', 'y']
  p  = xru.normalize(x, state_dims) * xru.normalize(y, state_dims)
  return p.sum(state_dims)

def calculate_time_until(vorticity_corr):
  threshold = 0.95
  return (vorticity_corr.mean('sample') >= threshold).idxmin('time').rename('time_until')

def calculate_time_until_bootstrap(vorticity_corr, bootstrap_samples=10000):
  rs = np.random.RandomState(0)
  indices = rs.choice(16, size=(10000, 16), replace=True)
  boot_vorticity_corr = vorticity_corr.isel(
      sample=(('boot', 'sample2'), indices)).rename({'sample2': 'sample'})
  return calculate_time_until(boot_vorticity_corr)

def calculate_upscaling(time_until):
  slope = ((np.log(16) - np.log(8))
          / (time_until.sel(model='baseline_1024')
              - time_until.sel(model='baseline_512')))
  x = time_until.sel(model='learned_interp_64')
  x0 = time_until.sel(model='baseline_512')
  intercept = np.log(8)
  factor = np.exp(slope * (x - x0) + intercept)
  return factor

def calculate_speedup(time_until):
  runtime_baseline_8x = 44.053293
  runtime_baseline_16x = 412.725656
  runtime_learned = 1.155115
  slope = ((np.log(runtime_baseline_16x) - np.log(runtime_baseline_8x))
          / (time_until.sel(model='baseline_1024')
              - time_until.sel(model='baseline_512')))
  x = time_until.sel(model='learned_interp_64')
  x0 = time_until.sel(model='baseline_512')
  intercept = np.log(runtime_baseline_8x)
  speedups = np.exp(slope * (x - x0) + intercept) / runtime_learned
  return speedups

AttributeError: module 'jax.numpy' has no attribute 'DeviceArray'

### Load data

In [None]:
%time ! gsutil -m cp -r gs://gresearch/jax-cfd/public_eval_datasets/kolmogorov_re_1000_fig1 /content

In [None]:
! ls /content/kolmogorov_re_1000_fig1

In [None]:
baseline_filenames = {
    f'baseline_{r}': f'baseline_{r}x{r}.nc'
    for r in [64, 128, 256, 512, 1024, 2048]
}
learned_filenames = {
    f'learned_interp_{r}': f'learned_{r}x{r}.nc'
    for r in [32, 64, 128]
}

models = {}
for k, v in baseline_filenames.items():
  models[k] = xarray.open_dataset(f'/content/kolmogorov_re_1000_fig1/{v}', chunks={'time': '100MB'})
for k, v in learned_filenames.items():
  ds = xarray.open_dataset(f'/content/kolmogorov_re_1000_fig1/{v}', chunks={'time': '100MB'})
  models[k] = ds.reindex_like(models['baseline_64'], method='nearest')

combined_fig1 = xarray.concat(list(models.values()), dim='model')
combined_fig1.coords['model'] = list(models.keys())
combined_fig1['vorticity'] = xru.vorticity_2d(combined_fig1)

In [None]:
# Notice that the data in Figure 1 was resampled to 32x32 for validation, the
# coarsest resolution of any of the constitutive models.
combined_fig1

In [None]:
df_raw = pd.read_csv('/content/kolmogorov_re_1000_fig1/tpu-speed-measurements.csv').reset_index(drop=True)

In [None]:
# raw timing data
df_raw

### Calculate headline result

Note: the 85.46 vs "86x" number reported in the paper is because we didn't use `.thin(time=2)` for the paper (but public Colab doesn't have enough memory to avoid thinning).

In [None]:
%%time
v = combined_fig1.vorticity.thin(time=2).sel(time=slice(10))
vorticity_correlation = correlation(v, v.sel(model='baseline_2048')).compute()

times = calculate_time_until(vorticity_correlation)
times_boot = calculate_time_until_bootstrap(vorticity_correlation)

speedup = calculate_speedup(times)
print('speedup estimate:', float(speedup))

speedups = calculate_speedup(times_boot)
print('speedup bootstrap mean:', float(speedups.mean('boot')))
print('speedup bootstrap stddev:', float(speedups.std('boot')))
print('speedup bootstrap median:', float(speedups.median('boot')))
print('speedup bootstrap range:', speedups.quantile(dim='boot', q=[0.05, 0.95]).values.tolist())

upscaling = calculate_upscaling(times)
print('upscaling estimate:', float(upscaling))

upscalings = calculate_upscaling(times_boot)
print('upscaling bootstrap mean:', float(upscalings.mean('boot')))
print('upscaling bootstrap stddev:', float(upscalings.std('boot')))
print('upscaling bootstrap median:', float(upscalings.median('boot')))
print('upscaling bootstrap range:', upscalings.quantile(dim='boot', q=[0.05, 0.95]).values.tolist())


The reported 86x speed-up thus should be associated with a 95% bootstrap CI of [64, 140]:

In [None]:
speedups.rename('speedup factor').plot.hist(bins=50);

### Pareto frontier plots

In [None]:
#@title Prepare dataframe
df = (
    df_raw
    .drop(['model', 'resolution', 'msec_per_sim_step'], axis=1)
    .set_index('model_name')
    .join(
        times.rename({'model': 'model_name'}).to_dataframe()
    )
    .join(
        times_boot
        .quantile(q=0.975, dim='boot')
        .drop('quantile')
        .rename('time_until_upper')
        .rename({'model': 'model_name'})
        .to_dataframe()
    )
    .join(
        times_boot
        .quantile(q=0.025, dim='boot')
        .drop('quantile')
        .rename('time_until_lower')
        .rename({'model': 'model_name'})
        .to_dataframe()
    )
    .reset_index()
)
df[['model', 'resolution']] = df.model_name.str.rsplit('_', 1, expand=True)
df['resolution'] = df['resolution'].astype(int)
# switch units from "msec per time step at 64x64" to
# "sec per simulation time step"
df['sec_per_sim_time'] = df['msec_per_dt'] / 0.007012 * 1e-3
df = df.sort_values(['resolution', 'model'])

In [None]:
#@title Pareto frontier with uncertainty

plt.figure(figsize=(5, 5))

df_baseline = df.query('model=="baseline"')
plt.errorbar(df_baseline.time_until,
             df_baseline.sec_per_sim_time,
             xerr=(df_baseline.time_until - df_baseline.time_until_lower,
                   df_baseline.time_until_upper - df_baseline.time_until),
             marker='s',
             label='baseline')

df_baseline = df.query('model=="learned_interp"')
plt.errorbar(df_baseline.time_until,
             df_baseline.sec_per_sim_time,
             xerr=(df_baseline.time_until - df_baseline.time_until_lower,
                   df_baseline.time_until_upper - df_baseline.time_until),
             marker='s',
             label='learned')

plt.xlim(0, 8.1)
plt.ylim(1.5e-2, 1e3)
plt.xlabel('Time until correlation < 0.95')
plt.ylabel('Runtime per time unit (s)')
plt.yscale('log')
plt.legend()
seaborn.despine()

In [None]:
#@title Pareto frontier (transposed) with uncertainty
plt.figure(figsize=(5, 5))

df_baseline = df.query('model=="baseline"')
plt.errorbar(df_baseline.sec_per_sim_time,
             df_baseline.time_until,
             yerr=(df_baseline.time_until - df_baseline.time_until_lower,
                   df_baseline.time_until_upper - df_baseline.time_until),
             marker='s',
             label='baseline')

df_baseline = df.query('model=="learned_interp"')
plt.errorbar(df_baseline.sec_per_sim_time,
             df_baseline.time_until,
             yerr=(df_baseline.time_until - df_baseline.time_until_lower,
                   df_baseline.time_until_upper - df_baseline.time_until),
             marker='s',
             label='learned')

plt.ylim(0, 8.1)
plt.xlim(1.5e-2, 1e3)
plt.ylabel('Time until correlation < 0.95')
plt.xlabel('Runtime per time unit (s)')
plt.xscale('log')
plt.legend()
seaborn.despine()

## Figure 2

Here we reproduce the key parts of Figure 2, from scratch.

Note that Figure 2 (and Figure 5) inadvertently used a different evaluation dataset than Figure 1 (different random initial velocity fields), saved at 64x64 resolution.

### Copy data

In [None]:
%time ! gsutil -m cp gs://gresearch/jax-cfd/public_eval_datasets/kolmogorov_re_1000/learned*.nc /content

In [None]:
%time ! gsutil -m cp gs://gresearch/jax-cfd/public_eval_datasets/kolmogorov_re_1000/long_eval*.nc /content

### Load data at 64x64 and 32x32 resolutions

In [None]:
import xarray
import seaborn
import jax_cfd.data.xarray_utils as xru
from jax_cfd.data import evaluation
import matplotlib.pyplot as plt

baseline_palette = seaborn.color_palette('YlGnBu', n_colors=7)[1:]
models_color = seaborn.xkcd_palette(['burnt orange'])
palette = baseline_palette + models_color


In [None]:
filenames = {
    f'baseline_{r}': f'long_eval_{r}x{r}_64x64.nc'
    for r in [64, 128, 256, 512, 1024, 2048]
}
filenames['learned_interp_64'] = 'learned_interpolation_long_eval_64x64_64x64.nc'

models = {}
for k, v in filenames.items():
  models[k] = xarray.open_dataset(f'/content/{v}', chunks={'time': '100MB'})

combined = xarray.concat(list(models.values()), dim='model')
combined.coords['model'] = list(models.keys())
combined['vorticity'] = xru.vorticity_2d(combined)

In [None]:
from jax_cfd.base import resize
import numpy as np
import pandas as pd

def resize_64_to_32(ds):
  coarse = xarray.Dataset({
      'u': ds.u.isel(x=slice(1, None, 2)).coarsen(y=2, coord_func='max').mean(),
      'v': ds.v.isel(y=slice(1, None, 2)).coarsen(x=2, coord_func='max').mean(),
  })
  coarse.attrs = ds.attrs
  return coarse

combined_32 = resize_64_to_32(combined)
combined_32['vorticity'] = xru.vorticity_2d(combined_32)

models_32 = {k: resize_64_to_32(v) for k, v in models.items()}

### Plot solutions: Fig 2(a)

In [None]:
combined_32

In [None]:
combined.vorticity.isel(sample=0).thin(time=50).head(time=5).plot.imshow(
    row='model', col='time', x='x', y='y', robust=True, size=2.3, aspect=0.9,
    add_colorbar=False, cmap=seaborn.cm.icefire, vmin=-10, vmax=10)

### Calculate speed-up

We can calculate speed-up based upon vorticity correlation either at a resolution of 32x32 or 64x64. The numbers are similar.

Note that our ML models are slightly more effective than the FVM method on these dataset (138x vs 86x speedup), but still within the range of uncertainty.

In [None]:
%%time
v = combined_32.vorticity.sel(time=slice(20))
vorticity_correlation = correlation(v, v.sel(model='baseline_2048')).compute()

times = calculate_time_until(vorticity_correlation)
times_boot = calculate_time_until_bootstrap(vorticity_correlation)

speedup = calculate_speedup(times)
print('speedup estimate:', float(speedup))

speedups = calculate_speedup(times_boot)
print('speedup bootstrap mean:', float(speedups.mean('boot')))
print('speedup bootstrap stddev:', float(speedups.std('boot')))
print('speedup bootstrap median:', float(speedups.median('boot')))
print('speedup bootstrap range:', speedups.quantile(dim='boot', q=[0.05, 0.95]).values.tolist())

upscaling = calculate_upscaling(times)
print('upscaling estimate:', float(upscaling))

upscalings = calculate_upscaling(times_boot)
print('upscaling bootstrap mean:', float(upscalings.mean('boot')))
print('upscaling bootstrap stddev:', float(upscalings.std('boot')))
print('upscaling bootstrap median:', float(upscalings.median('boot')))
print('upscaling bootstrap range:', upscalings.quantile(dim='boot', q=[0.05, 0.95]).values.tolist())


In [None]:
%%time
v = combined.vorticity.sel(time=slice(20))
vorticity_correlation = correlation(v, v.sel(model='baseline_2048')).compute()

times = calculate_time_until(vorticity_correlation)
times_boot = calculate_time_until_bootstrap(vorticity_correlation)

speedup = calculate_speedup(times)
print('speedup estimate:', float(speedup))

speedups = calculate_speedup(times_boot)
print('speedup bootstrap mean:', float(speedups.mean('boot')))
print('speedup bootstrap stddev:', float(speedups.std('boot')))
print('speedup bootstrap median:', float(speedups.median('boot')))
print('speedup bootstrap range:', speedups.quantile(dim='boot', q=[0.05, 0.95]).values.tolist())

upscaling = calculate_upscaling(times)
print('upscaling estimate:', float(upscaling))

upscalings = calculate_upscaling(times_boot)
print('upscaling bootstrap mean:', float(upscalings.mean('boot')))
print('upscaling bootstrap stddev:', float(upscalings.std('boot')))
print('upscaling bootstrap median:', float(upscalings.median('boot')))
print('upscaling bootstrap range:', upscalings.quantile(dim='boot', q=[0.05, 0.95]).values.tolist())


In [None]:
speedups.rename('speedup factor').plot.hist(bins=50);

### Calculate metrics

In [None]:
%%time
summary = xarray.concat([
    evaluation.compute_summary_dataset(ds, models['baseline_2048'])
    for ds in models.values()
], dim='model')
summary.coords['model'] = list(models.keys())

In [None]:
summary

In [None]:
%time correlation = summary.vorticity_correlation.sel(time=slice(20)).compute()

In [None]:
%time spectrum = summary.energy_spectrum_mean.tail(time=2000).mean('time').compute()

### Plot correlation over time: Fig 2(b)

In [None]:
plt.figure(figsize=(7, 6))
for color, model in zip(palette, summary['model'].data):
  style = '-' if 'baseline' in model else '--'
  correlation.sel(model=model).plot.line(
      color=color, linestyle=style, label=model, linewidth=3);
plt.axhline(y=0.95, xmin=0, xmax=20, color='gray')
plt.legend();
plt.title('')
plt.xlim(0, 15)

### Plot spectrum: Fig 2(c)

In [None]:
plt.figure(figsize=(10, 6))
for color, model in zip(palette, summary['model'].data):
  style = '-' if 'baseline' in model else '--'
  (spectrum.k ** 5 * spectrum).sel(model=model).plot.line(
      color=color, linestyle=style, label=model, linewidth=3);
plt.legend();
plt.yscale('log')
plt.xscale('log')
plt.title('')
plt.xlim(3.5, None)
plt.ylim(1e9, None)