Licensed under the Apache License, Version 2.0

In [None]:
# signs into Google Cloud so we can fetch the Google dataset
from google.colab import auth
auth.authenticate_user()
project_id = 'contrails-predictions-external'
!gcloud config set project {project_id}

In [None]:
# copies dataset to local storage
!gsutil cp gs://contrails_measurement_paper_data/contrail_bench_dataset.csv /content

In [None]:
import bisect
import os
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams.update({'font.size': 20})
import numpy as np
import pandas as pd


In [None]:
def get_metric(tdf, field, cutoff, roc=False):
  # If roc=True, return hit rate and false alarm rate
  # if False, return precision and hit rate

  actual = tdf.match

  predicted = tdf[field] > cutoff
  hr = np.sum(actual & predicted) / np.sum(actual)  # hit rate=true positives / all positives
  far = np.sum(~actual & predicted) / np.sum(~actual) # FAR=false positives / all negatives
  if roc:
    return hr, far
  else:
    npos = np.sum(actual)
    nneg = len(actual) - npos
    p = npos * hr / (npos * hr + far * nneg) # with a little math we can extract precision from FAR
    return p, hr

def get_curve(tdf, field, cutoffs, roc=False):
  # Run get_metric over all cutoffs and concatenate results
  out = np.zeros((len(cutoffs), 2))
  for i, cutoff in enumerate(cutoffs):
    out[i, :] = get_metric(tdf, field, cutoff, roc=roc)
  return out

def plot_df(ax, tdf, field, roc, cutoffs, **kwargs):
  # Make a Precision/Hit Rate or ROC plot
  # ax: matplotlib axis to plot on
  # tdf: dataframe to plot from. Should have a column called 'match', which
  # is a boolean that says whether this flight segment matched a contrail. And
  # a column called field, with a number which is higher for segments more
  # likely to make a contrail
  # field: see above
  # if true, make an ROC curve instead of a PR curve
  # cutoffs: plot a point on the PR curve for each of these values
  data_to_plot= get_curve(tdf, field, roc=roc,cutoffs=cutoffs)
  ax.plot(data_to_plot[:, 1], data_to_plot[:, 0], '.-', **kwargs)

def compute_metrics(df, key, cutoffs, label):
  # Compute the metrics that go in the ContrailBench table
  rcs = get_curve(df, key, cutoffs=cutoffs, roc=True)
  # Get hit rate at 20% precision
  i = bisect.bisect_left(rcs[::-1, 0], 0.2)
  print(f'Metrics for {label}')
  print('FAR@HR=20%', rcs[::-1, 1][i])
  prs = get_curve(df, key, cutoffs=cutoffs, roc=False)
  print('1/(PxHR)', 1/np.nanmax(prs[:, 0] * prs[:, 1]))

In [None]:
goes_df = pd.read_csv('/content/contrail_bench_dataset.csv')

In [None]:
goes_df

In [None]:
def make_outputs(df, keys, all_cutoffs):
  # Plot Precision/Hit Rate and ROC curves for all the predicition metrics in 'keys'
  # Also print the ContrailBench metrics
  plt.figure(figsize=(18, 6))
  ax1 = plt.subplot(121)
  ax2 = plt.subplot(122)

  for key, cutoffs, label in zip(keys, all_cutoffs, labels):
    plot_df(ax1, df, key, roc=False, cutoffs=cutoffs, label=label)
    plot_df(ax2, df, key, roc=True, cutoffs=cutoffs, label=label)
    compute_metrics(df, key, cutoffs, label)
  ax1.set_xlabel('Hit Rate')
  ax1.set_ylabel('Precision')
  ax2.set_xlim([0, 0.15])
  ax2.set_xlabel('False Alarm Rate')
  ax2.set_ylabel('Hit Rate')
  plt.legend()
  ax1.grid()
  ax2.grid()

In [None]:
# RH: relative humidity from ECMWF
# cocip_ef_lw: integral of long-wave radiative forcing over the first 2 hours
# of cocip predictions. This is most predictive of whether we will observe a
# contrail (total ef, optical depth, contrail age are not as predictive, they give worse metrics)
# ML_score: number output by ML model
keys = ['rh', 'cocip_ef_lw', 'ML_score']
cutoffs = [
    np.arange(30, 120, 1),
    np.concatenate([np.logspace(1, 8, 15), np.logspace(8, 10, 30)]),
    np.arange(0, 1, 0.01)
    ]
labels = ['Baseline', 'CoCiP', 'ML model']
make_outputs(goes_df, keys, cutoffs)