<a href="https://colab.research.google.com/github/google/business_intelligence_group/blob/development/solutions/causal-impact/CausalImpact_with_Experimental_Design.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **CausalImpact with Experimental Design**

This Colab file contains *Experimental Design* and *CausalImpact Analysis*.

See [README.md](https://github.com/google/business_intelligence_group/tree/main/solutions/causal-impact) for details

---

Copyright 2023 Google LLC

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

In [None]:
# @title Step.1

# library
print('Installing tfcausalimpact')
!pip install tfcausalimpact --quiet
print('Installed tfcausalimpact' + '\n')

from google.colab import auth
from google.colab import files
from google.colab import widgets
from google.auth import default
from google.cloud import bigquery


from dateutil.relativedelta import relativedelta
from IPython.display import display
from IPython.display import Markdown
from IPython.display import HTML
from IPython.display import Javascript
from oauth2client.client import GoogleCredentials
from causalimpact import CausalImpact
from scipy.spatial.distance import euclidean
from sklearn.metrics import mean_absolute_percentage_error

import altair as alt
import datetime
import fastdtw
import gspread
import logging
import io
import ipywidgets
import itertools
import numpy as np
import os
import pandas as pd
import random
import tensorflow
tensorflow.get_logger().setLevel(logging.ERROR)

def resize_colab_cell():
  display(Javascript('google.colab.output.setIframeHeight(0, true, {maxHeight: 5000})'))
get_ipython().events.register('pre_run_cell', resize_colab_cell)

# ipywidgets settings
def apply_style(text):
  span_style = "<span style='font-size:20px; background: linear-gradient(transparent 90%, #4285F4 0%);'>" + text + "</style>"
  return span_style

# Start Date | Start date to be used for design | ◯ | ◯
# End Date| End date to be used for design| ◯ | ◯
# iteration# | Number of times the random sampling is repeated | ◯ | ◯
# Split# | Number of divisions to be made | ◯ |
# Covariates# | Number of covariates to create | | ◯
# Pick# | Number of columns per covariate | | ◯
# target_geo | Areas to look for similar time series | | ◯

# 1
start_date = ipywidgets.DatePicker(
    description='Start Date:',
    value=datetime.date.today() - relativedelta(days=122),
    disabled=False
)
end_date = ipywidgets.DatePicker(
    description='End Date:',
    value=datetime.date.today() - relativedelta(days=32),
    disabled=False
)
pre_period_start = ipywidgets.DatePicker(
    description='Pre Start:',
    value=datetime.date.today() - relativedelta(days=122),
    disabled=False
)
pre_period_end = ipywidgets.DatePicker(
    description='Pre End:',
    value=datetime.date.today() - relativedelta(days=32),
    disabled=False
)
post_period_start = ipywidgets.DatePicker(
    description='Post Start:',
    value=datetime.date.today() - relativedelta(days=31),
    disabled=False
)
post_period_end = ipywidgets.DatePicker(
    description='Post End:',
    value=datetime.date.today(),
    disabled=False
)
# invisible
num_of_iterationｓ = ipywidgets.Dropdown(
    options=[10, 100, 1000, 10000],
    value=1000,
    disabled=False,
)
num_of_split = ipywidgets.Dropdown(
    options=[2, 3, 4, 5],
    value=2,
    description='split#:',
    disabled=False,
)
num_of_pick = ipywidgets.IntSlider(
    value=5,
    min=1,
    max=50,
    step=1,
    description='pick#:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)
num_of_covariate = ipywidgets.Dropdown(
    options=[1, 2, 3, 4, 5],
    value=2,
    description='covariate#:',
    disabled=False,
)
target_columns = ipywidgets.Text(
    placeholder='Please enter comma-separated entries',
    value='Tokyo, Kanagawa',
    description='target_geo:',
    disabled=False,
)
un_needed_cols = ipywidgets.Text(
    placeholder='Please enter them separated by commas.',
    description='un need col:',
    disabled=False,
    layout=ipywidgets.Layout(width='500px'))
estimate_icpa = ipywidgets.IntText(
    value=1000,
    description='Estimated iCPA:',
    style = {'description_width': 'initial'},
    disabled=False,
    )

text_un_need_col = "If any columns are not needed in the design, please list them comma-separated."
design_type = ipywidgets.Accordion(
    children=[
      ipywidgets.VBox([
          num_of_split,
          ipywidgets.HTML(value=text_un_need_col),
          un_needed_cols,
          ]),
      ipywidgets.VBox([
          target_columns,
          num_of_split,
          num_of_pick,
          num_of_covariate,
          ipywidgets.HTML(value=text_un_need_col),
          un_needed_cols,
          ]),])
design_type.set_title(0, 'A: divide_equally')
design_type.set_title(1, 'B: similarity_selection')

text_process_1 = "ⅰ. Enter the time period covered by this experimental design."
text_process_2 = """ⅱ. Please select the <b>experimental design method</b> and select the necessary items."""
text_process_3 = """ⅲ. Please enter the <b>Estimated incremental CPA</b>"""
purpose_selection = ipywidgets.Tab()
purpose_selection.children = [
    ipywidgets.VBox([
        ipywidgets.Label(value="Enter the time period covered by this CausalImpact analysis."),
        pre_period_start,
        pre_period_end,
        post_period_start,
        post_period_end]),
    ipywidgets.VBox([
        ipywidgets.HTML(value=apply_style(text_process_1)),
        start_date,
        end_date,
        ipywidgets.HTML(value=apply_style(text_process_2)),
        ipywidgets.HTML(value="""
        <li>A: divide_equally divides the time series data into n groups with similar movements.</li>
        <li>B: similarity_selection extracts n groups that move similarly to particular columns.</li>
        """),
        design_type,
        ipywidgets.HTML(value=apply_style(text_process_3)),
        ipywidgets.HTML(value="""
        <li>Incremental CPA = Cost of intervention ÷ Lift from intervention without bias.</li>
        <li>Hypothetical values are not a problem.</li>
        <li>The cost required to verify the hypothesis is used in the calculation.</li>
        """),
        estimate_icpa]),
    ]
purpose_selection.set_title(0, 'Causalimpact_Analysis')
purpose_selection.set_title(1, 'Experimental_Design')


# 2
sheet_url = ipywidgets.Text(
    placeholder='Please enter google spreadsheet url',
    value='https://docs.google.com/spreadsheets/d/1dISrbX1mZHgzpsIct2QXFOWWRRJiCxDSmSzjuZz64Tw/edit#gid=0',
    description='spreadsheet url:',
    style={'description_width': 'initial'},
    disabled=False,
    layout=ipywidgets.Layout(width='1000px'))
sheet_name = ipywidgets.Text(
    placeholder='Please enter sheet name',
    value='analysis_data',
    description='sheet name:',
    disabled=False,
)
csv_name = ipywidgets.Text(
    placeholder='Please enter csv name',
    description='csv name:',
    disabled=False,
    layout=ipywidgets.Layout(width='500px'))
bq_project_id = ipywidgets.Text(
    placeholder='Please enter project id',
    description='project id:',
    disabled=False,
    layout=ipywidgets.Layout(width='500px'))
bq_table_name = ipywidgets.Text(
    placeholder='Please enter table name',
    description='table name:',
    disabled=False,
    layout=ipywidgets.Layout(width='500px'))
text_ss_sample = "Sample data is available below. For Experimental Design, replace sheet name <b>analysis_data</b> with <b>raw_data</b>."
soure_selection = ipywidgets.Tab()
soure_selection.children = [
    ipywidgets.VBox([ipywidgets.HTML(text_ss_sample),sheet_url, sheet_name]),
    ipywidgets.VBox([csv_name]),
    ipywidgets.VBox([bq_project_id, bq_table_name]),
    ]
soure_selection.set_title(0, 'Google_Spreadsheet')
soure_selection.set_title(1, 'CSV_file')
soure_selection.set_title(2, 'Big_Query')

date_col = ipywidgets.Text(
    placeholder='Please enter date column name',
    value='Date',
    description='date column:',
    disabled=False,
)
pivot_col = ipywidgets.Text(
    placeholder='Please enter pivot column name',
    value='Geo',
    description='pivot column:',
    disabled=False,
)
kpi_col = ipywidgets.Text(
    placeholder='Please enter kpi column name',
    value='KPI',
    description='kpi column:',
    disabled=False,
)
text_wide = "Wide, or unstacked data is presented with each different data variable in a separate column."
text_narrow = "Narrow, stacked, or long data is presented with one column containing all the values and another column listing the context of the value"
data_type_selection = ipywidgets.Tab()
data_type_selection.children = [
    ipywidgets.VBox([
        ipywidgets.HTML(text_wide),
        date_col
    ]),
    ipywidgets.VBox([
        ipywidgets.Label(text_narrow),
        date_col,
        pivot_col,
        kpi_col]),
    ]
data_type_selection.set_title(0, 'Wide_Format')
data_type_selection.set_title(1, 'Narrow_Format')

display(
    Markdown(f"""<h2>1. Please select the purpose and set the conditions.</h2>"""),
    purpose_selection,
    Markdown(f"""<h2>2. Please select a data source and choose format <a href='https://en.wikipedia.org/wiki/Wide_and_narrow_data'>wide or narrow</a></h2>"""),
    soure_selection,
    ipywidgets.HTML(value="For the CausalImpact Analysis, please process in advance the <b>date, intervention group, and control group</b> in that order like <a href='https://docs.google.com/spreadsheets/d/1dISrbX1mZHgzpsIct2QXFOWWRRJiCxDSmSzjuZz64Tw/edit#gid=0'>analysis_data</a>."),
    data_type_selection
)


# class
class PreProcess(object):
  """PreProcess read, format, and check the data needed

  before Experimental Design.

  Attributes:
    load_data_from_sheet: Read data from spreadsheet.
    load_data_from_csv: Read data from CSV.
    shape_wide: Pivots the data in the specified column.
    trend_check: Visualize series data.
    _get_converted_multi_columns: Column Name Concatenation.
  """

  def __init__(self):
    pass

  def load_data_from_sheet(self, spreadsheet_url, sheet_name):
    """load_data_from_sheet read data from spreadsheet.

    Args:
      spreadsheet_url: Spreadsheet url with data.
      sheet_name: Sheet name with data.
    """
    auth.authenticate_user()
    creds, _ = default()
    gc = gspread.authorize(creds)
    self._workbook = gc.open_by_url(spreadsheet_url)
    self._worksheet = self._workbook.worksheet(sheet_name)
    self.df_sheet = pd.DataFrame(self._worksheet.get_all_values())
    self.df_sheet.columns = list(self.df_sheet.loc[0, :])
    self.df_sheet.drop(0, inplace=True)
    self.df_sheet.reset_index(drop=True, inplace=True)
    self.df_sheet = self.df_sheet.apply(pd.to_numeric, errors='ignore')
    self.df_sheet.replace(',', '', regex=True, inplace=True)

  def load_data_from_csv(self, csv_name):
    """load_data_from_csv read data from csv.

    Args:
      csv_name: csv file name.
    """
    uploaded = files.upload()
    self.df_sheet = pd.read_csv(io.BytesIO(uploaded[csv_name]))
    self.df_sheet = self.df_sheet.apply(pd.to_numeric, errors='ignore')
    self.df_sheet.replace(',', '', regex=True, inplace=True)

  def load_data_from_bigquery(self, bq_project_id, bq_table_name):
    """load_data_from_csv read data from csv.

    Args:
      csv_name: csv file name.
    """
    auth.authenticate_user()
    client = bigquery.Client(project=bq_project_id)
    self.query =  "SELECT * FROM `" + bq_table_name + "`;"
    self.df_sheet = client.query(self.query).to_dataframe()

    self.df_sheet = self.df_sheet.apply(pd.to_numeric, errors='ignore')
    self.df_sheet.replace(',', '', regex=True, inplace=True)

  def load_data(
      self, soure_selection, sheet_url, sheet_name, csv_name, bq_project_id, bq_table_name):
    if soure_selection == 0:
      try:
        self.load_data_from_sheet(sheet_url, sheet_name)
      except:
        print('1. There is something wrong with the spreadsheet-related settings.')
        print('sheet url:{}\nsheet name:{}'.format(sheet_url, sheet_name))
        raise Exception('Please check the top of this cell')
    elif soure_selection == 1:
      try:
        self.load_data_from_csv(csv_name)
      except:
        print('1. There is something wrong with the CSV-related settings.')
        print('CSV namel:{}'.format(csv_name))
        raise Exception('Please check the top of this cell')
    elif soure_selection == 2:
      try:
        self.load_data_from_bigquery(bq_project_id, bq_table_name)
      except:
        print('1. There is something wrong with the bq-related settings.')
        print('namel:{},{}'.format(bq_project_id, bq_table_name))
        raise Exception('Please check the top of this cell')
    else:
      raise Exception('Please select a data souce at Step.1-2.')

    print('1. The target data has been loaded and stored in analysis.df_sheet or design.df_sheet.\n')

  def shape_wide(self, date_column, pivot_column, kpi_column):
    """shape_wide pivots the data in the specified column.

    Converts long data to wide data suitable for experiment design using
    fastDTW.

    Args:
      date_column: Date column name.
      pivot_column: Columns name that are pivot keys for regions, etc.
      kpi_column: KPI column name.
    """
    self._group_cols = []
    if pivot_column.count(',') == 0:
      self._group_cols.append(pivot_column)
    else:
      self._group_cols = pivot_column.replace(' ', '').split(',')

    self.df_sheet = self.df_sheet[[date_column, kpi_column, *self._group_cols
                                  ]].groupby([date_column, *self._group_cols
                                             ]).sum().reset_index()
    self.df_sheet = pd.pivot_table(
        self.df_sheet, index=date_column, columns=self._group_cols,fill_value=0)
    self.df_sheet.columns = self.df_sheet.columns.droplevel(0)
    if len(self.df_sheet.columns.names) == 1:
      pass
    else:
      self.df_sheet.columns = self._get_converted_multi_columns(self.df_sheet)

  def data_shaping(self, data_type_selection, date_col, pivot_col, kpi_col):
    if data_type_selection == 0:
      pass
    elif data_type_selection == 1:
      try:
        self.shape_wide(date_col, pivot_col, kpi_col)
      except:
        print('2. Data formatting does not work. Please check these columns\n')
        print('Narrow_Format settings>> date col name:{}, pivot col name:{}, kpi col name:{}'.format(date_col, pivot_col, kpi_col))
        display(self.df_sheet.head())
        raise Exception('Please check the top of this cell')
    else:
      print('Please select a data format wide or narrow at Step.1-2')

    if self.df_sheet.index.name != date_col:
      try:
        self.df_sheet.set_index(date_col, inplace=True)
        self.df_sheet.index = self.df_sheet.index.map(str)
      except:
        print('2. The name of the date column appears to be incorrect.')
        print('date col name:{}'.format(date_col))
        display(self.df_sheet.head())
        raise Exception('Please check the top of this cell')
    else:
      pass

    try:
      self.df_sheet.index = pd.to_datetime(self.df_sheet.index)
      self.trend_check(self.df_sheet, date_col)
    except:
      print('2. Data is not in wide format.')
      print('Please format the date and intervention and control group(s).')
      display(self.df_sheet.head())
      raise Exception('Please check the top of this cell')

    print('2. The data frame was formatted for analysis.')
    display(self.df_sheet.head())

  def _get_converted_multi_columns(self, df):
    return [col[0] + '_' + col[1] for col in df.columns.values]

  def trend_check(self, dataframe_wide, date_column):
    """trend_check visualize daily trend, 7-day moving average

    Args:
      dataframe_wide: Wide data to check the trend
    """
    self._df_long = pd.melt(
        dataframe_wide.reset_index(),
        id_vars=date_column).groupby(date_column).sum().reset_index()
    self._line = alt.Chart(self._df_long).mark_line().encode(
        x=alt.X(date_column, axis=alt.Axis(title='')),
        y=alt.Y('value:Q', axis=alt.Axis(title='kpi')),
        color=alt.value('#4285F4'))
    self._moving_average = alt.Chart(self._df_long).transform_window(
        rolling_mean='mean(value)',
        frame=[-7, 0],
    ).mark_line().encode(
        x=alt.X(date_column), y=alt.Y('rolling_mean:Q'), color=alt.value('#DB4437'))

    self._df_scaled = dataframe_wide.copy()
    for column in self._df_scaled:
      self._df_scaled[column] = (
          self._df_scaled[column] - self._df_scaled[column].min()) / (
              self._df_scaled[column].max() - self._df_scaled[column].min())

    self._df_scaled_cols = list(self._df_scaled.columns)
    self._each = alt.Chart(self._df_scaled.reset_index()).transform_fold(
        self._df_scaled_cols, as_=['pivot', 'kpi']).mark_line().encode(
            alt.X(date_column+':T', title=None, axis=alt.Axis(grid=False)),
            alt.Y('kpi:Q', stack=None, axis=None), alt.Color('pivot:N'),
            alt.Row(
                'pivot:N',
                title=None,
                header=alt.Header(labelAngle=0, labelAlign='left'))).properties(
                    bounds='flush',
                    height=50).configure_facet(spacing=0).configure_view(
                        stroke=None).configure_title(anchor='end')

    self._tb_trend = widgets.TabBar(['all', 'each', 'describe'])
    with self._tb_trend.output_to('all'):
      display((self._line + self._moving_average).properties(
          width=600,
          height=250,
          title={
              'text': ['Daily Trend(blue) & 7days moving average(red)'],
          }))
    with self._tb_trend.output_to('each'):
      display((self._each).properties(width=800))
    with self._tb_trend.output_to('describe'):
      display(dataframe_wide.describe(include='all'))
    with self._tb_trend.output_to('all'):
      pass

class CausalImpact_Analysis(PreProcess):
  def __init__(self):
    self.pre_period_start = pre_period_start.value
    self.pre_period_end = pre_period_end.value
    self.post_period_start = post_period_start.value
    self.post_period_end = post_period_end.value

    self.soure_selection = soure_selection.selected_index
    self.sheet_url = sheet_url.value
    self.sheet_name = sheet_name.value
    self.csv_name = csv_name.value
    self.bq_project_id = bq_project_id.value
    self.bq_table_name = bq_table_name.value

    self.data_type_selection = data_type_selection.selected_index
    self.date_col = date_col.value
    self.pivot_col = pivot_col.value
    self.kpi_col = kpi_col.value

  def analyze(self):
    self.pre_period = [
        pd.to_datetime(self.pre_period_start),
        pd.to_datetime(self.pre_period_end)
        ]
    self.post_period = [
        pd.to_datetime(self.post_period_start),
        pd.to_datetime(self.post_period_end)
        ]

    self.ci = CausalImpact(self.df_sheet, self.pre_period, self.post_period)
    self.mape = mean_absolute_percentage_error(
                        self.ci.data.iloc[:,0][self.ci.pre_period[0]:self.ci.pre_period[1]],
                        self.ci.inferences.complete_preds_means[self.ci.pre_period[0]:self.ci.pre_period[1]])
    print('=====================================' + '\n')
    print('MAPE: {0:.2%}'.format(self.mape)+ '\n')
    print(self.ci.summary())
    self.ci.plot()

class ExperimentalDesign(PreProcess):
  """ExperimentalDesign runs xxx

  Attributes:
    drop_cols:
    equal_division:
    similar_divisions:
    _run_dtw:
    _visualize_trend_diff:
    select_dataframe:
    causalimpact_simulation:
  """

  def __init__(self):
    self.start_date = start_date.value
    self.end_date = end_date.value

    self.design_type = design_type.selected_index
    self.num_of_iterationｓ = num_of_iterationｓ.value
    self.num_of_split = num_of_split.value
    self.target_columns = target_columns.value
    self.num_of_pick = num_of_pick.value
    self.num_of_covariate = num_of_covariate.value
    self.un_needed_cols = un_needed_cols.value
    self.estimate_icpa = estimate_icpa.value

    self.soure_selection = soure_selection.selected_index
    self.sheet_url = sheet_url.value
    self.sheet_name = sheet_name.value
    self.csv_name = csv_name.value
    self.bq_project_id = bq_project_id.value
    self.bq_table_name = bq_table_name.value

    self.data_type_selection = data_type_selection.selected_index
    self.date_col = date_col.value
    self.pivot_col = pivot_col.value
    self.kpi_col = kpi_col.value

  def design(self):
    if self.design_type == 0:
      self.equal_division(
          dataframe_wide=self.df_sheet,
          num_of_iteration=self.num_of_iterationｓ,
          num_of_split=self.num_of_split,
          start_date=self.start_date,
          end_date=self.end_date,
      )
    elif self.design_type == 1:
      self.similar_divisions(
          dataframe_wide=self.df_sheet,
          target_columns=self.target_columns.replace(' ', '').split(','),
          num_of_iteration=self.num_of_iterationｓ,
          num_of_pick=self.num_of_pick,
          num_of_covariate=self.num_of_covariate,
          start_date=self.start_date,
          end_date=self.end_date,
      )
    else:
      raise Exception('Please select the purpose at Step.1')

  def equal_division(
      self, dataframe_wide, num_of_iteration, num_of_split, start_date, end_date
  ):
    """equal_division xxx

    Xxx

    Args:
      xxx: xxx.

    Returns:
      xxx
    """
    self.df_dtw = pd.DataFrame(index=[], columns=[])
    self.dataframe = dataframe_wide.query('@start_date <= index <= @end_date')
    self.num_of_pick = len(self.dataframe.columns) // num_of_split

    for l in range(num_of_iteration):
      self.col_list = list(range(0, len(self.dataframe.columns)))
      self.choice_list = pd.DataFrame(index=[], columns=range(num_of_split))
      self.df_picked = pd.DataFrame(index=[], columns=[])
      self.df_candidate = pd.DataFrame()

      for s in range(num_of_split):
        self.tg = random.sample(self.col_list, self.num_of_pick)
        self.choice_list.loc[0, s] = self.tg
        self.col_list = [ele for ele in self.col_list if ele not in self.tg]

      self.choice_list.loc[0, s].extend(self.col_list)

      for i in range(len(self.choice_list.columns)):
        self.picked = pd.DataFrame(
            self.dataframe.iloc[:, self.choice_list.loc[0, i]].sum(axis=1),
            columns=[i],
        )
        self.df_picked = pd.concat([self.df_picked, self.picked], axis=1)

      # self.dtw_row = self._run_dtw(self.df_picked)
      self.dtw_row = self._run_dtw(self.df_picked.reset_index(drop=True))

      for c in range(len(self.choice_list.columns)):
        self.dtw_row.append(str(self.choice_list.loc[0, c]))

      self.dtw_row = pd.DataFrame(self.dtw_row).T
      self.df_dtw = pd.concat([self.df_dtw, self.dtw_row], ignore_index=True)

    self.df_dtw = (
        self.df_dtw.drop_duplicates().sort_values(0).reset_index(drop=True)
    )
    self.df_dtw = self.df_dtw.head(1)

    self.df_dtw_columns = pd.DataFrame()
    for i in range(len(self.df_dtw)):
      for x in range(len(self.df_dtw.columns) - 1):
        self.df_dtw_columns.at[i, x] = str(
            list(
                self.dataframe.iloc[:, eval(self.df_dtw.iloc[i, x + 1])].columns
            )
        )

    self.df_dtw_values = pd.DataFrame()
    for i in range(len(self.df_dtw)):
      self.df_dtw_values.at[i, self.date_col] = (
          self.dataframe.iloc[:, eval(self.df_dtw.iloc[i, x + 1])]
          .sum(axis=1)
          .reset_index()[self.date_col]
          .to_string(index=False)
          .replace(' ', '')
      )
      for x in range(len(self.df_dtw.columns) - 1):
        self.df_dtw_values.at[i, x] = (
            self.dataframe.iloc[:, eval(self.df_dtw.iloc[i, x + 1])]
            .sum(axis=1)
            .reset_index()[0]
            .to_string(index=False)
            .replace(' ', '')
        )

    self._visualize_trend_diff(self.df_dtw_values, self.df_dtw_columns)

    self.test_column = ipywidgets.Dropdown(
        options=list(self.df_candidate.columns),
        description='test column:',
        disabled=False,
    )

    self.control_column = ipywidgets.SelectMultiple(
        options=list(self.df_candidate.columns),
        description='control column:',
        value=('col2',),
        style={'description_width': 'initial'},
        disabled=False,
    )
    self.test_control_selection = (
        'Please select one test column & control column(s).'
    )
    display(
        Markdown(f"""<h3>{self.test_control_selection}</h3>"""),
        ipywidgets.HBox([
            self.test_column,
            self.control_column,
        ]),
    )

  def similar_divisions(
      self,
      dataframe_wide,
      target_columns,
      num_of_iteration,
      num_of_pick,
      num_of_covariate,
      start_date,
      end_date):
    """similar_divisions xxx

    Xxx

    Args:
      xxx: xxx.

    Returns:
      xxx
    """
    self.df_dtw = pd.DataFrame()
    self.choice_list = pd.DataFrame()
    self.target_cols = []
    self.dataframe = dataframe_wide.query('@start_date <= index <= @end_date')

    for g in range(len(target_columns)):
      self.target_cols.append(
          self.dataframe.columns.values.tolist().index(target_columns[g])
      )

    self.col_list = list(range(0, len(self.dataframe.columns)))
    for g in range(len(target_columns)):
      self.col_list.remove(
          self.dataframe.columns.values.tolist().index(target_columns[g])
      )

    for l in range(num_of_iteration):
      self.col_list_candidate = self.col_list.copy()
      self.choice_list = [self.target_cols]
      self.df_picked = pd.DataFrame()

      for s in range(num_of_covariate):
        self.tg = random.sample(self.col_list_candidate, num_of_pick)
        self.choice_list.append(self.tg)
        self.col_list_candidate = [
            ele for ele in self.col_list_candidate if ele not in self.tg
        ]

      for i in range(len(self.choice_list)):
        self.picked = pd.DataFrame(
            self.dataframe.iloc[:, self.choice_list[i]].sum(axis=1), columns=[i]
        )
        self.df_picked = pd.concat([self.df_picked, self.picked], axis=1)

      # self.dtw_row = self._run_dtw(self.df_picked)
      self.dtw_row = self._run_dtw(self.df_picked.reset_index(drop=True))

      for c in range(len(self.choice_list)):
        self.dtw_row.append(str(self.choice_list[c]))

      self.dtw_row = pd.DataFrame(self.dtw_row).T
      self.df_dtw = pd.concat([self.df_dtw, self.dtw_row], ignore_index=True)

    self.df_dtw = (
        self.df_dtw.drop_duplicates().sort_values(0).reset_index(drop=True)
    )
    self.df_dtw = self.df_dtw.head()

    self.df_dtw_columns = pd.DataFrame()
    for i in range(len(self.df_dtw)):
      for x in range(len(self.df_dtw.columns) - 1):
        self.df_dtw_columns.at[i, x] = str(
            list(
                self.dataframe.iloc[:, eval(self.df_dtw.iloc[i, x + 1])].columns
            )
        )

    self.df_dtw_values = pd.DataFrame()
    for i in range(len(self.df_dtw)):
      self.df_dtw_values.at[i, self.date_col] = (
          self.dataframe.iloc[:, eval(self.df_dtw.iloc[i, x + 1])]
          .sum(axis=1)
          .reset_index()[self.date_col]
          .to_string(index=False)
          .replace(' ', '')
      )
      for x in range(len(self.df_dtw.columns) - 1):
        self.df_dtw_values.at[i, x] = (
            self.dataframe.iloc[:, eval(self.df_dtw.iloc[i, x + 1])]
            .sum(axis=1)
            .reset_index()[0]
            .to_string(index=False)
            .replace(' ', '')
        )

    self._visualize_trend_diff(self.df_dtw_values, self.df_dtw_columns)

    self.your_choice = ipywidgets.Dropdown(
        options=['option_1', 'option_2', 'option_3', 'option_4', 'option_5'],
        description='your choice:',
        disabled=False,
    )

    self.test_column = ipywidgets.Dropdown(
        options=list(self.df_candidate.columns),
        description='test column:',
        disabled=False,
    )

    self.control_column = ipywidgets.SelectMultiple(
        options=list(self.df_candidate.columns),
        description='control column:',
        value=('col2',),
        style={'description_width': 'initial'},
        disabled=False,
    )
    self.test_control_selection = (
        'Please select option, test column & control column(s).'
    )
    display(
        Markdown(f"""<h3>{self.test_control_selection}</h3>"""),
        ipywidgets.HBox([
            self.your_choice,
            self.test_column,
            self.control_column,
        ]),
    )

  def _run_dtw(self, df_picked):
    """_run_dtw calculates the DTW distance between time series

    Args:
      xxx: xxx.

    Returns:
      xxx
    """
    self.dist = 0
    self.df_dtw_row = []
    for column in df_picked:
      df_picked[column] = (df_picked[column] - df_picked[column].min()) / (
          df_picked[column].max() - df_picked[column].min()
      )
    for v in itertools.combinations(list(df_picked.columns), 2):
      distance, path = fastdtw.fastdtw(
          # df_picked.loc[:, v[0]],
          # df_picked.loc[:, v[1]],
          df_picked.reset_index().loc[:, ['index', v[0]]],
          df_picked.reset_index().loc[:, ['index', v[1]]],
          dist=euclidean
      )
      self.dist = self.dist + distance
    self.df_dtw_row.append(self.dist)

    return self.df_dtw_row

  def _visualize_trend_diff(self, df_dtw_values, df_dtw_columns):
    """_visualize_trend_diff xxx

    Xxx

    Args:
      xxx: xxx.

    Returns:
      xxx
    """
    self.candidate_tb = widgets.TabBar(
        ['option_' + str(sub + 1) for sub in df_dtw_values.index.tolist()]
    )

    for i in range(len(df_dtw_columns)):
      with self.candidate_tb.output_to(i):
        for col in range(len(df_dtw_columns.columns)):
          print('col' + str(col + 1) + df_dtw_columns.iloc[i, col])
        print('\n')

        self.df_candidate = pd.DataFrame(
            {'date': df_dtw_values.loc[i, self.date_col].split('\n')}
        )

        for col in range(len(df_dtw_values.columns) - 1):
          self.df_candidate = pd.concat(
              [
                  self.df_candidate,
                  pd.Series(df_dtw_values.loc[i, col].split('\n')),
              ],
              ignore_index=True,
              axis=1,
          )
        self.df_candidate = self.df_candidate.apply(
            pd.to_numeric, errors='ignore'
        )
        self.df_candidate.set_index(0, inplace=True)
        self.df_candidate.index.rename('date', inplace=True)
        self.df_candidate = self.df_candidate.add_prefix('col')

        self.df_diff = pd.DataFrame(
            np.diff(self.df_candidate, axis=0),
            columns=self.df_candidate.columns.values,
        )

        self.line = (
            alt.Chart(self.df_candidate.reset_index())
            .transform_fold(
                [str(n) for n in list(self.df_candidate.columns)],
            )
            .mark_line()
            .encode(alt.X('date:T'), alt.Y('value:Q'), color='key:N')
            .properties(width=600, height=250)
        )

        self.scatter = (
            alt.Chart(self.df_diff.reset_index())
            .mark_circle()
            .encode(
                alt.X(alt.repeat('column'), type='quantitative'),
                alt.Y(alt.repeat('row'), type='quantitative'),
            )
            .properties(width=100, height=100)
            .repeat(
                row=self.df_diff.columns.values,
                column=np.flipud(self.df_diff.columns.values),
            )
        )
        alt.hconcat(self.line, self.scatter).display()
    with self.candidate_tb.output_to(0):
      pass

  def select_dataframe(self):
    if self.design_type == 0:
      self.your_choice_selection = 'option_1'
    elif self.design_type == 1:
      self.your_choice_selection = self.your_choice.value

    self.selection_row = self.df_dtw_columns.iloc[
        int(self.your_choice_selection.replace('option_', '')) - 1
    ]
    self.selection_cols = [
        int(self.test_column.value.replace('col', '')) - 1
    ] + [int(s.replace('col', '')) - 1 for s in list(self.control_column.value)]
    self.colnames = ['test']
    for i in range(len(self.selection_cols) - 1):
      self.colnames.append('control_' + str(i + 1))
    self.df = pd.DataFrame()
    for cols in range(len(self.selection_cols)):
      self.df = pd.concat(
          [
              self.df,
              self.df_sheet.loc[
                  :, eval(self.selection_row[self.selection_cols[cols]])
              ].sum(axis=1),
          ],
          axis=1,
      )
    self.df.set_axis(self.colnames, axis=1, inplace=True)
    for x, i in zip(self.df.columns, range(len(self.selection_cols))):
      print('{}: {}'.format(x, self.selection_row[self.selection_cols[i]]))

  def color_p_value(self, val):
    color = 'red' if val <= 0.05 else 'black'
    return 'color: %s' % color

  def simulate_params(
      self,
      dataframe,
      estimation_icpa,
      treat_duration=[7, 14, 28],
      treat_impact=[1.01, 1.03, 1.05, 1.10]):
    """simulate_params simulates test period and budget requirements.

    Simulation data with different hypothetical time periods and effects are
    created from time series data and analyze impact using CausalImpact.
    Output the combination of period and budget that produces significant
    results.

    Args:
      dataframe: Time series data for CausalImpact
      estimation_icpa:

    Returns:
      sim_df: A data frame simulates the period and budget requirements
    """

    self.simulation_combination = list(
        itertools.product(treat_duration, treat_impact)
    )

    self.pre_start_date = pd.to_datetime(min(dataframe.index))
    self.post_end_date = pd.to_datetime(max(dataframe.index))
    self.adjusted_df = dataframe.copy()
    self.simulation = []
    self.simulation_df = pd.DataFrame(
        index=[],
        columns=[
            'days',
            'impact',
            'MAPE',
            'abs_effect',
            'mean',
            'budget',
            'p_value',
        ],
    )

    for duration in treat_duration:
      self.pre_end_date = self.post_end_date + datetime.timedelta(
          days=-duration
      )
      self.post_start_date = self.pre_end_date + datetime.timedelta(days=1)
      self.pre_period = [self.pre_start_date, self.pre_end_date]
      self.post_period = [self.post_start_date, self.post_end_date]

      for impact in treat_impact:
        self.adjusted_df.loc[
            dataframe.index >= self.post_start_date, 'test'
        ] = (
            dataframe.loc[dataframe.index >= self.post_start_date, 'test']
            * impact
        )
        self.adjusted_df.index = self.adjusted_df.index.map(str)
        self.ci = CausalImpact(
            self.adjusted_df, self.pre_period, self.post_period
        )
        self.simulation.append(self.ci)
        self.impact_dict = {
            'days': [duration],
            'impact': [impact],
            'MAPE': [
                mean_absolute_percentage_error(
                    self.ci.data.iloc[:, 0][
                        self.ci.pre_period[0] : self.ci.pre_period[1]
                    ],
                    self.ci.inferences.complete_preds_means[
                        self.ci.pre_period[0] : self.ci.pre_period[1]
                    ],
                )
            ],
            'abs_effect': [
                self.ci.summary_data.loc['abs_effect', 'cumulative']
            ],
            'mean': [self.ci.summary_data.loc['abs_effect', 'average']],
            'budget': [
                self.ci.summary_data.loc['abs_effect', 'cumulative']
                * estimation_icpa
            ],
            'p_value': [self.ci.p_value],
        }
        self.simulation_df = pd.concat(
            [self.simulation_df, pd.DataFrame.from_dict(self.impact_dict)],
            ignore_index=True,
        )

    display(
        self.simulation_df.style.applymap(self.color_p_value, subset='p_value')
        .format('{:,.0f}', subset=['abs_effect', 'mean', 'budget'])
        .format('{:,.2f}', subset=['impact', 'p_value'])
        .format('{:,.2%}', subset=['MAPE'])
    )

  def plot(self):
    self.simulation_tb = widgets.TabBar(self.simulation_combination)
    for i in range(len(self.simulation_combination)):
      with self.simulation_tb.output_to(i):
        print(self.simulation[i].summary())
        display(self.simulation[i].plot())
    with self.simulation_tb.output_to(0):
      pass

In [None]:
# @title Step.2
if purpose_selection.selected_index == 0:
  analysis = CausalImpact_Analysis()
  analysis.load_data(
      analysis.soure_selection, analysis.sheet_url, analysis.sheet_name,
      analysis.csv_name, analysis.bq_project_id, analysis.bq_table_name)
  analysis.data_shaping(
      analysis.data_type_selection, analysis.date_col, analysis.pivot_col, analysis.kpi_col)
  analysis.analyze()
elif purpose_selection.selected_index == 1:
  design = ExperimentalDesign()
  design.load_data(
      design.soure_selection, design.sheet_url, design.sheet_name,
      design.csv_name, design.bq_project_id, design.bq_table_name
  )
  design.data_shaping(
      design.data_type_selection, design.date_col, design.pivot_col, design.kpi_col
  )
  design.design()
else:
  print('Please choose Experimental Design or CausalImpact Analysis in Step.1 for the purpose of the analysis.')

In [None]:
# @title Step.3 (Experimental Design Only) it takes about 5+ minutes.
design.select_dataframe()
design.simulate_params(
        dataframe = design.df,
        estimation_icpa = design.estimate_icpa,
        treat_duration=[7, 14, 28],
        treat_impact=[1.01, 1.03, 1.05, 1.10]
        )

In [None]:
# @title Step.4 (Experimental Design Only)
design.plot()