<a href="https://colab.research.google.com/github/falco1978/trimmed_match/blob/master/trimmed_match/notebook/post_analysis_colab_for_trimmed_match.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

#@markdown * Connect to the hosted runtime and run each cell after updating the necessary inputs
#@markdown * Download the file "example_data_for_post_analysis.csv" from the folder "example_datasets" in github.
#@markdown * Upload the csv file to your Google Drive and open it with Google Sheets
#@markdown * In the cell below, copy and paste the url of the sheet.

In [None]:
#@markdown ### Load the required packages, e.g. trimmed_match.

BAZEL_VERSION = '6.1.2'
!wget https://github.com/bazelbuild/bazel/releases/download/{BAZEL_VERSION}/bazel-{BAZEL_VERSION}-installer-linux-x86_64.sh
!chmod +x bazel-{BAZEL_VERSION}-installer-linux-x86_64.sh
!./bazel-{BAZEL_VERSION}-installer-linux-x86_64.sh
!sudo apt-get install python3-dev python3-setuptools git
!git clone https://github.com/google/trimmed_match
!python3 -m pip install ./trimmed_match

"""Loading the necessary python modules."""
import matplotlib.pyplot as plt
import pandas as pd
import re
import seaborn as sns

from IPython.display import display
from IPython.core.interactiveshell import InteractiveShell
from pandas.plotting import register_matplotlib_converters

import gspread
import warnings
from google import auth as google_auth
from google.colab import auth
from google.colab import data_table
from google.colab import drive
from trimmed_match.design.common_classes import GeoAssignment
from trimmed_match.design import plot_utilities
from trimmed_match.design import util
from trimmed_match.post_analysis import trimmed_match_post_analysis

warnings.filterwarnings('ignore')
register_matplotlib_converters()
InteractiveShell.ast_node_interactivity = "all"

In [None]:
#@markdown ### Enter the trix id for the sheet file containing the Data:
#@markdown The spreadsheet should contain the mandatory columns:
#@markdown * date: date in the format YYYY-MM-DD
#@markdown * geo: the number which identifies the geo
#@markdown * pair: the number which identifies the geo pair
#@markdown * assignment: geo assignment (1=Treatment, 2=Control)
#@markdown * response: variable on which you want to measure incrementality
#@markdown (e.g. sales, transactions)
#@markdown * cost: variable on ad spend

#@markdown ---

## load the trix in input
#@markdown Spreadsheet URL


experiment_table = "add your url here, which should look like https://docs.google.com/spreadsheets/d/???/edit#gid=???" #@param {type:"string"}
auth.authenticate_user()
creds, _ = google_auth.default()
gc = gspread.authorize(creds)
wks = gc.open_by_url(experiment_table).sheet1
data = wks.get_all_values()
headers = data.pop(0)
data = pd.DataFrame(data, columns=headers)

data["date"] = pd.to_datetime(data["date"])
for colname in ["geo", "pair", "assignment", "response", "cost"]:
  data[colname] = pd.to_numeric(data[colname])

In [None]:
#@title Summary of the data for the design, test, and test+cooldown period

test_start_date = "2020-11-04" #@param {type:"date"}
test_end_date = "2020-12-01" #@param {type:"date"}
cooldown_end_date = "2020-12-16" #@param {type:"date"}
design_eval_start_date = "2020-09-03" #@param {type:"date"}
design_eval_end_date = "2020-10-01" #@param {type:"date"}

#@markdown Use an average order value of 1 if the experiment is based on sales/revenue or an actual average order value (e.g. 80$) for an experiment based on transactions/footfall/contracts.
average_order_value =  1#@param{type: "number"}

test_start_date = pd.to_datetime(test_start_date)
test_end_date = pd.to_datetime(test_end_date)
cooldown_end_date = pd.to_datetime(cooldown_end_date)
design_eval_start_date = pd.to_datetime(design_eval_start_date)
design_eval_end_date = pd.to_datetime(design_eval_end_date)

