In [1]:
import numpy as np
import pandas as pd
import attrs
import copy
from typing import Optional
import pandas_gbq

## Set up data

You will need a GCP project id. To get one,

1.   Use the [Cloud Resource Manager](https://console.cloud.google.com/cloud-resource-manager) to Create a Cloud Platform project if you do not already have one.
2.   [Enable billing](https://support.google.com/cloud/answer/6293499#enable-billing) for the project.
3.   [Enable BigQuery](https://console.cloud.google.com/flows/enableapi?apiid=bigquery) APIs for the project.


In [None]:
np.random.seed(42)
n = 100
df = pd.DataFrame({
    'lost': np.random.choice([0, 1.], n),
    'region': np.random.choice(('US', 'non-US'), n),
    'experiment': np.random.choice(('control', 'experiment1', 'experiment2', 'experiment3'), n),
})
project_id='meterstick-personal'
pandas_gbq.to_gbq(df, 'demo.data', project_id=project_id, if_exists='replace')

100%|██████████| 1/1 [00:00<00:00, 9822.73it/s]


## Metric implementations

In [None]:
class Metric:

  def compute_on(self, data, split_by=None):
    if split_by:
      res = self.compute(data.groupby(split_by))
    else:
      res = [self.compute(data)]
    res = pd.DataFrame(res)
    res.columns = self.names
    return res

  def set_names(self, names):
    self._names = names
    return self

  @property
  def names(self):
    return getattr(self, '_names', self.default_names)

  def sql_aggregate(self, data, dimensions):
    # Helper function for constructing aggregation queries.
    dim_sql = ','.join(dimensions) + ',' if dimensions else ''
    groupby = f'GROUP BY ' + ','.join(dimensions) if dim_sql else ''
    val_cols = ','.join([f'{s} AS {n}' for s, n in zip(self.sql, self.names)])
    return f'SELECT {dim_sql} {val_cols} FROM {data} {groupby}'

  def to_sql(self, data, split_by=None):
    return self.sql_aggregate(data, split_by)

  def compute_on_sql(self, data, split_by=None):
    res = pandas_gbq.read_gbq(self.to_sql(data, split_by), project_id=project_id)
    dims = split_by + self.extra_dims
    return res.set_index(dims).sort_index() if dims else res

  def __truediv__(self, other):
    return Div(self, other)

  def __or__(self, fn):
    """Overwrites the '|' operator to enable pipeline chaining."""
    return fn(self)


class Operation(Metric):

  def compute_on(self, data, split_by=None):
    data_preprocessed = self.preprocess(data, split_by)
    child_res = self.compute_children(data_preprocessed, split_by)
    return self.process_results(child_res, split_by)

  def __call__(self, child: Metric):
    op = copy.deepcopy(self) if self.child else self
    op.child = child
    return op

  def sql_select(self, data, dimensions):
    # Helper function for constructing select queries.
    dim_sql = ','.join(dimensions) + ',' if dimensions else ''
    val_cols = ','.join([f'{s} AS {n}' for s, n in zip(self.sql, self.names)])
    return f'SELECT {dim_sql} {val_cols} FROM {data}'

  def to_sql(self, data, split_by=None):
    data_preprocessed = self.preprocess_sql(data, split_by)
    children_query = self.children_to_sql(data_preprocessed, split_by)
    return self.assemble_query(children_query, split_by)

  @property
  def extra_dims(self):
    return []


@attrs.define
class Sum(Metric):
  var: str

  def compute(self, data):
    return data[self.var].sum()

  @property
  def default_names(self):
    return [f'sum_{self.var}']

  @property
  def sql(self):
    return [f'SUM({self.var})']


@attrs.define
class Count(Metric):
  var: str

  def compute(self, data):
    return data[self.var].count()

  @property
  def default_names(self):
    return [f'count_{self.var}']

  @property
  def sql(self):
    return [f'COUNT({self.var})']


@attrs.define
class Div(Operation):
  child1: Metric
  child2: Metric

  def preprocess(self, data, split_by):
    return data

  def compute_children(self, data, split_by):
    return (self.child1.compute_on(data, split_by), self.child2.compute_on(data, split_by))

  def process_results(self, child_res, split_by):
    num, denom = child_res
    num.columns = self.names
    denom.columns = self.names
    return num / denom

  @property
  def default_names(self):
    return map('_div_'.join, zip(self.child1.names, self.child2.names))

  def preprocess_sql(self, data, split_by):
    return data

  def children_to_sql(self, data, split_by):
    return data

  def assemble_query(self, child_res, split_by):
    return self.sql_aggregate(child_res, split_by)

  @property
  def sql(self):
    return map(' / '.join, zip(self.child1.sql, self.child2.sql))

  @property
  def extra_dims(self):
    return []


@attrs.define
class PercentChange(Operation):
  condition: str
  baseline: str
  child: Optional[Metric] = None

  def preprocess(self, data, split_by):
    return data

  def compute_children(self, data, split_by):
    return self.child.compute_on(data, split_by + [self.condition])

  def process_results(self, child_res, split_by):
    if split_by:
      base = child_res.xs(self.baseline, level=self.condition)
    else:
      base = child_res.loc[self.baseline]
    res = child_res / base - 1
    res.columns = self.names
    return res * 100

  @property
  def default_names(self):
    return [f'pct_change_of_{n}' for n in self.child.names]

  def preprocess_sql(self, data, split_by):
    return data

  def children_to_sql(self, data, split_by):
    return self.child.to_sql(data, split_by + [self.condition])

  def assemble_query(self, child_res, split_by):
    dims = self.extra_dims + split_by
    u = ','.join(dims[1:])
    join = f'T JOIN Base USING ({u})'
    if not u:
      join = 'T CROSS JOIN Base'
    return f"""
    WITH T AS ({child_res}),
    Base AS (SELECT *
    EXCEPT ({self.condition}) FROM T
    WHERE {self.condition}
      = '{self.baseline}')
    {self.sql_select(join, dims)}"""

  @property
  def sql(self):
    return [f'(T.{c} / Base.{c} - 1) * 100' for c in self.child.names]

  @property
  def extra_dims(self):
    return [self.condition] + self.child.extra_dims


@attrs.define
class Bootstrap(Operation):
  n_rep: int = attrs.field(default=50)
  child: Optional[Metric] = None

  def preprocess(self, data, split_by):
    for i in range(self.n_rep):
      yield data.sample(frac=1, replace=True)

  def compute_children(self, data, split_by):
    sample_res = [self.child.compute_on(sample, split_by) for sample in data]
    return pd.concat(sample_res, axis=1)

  def process_results(self, child_res, split_by):
    std = child_res.T.groupby(level=0).std().T
    std.columns = self.names
    return std

  @property
  def default_names(self):
    return [f'se_{n}' for n in self.child.names]

  def preprocess_sql(self, data, split_by):
    return resample_n_times(data, split_by, self.n_rep)

  def children_to_sql(self, data, split_by):
    return (*data, self.child.to_sql('Samples', split_by + ['sample_idx']))

  def assemble_query(self, child_res, split_by):
    (input_data, samples, sample_res) = child_res
    sql = self.sql_aggregate('SampleRes', split_by + self.extra_dims)
    return f"""
      CREATE TEMP TABLE Data
        AS ({input_data});
      WITH Samples AS ({samples}),
      SampleRes AS ({sample_res})
      {sql}"""

  @property
  def sql(self):
    return [f'STDDEV({n})' for n in self.child.names]

  @property
  def extra_dims(self):
    return self.child.extra_dims


def resample_n_times(data, split_by, n_rep):
  by_sql = ','.join(split_by) + ',' if split_by else ''
  input_data = f"""
    SELECT
      *,
      ROW_NUMBER() OVER (PARTITION BY sample_idx) AS row_number,
      CEIL(RAND() * COUNT(*) OVER (PARTITION BY sample_idx))
        AS random_row_number,
    FROM {data},
    UNNEST(GENERATE_ARRAY(1, {n_rep})) AS sample_idx"""
  samples = f"""
    SELECT b.*
    FROM (
      SELECT
        {by_sql}
        sample_idx,
        random_row_number AS row_number
      FROM Data) AS a
    JOIN Data AS b
    USING ({by_sql} sample_idx, row_number)"""
  return (input_data, samples)

## Output

In [None]:
split_bys = [[], ]
churn = (Sum("lost") / Count("lost")).set_names(["churn"])
pct = churn | PercentChange("experiment", "control")
bst = pct | Bootstrap()

for split_by in split_bys:
  print(f'split_by is {split_by}\n\n')
  print(f'Churn rate is {churn.compute_on(df, split_by)}\n')
  print(f'Churn rate in SQL is {churn.compute_on_sql("demo.data", split_by)}\n')

  print(f'Percent change is {pct.compute_on(df, split_by)}\n')
  print(f'Percent change in SQL is {pct.compute_on_sql("demo.data", split_by)}\n')

  print(f'Bootstrap is {bst.compute_on(df, split_by)}\n')
  print(f'Bootstrap in SQL is {bst.compute_on_sql("demo.data", split_by)}\n')

splitby is []


Churn rate is    churn
0   0.56

Downloading: 100%|[32m██████████[0m|
Churn rate in SQL is    churn
0   0.56

Percent change is              pct_change_of_churn
experiment                      
control                 0.000000
experiment1            -2.673797
experiment2           -21.227621
experiment3            -8.496732

Downloading: 100%|[32m██████████[0m|
Percent change in SQL is              pct_change_of_churn
experiment                      
control                 0.000000
experiment1            -2.673797
experiment2           -21.227621
experiment3            -8.496732

Bootstrap is              se_pct_change_of_churn
experiment                         
control                    0.000000
experiment1               26.014820
experiment2               19.063005
experiment3               18.845573

Downloading: 100%|[32m██████████[0m|
Bootstrap in SQL is              se_pct_change_of_churn
experiment                         
control                    0.00