<a href="https://colab.research.google.com/github/google/business_intelligence_group/blob/development/solutions/causal-impact/CausalImpact_with_Experimental_Design.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **CausalImpact with Experimental Design**

This Colab file contains *Experimental Design* and *CausalImpact Analysis*.

See [README.md](https://github.com/google/business_intelligence_group/tree/main/solutions/causal-impact) for details

---

Copyright 2023 Google LLC

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

In [None]:
# @title Step.1 (~ 2min)
%%time
import sys
if 'fastdtw' not in sys.modules:
  !pip install 'fastdtw' --q
if 'tfp-causalimpact' not in sys.modules:
  !pip install 'tfp-causalimpact' --q

# Data Load
from google.colab import auth, files, widgets
from google.auth import default
from google.cloud import bigquery
import io
import os
import gspread
from oauth2client.client import GoogleCredentials

# Calculate
import altair as alt
import itertools
import random
import numpy as np
import pandas as pd
import fastdtw

from scipy.spatial.distance import euclidean
from sklearn.metrics import mean_absolute_percentage_error
from sklearn.preprocessing import MinMaxScaler
from statsmodels.tsa.seasonal import STL

# Input
import datetime
from dateutil.relativedelta import relativedelta
import ipywidgets
from IPython.display import display, Markdown, HTML, Javascript
from tqdm.auto import tqdm
import warnings
warnings.simplefilter('ignore')

# causalimpact
import causalimpact
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions


class PreProcess(object):
  def __init__(self):
    # soure_selection
    self.sheet_url = ipywidgets.Text(
        placeholder='Please enter google spreadsheet url',
        value='https://docs.google.com/spreadsheets/d/1dISrbX1mZHgzpsIct2QXFOWWRRJiCxDSmSzjuZz64Tw/edit#gid=0',
        description='spreadsheet url:',
        style={'description_width': 'initial'},
        layout=ipywidgets.Layout(width='1000px'),
    )
    self.sheet_name = ipywidgets.Text(
        placeholder='Please enter sheet name',
        value='analysis_data',
        # value='raw_data',
        description='sheet name:',
    )
    self.csv_name = ipywidgets.Text(
        placeholder='Please enter csv name',
        description='csv name:',
        layout=ipywidgets.Layout(width='500px'),
    )
    self.bq_project_id = ipywidgets.Text(
        placeholder='Please enter project id',
        description='project id:',
        layout=ipywidgets.Layout(width='500px'),
    )
    self.bq_table_name = ipywidgets.Text(
        placeholder='Please enter table name',
        description='table name:',
        layout=ipywidgets.Layout(width='500px'),
    )
    # data_type_selection
    self.date_col = ipywidgets.Text(
        placeholder='Please enter date column name',
        value='Date',
        description='date column:',
    )
    self.pivot_col = ipywidgets.Text(
        placeholder='Please enter pivot column name',
        value='Geo',
        description='pivot column:',
    )
    self.kpi_col = ipywidgets.Text(
        placeholder='Please enter kpi column name',
        value='KPI',
        description='kpi column:',
    )
    # date
    self.pre_period_start = ipywidgets.DatePicker(
        description='Pre Start:',
        value=datetime.date.today() - relativedelta(days=122),
    )
    self.pre_period_end = ipywidgets.DatePicker(
        description='Pre End:',
        value=datetime.date.today() - relativedelta(days=32),
    )
    self.post_period_start = ipywidgets.DatePicker(
        description='Post Start:',
        value=datetime.date.today() - relativedelta(days=31),
    )
    self.post_period_end = ipywidgets.DatePicker(
        description='Post End:',
        value=datetime.date.today(),
    )
    self.start_date = ipywidgets.DatePicker(
        description='Start Date:',
        value=datetime.date.today() - relativedelta(days=122),
    )
    self.end_date = ipywidgets.DatePicker(
        description='End Date:',
        value=datetime.date.today() - relativedelta(days=32),
    )
    self.depend_data = ipywidgets.ToggleButton(
        value=False,
        description='Click >> Use the beginning and end of data',
        disabled=False,
        button_style='info',
        tooltip='Description',
        layout=ipywidgets.Layout(width='300px'),
    )
    # design_type
    self.num_of_split = ipywidgets.Dropdown(
        options=[2, 3, 4, 5],
        value=2,
        description='split#:',
        disabled=False,
    )
    self.target_columns = ipywidgets.Text(
        placeholder='Please enter comma-separated entries',
        value='Tokyo, Kanagawa',
        description='target_cols:',
        layout=ipywidgets.Layout(width='500px'),
    )
    self.control_columns = ipywidgets.Text(
        placeholder='Please enter comma-separated entries',
        value='Aomori, Akita',
        description='control_cols:',
        layout=ipywidgets.Layout(width='500px'),
    )
    self.num_of_pick_range = ipywidgets.IntRangeSlider(
        value=[5, 10],
        min=1,
        max=50,
        step=1,
        description='max pick#:',
        orientation='horizontal',
        readout=True,
        readout_format='d',
    )
    self.num_of_covariate = ipywidgets.Dropdown(
        options=[1, 2, 3, 4, 5],
        value=1,
        description='covariate#:',
        layout=ipywidgets.Layout(width='192px'),
    )
    self.target_share = ipywidgets.FloatSlider(
        value=0.3,
        min=0.05,
        max=0.5,
        step=0.05,
        description='target share#:',
        orientation='horizontal',
        readout=True,
        readout_format='.1%',
    )
    self.un_needed_cols = ipywidgets.Text(
        placeholder=(
            'Enter comma-separated columns if any columns are not used in the'
            ' design.'
        ),
        description='un need col:',
        layout=ipywidgets.Layout(width='1000px'),
    )
    # simulation
    self.has_seasons = ipywidgets.ToggleButton(
        value=False,
        description='Click >> Set the following number of seasons',
        disabled=False,
        button_style='info',
        layout=ipywidgets.Layout(width='300px'),
    )
    self.num_of_nseasons = ipywidgets.IntText(
        value=7,
        description='nseasons#:',
        disabled=False,
    )
    self.estimate_icpa = ipywidgets.IntText(
        value=1000,
        description='Estimated iCPA:',
        style={'description_width': 'initial'},
    )

    # option
    self.your_choice = ipywidgets.Dropdown(
        options=['option_1', 'option_2', 'option_3'],
        description='your choice:',
    )
    self.test_column = ipywidgets.Dropdown(
        options=['col_1', 'col_2', 'col_3', 'col_4', 'col_5'],
        description='test column:',
    )
    self.control_column = ipywidgets.SelectMultiple(
        options=['col_1', 'col_2', 'col_3', 'col_4', 'col_5'],
        description='control column:',
        value=('col_2',),
        style={'description_width': 'initial'},
    )

  @staticmethod
  def _apply_style(size, text):
    span_style = ipywidgets.HTML(
        "<span style='font-size:" + str(size) + "px; background: "
        "linear-gradient(transparent 90%, #4285F4 0%);'>"
        + text
        + '</style>'
    )
    return span_style

  @staticmethod
  def success_text(text):
    return print(f"\033[38;2;15;157;88m " + text + "\033[0m")

  @staticmethod
  def failure_text(text):
    return print(f"\033[38;2;219;68;55m " + text + "\033[0m")

  @staticmethod
  def resize_colab_cell():
    display(
        Javascript(
            "google.colab.output.setIframeHeight(0, true, {maxHeight: 10000})"
        )
    )

  @staticmethod
  def saving_params(instance):
    params_dict = {
        'soure_selection': instance.soure_selection.selected_index,
        'sheet_url': instance.sheet_url.value,
        'sheet_name': instance.sheet_name.value,
        'csv_name': instance.csv_name.value,
        'bq_project_id': instance.bq_project_id.value,
        'bq_table_name': instance.bq_table_name.value,

        'data_type_selection': instance.data_type_selection.selected_index,
        'date_col': instance.date_col.value,
        'pivot_col': instance.pivot_col.value,
        'kpi_col': instance.kpi_col.value,

        'purpose_selection': instance.purpose_selection.selected_index,
        'pre_period_start': instance.pre_period_start.value,
        'pre_period_end': instance.pre_period_end.value,
        'post_period_start': instance.post_period_start.value,
        'post_period_end': instance.post_period_end.value,
        'start_date': instance.start_date.value,
        'end_date': instance.end_date.value,
        'depend_data': instance.depend_data.value,

        'design_type': instance.design_type.selected_index,
        'num_of_split': instance.num_of_split.value,
        'target_columns': instance.target_columns.value,
        'control_columns': instance.control_columns.value,
        'num_of_pick_range': instance.num_of_pick_range.value,
        'num_of_covariate': instance.num_of_covariate.value,
        'target_share': instance.target_share.value,
        'un_needed_cols': instance.un_needed_cols.value,

        'has_seasons': instance.has_seasons.value,
        'num_of_nseasons': instance.num_of_nseasons.value,
        'estimate_icpa': instance.estimate_icpa.value,
        }
    return params_dict

  @staticmethod
  def set_params(instance, dict_params):
    instance.soure_selection.selected_index = dict_params['soure_selection']
    instance.sheet_url.value = dict_params['sheet_url']
    instance.sheet_name.value = dict_params['sheet_name']
    instance.csv_name.value = dict_params['csv_name']
    instance.bq_project_id.value = dict_params['bq_project_id']
    instance.bq_table_name.value = dict_params['bq_table_name']

    instance.data_type_selection.selected_index = dict_params['data_type_selection']
    instance.date_col.value = dict_params['date_col']
    instance.pivot_col.value = dict_params['pivot_col']
    instance.kpi_col.value = dict_params['kpi_col']

    instance.purpose_selection.selected_index = dict_params['purpose_selection']
    instance.pre_period_start.value = dict_params['pre_period_start']
    instance.pre_period_end.value = dict_params['pre_period_end']
    instance.post_period_start.value = dict_params['post_period_start']
    instance.post_period_end.value = dict_params['post_period_end']
    instance.start_date.value = dict_params['start_date']
    instance.end_date.value = dict_params['end_date']
    instance.depend_data.value = dict_params['depend_data']

    instance.design_type.selected_index = dict_params['design_type']
    instance.num_of_split.value = dict_params['num_of_split']
    instance.target_columns.value = dict_params['target_columns']
    instance.control_columns.value = dict_params['control_columns']
    instance.num_of_pick_range.value = dict_params['num_of_pick_range']
    instance.num_of_covariate.value = dict_params['num_of_covariate']
    instance.target_share.value = dict_params['target_share']
    instance.un_needed_cols.value = dict_params['un_needed_cols']

    instance.has_seasons.value = dict_params['has_seasons']
    instance.num_of_nseasons.value = dict_params['num_of_nseasons']
    instance.estimate_icpa.value = dict_params['estimate_icpa']

  def generate_purpose_section(self):
    # soure_selection
    self.soure_selection = ipywidgets.Tab()
    self.soure_selection.children = [
        ipywidgets.VBox([self.sheet_url, self.sheet_name]),
        ipywidgets.VBox([self.csv_name]),
        ipywidgets.VBox([self.bq_project_id, self.bq_table_name]),
    ]
    self.soure_selection.set_title(0, 'Google_Spreadsheet')
    self.soure_selection.set_title(1, 'CSV_file')
    self.soure_selection.set_title(2, 'Big_Query')

    # data_type_selection
    self.data_type_selection = ipywidgets.Tab()
    self.data_type_selection.children = [
        ipywidgets.VBox([
            ipywidgets.Label(
                'Wide, or unstacked data is presented with each different data'
                ' variable in a separate column.'),
            self.date_col]),
        ipywidgets.VBox([
            ipywidgets.Label(
                'Narrow, stacked, or long data is presented with one column '
                'containing all the values and another column listing the '
                'context of the value'
            ),
            ipywidgets.HBox([self.date_col, self.pivot_col, self.kpi_col]),
        ]),
    ]
    self.data_type_selection.set_title(0, 'Wide_Format')
    self.data_type_selection.set_title(1, 'Narrow_Format')

    # design_type
    self.design_type = ipywidgets.Accordion(
        children=[
            ipywidgets.VBox([
                self.num_of_split,
                self.un_needed_cols,
            ]),
            ipywidgets.VBox([
                ipywidgets.HBox([
                    self.target_columns,
                    self.num_of_pick_range,
                    self.num_of_covariate,
                ]),
                self.un_needed_cols,
            ]),
            ipywidgets.VBox([
                ipywidgets.HBox([
                  self.target_share,
                  self.num_of_covariate,
                ]),
                self.un_needed_cols,
            ]),
            ipywidgets.VBox([
                ipywidgets.HTML('To improve reproducibility, it is important to create an accurate counterfactual model rather than a balanced assignment.'),
                self.target_columns,
                self.control_columns,
            ]),
        ]
    )
    self.design_type.set_title(
        0,
        (
            'A: divide_equally divides the time series data into N'
            ' groups(split#) with similar movements.'
        ),
    )
    self.design_type.set_title(
        1,
        (
            'B: similarity_selection extracts N groups(covariate#) that move'
            ' similarly to particular columns(target_cols).'
        ),
    )
    self.design_type.set_title(
        2,
        (
            'C: target share extracts targeted time series data from'
            ' the proportion of interventions.'
        ),
    )
    self.design_type.set_title(
        3,
        (
            'D: Assignment of specified targets and controls.'
        ),
    )

    # purpose_selection
    self.purpose_selection = ipywidgets.Tab()
    self.purpose_selection.children = [
        # Causalimpact
        ipywidgets.VBox([
            PreProcess._apply_style(
                15,
                '⑶ - a: Enter the Pre and Post the intervention.'
                ),
            ipywidgets.GridBox(
              [
                  ipywidgets.HTML(
                      '<b>Option 1: Use the beginning and end of input data for'
                      ' Pre Start & Post End</b>'
                      ),
                  ipywidgets.HTML(
                      '<b>Option 2: Manually enter the following</b>'
                      ),
                  self.depend_data,
                  self.pre_period_start,
                  ipywidgets.HTML('& Enter the rest manually'),
                  self.pre_period_end,
                  self.pre_period_end,
                  self.post_period_start,
                  self.post_period_start,
                  self.post_period_end,
              ],
              layout=ipywidgets.Layout(grid_template_columns="repeat(2, 500px)")
              ),
            ipywidgets.Label(''),
            PreProcess._apply_style(
                15,
                '⑶ - b: (Optional) Enter the number of periodicities in the '
                'time series data.'
                ),
            ipywidgets.VBox([
                self.has_seasons,
                self.num_of_nseasons,
                ]),
            ],
                        ),
        # Experimental_Design
        ipywidgets.VBox([
            PreProcess._apply_style(
                15,
                '⑶ - a: Enter the time period to be used for experimental '
                'design.'
                ),
            ipywidgets.GridBox(
              [
                  ipywidgets.HTML(
                      '<b>Option 1: Use the beginning and end of input data for'
                      ' Start & End</b>'),
                  ipywidgets.HTML(
                      '<b>Option 2: Manually enter the following</b>'),
                  self.depend_data,
                  self.start_date,
                  ipywidgets.Label(''),
                  self.end_date,
              ],
              layout=ipywidgets.Layout(grid_template_columns="repeat(2, 500px)")
            ),
            ipywidgets.Label(''),
            PreProcess._apply_style(
                15,
                '⑶ - b: Select the <b>experimental design method</b> and'
                ' enter the necessary items.'
                ),
            self.design_type,
            ipywidgets.Label(''),
            PreProcess._apply_style(
                15,
                '⑶ - c: (Optional) Enter <b>Estimated incremental CPA</b>(Cost'
                ' of intervention ÷ Lift from intervention without bias) & the '
                'number of periodicities in the time series data.'
                ),
            ipywidgets.HBox([
              self.estimate_icpa,
              ipywidgets.VBox([
                self.has_seasons,
                self.num_of_nseasons,
                ]),
              ]),
            ]),
        ]
    self.purpose_selection.set_title(0, 'Causalimpact')
    self.purpose_selection.set_title(1, 'Experimental_Design')

    display(
        PreProcess._apply_style(18, '⑴ Please select a data source.'),
        self.soure_selection,
        Markdown('<br>'),
        PreProcess._apply_style(
            18,
            '⑵ Please select wide or narrow data format.'),
        self.data_type_selection,
        Markdown('<br>'),
        PreProcess._apply_style(
            18,
            '⑶ Please select the purpose and set conditions.'),
        self.purpose_selection,
        )

  def load_data(self):
    if self.soure_selection.selected_index == 0:
      try:
        self._load_data_from_sheet(self.sheet_url.value, self.sheet_name.value)
      except Exception as e:
        PreProcess.failure_text('\n\nFailure!!')
        print('Error: {}'.format(e))
        print('Please check the following:')
        print(
            '* There is something wrong with the spreadsheet-related settings.'
        )
        print('* sheet url:{}'.format(self.sheet_url.value))
        print('* sheet name:{}'.format(self.sheet_name.value))
        PreProcess.failure_text('▲▲▲▲▲▲\n\n')
        raise Exception('Please check Failure')

    elif self.soure_selection.selected_index == 1:
      try:
        self._load_data_from_csv(self.csv_name.value)
      except Exception as e:
        PreProcess.failure_text('\n\nFailure!!')
        print('Error: {}'.format(e))
        print('Please check the following:')
        print('* There is something wrong with the CSV-related settings.')
        print('* CSV namel:{}'.format(self.csv_name.value))
        PreProcess.failure_text('▲▲▲▲▲▲\n\n')
        raise Exception('Please check Failure')
    elif self.soure_selection.selected_index == 2:
      try:
        self._load_data_from_bigquery(
            self.bq_project_id.value, self.bq_table_name.value
        )
      except Exception as e:
        PreProcess.failure_text('\n\nFailure!!')
        print('Error: {}'.format(e))
        print('Please check the following:')
        print('* There is something wrong with the bq-related settings.')
        print('* bq project id:{}'.format(self.bq_project_id.value))
        print('* bq table name :{}'.format(self.bq_table_name.value))
        PreProcess.failure_text('▲▲▲▲▲▲\n\n')
        raise Exception('Please check Failure')
    else:
      raise Exception('Please select a data souce at Step.1-2.')

    PreProcess.success_text('Success! The target data has been loaded.')
    display(self.df_sheet.head(3))

  def _load_data_from_sheet(self, spreadsheet_url, sheet_name):
    """load_data_from_sheet read data from spreadsheet.

    Args:
    spreadsheet_url: Spreadsheet url with data.
    sheet_name: Sheet name with data.
    """
    auth.authenticate_user()
    creds, _ = default()
    gc = gspread.authorize(creds)
    self._workbook = gc.open_by_url(spreadsheet_url)
    self._worksheet = self._workbook.worksheet(sheet_name)
    self.df_sheet = pd.DataFrame(self._worksheet.get_all_values())
    self.df_sheet.columns = list(self.df_sheet.loc[0, :])
    self.df_sheet.drop(0, inplace=True)
    self.df_sheet.reset_index(drop=True, inplace=True)
    self.df_sheet.replace(',', '', regex=True, inplace=True)
    self.df_sheet.rename(columns=lambda x: x.replace(" ", ""), inplace=True)
    self.df_sheet = self.df_sheet.apply(pd.to_numeric, errors='ignore')

  def _load_data_from_csv(self, csv_name):
    """load_data_from_csv read data from csv.

    Args:
    csv_name: csv file name.
    """
    uploaded = files.upload()
    self.df_sheet = pd.read_csv(io.BytesIO(uploaded[csv_name]))
    self.df_sheet.replace(',', '', regex=True, inplace=True)
    self.df_sheet.rename(columns=lambda x: x.replace(" ", ""), inplace=True)
    self.df_sheet = self.df_sheet.apply(pd.to_numeric, errors='ignore')

  def _load_data_from_bigquery(self, bq_project_id, bq_table_name):
    """load_data_from_csv read data from csv.

    Args:
    csv_name: csv file name.
    """
    auth.authenticate_user()
    client = bigquery.Client(project=bq_project_id)
    self.query = 'SELECT * FROM `' + bq_table_name + '`;'
    self.df_sheet = client.query(self.query).to_dataframe()
    self.df_sheet.replace(',', '', regex=True, inplace=True)
    self.df_sheet.rename(columns=lambda x: x.replace(" ", ""), inplace=True)
    self.df_sheet = self.df_sheet.apply(pd.to_numeric, errors='ignore')

  def format_data(self):
    """format_data formats the data according to the selected data type.

    Args:
        self: The instance of the class.

    Returns:
        The formatted data frame.
    """
    # Get the selected data type.
    self.data_type = self.data_type_selection.selected_index
    self.date_col_name = self.date_col.value.replace(" ", "")
    self.pivot_col_name = self.pivot_col.value.replace(" ", "")
    self.kpi_col_name = self.kpi_col.value.replace(" ", "")

    # Format the data.
    try:
      if self.data_type == 0:
        self.df_shaped = self.df_sheet.copy()
        self.df_shaped.drop(self.un_needed_cols.value.replace(", ", ",").split(','), axis=1, errors='ignore', inplace=True)
        self.df_shaped[self.date_col_name] = pd.to_datetime(self.df_shaped[self.date_col_name])
        self._trend_check()
      elif self.data_type == 1:
        self.df_shaped = self._shape_wide(
            self.df_sheet,
            self.date_col_name,
            self.pivot_col_name,
            self.kpi_col_name,
        )
        PreProcess.success_text('\nSuccess! The data was formatted for analysis.')
        self.df_shaped.drop(self.un_needed_cols.value.replace(", ", ",").split(','), axis=1, errors='ignore', inplace=True)
        self.df_shaped[self.date_col_name] = pd.to_datetime(self.df_shaped[self.date_col_name])
        display(self.df_shaped.head(3))
        self._trend_check()
      else:
        raise ValueError('Invalid data type.')
    except Exception as e:
      PreProcess.failure_text('\n\nFailure!!')
      print('Error: {}'.format(e))
      print('Please check the following:')
      print('* The selected data format.')
      print('* The data must be wide. Check the format in the previous cell.')
      print('* The values of the date, pivot, and kpi column.\n\n')
      raise Exception('Please check Failure')

  def _shape_wide(self, dataframe, date_column, pivot_column, kpi_column):
    """shape_wide pivots the data in the specified column.

    Converts long data to wide data suitable for experiment design using
    fastDTW.

    Args:
        dataframe: The DataFrame to be pivoted.
        date_column: The name of the column that contains the dates.
        pivot_column: The name of the column that contains the pivot keys.
        kpi_column: The name of the column that contains the KPI values.

    Returns:
        A DataFrame with the pivoted data.
    """
    # Check if the pivot_column is a single column or a list of columns.
    if ',' in pivot_column:
      self.group_cols = pivot_column.replace(', ', ',').split(',')
    else:
      self.group_cols = [pivot_column]

    # Group the dataframe by the date and group columns, and sum the kpi column.
    dataframe = dataframe[[date_column] + [kpi_column] + self.group_cols]
    dataframe = dataframe.groupby([date_column] + self.group_cols).sum()

    # Pivot the dataframe, filling missing values with 0.
    dataframe = pd.pivot_table(
        dataframe, index=date_column, columns=self.group_cols, fill_value=0
    )

    # Drop the first level of the column names.
    dataframe.columns = dataframe.columns.droplevel(0)

    # If there are multiple columns, convert the column names to a single string.
    if len(dataframe.columns.names) > 1:
      self.new_cols = ['_'.join([x.replace(",", "_") for x in y]) for y in dataframe.columns.values]
      dataframe.columns = self.new_cols

    dataframe = dataframe.reset_index()

    return dataframe

  def _trend_check(self):
    """trend_check visualize daily trend, 7-day moving average

    Args:
    dataframe_wide: Wide data to check the trend
    """
    self.df_each = pd.DataFrame(index=self.df_shaped[self.date_col_name])
    self.df_each.index = pd.to_datetime(self.df_each.index)
    self.tick_count = len(self.df_each.resample('M')) - 1

    self.col_list = list(self.df_shaped.columns)
    self.col_list.remove(self.date_col_name)
    for column in self.col_list:
      self.df_each[column] = list(
          self.df_shaped[column] - self.df_shaped[column].min()
      ) / (self.df_shaped[column].max() - self.df_shaped[column].min())

    self.line_each = (
        alt.Chart(self.df_each.reset_index())
        .transform_fold(fold=self.col_list, as_=['pivot', 'kpi'])
        .mark_line()
        .encode(
            alt.X(
                self.date_col_name + ':T',
                title=None,
                axis=alt.Axis(
                    grid=False, format='%Y %b', tickCount=self.tick_count
                ),
            ),
            alt.Y('kpi:Q', stack=None, axis=None),
            alt.Color('pivot:N'),
            alt.Row(
                'pivot:N',
                title=None,
                header=alt.Header(labelAngle=0, labelAlign='left'),
            ),
        )
        .properties(bounds='flush', height=50)
        .configure_facet(spacing=0)
        .configure_view(stroke=None)
        .configure_title(anchor='end')
    )

    self.df_long = (
        pd.melt(self.df_shaped, id_vars=self.date_col_name)
        .groupby(self.date_col_name)
        .sum(numeric_only=True)
        .reset_index()
    )
    self.line_total = (
        alt.Chart(self.df_long)
        .mark_line()
        .encode(
            x=alt.X(
                self.date_col_name + ':T',
                axis=alt.Axis(
                    title='', format='%Y %b', tickCount=self.tick_count
                ),
            ),
            y=alt.Y('value:Q', axis=alt.Axis(title='kpi')),
            color=alt.value('#4285F4'),
        )
    )
    self.moving_average = (
        alt.Chart(self.df_long)
        .transform_window(
            rolling_mean='mean(value)',
            frame=[-4, 3],
        )
        .mark_line()
        .encode(
            x=alt.X(self.date_col_name + ':T'),
            y=alt.Y('rolling_mean:Q'),
            color=alt.value('#DB4437'),
        )
    )

    self.tb_trend = widgets.TabBar(['all', 'each', 'describe'])
    with self.tb_trend.output_to('all'):
      display(
          (self.line_total + self.moving_average).properties(
              width=700,
              height=200,
              title={
                  'text': ['Daily Trend(blue) & 7days moving average(red)'],
              },
          )
      )
    with self.tb_trend.output_to('each'):
      display((self.line_each).properties(width=700))
    with self.tb_trend.output_to('describe'):
      display(self.df_shaped.describe(include='all'))
    with self.tb_trend.output_to('all'):
      pass


class CausalImpact(PreProcess):
  """ CausalImpact analysis and experimental design on CausalImpact.
  CausalImpact Analysis performs a CausalImpact analysis on the given data and outputs
  the results.
  The experimental design will be based on N partitions, similarity, or share, with 1000
  iterations of random sampling, and will output the three candidate groups with the
  closest DTW distance. A combination of increments and periods will be used to simulate and return which
  combination will result in a significantly different validation.

  Attributes:
    run_causalImpact: Runs CausalImpact on the given case.
    create_causalimpact_object:
    display_causalimpact_result:
    plot_causalimpact:

  Returns:
    The CausalImpact object.
  """

  def __init__(self):
    super().__init__()
    self.colors = [
      '#DB4437',
      '#AB47BC',
      '#4285F4',
      '#00ACC1',
      '#0F9D58',
      '#9E9D24',
      '#F4B400',
      '#FF7043',
    ]
    self.num_of_iteration = 1000
    self.combination_target = 10

  def run_causalImpact(self):
    self.ci_objs = []
    try:
      self.ci_obj = self.create_causalimpact_object(
        self.df_shaped,
        self.date_col_name,
        self.pre_period_start.value,
        self.pre_period_end.value,
        self.post_period_start.value,
        self.post_period_end.value,
      )
      self.ci_objs.append(self.ci_obj)
      PreProcess.success_text(
        '\nSuccess! CausalImpact has been performed. Check the results in the next cell.'
      )
    except Exception as e:
      PreProcess.failure_text('\n\nFailure!!')
      print('Error: {}'.format(e))
      print('Please check the following:')
      print('* Date source.')
      print('* Date Column Name.')
      print('* Duration of experiment (pre and post).')
      PreProcess.failure_text('▲▲▲▲▲▲\n\n')
      raise Exception('Please check Failure')

  def create_causalimpact_object(
      self, data, date_col, pre_start, pre_end, post_start, post_end):
    if self.has_seasons.value == True:
      self.causalimpact_object = causalimpact.fit_causalimpact(
        data=data.set_index(date_col),
        pre_period=(str(pre_start), str(pre_end)),
        post_period=(str(post_start), str(post_end)),
        model_options=causalimpact.ModelOptions(
            seasons=[causalimpact.Seasons(num_seasons=self.num_of_nseasons.value),])
      )
    else:
      self.causalimpact_object = causalimpact.fit_causalimpact(
        data=data.set_index(date_col),
        pre_period=(str(pre_start), str(pre_end)),
        post_period=(str(post_start), str(post_end))
      )
    return self.causalimpact_object

  def display_causalimpact_result(self):
    self.col_list =list(self.df_shaped.columns)
    self.col_list.remove(self.date_col_name)

    print('Test & Control Time Series')
    display(alt.Chart(self.df_shaped).transform_fold(
        self.col_list
    ).mark_line().encode(
        alt.X(
            self.date_col_name + ':T',
            title=None,
            axis=alt.Axis(
                format='%Y %b', tickCount=self.tick_count
            ),
        ),
        y=alt.Y('value:Q', axis=alt.Axis(title='kpi')),
        color=alt.Color('key:N',legend=alt.Legend(
            title=None,
            orient='none',
            legendY=-20,
            direction='horizontal',
            titleAnchor='start'),
            scale=alt.Scale(domain=self.col_list, range=self.colors)),
    ).properties(height=200, width=600))
    print("="*100)
    print('\n')

    self.plot_causalimpact(
        self.ci_objs[0],
        self.pre_period_start.value,
        self.pre_period_end.value,
        self.post_period_start.value,
        self.post_period_end.value,
    )

  def plot_causalimpact(self, causalimpact_object, pre_start, pre_end, tread_start, treat_end):
    self.causalimpact_df = causalimpact_object.series.copy()
    self.mape = mean_absolute_percentage_error(
        self.causalimpact_df['observed'][str(pre_start):str(pre_end)],
        self.causalimpact_df['posterior_mean'][str(pre_start):str(pre_end)])
    print('Approximate model accuracy >> MAPE:{:.2%}\n'.format(self.mape))
    print(causalimpact.summary(causalimpact_object, output_format='summary'))

    self.line_1 = alt.Chart(self.causalimpact_df.reset_index()).transform_fold(
        ['observed', 'posterior_mean',]
      ).mark_line().encode(
        x = alt.X('yearmonthdate('+self.date_col_name+')',
                  axis=alt.Axis(
                      title='',
                      labels=False,
                      ticks=False,
                      format='%Y %b',
                      tickCount=self.tick_count
                      )),
        y = alt.Y('value:Q', scale=alt.Scale(zero=False), axis=alt.Axis(title=''),),
        color=alt.Color('key:N',legend=alt.Legend(
            title=None,
            orient='none',
            legendY=-20,
            direction='horizontal',
            titleAnchor='start'),
            sort=['posterior_mean', 'observed']),
        strokeDash=alt.condition(
            alt.datum.key == 'posterior_mean',
            alt.value([5, 5]),
            alt.value([0]),))
    self.area_1 = alt.Chart(self.causalimpact_df.reset_index()).mark_area(opacity=0.3).encode(
        x = alt.X('yearmonthdate('+self.date_col_name+')'),
        y = alt.Y("posterior_lower:Q", scale=alt.Scale(zero=False)),
        y2 = alt.Y2("posterior_upper:Q"))
    self.line_2 = alt.Chart(self.causalimpact_df.reset_index()).mark_line(strokeDash=[5,5]).encode(
        x = alt.X('yearmonthdate('+self.date_col_name+')',
                  axis=alt.Axis(
                      title='',
                      labels=False,
                      ticks=False,
                      format='%Y %b',
                      tickCount=self.tick_count
                      )),
        y = alt.Y('point_effects_mean:Q', scale=alt.Scale(zero=False), axis=alt.Axis(title='')),)
    self.area_2 = alt.Chart(self.causalimpact_df.reset_index()).mark_area(opacity=0.3).encode(
        x = alt.X('yearmonthdate('+self.date_col_name+')'),
        y = alt.Y("point_effects_lower:Q", scale=alt.Scale(zero=False)),
        y2 = alt.Y2("point_effects_upper:Q"),)
    self.line_3 = alt.Chart(self.causalimpact_df.reset_index()).mark_line(strokeDash=[5,5]).encode(
        x = alt.X('yearmonthdate('+self.date_col_name+')',
                  axis=alt.Axis(
                      title='',
                      format='%Y %b',
                      tickCount=self.tick_count
                      )),
        y = alt.Y('cumulative_effects_mean:Q', scale=alt.Scale(zero=False), axis=alt.Axis(title='')),)
    self.area_3 = alt.Chart(self.causalimpact_df.reset_index()).mark_area(opacity=0.3).encode(
        x = alt.X('yearmonthdate('+self.date_col_name+')'),
        y = alt.Y("cumulative_effects_lower:Q", scale=alt.Scale(zero=False)),
        y2 = alt.Y2("cumulative_effects_upper:Q"),)
    self.zero_line = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule().encode(y='y',color=alt.value("gray"))
    self.rules = alt.Chart(
        pd.DataFrame({'Date': [str(tread_start), str(treat_end)], 'color': ['red', 'orange']})
        ).mark_rule(strokeDash=[5, 5]).encode(
            x='Date:T',
            color=alt.Color('color:N', scale=None))
    self.plot = alt.vconcat(
        (self.line_1 + self.area_1 + self.rules).properties(height=100, width=600),
        (self.line_2 + self.area_2 + self.rules + self.zero_line).properties(height=100, width=600),
        (self.line_3 + self.area_3 + self.rules + self.zero_line).properties(height=100, width=600)
    )
    display(self.plot)

  def run_experimental_design(self):
    self.df_design = self.df_shaped.copy().set_index(self.date_col_name)
    self.df_design.index = pd.to_datetime(self.df_design.index)
    if self.depend_data.value == True:
      self.start_date_value = min(self.df_design.index).date()
      self.end_date_value = max(self.df_design.index).date()
    else:
      self.start_date_value = self.start_date.value
      self.end_date_value = self.end_date.value
      self.df_design = self.df_design.query(
          '@self.start_date_value <= index <= @self.end_date_value'
      )

    if self.design_type.selected_index == 0:
      self.n_part_split()
    elif self.design_type.selected_index == 1:
      self.find_similar()
    elif self.design_type.selected_index == 2:
      self.from_share()
    elif self.design_type.selected_index == 3:
      self.given_assignment()
    else:
      PreProcess.failure_text('\n\nFailure!!')
      print('Please check the following:')
      print('* There is something wrong with design type.')
      print('* Please select A or B or C.')
      PreProcess.failure_text('▲▲▲▲▲▲\n\n')
      raise Exception('Please check Failure')

    self.reconstitute_dataframe()

  def n_part_split(self):
    self.df_dtw = pd.DataFrame(columns=['distance'])
    self.num_of_pick = len(self.df_design.columns) // self.num_of_split.value

    for l in tqdm(range(self.num_of_iteration)):
      self.col_list = list(self.df_design.columns)
      self.df_picked = pd.DataFrame()

      # random pick
      self.picks = []
      for s in range(self.num_of_split.value):
        self.pick = random.sample(self.col_list, self.num_of_pick)
        self.picks.append(self.pick)
        self.col_list = [ele for ele in self.col_list if ele not in self.pick]
      self.picks[0] + self.col_list
      for i in range(len(self.picks)):
        self.picked = pd.DataFrame(
            self.df_design[self.picks[i]].sum(axis=1), columns=[i]
        )
        self.df_picked = pd.concat([self.df_picked, self.picked], axis=1)

      # dtw
      self.distance = self._calculate_distance(
          self.df_picked.reset_index(drop=True)
      )
      self.df_dtw.loc[l, 'distance'] = float(self.distance)
      for j in range(len(self.picks)):
        self.df_dtw.at[l, j] = str(sorted(self.picks[j]))

    self.df_dtw = (
        self.df_dtw.drop_duplicates()
        .sort_values('distance')
        .head(3)
        .reset_index(drop=True)
    )

  def find_similar(self):
    self.df_dtw = pd.DataFrame(columns=['distance'])
    self.target_cols = self.target_columns.value.replace(', ', ',').split(',')
    self.pick_range = self.num_of_pick_range.value

    if (
        len(self.df_design.columns) - len(self.target_cols)
          >= self.pick_range[1] * self.num_of_covariate.value):
      pass
    else:
      print('Please check the following:')
      print('* There is something wrong with similarity settings.')
      print('* Total number of columns ー the target = {}'.format(
          len(self.df_design.columns) - len(self.target_cols)))
      print('* But your settings are {}(max pick#) × {}(covariate#)'.format(
          self.pick_range[1], self.num_of_covariate.value))
      print('* Please set it so that it does not exceed.')
      PreProcess.failure_text('▲▲▲▲▲▲\n\n')
      raise Exception('Please check Failure')

    for l in tqdm(range(self.num_of_iteration)):
      self.df_picked = pd.DataFrame()
      self.remained_list = [
          i for i in list(self.df_design.columns) if i not in self.target_cols
      ]

      # similar
      self.picks = []
      for s in range(self.num_of_covariate.value):
        self.num_of_pick = random.randrange(
            self.pick_range[0], self.pick_range[1] + 1, 1
        )
        self.pick = random.sample(self.remained_list, self.num_of_pick)
        self.picks.append(self.pick)
        self.remained_list = [
            ele for ele in self.remained_list if ele not in self.pick
        ]
      self.picks.insert(0, self.target_cols)
      for i in range(len(self.picks)):
        self.picked = pd.DataFrame(
            self.df_design[self.picks[i]].sum(axis=1), columns=[i]
        )
        self.df_picked = pd.concat([self.df_picked, self.picked], axis=1)

      # dtw
      self.distance = self._calculate_distance(
          self.df_picked.reset_index(drop=True)
      )
      self.df_dtw.loc[l, 'distance'] = float(self.distance)
      for j in range(len(self.picks)):
        self.df_dtw.at[l, j] = str(list(self.picks[j]))

    self.df_dtw = (
        self.df_dtw.drop_duplicates()
        .sort_values('distance')
        .head(3)
        .reset_index(drop=True)
    )

  def from_share(self):
      self.df_dtw = pd.DataFrame(columns=['distance'])
      self.pick_range = self.num_of_pick_range.value
      self.df_sum = pd.DataFrame(
          self.df_shaped
          .drop(self.date_col_name, axis=1)
          .sum(axis=0)).T
      self.share = self.target_share.value
      self.target = self.df_sum.sum(axis=1).loc[0] * self.share
      self.combinations = []

      n = 10000
      while len(self.combinations) < self.combination_target:
          n -= 1
          self.num_of_pick = random.randint(1, len(self.df_sum.columns)//2 + 1)
          self.picked_col = np.random.choice(self.df_sum.columns, self.num_of_pick, replace=False)
          self.sum_of_picked_numbers = self.df_sum[self.picked_col].sum(axis=1).loc[0]

          if abs(self.sum_of_picked_numbers - self.target) < self.target * 0.1:
            self.combination = set(self.picked_col)
            self.combinations.append(sorted(self.combination))
          if n == 1:
            PreProcess.failure_text('\n\nFailure!!')
            print('Please check the following:')
            print('* There is something wrong with design type C.')
            print('* Please re-set target share')
            PreProcess.failure_text('▲▲▲▲▲▲\n\n')
            raise Exception('Please check Failure')

      self._id = 0
      for comb in tqdm(self.combinations):
        self.df_picked = pd.DataFrame()

        for l in tqdm(range(self.num_of_iteration // self.combination_target)):
          self.df_picked = pd.DataFrame()
          self.remained_list = [
              i for i in list(self.df_design.columns) if i not in comb
          ]

          self.picks = []
          for s in range(self.num_of_covariate.value):
            self.num_of_pick = random.randrange(
                round(len(self.remained_list) // self.num_of_covariate.value * 0.5),
                len(self.remained_list) // self.num_of_covariate.value,
                1
            )
            self.pick = random.sample(self.remained_list, self.num_of_pick)
            self.picks.append(self.pick)
            self.remained_list = [
                ele for ele in self.remained_list if ele not in self.pick
            ]
          self.picks.insert(0, comb)

          for i in range(len(self.picks)):
            self.picked = pd.DataFrame(
                self.df_design[self.picks[i]].sum(axis=1), columns=[i]
            )
            self.df_picked = pd.concat([self.df_picked, self.picked], axis=1)

          self.distance = self._calculate_distance(
              self.df_picked.reset_index(drop=True)
          )
          self.df_dtw.loc[self._id, 'distance'] = float(self.distance)
          for j in range(len(self.picks)):
            self.df_dtw.at[self._id, j] = str(list(self.picks[j]))
          self._id += 1

      self.df_dtw = (
          self.df_dtw.drop_duplicates()
          .sort_values('distance')
          .head(3)
          .reset_index(drop=True)
      )

  def given_assignment(self):
    self.df_dtw = pd.DataFrame()
    self.df_dtw.loc[0, 'distance'] = 0
    self.df_dtw.loc[0, 0] = str(self.target_columns.value.replace(', ', ',').split(','))
    self.df_dtw.loc[0, 1] = str(self.control_columns.value.replace(', ', ',').split(','))

  def _calculate_distance(self, dataframe):
    self.dist = 0
    for column in dataframe:
      dataframe[column] = (dataframe[column] - dataframe[column].min()) / (
          dataframe[column].max() - dataframe[column].min()
      )
    # dataframe = dataframe.reset_index()
    dataframe = dataframe.diff().reset_index().dropna()
    for v in itertools.combinations(list(dataframe.columns), 2):
      self.distance, self.path = fastdtw.fastdtw(
          dataframe.loc[:, ['index', v[0]]],
          dataframe.loc[:, ['index', v[1]]],
          dist=euclidean,
      )
      self.dist = self.dist + self.distance
    return self.dist

  def reconstitute_dataframe(self):
    self.design_days = (self.end_date_value - self.start_date_value).days
    self.pre_total = self.df_design.sum().sum()
    self.pre_daily = self.pre_total // self.design_days

    self.candidate_tb = widgets.TabBar(
        ['option_' + str(sub + 1) for sub in self.df_dtw.index.tolist()]
    )
    for i in range(len(self.df_dtw)):
      with self.candidate_tb.output_to(i):
        self.candidate_df = pd.DataFrame(
            index=self.df_shaped[self.date_col_name]
        )
        for col in range(len(self.df_dtw.columns) - 1):
          print('col_' + str(col + 1) + self.df_dtw.at[i, col])
          self.candidate_df[col + 1] = list(
              self.df_shaped.loc[:, eval(self.df_dtw.at[i, col])].sum(axis=1)
          )
        print('\n')
        self.candidate_df = self.candidate_df.add_prefix('col_')

        self.candidate_share = pd.DataFrame(
            self.candidate_df.loc[str(self.start_date_value):str(self.end_date_value), :].sum(),
            columns=['total'])
        self.candidate_share['daily'] = self.candidate_share['total'] // self.design_days
        self.candidate_share['share'] = self.candidate_share['total'] / self.pre_total
        for i in self.candidate_df.columns:
          self.stl = STL(self.candidate_df[i], robust=True).fit()
          self.candidate_share.loc[i, 'std'] = np.std(self.stl.seasonal + self.stl.resid)

        display(case_1.candidate_share[['daily', 'share', 'std']].style.format({
            'daily': '{:,.0f}',
            'share': '{:.1%}',
            'std': '{:,.0f}',
        }))

        self.chart_line = (
            alt.Chart(self.candidate_df.reset_index())
            .transform_fold(
                fold=list(self.candidate_df.columns), as_=['pivot', 'kpi']
            )
            .mark_line()
            .encode(
                x=alt.X(
                    self.date_col_name + ':T',
                    title=None,
                    axis=alt.Axis(
                    grid=False, format='%Y %b', tickCount=self.tick_count
                    ),
                ),
                y=alt.Y('kpi:Q'),
                color=alt.Color(
                  'pivot:N',
                  legend=alt.Legend(
                    title=None,
                    orient='none',
                    legendY=-20,
                    direction='horizontal',
                    titleAnchor='start'),
                  scale=alt.Scale(domain=list(self.candidate_df.columns), range=self.colors)),
                )
            .properties(width=600, height=200)
        )

        self.rules = alt.Chart(
            pd.DataFrame({'Date': [str(self.start_date_value), str(self.end_date_value)], 'color': ['red', 'orange']})
            ).mark_rule(strokeDash=[5, 5]).encode(
                x='Date:T',
                color=alt.Color('color:N', scale=None))

        self.df_scaled = self.candidate_df.copy()
        self.df_scaled[:] = MinMaxScaler().fit_transform(self.df_scaled)
        self.chart_line_scaled = (
            alt.Chart(self.df_scaled.reset_index())
            .transform_fold(
                fold=list(self.df_scaled.columns),
                as_=['pivot', 'kpi']
            )
            .mark_line()
            .encode(
                x=alt.X(
                    self.date_col_name + ':T',
                    title=None,
                    axis=alt.Axis(
                    grid=False, format='%Y %b', tickCount=self.tick_count
                    ),
                ),
                y=alt.Y('kpi:Q'),
                color=alt.Color(
                  'pivot:N',
                  legend=alt.Legend(
                    title=None,
                    orient='none',
                    legendY=-20,
                    direction='horizontal',
                    titleAnchor='start'),
                  scale=alt.Scale(domain=list(self.df_scaled.columns), range=self.colors)),
                )
            .properties(width=600, height=80)
        )

        self.df_diff = pd.DataFrame(
            np.diff(self.candidate_df, axis=0),
            columns=self.candidate_df.columns.values,
        )
        self.scatter = (
            alt.Chart(self.df_diff.reset_index())
            .mark_circle()
            .encode(
                alt.X(alt.repeat('column'), type='quantitative'),
                alt.Y(alt.repeat('row'), type='quantitative'),
            )
            .properties(width=80, height=80)
            .repeat(
                row=self.df_diff.columns.values,
                column=self.df_diff.columns.values,
            )
        )
        display(
            alt.vconcat(self.chart_line + self.rules, self.chart_line_scaled) | self.scatter)

    with self.candidate_tb.output_to(0):
      pass

    display(
        PreProcess._apply_style(
            18,
            'Please select option, test column & control column(s).'),
        ipywidgets.HBox([
            self.your_choice,
            self.test_column,
            self.control_column,
        ]),
    )

  def generate_simulation(self):
    # reconstitute_dataframe
    self.selection_row = int(self.your_choice.value.replace('option_', '')) - 1
    self.selection_cols = [
        int(self.test_column.value.replace('col_', '')) - 1
    ] + [
        int(s.replace('col_', '')) - 1 for s in list(self.control_column.value)
    ]
    self.colnames = ['test']
    for i in range(len(self.selection_cols) - 1):
      self.colnames.append('control_' + str(i + 1))
    self.df = pd.DataFrame(index=self.df_shaped[self.date_col_name])
    self.df.index = pd.to_datetime(self.df.index)
    for col in self.selection_cols:
      self.df[col] = list(
          self.df_shaped.loc[
              :, eval(self.df_dtw.at[self.selection_row, col])
          ].sum(axis=1)
      )
    self.df.columns = self.colnames
    for x, i in zip(self.df.columns, self.selection_cols):
      print('{}: {}'.format(x, self.df_dtw.at[self.selection_row, i]))

    # simulation
    self.ci_objs = []
    self.simulate_periods = []
    self.adjusted_df = self.df.copy()
    self.treat_duration = [7, 14, 28]
    self.treat_impact = [1, 1.01, 1.03, 1.05, 1.10, 1.2]
    self.simulation_combination = list(
        itertools.product(self.treat_impact, self.treat_duration)
    )
    for impact in tqdm(self.treat_impact):
      for duration in tqdm(self.treat_duration):
        self.pre_end_date = self.end_date_value + datetime.timedelta(
            days=-duration
        )
        self.post_start_date = self.pre_end_date + datetime.timedelta(days=1)
        self.adjusted_df.loc[
            np.datetime64(self.post_start_date) : np.datetime64(
                self.end_date_value
            ),
            'test',
        ] = (
            self.df.loc[
                np.datetime64(self.post_start_date) : np.datetime64(
                    self.end_date_value
                ),
                'test',
            ]
            * impact
        )
        self.ci_obj = self.create_causalimpact_object(
            self.adjusted_df.reset_index(),
            self.date_col_name,
            self.start_date_value,
            self.pre_end_date,
            self.post_start_date,
            self.end_date_value,
        )
        self.simulate_periods.append([
            self.start_date_value,
            self.pre_end_date,
            self.post_start_date,
            self.end_date_value,
        ])
        self.ci_objs.append(self.ci_obj)

    self.simulation_df = pd.DataFrame(
        index=[],
        columns=[
            'Simulated_impact',
            'Days_simulated',
            'MAPE',
            'Total_effect',
            'Average_effect',
            'Required_budget',
            'p_value',
        ],
    )
    for i in range(len(self.ci_objs)):
      self.periods = self.simulate_periods[i]
      self.impact_df = self.ci_objs[i].series.copy()
      self.impact_dict = {
          'Simulated_impact': self.simulation_combination[i][0] - 1,
          'Days_simulated': self.simulation_combination[i][1],
          'MAPE': [
              mean_absolute_percentage_error(
                  self.impact_df.loc[:, 'observed'][
                      str(self.simulate_periods[i][0]) : str(
                          self.simulate_periods[i][1]
                      )
                  ],
                  self.impact_df.loc[:, 'posterior_mean'][
                      str(self.simulate_periods[i][0]) : str(
                          self.simulate_periods[i][1]
                      )
                  ],
              )
          ],
          'Total_effect': self.ci_objs[i].summary.loc['cumulative', 'abs_effect'],
          'Average_effect': self.ci_objs[i].summary.loc['average', 'abs_effect'],
          'Required_budget': [
              self.ci_objs[i].summary.loc['cumulative', 'abs_effect'] * self.estimate_icpa.value
          ],
          'p_value': self.ci_objs[i].summary.loc['average', 'p_value'],
      }
      self.simulation_df = pd.concat(
          [self.simulation_df, pd.DataFrame.from_dict(self.impact_dict)],
          ignore_index=True,
      )

    display(
        self.simulation_df.style.format({
            'Days_simulated': '{:.0f} d',
            'Simulated_impact': '{:+.0%}',
            'MAPE': '{:.2%}',
            'Total_effect': '{:,.2f}',
            'Average_effect': '{:,.2f}',
            'Required_budget': '{:,.0f}',
            'p_value': '{:,.2f}',
        })
    )

    self.simulation_tb = widgets.TabBar(self.simulation_combination)
    for i in range(len(self.simulation_combination)):
      with self.simulation_tb.output_to(i):
        print(
            'Pre Period:{} ~ {}\nPost Period:{} ~ {}'.format(
                self.simulate_periods[i][0],
                self.simulate_periods[i][1],
                self.simulate_periods[i][2],
                self.simulate_periods[i][3],
            )
        )
        self.plot_causalimpact(
            self.ci_objs[i],
            self.simulate_periods[i][0],
            self.simulate_periods[i][1],
            self.simulate_periods[i][2],
            self.simulate_periods[i][3],
        )
      with self.simulation_tb.output_to(0):
        pass

get_ipython().events.register("pre_run_cell", PreProcess.resize_colab_cell)
case_1 = CausalImpact()
case_1.generate_purpose_section()
if 'dict_params' in globals():
  PreProcess.set_params(case_1, dict_params)

In [None]:
# @title Step.2
%%time
case_1.load_data()
case_1.format_data()
dict_params = PreProcess.saving_params(case_1)

if case_1.purpose_selection.selected_index == 0:
  case_1.run_causalImpact()
else:
  case_1.run_experimental_design()

In [None]:
# @title Step.3
%%time
if case_1.purpose_selection.selected_index == 0:
  case_1.display_causalimpact_result()
else:
  case_1.generate_simulation()

# ================ 【Option】 ================

The following are optional.

You can carry over the parameters set in Step.1 if you want to design or analyze under different conditions.

In [None]:
# @title Step.1
overwrite_pramas = True #@param {type:"boolean"}
case_2 = CausalImpact()
case_2.generate_purpose_section()
if overwrite_pramas == True: PreProcess.set_params(case_2, dict_params)


In [None]:
# @title Step.2
%%time
case_2.load_data()
case_2.format_data()

if case_2.purpose_selection.selected_index == 0:
  case_2.run_causalImpact()
else:
  case_2.run_experimental_design()

In [None]:
# @title Step.3
%%time
if case_2.purpose_selection.selected_index == 0:
  case_2.display_causalimpact_result()
else:
  case_2.generate_simulation()

##### One more analysis

In [None]:
# @title Step.1
overwrite_pramas = False #@param {type:"boolean"}
case_3 = CausalImpact()
case_3.generate_purpose_section()
if overwrite_pramas == True: PreProcess.set_params(case_3, dict_params)

In [None]:
# @title Step.2
%%time
case_3.load_data()
case_3.format_data()

if case_3.purpose_selection.selected_index == 0:
  case_3.run_causalImpact()
else:
  case_3.run_experimental_design()

In [None]:
# @title Step.3
%%time
if case_3.purpose_selection.selected_index == 0:
  case_3.display_causalimpact_result()
else:
  case_3.generate_simulation()