<a href="https://colab.research.google.com/github/google/business_intelligence_group/blob/development/solutions/causal-impact/CausalImpact_with_Experimental_Design.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **CausalImpact with Experimental Design**

This Colab file contains *Experimental Design* and *CausalImpact Analysis*.

See [README.md](https://github.com/google/business_intelligence_group/tree/main/solutions/causal-impact) for details

---

Copyright 2023 Google LLC

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

In [None]:
# @title Step.1 (~ 2min)

print('Installing packages')
!pip install fastdtw --quiet
!pip install rpy2==3.5.1 -q
print('Installed packages' + '\n')

# Data Load
from google.colab import auth, files, widgets
from google.auth import default
from google.cloud import bigquery
import io
import os
import gspread
from oauth2client.client import GoogleCredentials

# Calculate
import altair as alt
import itertools
import random
import numpy as np
import pandas as pd
import fastdtw
from scipy.spatial.distance import euclidean
from sklearn.metrics import mean_absolute_percentage_error

# Input
import datetime
from dateutil.relativedelta import relativedelta
import ipywidgets
from IPython.display import display, Markdown, HTML, Javascript

import warnings
warnings.simplefilter('ignore')

# R
import rpy2.robjects as robjects
import rpy2.rinterface_lib.callbacks as callbacks
%load_ext rpy2.ipython
callbacks.consolewrite_warnerror = lambda s: None
robjects.r('system("apt install r-cran-causalimpact")')
get_ipython().run_line_magic('R', 'library(CausalImpact)')

def resize_colab_cell():
  display(
      Javascript(
          "google.colab.output.setIframeHeight(0, true, {maxHeight: 5000})"
      )
  )


get_ipython().events.register("pre_run_cell", resize_colab_cell)


def success_text(text):
  return print(f"\033[38;2;15;157;88m " + text + "\033[0m")


def failure_text(text):
  return print(f"\033[38;2;219;68;55m " + text + "\033[0m")

class PreProcess(object):

  def __init__(self):
    self.start_date = ipywidgets.DatePicker(
        description='Start Date:',
        value=datetime.date.today() - relativedelta(days=122),
    )
    self.end_date = ipywidgets.DatePicker(
        description='End Date:',
        value=datetime.date.today() - relativedelta(days=32),
    )
    self.pre_period_start = ipywidgets.DatePicker(
        description='Pre Start:',
        value=datetime.date.today() - relativedelta(days=122),
    )
    self.pre_period_end = ipywidgets.DatePicker(
        description='Pre End:',
        value=datetime.date.today() - relativedelta(days=32),
    )
    self.post_period_start = ipywidgets.DatePicker(
        description='Post Start:',
        value=datetime.date.today() - relativedelta(days=31),
    )
    self.post_period_end = ipywidgets.DatePicker(
        description='Post End:',
        value=datetime.date.today(),
    )
    self.num_of_split = ipywidgets.Dropdown(
        options=[2, 3, 4, 5],
        value=2,
        description='split#:',
        disabled=False,
    )
    self.num_of_pick_range = ipywidgets.IntRangeSlider(
        value=[5, 10],
        min=1,
        max=30,
        step=1,
        description='max pick#:',
        orientation='horizontal',
        readout=True,
        readout_format='d',
    )
    self.num_of_covariate = ipywidgets.Dropdown(
        options=[1, 2, 3, 4, 5],
        value=2,
        description='covariate#:',
        layout=ipywidgets.Layout(width='192px'),
    )
    self.target_columns = ipywidgets.Text(
        placeholder='Please enter comma-separated entries',
        value='Tokyo, Kanagawa',
        description='target_cols:',
        layout=ipywidgets.Layout(width='400px'),
    )
    self.un_needed_cols = ipywidgets.Text(
        placeholder=(
            'Enter comma-separated columns if any columns are not used in the'
            ' design.'
        ),
        description='un need col:',
        layout=ipywidgets.Layout(width='1000px'),
    )
    self.estimate_icpa = ipywidgets.IntText(
        value=1000,
        description='Estimated iCPA:',
        style={'description_width': 'initial'},
    )
    self.sheet_url = ipywidgets.Text(
        placeholder='Please enter google spreadsheet url',
        value='https://docs.google.com/spreadsheets/d/1dISrbX1mZHgzpsIct2QXFOWWRRJiCxDSmSzjuZz64Tw/edit#gid=0',
        description='spreadsheet url:',
        style={'description_width': 'initial'},
        layout=ipywidgets.Layout(width='1000px'),
    )
    self.sheet_name = ipywidgets.Text(
        placeholder='Please enter sheet name',
        # value='raw_data',
        value='analysis_data',
        description='sheet name:',
    )
    self.csv_name = ipywidgets.Text(
        placeholder='Please enter csv name',
        description='csv name:',
        layout=ipywidgets.Layout(width='500px'),
    )
    self.bq_project_id = ipywidgets.Text(
        placeholder='Please enter project id',
        description='project id:',
        layout=ipywidgets.Layout(width='500px'),
    )
    self.bq_table_name = ipywidgets.Text(
        placeholder='Please enter table name',
        description='table name:',
        layout=ipywidgets.Layout(width='500px'),
    )
    self.date_col = ipywidgets.Text(
        placeholder='Please enter date column name',
        value='Date',
        description='date column:',
    )
    self.pivot_col = ipywidgets.Text(
        placeholder='Please enter pivot column name',
        value='Geo',
        description='pivot column:',
    )
    self.kpi_col = ipywidgets.Text(
        placeholder='Please enter kpi column name',
        value='KPI',
        description='kpi column:',
    )
    self.your_choice = ipywidgets.Dropdown(
        options=['option_1', 'option_2', 'option_3'],
        description='your choice:',
    )
    self.test_column = ipywidgets.Dropdown(
        options=['col_1', 'col_2', 'col_3', 'col_4', 'col_5'],
        description='test column:',
    )
    self.control_column = ipywidgets.SelectMultiple(
        options=['col_1', 'col_2', 'col_3', 'col_4', 'col_5'],
        description='control column:',
        value=('col_2',),
        style={'description_width': 'initial'},
    )

  def _apply_style(self, text):
    self.span_style = (
        "<span style='font-size:15px; background: linear-gradient(transparent"
        " 90%, #4285F4 0%);'>"
        + text
        + '</style>'
    )
    return self.span_style

  def generate_purpose_section(self):
    self.design_type = ipywidgets.Accordion(
        children=[
            ipywidgets.VBox([
                self.num_of_split,
                self.un_needed_cols,
            ]),
            ipywidgets.VBox([
                ipywidgets.HBox([
                    self.target_columns,
                    self.num_of_pick_range,
                    self.num_of_covariate,
                ]),
                self.un_needed_cols,
            ]),
        ]
    )
    self.design_type.set_title(
        0,
        (
            'A: divide_equally divides the time series data into N'
            ' groups(split#) with similar movements.'
        ),
    )
    self.design_type.set_title(
        1,
        (
            'B: similarity_selection extracts N groups(covariate#) that move'
            ' similarly to particular columns(target_cols).'
        ),
    )

    self.purpose_selection = ipywidgets.Tab()
    self.purpose_selection.children = [
        ipywidgets.VBox([
            ipywidgets.HTML(
                value=self._apply_style(
                    'Enter the Pre and Post the intervention.'
                )
            ),
            ipywidgets.HBox([
                ipywidgets.VBox([self.pre_period_start, self.pre_period_end]),
                ipywidgets.VBox([self.post_period_start, self.post_period_end]),
            ]),
        ]),
        ipywidgets.VBox([
            ipywidgets.HTML(
                value=self._apply_style(
                    'ⅰ. Enter the time period to be used for experimental'
                    ' design.'
                )
            ),
            ipywidgets.HBox([
                self.start_date,
                self.end_date,
            ]),
            ipywidgets.HTML(
                value=self._apply_style(
                    'ⅱ. Select the <b>experimental design method</b> and'
                    ' enter the necessary items.'
                )
            ),
            self.design_type,
            ipywidgets.HTML(
                value=self._apply_style(
                    'ⅲ. Enter the <b>Estimated incremental CPA</b>(Cost of'
                    ' intervention ÷ Lift from intervention without bias)'
                )
            ),
            ipywidgets.HTML(value="""
                <li>Hypothetical values are not a problem. The cost required to verify the hypothesis is used in the calculation.</li>
                """),
            self.estimate_icpa,
        ]),
    ]
    self.purpose_selection.set_title(0, 'Causalimpact')
    self.purpose_selection.set_title(1, 'Experimental_Design')

    self.soure_selection = ipywidgets.Tab()
    self.soure_selection.children = [
        ipywidgets.VBox([self.sheet_url, self.sheet_name]),
        ipywidgets.VBox([self.csv_name]),
        ipywidgets.VBox([self.bq_project_id, self.bq_table_name]),
    ]
    self.soure_selection.set_title(0, 'Google_Spreadsheet')
    self.soure_selection.set_title(1, 'CSV_file')
    self.soure_selection.set_title(2, 'Big_Query')
    self.text_wide = (
        'Wide, or unstacked data is presented with each different data variable'
        ' in a separate column.'
    )
    self.text_narrow = (
        'Narrow, stacked, or long data is presented with one column containing'
        ' all the values and another column listing the context of the value'
    )

    self.data_type_selection = ipywidgets.Tab()
    self.data_type_selection.children = [
        ipywidgets.VBox([ipywidgets.HTML(self.text_wide), self.date_col]),
        ipywidgets.VBox([
            ipywidgets.Label(self.text_narrow),
            ipywidgets.HBox([self.date_col, self.pivot_col, self.kpi_col]),
        ]),
    ]
    self.data_type_selection.set_title(0, 'Wide_Format')
    self.data_type_selection.set_title(1, 'Narrow_Format')

    display(
        Markdown(
            f"""<h3>1. Please select the purpose and set conditions.</h3>"""
        ),
        self.purpose_selection,
        Markdown(
            f"""<h3>2. Please select a data source and choose format <a href='https://en.wikipedia.org/wiki/Wide_and_narrow_data'>wide or narrow</a></h3>"""
        ),
        self.soure_selection,
        self.data_type_selection,
    )

  def _load_data_from_sheet(self, spreadsheet_url, sheet_name):
    """load_data_from_sheet read data from spreadsheet.

    Args:
    spreadsheet_url: Spreadsheet url with data.
    sheet_name: Sheet name with data.
    """
    auth.authenticate_user()
    creds, _ = default()
    gc = gspread.authorize(creds)
    self._workbook = gc.open_by_url(spreadsheet_url)
    self._worksheet = self._workbook.worksheet(sheet_name)
    self.df_sheet = pd.DataFrame(self._worksheet.get_all_values())
    self.df_sheet.columns = list(self.df_sheet.loc[0, :])
    self.df_sheet.drop(0, inplace=True)
    self.df_sheet.reset_index(drop=True, inplace=True)
    self.df_sheet = self.df_sheet.apply(pd.to_numeric, errors='ignore')
    self.df_sheet.replace(',', '', regex=True, inplace=True)
    self.df_sheet.rename(columns=lambda x: x.replace(" ", ""), inplace=True)

  def _load_data_from_csv(self, csv_name):
    """load_data_from_csv read data from csv.

    Args:
    csv_name: csv file name.
    """
    uploaded = files.upload()
    self.df_sheet = pd.read_csv(io.BytesIO(uploaded[csv_name]))
    self.df_sheet = self.df_sheet.apply(pd.to_numeric, errors='ignore')
    self.df_sheet.replace(',', '', regex=True, inplace=True)
    self.df_sheet.rename(columns=lambda x: x.replace(" ", ""), inplace=True)

  def _load_data_from_bigquery(self, bq_project_id, bq_table_name):
    """load_data_from_csv read data from csv.

    Args:
    csv_name: csv file name.
    """
    auth.authenticate_user()
    client = bigquery.Client(project=bq_project_id)
    self.query = 'SELECT * FROM `' + bq_table_name + '`;'
    self.df_sheet = client.query(self.query).to_dataframe()

    self.df_sheet = self.df_sheet.apply(pd.to_numeric, errors='ignore')
    self.df_sheet.replace(',', '', regex=True, inplace=True)
    self.df_sheet.rename(columns=lambda x: x.replace(" ", ""), inplace=True)

  def load_data(self):
    if self.soure_selection.selected_index == 0:
      try:
        self._load_data_from_sheet(self.sheet_url.value, self.sheet_name.value)
      except Exception as e:
        failure_text('\n\nFailure!!')
        print('Error: {}'.format(e))
        print('Please check the following:')
        print(
            '* There is something wrong with the spreadsheet-related settings.'
        )
        print('* sheet url:{}'.format(self.sheet_url.value))
        print('* sheet name:{}'.format(self.sheet_name.value))
        failure_text('▲▲▲▲▲▲\n\n')
        raise Exception('Please check Failure')

    elif self.soure_selection.selected_index == 1:
      try:
        self._load_data_from_csv(self.csv_name.value)
      except Exception as e:
        failure_text('\n\nFailure!!')
        print('Error: {}'.format(e))
        print('Please check the following:')
        print('* There is something wrong with the CSV-related settings.')
        print('* CSV namel:{}'.format(self.csv_name.value))
        failure_text('▲▲▲▲▲▲\n\n')
        raise Exception('Please check Failure')
    elif self.soure_selection.selected_index == 2:
      try:
        self._load_data_from_bigquery(
            self.bq_project_id.value, self.bq_table_name.value
        )
      except Exception as e:
        failure_text('\n\nFailure!!')
        print('Error: {}'.format(e))
        print('Please check the following:')
        print('* There is something wrong with the bq-related settings.')
        print('* bq project id:{}'.format(self.bq_project_id.value))
        print('* bq table name :{}'.format(self.bq_table_name.value))
        failure_text('▲▲▲▲▲▲\n\n')
        raise Exception('Please check Failure')
    else:
      raise Exception('Please select a data souce at Step.1-2.')

    success_text('Success! The target data has been loaded.')
    display(self.df_sheet.head(3))

  def format_data(self):
    """format_data formats the data according to the selected data type.

    Args:
        self: The instance of the class.

    Returns:
        The formatted data frame.
    """
    # Get the selected data type.
    self.data_type = self.data_type_selection.selected_index
    self.date_col_name = self.date_col.value.replace(" ", "")
    self.pivot_col_name = self.pivot_col.value.replace(" ", "")
    self.kpi_col_name = self.kpi_col.value.replace(" ", "")

    # Format the data.
    try:
      if self.data_type == 0:
        self.df_shaped = self.df_sheet.copy()
        self._trend_check()
      elif self.data_type == 1:
        self.df_shaped = self._shape_wide(
            self.df_sheet,
            self.date_col_name,
            self.pivot_col_name,
            self.kpi_col_name,
        )
        success_text('\nSuccess! The data was formatted for analysis.')
        display(self.df_shaped.head(3))
        self._trend_check()
      else:
        raise ValueError('Invalid data type.')
    except Exception as e:
      failure_text('\n\nFailure!!')
      print('Error: {}'.format(e))
      print('Please check the following:')
      print('* The selected data format.')
      print('* The data must be wide. Check the format in the previous cell.')
      print('* The values of the date, pivot, and kpi column.\n\n')
      raise Exception('Please check Failure')

  def _shape_wide(self, dataframe, date_column, pivot_column, kpi_column):
    """shape_wide pivots the data in the specified column.

    Converts long data to wide data suitable for experiment design using
    fastDTW.

    Args:
        dataframe: The DataFrame to be pivoted.
        date_column: The name of the column that contains the dates.
        pivot_column: The name of the column that contains the pivot keys.
        kpi_column: The name of the column that contains the KPI values.

    Returns:
        A DataFrame with the pivoted data.
    """
    # Check if the pivot_column is a single column or a list of columns.
    if ',' in pivot_column:
      group_cols = pivot_column.replace(' ', '').split(',')
    else:
      group_cols = [pivot_column]

    # Group the dataframe by the date and group columns, and sum the kpi column.
    dataframe = dataframe[[date_column] + [kpi_column] + group_cols]
    dataframe = dataframe.groupby([date_column] + group_cols).sum()

    # Pivot the dataframe, filling missing values with 0.
    dataframe = pd.pivot_table(
        dataframe, index=date_column, columns=group_cols, fill_value=0
    )

    # # Drop the first level of the column names.
    dataframe.columns = dataframe.columns.droplevel(0)

    # If there are multiple columns, convert the column names to a single string.
    if len(dataframe.columns.names) > 1:
      new_cols = ['_'.join([x.replace(",", "_") for x in y]) for y in dataframe.columns.values]
      dataframe.columns = new_cols

    dataframe = dataframe.reset_index()

    return dataframe

  def _trend_check(self):
    """trend_check visualize daily trend, 7-day moving average

    Args:
    dataframe_wide: Wide data to check the trend
    """
    self.df_each = pd.DataFrame(index=self.df_shaped[self.date_col_name])
    self.df_each.index = pd.to_datetime(self.df_each.index)
    self.tick_count = len(self.df_each.resample('M')) - 1

    self.col_list = list(self.df_shaped.columns)
    self.col_list.remove(self.date_col_name)
    for column in self.col_list:
      self.df_each[column] = list(
          self.df_shaped[column] - self.df_shaped[column].min()
      ) / (self.df_shaped[column].max() - self.df_shaped[column].min())

    self.line_each = (
        alt.Chart(self.df_each.reset_index())
        .transform_fold(fold=self.col_list, as_=['pivot', 'kpi'])
        .mark_line()
        .encode(
            alt.X(
                self.date_col_name + ':T',
                title=None,
                axis=alt.Axis(
                    grid=False, format='%Y %b', tickCount=self.tick_count
                ),
            ),
            alt.Y('kpi:Q', stack=None, axis=None),
            alt.Color('pivot:N'),
            alt.Row(
                'pivot:N',
                title=None,
                header=alt.Header(labelAngle=0, labelAlign='left'),
            ),
        )
        .properties(bounds='flush', height=50)
        .configure_facet(spacing=0)
        .configure_view(stroke=None)
        .configure_title(anchor='end')
    )

    self.df_long = (
        pd.melt(self.df_shaped, id_vars=self.date_col_name)
        .groupby(self.date_col_name)
        .sum(numeric_only=True)
        .reset_index()
    )
    self.line_total = (
        alt.Chart(self.df_long)
        .mark_line()
        .encode(
            x=alt.X(
                self.date_col_name + ':T',
                axis=alt.Axis(
                    title='', format='%Y %b', tickCount=self.tick_count
                ),
            ),
            y=alt.Y('value:Q', axis=alt.Axis(title='kpi')),
            color=alt.value('#4285F4'),
        )
    )
    self.moving_average = (
        alt.Chart(self.df_long)
        .transform_window(
            rolling_mean='mean(value)',
            frame=[-4, 3],
        )
        .mark_line()
        .encode(
            x=alt.X(self.date_col_name + ':T'),
            y=alt.Y('rolling_mean:Q'),
            color=alt.value('#DB4437'),
        )
    )

    self.tb_trend = widgets.TabBar(['all', 'each', 'describe'])
    with self.tb_trend.output_to('all'):
      display(
          (self.line_total + self.moving_average).properties(
              width=700,
              height=200,
              title={
                  'text': ['Daily Trend(blue) & 7days moving average(red)'],
              },
          )
      )
    with self.tb_trend.output_to('each'):
      display((self.line_each).properties(width=700))
    with self.tb_trend.output_to('describe'):
      display(self.df_shaped.describe(include='all'))
    with self.tb_trend.output_to('all'):
      pass

class CausalImpact(PreProcess):

  def __init__(self):
    super().__init__()
    self.col_name = [
        'response',
        'cum_response',
        'point_pred',
        'point_pred_lower',
        'point_pred_upper',
        'cum_pred',
        'cum_pred_lower',
        'cum_pred_upper',
        'point_effect',
        'point_effect_lower',
        'point_effect_upper',
        'cum_effect',
        'cum_effect_lower',
        'cum_effect_upper',
    ]

  def create_causalimpact_object(
    self, data, date_col, pre_start, pre_end, post_start, post_end):
    self.ci_code = '''
    index <- colnames(df) != "{date}"
    df_zoo <- zoo(df[index], as.Date(df${date}))
    pre.period <- as.Date(c("{pre_start}", "{pre_end}"))
    post.period <- as.Date(c("{post_start}", "{post_end}"))
    impact <- CausalImpact(df_zoo, pre.period, post.period)
    '''
    df = data
    %R -i df
    self.causalimpact_object = robjects.r(self.ci_code.format(
        date=date_col,
        pre_start=pre_start,
        pre_end=pre_end,
        post_start=post_start,
        post_end=post_end
        )
    )
    return self.causalimpact_object

  def display_causalimpact_object(
    self, causalimpact_object, original_df, date_col, pre_start, pre_end, tread_start, treat_end):
    # prep
    self.causalimpact_df = pd.DataFrame(np.array_split(list(causalimpact_object[0]), 14), index = self.col_name).T
    self.causalimpact_df = self.causalimpact_df.set_index(self.df_shaped[self.date_col_name])

    self.mape = mean_absolute_percentage_error(
        self.causalimpact_df.loc[:,'response'][str(pre_start) : str(pre_end)],
        self.causalimpact_df.loc[:,'point_pred'][str(pre_start) : str(pre_end)])

    # display result
    print('Approximate model accuracy >> MAPE:{:.2%}\n'.format(self.mape))
    robjects.r.assign("ci_obj", causalimpact_object)
    robjects.r('summary(ci_obj)')

    self.line_1 = alt.Chart(self.causalimpact_df.reset_index()).transform_fold(
        ['response', 'point_pred',]
    ).mark_line().encode(
        x = alt.X('yearmonthdate('+self.date_col_name+')', axis=alt.Axis(title='',labels=False, ticks=False)),
        y = alt.Y('value:Q', scale=alt.Scale(zero=False), axis=alt.Axis(title=''),),
        color=alt.Color('key:N',legend=alt.Legend(
            title=None,
            orient='none',
            legendY=-20,
            direction='horizontal',
            titleAnchor='start')),
        strokeDash=alt.condition(
            alt.datum.key == 'point_pred',
            alt.value([5, 5]),
            alt.value([0]),))
    self.area_1 = alt.Chart(self.causalimpact_df.reset_index()).mark_area(opacity=0.3).encode(
        x = alt.X('yearmonthdate('+self.date_col_name+')'),
        y = alt.Y("point_pred_lower:Q", scale=alt.Scale(zero=False)),
        y2 = alt.Y2("point_pred_upper:Q"))
    self.line_2 = alt.Chart(self.causalimpact_df.reset_index()).mark_line(strokeDash=[5,5]).encode(
        x = alt.X('yearmonthdate('+self.date_col_name+')', axis=alt.Axis(title='',labels=False, ticks=False)),
        y = alt.Y('point_effect:Q', scale=alt.Scale(zero=False), axis=alt.Axis(title='')),)
    self.area_2 = alt.Chart(self.causalimpact_df.reset_index()).mark_area(opacity=0.3).encode(
        x = alt.X('yearmonthdate('+self.date_col_name+')'),
        y = alt.Y("point_effect_lower:Q", scale=alt.Scale(zero=False)),
        y2 = alt.Y2("point_effect_upper:Q"),)

    self.line_3 = alt.Chart(self.causalimpact_df.reset_index()).mark_line(strokeDash=[5,5]).encode(
        x = alt.X('yearmonthdate('+self.date_col_name+')', axis=alt.Axis(title='')),
        y = alt.Y('cum_effect:Q', scale=alt.Scale(zero=False), axis=alt.Axis(title='')),)
    self.area_3 = alt.Chart(self.causalimpact_df.reset_index()).mark_area(opacity=0.3).encode(
        x = alt.X('yearmonthdate('+self.date_col_name+')'),
        y = alt.Y("cum_effect_lower:Q", scale=alt.Scale(zero=False)),
        y2 = alt.Y2("cum_effect_upper:Q"),)
    self.zero_line = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule().encode(y='y',color=alt.value("gray"))
    self.rules = alt.Chart(
        pd.DataFrame({'Date': [str(tread_start), str(treat_end)], 'color': ['red', 'orange']})
        ).mark_rule(strokeDash=[5, 5]).encode(
            x='Date:T',
            color=alt.Color('color:N', scale=None))
    self.plot = alt.vconcat(
        (self.line_1 + self.area_1 + self.rules).properties(height=100, width=600),
        (self.line_2 + self.area_2 + self.rules + self.zero_line).properties(height=100, width=600),
        (self.line_3 + self.area_3 + self.rules + self.zero_line).properties(height=100, width=600)
    )
    display(self.plot)

  def _calculate_distance(self, dataframe):
    self.dist = 0
    for column in dataframe:
      dataframe[column] = (dataframe[column] - dataframe[column].min()) / (
          dataframe[column].max() - dataframe[column].min()
      )
    dataframe = dataframe.reset_index()
    for v in itertools.combinations(list(dataframe.columns), 2):
      self.distance, self.path = fastdtw.fastdtw(
          dataframe.loc[:, ['index', v[0]]],
          dataframe.loc[:, ['index', v[1]]],
          dist=euclidean,
      )
      self.dist = self.dist + self.distance
    return self.dist

  def n_part_split(self):
    self.df_design = self.df_shaped.copy().set_index(self.date_col_name)
    self.df_design.index = pd.to_datetime(self.df_design.index)
    self.df_design = self.df_design.query(
        '@self.start_date.value <= index <= @self.end_date.value'
    )

    self.df_dtw = pd.DataFrame(columns=['distance'])
    self.num_of_pick = len(self.df_design.columns) // self.num_of_split.value

    self.num_of_iteration = 1000
    for l in range(self.num_of_iteration):
      self.col_list = list(self.df_design.columns)
      self.df_picked = pd.DataFrame()

      # random pick
      self.picks = []
      for s in range(self.num_of_split.value):
        self.pick = random.sample(self.col_list, self.num_of_pick)
        self.picks.append(self.pick)
        self.col_list = [ele for ele in self.col_list if ele not in self.pick]
      self.picks[0] + self.col_list
      for i in range(len(self.picks)):
        self.picked = pd.DataFrame(
            self.df_design[self.picks[i]].sum(axis=1), columns=[i]
        )
        self.df_picked = pd.concat([self.df_picked, self.picked], axis=1)

      # dtw
      self.distance = self._calculate_distance(
          self.df_picked.reset_index(drop=True)
      )
      self.df_dtw.loc[l, 'distance'] = float(self.distance)
      for j in range(len(self.picks)):
        self.df_dtw.at[l, j] = str(list(self.picks[j]))

    self.df_dtw = (
        self.df_dtw.sort_values('distance').head(3).reset_index(drop=True)
    )

  def find_similar(self):
    self.df_design = self.df_shaped.copy().set_index(self.date_col_name)
    self.df_design.index = pd.to_datetime(self.df_design.index)
    self.df_design = self.df_design.query(
        '@self.start_date.value <= index <= @self.end_date.value'
    )
    self.df_dtw = pd.DataFrame(columns=['distance'])

    self.target_cols = self.target_columns.value.replace(' ', '').split(',')
    self.pick_range = self.num_of_pick_range.value
    self.num_of_iteration = 1000
    for l in range(self.num_of_iteration):
      self.df_picked = pd.DataFrame()
      self.remained_list = [
          i for i in list(self.df_design.columns) if i not in self.target_cols
      ]

      # similar
      self.picks = []
      for s in range(self.num_of_covariate.value):
        self.num_of_pick = random.randrange(
            self.pick_range[0], self.pick_range[1] + 1, 1
        )
        self.pick = random.sample(self.remained_list, self.num_of_pick)
        self.picks.append(self.pick)
        self.remained_list = [
            ele for ele in self.remained_list if ele not in self.pick
        ]
      self.picks.insert(0, self.target_cols)
      for i in range(len(self.picks)):
        self.picked = pd.DataFrame(
            self.df_design[self.picks[i]].sum(axis=1), columns=[i]
        )
        self.df_picked = pd.concat([self.df_picked, self.picked], axis=1)

      # dtw
      self.distance = self._calculate_distance(
          self.df_picked.reset_index(drop=True)
      )
      self.df_dtw.loc[l, 'distance'] = float(self.distance)
      for j in range(len(self.picks)):
        self.df_dtw.at[l, j] = str(list(self.picks[j]))

    self.df_dtw = (
        self.df_dtw.sort_values('distance').head(3).reset_index(drop=True)
    )

  def reconstitute_dataframe(self):
    self.candidate_tb = widgets.TabBar(
        ['option_' + str(sub + 1) for sub in self.df_dtw.index.tolist()]
    )
    for i in range(len(self.df_dtw)):
      with self.candidate_tb.output_to(i):
        self.candidate_df = pd.DataFrame(
            index=self.df_shaped[self.date_col_name]
        )
        for col in range(len(self.df_dtw.columns) - 1):
          print('col_' + str(col + 1) + self.df_dtw.at[i, col])
          self.candidate_df[col + 1] = list(
              self.df_shaped.loc[:, eval(self.df_dtw.at[i, col])].sum(axis=1)
          )
        print('\n')
        self.candidate_df = self.candidate_df.add_prefix('col_')
        self.chart_line = (
            alt.Chart(self.candidate_df.reset_index())
            .transform_fold(
                fold=list(self.candidate_df.columns), as_=['pivot', 'kpi']
            )
            .mark_line()
            .encode(x=alt.X(self.date_col_name + ':T'), y=alt.Y('kpi:Q'), color='pivot:N')
            .properties(width=600, height=200)
        )
        self.df_diff = pd.DataFrame(
            np.diff(self.candidate_df, axis=0),
            columns=self.candidate_df.columns.values,
        )
        self.scatter = (
            alt.Chart(self.df_diff.reset_index())
            .mark_circle()
            .encode(
                alt.X(alt.repeat('column'), type='quantitative'),
                alt.Y(alt.repeat('row'), type='quantitative'),
            )
            .properties(width=80, height=80)
            .repeat(
                row=self.df_diff.columns.values,
                column=self.df_diff.columns.values,
            )
        )
        display(self.chart_line | self.scatter)

    with self.candidate_tb.output_to(0):
      pass

    display(
        ipywidgets.HTML(
            value=self._apply_style(
                'Please select option, test column & control column(s).'
            )
        ),
        ipywidgets.HBox([
            self.your_choice,
            self.test_column,
            self.control_column,
        ]),
    )

  def generate_simulation(self):
    # reconstitute_dataframe
    self.selection_row = int(self.your_choice.value.replace('option_', '')) - 1
    self.selection_cols = [
        int(self.test_column.value.replace('col_', '')) - 1
    ] + [
        int(s.replace('col_', '')) - 1 for s in list(self.control_column.value)
    ]
    self.colnames = ['test']
    for i in range(len(self.selection_cols) - 1):
      self.colnames.append('control_' + str(i + 1))
    self.df = pd.DataFrame(index=self.df_shaped[self.date_col_name])
    self.df.index = pd.to_datetime(self.df.index)
    for col in self.selection_cols:
      self.df[col] = list(
          self.df_shaped.loc[
              :, eval(self.df_dtw.at[self.selection_row, col])
          ].sum(axis=1)
      )
    self.df.columns = self.colnames
    for x, i in zip(self.df.columns, self.selection_cols):
      print('{}: {}'.format(x, self.df_dtw.at[self.selection_row, i]))

    # simulation
    self.ci_objs = []
    self.simulate_periods = []
    self.adjusted_df = self.df.copy()
    self.treat_duration = [7, 14, 28]
    self.treat_impact = [1.01, 1.03, 1.05, 1.10, 1.2]
    self.simulation_combination = list(
        itertools.product(self.treat_duration, self.treat_impact)
    )
    self.pre_start_date = self.start_date.value
    self.post_end_date = self.end_date.value
    for duration in self.treat_duration:
      self.pre_end_date = self.post_end_date + datetime.timedelta(
          days=-duration
      )
      self.post_start_date = self.pre_end_date + datetime.timedelta(days=1)

      for impact in self.treat_impact:
        self.adjusted_df.loc[
            np.datetime64(self.post_start_date) : np.datetime64(
                self.post_end_date
            ),
            'test',
        ] = (
            self.df.loc[
                np.datetime64(self.post_start_date) : np.datetime64(
                    self.post_end_date
                ),
                'test',
            ]
            * impact
        )
        self.ci_obj = self.create_causalimpact_object(
            self.adjusted_df.reset_index(),
            self.date_col_name,
            self.pre_start_date,
            self.pre_end_date,
            self.post_start_date,
            self.post_end_date,
        )
        self.simulate_periods.append([
            self.pre_start_date,
            self.pre_end_date,
            self.post_start_date,
            self.post_end_date,
        ])
        self.ci_objs.append(self.ci_obj)

    self.simulation_df = pd.DataFrame(
        index=[],
        columns=[
            'Days_simulated',
            'Simulated_impact',
            'MAPE',
            'Total_effect',
            'Average_effect',
            'Required_budget',
            'p_value',
        ],
    )
    for i in range(len(self.ci_objs)):
      self.periods = self.simulate_periods[i]
      self.impact_df = pd.DataFrame(
          np.array_split(list(self.ci_objs[i][0]), 14), index=self.col_name
      ).T
      self.impact_df = self.impact_df.set_index(
          self.df_shaped[self.date_col_name]
      )
      self.impact_dict = {
          'Days_simulated': self.simulation_combination[i][0],
          'Simulated_impact': self.simulation_combination[i][1] - 1,
          'MAPE': [
              mean_absolute_percentage_error(
                  self.impact_df.loc[:, 'response'][
                      str(self.simulate_periods[i][0]) : str(
                          self.simulate_periods[i][1]
                      )
                  ],
                  self.impact_df.loc[:, 'point_pred'][
                      str(self.simulate_periods[i][0]) : str(
                          self.simulate_periods[i][1]
                      )
                  ],
              )
          ],
          'Total_effect': self.ci_objs[i][1][5][1],
          'Average_effect': self.ci_objs[i][1][5][0],
          'Required_budget': [
              self.ci_objs[i][1][5][1] * self.estimate_icpa.value
          ],
          'p_value': self.ci_objs[i][1][14][0],
      }
      self.simulation_df = pd.concat(
          [self.simulation_df, pd.DataFrame.from_dict(self.impact_dict)],
          ignore_index=True,
      )

    display(
        self.simulation_df.style.format({
            'Days_simulated': '{:.0f} d',
            'Simulated_impact': '{:+.0%}',
            'MAPE': '{:.2%}',
            'Total_effect': '{:,.2f}',
            'Average_effect': '{:,.2f}',
            'Required_budget': '{:,.0f}',
            'p_value': '{:,.2f}',
        })
    )

    self.simulation_tb = widgets.TabBar(self.simulation_combination)
    for i in range(len(self.simulation_combination)):
      with self.simulation_tb.output_to(i):
        print(
            'Pre Period:{} ~ {}\nPost Period:{} ~ {}'.format(
                self.simulate_periods[i][0],
                self.simulate_periods[i][1],
                self.simulate_periods[i][2],
                self.simulate_periods[i][3],
            )
        )
        self.display_causalimpact_object(
            self.ci_objs[i],
            self.df_shaped,
            self.date_col_name,
            self.simulate_periods[i][0],
            self.simulate_periods[i][1],
            self.simulate_periods[i][2],
            self.simulate_periods[i][3],
        )
      with self.simulation_tb.output_to(0):
        pass

case_1 = CausalImpact()
case_1.generate_purpose_section()

In [None]:
# @title Step.2
case_1.load_data()
case_1.format_data()

if case_1.purpose_selection.selected_index == 0:
  case_1.ci_objs = []
  try:
    case_1.ci_obj = case_1.create_causalimpact_object(
        case_1.df_shaped,
        case_1.date_col_name,
        case_1.pre_period_start.value,
        case_1.pre_period_end.value,
        case_1.post_period_start.value,
        case_1.post_period_end.value,
    )
    case_1.ci_objs.append(case_1.ci_obj)
    success_text(
        '\nSuccess! CausalImpact has been performed. Check the results in the'
        ' next cell.'
    )
  except Exception as e:
    failure_text('\n\nFailure!!')
    print('Error: {}'.format(e))
    print('Please check the following:')
    print('* Duration of experiment (pre and post).')
    failure_text('▲▲▲▲▲▲\n\n')
    raise Exception('Please check Failure')

elif (
    case_1.purpose_selection.selected_index == 1
    and case_1.design_type.selected_index == 0
):
  case_1.n_part_split()
  case_1.reconstitute_dataframe()

elif (
    case_1.purpose_selection.selected_index == 1
    and case_1.design_type.selected_index == 1
):
  case_1.find_similar()
  case_1.reconstitute_dataframe()

In [None]:
# @title Step.3
if case_1.purpose_selection.selected_index == 0:
  case_1.display_causalimpact_object(
      case_1.ci_objs[0],
      case_1.df_shaped,
      case_1.date_col_name,
      case_1.pre_period_start.value,
      case_1.pre_period_end.value,
      case_1.post_period_start.value,
      case_1.post_period_end.value,
  )
else:
  case_1.generate_simulation()

## (Optional) Another analysis

In [None]:
# @title Step.1
case_2 = CausalImpact()
case_2.generate_purpose_section()

In [None]:
# @title Step.2
case_2.load_data()
case_2.format_data()

if case_2.purpose_selection.selected_index == 0:
  case_2.ci_objs = []
  try:
    case_2.ci_obj = case_2.create_causalimpact_object(
        case_2.df_shaped,
        case_2.date_col_name,
        case_2.pre_period_start.value,
        case_2.pre_period_end.value,
        case_2.post_period_start.value,
        case_2.post_period_end.value,
    )
    case_2.ci_objs.append(case_2.ci_obj)
    success_text(
        '\nSuccess! CausalImpact has been performed. Check the results in the'
        ' next cell.'
    )
  except Exception as e:
    failure_text('\n\nFailure!!')
    print('Error: {}'.format(e))
    print('Please check the following:')
    print('* Duration of experiment (pre and post).')
    failure_text('▲▲▲▲▲▲\n\n')
    raise Exception('Please check Failure')

elif (
    case_2.purpose_selection.selected_index == 1
    and case_2.design_type.selected_index == 0
):
  case_2.n_part_split()
  case_2.reconstitute_dataframe()

elif (
    case_2.purpose_selection.selected_index == 1
    and case_2.design_type.selected_index == 1
):
  case_2.find_similar()
  case_2.reconstitute_dataframe()

In [None]:
# @title Step.3
if case_2.purpose_selection.selected_index == 0:
  case_2.display_causalimpact_object(
      case_2.ci_objs[0],
      case_2.df_shaped,
      case_2.date_col_name,
      case_2.pre_period_start.value,
      case_2.pre_period_end.value,
      case_2.post_period_start.value,
      case_2.post_period_end.value,
  )
else:
  case_2.generate_simulation()