#@markdown (OPTIONAL) List the pairs of geos you need to exclude separated by a comma e.g. 1,2. Leave empty to select all pairs.
pairs_exclude = "" #@param {type: "string"}
pairs_exclude = [] if pairs_exclude == "" else [
    int(re.sub(r"\W+", "", x)) for x in pairs_exclude.split(",")
]

# these are numerical identifier used in the table in input to identify the two
# groups
group_treatment = GeoAssignment.TREATMENT
group_control = GeoAssignment.CONTROL

geox_data = trimmed_match_post_analysis.check_input_data(
    data.copy(),
    group_control=group_control,
    group_treatment=group_treatment)
geox_data = geox_data[~geox_data["pair"].isin(pairs_exclude)]

geox_data["period"] = geox_data["date"].apply(
    lambda row: 0 if row in pd.Interval(
        design_eval_start_date, design_eval_end_date, closed="both") else
    (1 if row in pd.Interval(test_start_date, test_end_date, closed="both") else
     (2 if row in pd.Interval(test_end_date, cooldown_end_date, closed="right")
      else -1)))
geox_data = geox_data[["date", "geo", "pair", "assignment", "response", "cost",
       "period"]]
pairs = geox_data["pair"].sort_values().drop_duplicates().to_list()

total_cost = geox_data.loc[geox_data["period"]==1, "cost"].sum()
print("Total cost: {}".format(util.human_readable_number(total_cost)))

print("Total response and cost by period and group")
output_table = geox_data.loc[
    geox_data["period"].isin([0, 1]),
    ["period", "assignment", "response", "cost"]].groupby(
        ["period", "assignment"], as_index=False).sum()
output_table.assignment = output_table.assignment.map(
    {group_control: "Control", group_treatment: "Treatment"})
output_table.period = output_table.period.map({0: "Pretest", 1: "Test"})

data_table.DataTable(output_table, include_index=False)

tmp = geox_data[geox_data["period"].isin([0, 1])].groupby(
    ["period", "assignment", "pair"])["response"].sum()**0.5
tmp = tmp.reset_index()

pretreatment = (tmp["period"]==0) & (tmp["assignment"]==group_treatment)
precontrol = (tmp["period"]==0) & (tmp["assignment"]==group_control)
posttreatment = (tmp["period"]==1) & (tmp["assignment"]==group_treatment)
postcontrol = (tmp["period"]==1) & (tmp["assignment"]==group_control)

comp = pd.DataFrame({"pretreatment": tmp[pretreatment]["response"].to_list(),
                   "precontrol": tmp[precontrol]["response"].to_list(),
                   "posttreatment": tmp[posttreatment]["response"].to_list(),
                   "postcontrol": tmp[postcontrol]["response"].to_list()})


fig, ax = plt.subplots(4, 4, figsize=(15, 15))
label = ["pretreatment", "precontrol", "posttreatment", "postcontrol"]
min_ax = min(comp.min())
max_ax = max(comp.max())
for col_ind in range(4):
  for row_ind in range(4):
    if col_ind > row_ind:
      useless = ax[row_ind, col_ind].scatter(comp[label[col_ind]],
                                             comp[label[row_ind]])
      useless = ax[row_ind, col_ind].plot([min_ax*0.97, max_ax*1.03],
                                          [min_ax*0.97, max_ax*1.03], 'r')
      useless = ax[row_ind, col_ind].set_xlim([min_ax*0.97, max_ax*1.03])
      useless = ax[row_ind, col_ind].set_ylim([min_ax*0.97, max_ax*1.03])
    elif col_ind == row_ind:
      useless = ax[row_ind, col_ind].annotate(label[col_ind],
                                              size=20,
                                              xy=(0.15, 0.5),
                                              xycoords="axes fraction")
      useless = ax[row_ind, col_ind].set_xlim([min_ax*0.97, max_ax*1.03])
      useless = ax[row_ind, col_ind].set_ylim([min_ax*0.97, max_ax*1.03])
    else:
      useless = ax[row_ind, col_ind].axis("off")

In [None]:
#@title Visualization of experiment data.

geox_data = geox_data.sort_values(by="date")

def plot_ts_comparison(geox_data, metric):
  f, axes = plt.subplots(1,1, figsize=(15,7.5))
  treatment_time_series = geox_data[geox_data["assignment"] ==
                                    group_treatment].groupby(
                                        ["date"], as_index=False)[metric].sum()
  control_time_series = geox_data[geox_data["assignment"] ==
                                  group_control].groupby(
                                      ["date"], as_index=False)[metric].sum()
  axes.plot(treatment_time_series["date"], treatment_time_series[metric],
            label="treatment")
  axes.plot(control_time_series["date"], control_time_series[metric],
            label="control")
  axes.set_ylabel(metric)
  axes.set_xlabel("date")
  axes.axvline(x=test_end_date, color="black", ls="-",
               label='Experiment period')
  axes.axvline(x=design_eval_start_date, color="red", ls="--",
               label='Design evaluation period')
  axes.axvline(x=cooldown_end_date, color="black", ls="--",
               label='End of cooldown period')
  axes.axvline(x=test_start_date, color="black", ls="-")
  axes.axvline(x=design_eval_end_date, color="red", ls="--")
  axes.legend(bbox_to_anchor=(0.5,1.1), loc='center')

plot_ts_comparison(geox_data, "response")

plot_ts_comparison(geox_data, "cost")

def ts_plot(x,y, **kwargs):
  ax=plt.gca()
  data=kwargs.pop("data")
  data.plot(x=x, y=y, ax=ax, grid=False, **kwargs)

g = sns.FacetGrid(geox_data, col="pair", hue="assignment", col_wrap=3,
                  sharey=False,sharex=False, legend_out=False, height=5,
                  aspect=2)
g = (g.map_dataframe(ts_plot, "date", "response").add_legend())
for ind in range(len(g.axes)):
  cont=geox_data[(geox_data["pair"]==pairs[ind]) &
                 (geox_data["assignment"]==group_control)]["geo"].values[0]
  treat=geox_data[(geox_data["pair"]==pairs[ind]) &
                  (geox_data["assignment"]==group_treatment)]["geo"].values[0]
  useless = g.axes[ind].axvline(x=test_end_date, color="black", ls="-")
  useless = g.axes[ind].axvline(x=design_eval_start_date, color="red", ls="--")
  useless = g.axes[ind].axvline(x=cooldown_end_date, color="black", ls="--")
  useless = g.axes[ind].axvline(x=test_start_date, color="black", ls="-")
  useless = g.axes[ind].axvline(x=design_eval_end_date, color="red", ls="--")
  useless = g.axes[ind].legend(["treatment"+" (geo {})".format(treat),
                                "control"+" (geo {})".format(cont),
                                "Experiment period", "Design evaluation period",
                                "End of cooldown period"], loc="best")

In [None]:
#@title Exclude the cooling down period.

geo_data = trimmed_match_post_analysis.prepare_data_for_post_analysis(
    geox_data=geox_data,
    exclude_cooldown=True,
    group_control=group_control,
    group_treatment=group_treatment
)

results = trimmed_match_post_analysis.calculate_experiment_results(geo_data)
trimmed_match_post_analysis.report_experiment_results(results, average_order_value)

In [None]:
#@title Include the cooling down period

geo_data_including_cooldown = trimmed_match_post_analysis.prepare_data_for_post_analysis(
    geox_data=geox_data,
    exclude_cooldown=False,
    group_control=group_control,
    group_treatment=group_treatment
)

results_with_cd = trimmed_match_post_analysis.calculate_experiment_results(
    geo_data_including_cooldown)
trimmed_match_post_analysis.report_experiment_results(results_with_cd, average_order_value)