<a href="https://colab.research.google.com/github/google/business_intelligence_group/blob/development/solutions/causal-impact/CausalImpact_with_Experimental_Design.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **CausalImpact with Experimental Design**

This Colab file contains *Experimental Design* and *CausalImpact Analysis*.

See [README.md](https://github.com/google/business_intelligence_group/tree/main/solutions/causal-impact) for details

---

Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

In [None]:
# @title Step.1 (~ 2min)
%%time

import sys
if 'fastdtw' not in sys.modules:
  !pip install 'fastdtw' --q
if 'tslearn' not in sys.modules:
  !pip install 'tslearn' --q
if 'tfp-causalimpact' not in sys.modules:
  !pip install 'tfp-causalimpact' --q

# Data Load
from google.colab import auth, files, widgets
from google.auth import default
from google.cloud import bigquery
import io
import os
import gspread
from oauth2client.client import GoogleCredentials

# Calculate
import altair as alt
import itertools
import random
import numpy as np
import pandas as pd
import fastdtw

from tslearn.clustering import TimeSeriesKMeans
from decimal import Decimal, ROUND_HALF_UP
from scipy.spatial.distance import euclidean
from sklearn.metrics import mean_absolute_percentage_error
from sklearn.preprocessing import MinMaxScaler
from statsmodels.tsa.seasonal import STL

# UI/UX
import datetime
from dateutil.relativedelta import relativedelta
import ipywidgets
from IPython.display import display, Markdown, HTML, Javascript
from tqdm.auto import tqdm
import warnings
warnings.simplefilter('ignore')

# causalimpact
import causalimpact
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

class PreProcess(object):
  """PreProcess handles process from data loading to visualization.

    Create a UI, load time series data based on input and do some
    transformations to pass it to analysis. This also includes visualization of
    points that should be confirmed in time series data.

    Attributes:
      _apply_text_style: Decorate the text
      define_ui: Define the UI using ipywidget
      generate_ui: Generates UI for input from the user
      load_data: Load data from any data source
      _load_data_from_sheet: Load data from spreadsheet
      _load_data_from_csv: Load data from CSV
      _load_data_from_bigquery: Load data from Big Query
      format_date: Set index
      _shape_wide: Configure narrow/wide conversion
      _trend_check: Visualize data
      saving_params: Save the contents entered in the UI
      set_params: Set the saved input contents to the instance
  """

  def __init__(self):
    self.define_ui()

  @staticmethod
  def _apply_text_style(type, text):
    # todo@(rhirota): Need to reconsideration about type
    if type == 'success':
      return print(f"\033[38;2;15;157;88m " + text + "\033[0m")

    if type == 'failure':
      return print(f"\033[38;2;219;68;55m " + text + "\033[0m")

    if isinstance(type, int):
      span_style = ipywidgets.HTML(
        "<span style='font-size:" + str(type) + "px; background: "
        "linear-gradient(transparent 90%, #4285F4 0%);'>"
        + text
        + '</style>'
      )
      return span_style

  def define_ui(self):
    # Input box for data sources
    self.sheet_url = ipywidgets.Text(
        placeholder='Please enter google spreadsheet url',
        value='https://docs.google.com/spreadsheets/d/1dISrbX1mZHgzpsIct2QXFOWWRRJiCxDSmSzjuZz64Tw/edit#gid=0',
        description='spreadsheet url:',
        style={'description_width': 'initial'},
        layout=ipywidgets.Layout(width='800px'),
    )
    self.sheet_name = ipywidgets.Text(
        placeholder='Please enter sheet name',
        value='analysis_data',
        # value='raw_data',
        description='sheet name:',
    )
    self.csv_name = ipywidgets.Text(
        placeholder='Please enter csv name',
        description='csv name:',
        layout=ipywidgets.Layout(width='500px'),
    )
    self.bq_project_id = ipywidgets.Text(
        placeholder='Please enter project id',
        description='project id:',
        layout=ipywidgets.Layout(width='500px'),
    )
    self.bq_table_name = ipywidgets.Text(
        placeholder='Please enter table name',
        description='table name:',
        layout=ipywidgets.Layout(width='500px'),
    )

    # Input box for data format
    self.date_col = ipywidgets.Text(
        placeholder='Please enter date column name',
        value='Date',
        description='date column:',
    )
    self.pivot_col = ipywidgets.Text(
        placeholder='Please enter pivot column name',
        value='Geo',
        description='pivot column:',
    )
    self.kpi_col = ipywidgets.Text(
        placeholder='Please enter kpi column name',
        value='KPI',
        description='kpi column:',
    )

    # Input box for Date-related
    self.pre_period_start = ipywidgets.DatePicker(
        description='Pre Start:',
        value=datetime.date.today() - relativedelta(days=122),
    )
    self.pre_period_end = ipywidgets.DatePicker(
        description='Pre End:',
        value=datetime.date.today() - relativedelta(days=32),
    )
    self.post_period_start = ipywidgets.DatePicker(
        description='Post Start:',
        value=datetime.date.today() - relativedelta(days=31),
    )
    self.post_period_end = ipywidgets.DatePicker(
        description='Post End:',
        value=datetime.date.today(),
    )
    self.start_date = ipywidgets.DatePicker(
        description='Start Date:',
        value=datetime.date.today() - relativedelta(days=122),
    )
    self.end_date = ipywidgets.DatePicker(
        description='End Date:',
        value=datetime.date.today() - relativedelta(days=32),
    )
    self.depend_data = ipywidgets.ToggleButton(
        value=False,
        description='Click >> Use the beginning and end of data',
        disabled=False,
        button_style='info',
        tooltip='Description',
        layout=ipywidgets.Layout(width='300px'),
    )

    # Input box for Experimental_Design-related
    self.exclude_cols = ipywidgets.Text(
        placeholder=(
            'Enter comma-separated columns if any columns are not used in the'
            ' design.'
        ),
        description='exclude cols:',
        layout=ipywidgets.Layout(width='1000px'),
    )
    self.num_of_split = ipywidgets.Dropdown(
        options=[2, 3, 4, 5],
        value=2,
        description='split#:',
        disabled=False,
    )
    self.target_columns = ipywidgets.Text(
        placeholder='Please enter comma-separated entries',
        value='Tokyo, Kanagawa',
        description='target_cols:',
        layout=ipywidgets.Layout(width='500px'),
    )
    self.num_of_pick_range = ipywidgets.IntRangeSlider(
        value=[5, 10],
        min=1,
        max=50,
        step=1,
        description='pick range:',
        orientation='horizontal',
        readout=True,
        readout_format='d',
    )
    self.num_of_covariate = ipywidgets.Dropdown(
        options=[1, 2, 3, 4, 5],
        value=1,
        description='covariate#:',
        layout=ipywidgets.Layout(width='192px'),
    )
    self.target_share = ipywidgets.FloatSlider(
        value=0.3,
        min=0.05,
        max=0.5,
        step=0.05,
        description='target share#:',
        orientation='horizontal',
        readout=True,
        readout_format='.1%',
    )
    self.control_columns = ipywidgets.Text(
        placeholder='Please enter comma-separated entries',
        value='Aomori, Akita',
        description='control_cols:',
        layout=ipywidgets.Layout(width='500px'),
    )

    # Input box for simulation params
    self.num_of_seasons = ipywidgets.IntText(
        value=1,
        description='num_of_seasons:',
        disabled=False,
        style={'description_width': 'initial'},
    )
    self.estimate_icpa = ipywidgets.IntText(
        value=1000,
        description='Estimated iCPA:',
        style={'description_width': 'initial'},
    )
    self.confidence_interval = ipywidgets.RadioButtons(
        options=[90, 95],
        value=95,
        description='Confidence interval %:',
        style={'description_width': 'initial'},
    )

  def generate_ui(self):
    # UI for data soure
    self.soure_selection = ipywidgets.Tab()
    self.soure_selection.children = [
        ipywidgets.VBox([self.sheet_url, self.sheet_name]),
        ipywidgets.VBox([self.csv_name]),
        ipywidgets.VBox([self.bq_project_id, self.bq_table_name]),
    ]
    self.soure_selection.set_title(0, 'Google_Spreadsheet')
    self.soure_selection.set_title(1, 'CSV_file')
    self.soure_selection.set_title(2, 'Big_Query')

    # UI for data type(narrow or wide)
    self.data_type_selection = ipywidgets.Tab()
    self.data_type_selection.children = [
        ipywidgets.VBox([
            ipywidgets.Label(
                'Wide, or unstacked data is presented with each different'
                ' data variable in a separate column.'
            ),
            self.date_col,
        ]),
        ipywidgets.VBox([
            ipywidgets.Label(
                'Narrow, stacked, or long data is presented with one column '
                'containing all the values and another column listing the '
                'context of the value'
            ),
            ipywidgets.HBox([self.date_col, self.pivot_col, self.kpi_col]),
        ]),
    ]
    self.data_type_selection.set_title(0, 'Wide_Format')
    self.data_type_selection.set_title(1, 'Narrow_Format')

    # UI for experimental design
    self.design_type = ipywidgets.Tab(
        children=[
            ipywidgets.VBox([
                ipywidgets.HTML(
                    'divide_equally divides the time series data into N'
                    ' groups(split#) with similar movements.'
                ),
                self.num_of_split,
                self.exclude_cols,
            ]),
            ipywidgets.VBox([
                ipywidgets.HTML(
                    'similarity_selection extracts N groups(covariate#) that '
                    'move similarly to particular columns(target_cols).'
                ),
                ipywidgets.HBox([
                    self.target_columns,
                    self.num_of_covariate,
                    self.num_of_pick_range,
                ]),
                self.exclude_cols,
            ]),
            ipywidgets.VBox([
                ipywidgets.HTML(
                    'target share extracts targeted time series data from'
                    ' the proportion of interventions.'
                ),
                self.target_share,
                self.exclude_cols,
            ]),
            ipywidgets.VBox([
                ipywidgets.HTML(
                    'To improve reproducibility, it is important to create an'
                    ' accurate counterfactual model rather than a balanced'
                    ' assignment.'
                ),
                self.target_columns,
                self.control_columns,
            ]),
        ]
    )
    self.design_type.set_title(0, 'A: divide_equally')
    self.design_type.set_title(1, 'B: similarity_selection')
    self.design_type.set_title(2, 'C: target_share')
    self.design_type.set_title(3, 'D: pre-allocated')

    # UI for purpose (CausalImpact or Experimental Design)
    self.purpose_selection = ipywidgets.Tab()
    self.date_selection = ipywidgets.Tab()
    self.date_selection.children = [
        ipywidgets.VBox(
            [
                ipywidgets.HTML('The <b>minimum</b> date of the data is '
                'selected as the start date.'),
                ipywidgets.HTML('The <b>maximum</b> date in the data is '
                'selected as the end date.'),
            ]),
        ipywidgets.VBox(
            [
                self.start_date,
                self.end_date,
            ]
        )]
    self.date_selection.set_title(0, 'automatic selection')
    self.date_selection.set_title(1, 'manual input')

    self.purpose_selection.children = [
        # Causalimpact
        ipywidgets.VBox([
            PreProcess._apply_text_style(
                15, '⑶ - a: Enter the Pre and Post the intervention.'
            ),
            self.pre_period_start,
            self.pre_period_end,
            self.post_period_start,
            self.post_period_end,
            PreProcess._apply_text_style(
                15,
                '⑶ - b: Enter the number of periodicities in the'
                ' time series data.(default=1)',
            ),
            ipywidgets.VBox([self.num_of_seasons, self.confidence_interval]),
        ]),
        # Experimental_Design
        ipywidgets.VBox([
            PreProcess._apply_text_style(
                15,
                '⑶ - a: Please select date for experimental design',
            ),
            self.date_selection,
            PreProcess._apply_text_style(
                15,
                '⑶ - b: Select the <b>experimental design method</b> and'
                ' enter the necessary items.',
            ),
            self.design_type,
            PreProcess._apply_text_style(
                15,
                '⑶ - c: (Optional) Enter <b>Estimated incremental CPA</b>(Cost'
                ' of intervention ÷ Lift from intervention without bias) & the '
                'number of periodicities in the time series data.',
            ),
            ipywidgets.VBox([
                self.estimate_icpa,
                self.num_of_seasons,
                self.confidence_interval,
            ]),
        ]),
    ]
    self.purpose_selection.set_title(0, 'Causalimpact')
    self.purpose_selection.set_title(1, 'Experimental_Design')

    display(
        PreProcess._apply_text_style(18, '⑴ Please select a data source.'),
        self.soure_selection,
        Markdown('<br>'),
        PreProcess._apply_text_style(
            18, '⑵ Please select wide or narrow data format.'
        ),
        self.data_type_selection,
        Markdown('<br>'),
        PreProcess._apply_text_style(
            18, '⑶ Please select the purpose and set conditions.'
        ),
        self.purpose_selection,
    )

  def load_data(self):
    if self.soure_selection.selected_index == 0:
      try:
        self.loaded_df = self._load_data_from_sheet(
            self.sheet_url.value, self.sheet_name.value
        )
      except Exception as e:
        self._apply_text_style('failure', '\n\nFailure!!')
        print('Error: {}'.format(e))
        print('Please check the following:')
        print('* sheet url:{}'.format(self.sheet_url.value))
        print('* sheet name:{}'.format(self.sheet_name.value))
        raise Exception('Please check Failure')

    elif self.soure_selection.selected_index == 1:
      try:
        self.loaded_df = self._load_data_from_csv(self.csv_name.value)
      except Exception as e:
        self._apply_text_style('failure', '\n\nFailure!!')
        print('Error: {}'.format(e))
        print('Please check the following:')
        print('* There is something wrong with the CSV-related settings.')
        print('* CSV namel:{}'.format(self.csv_name.value))
        raise Exception('Please check Failure')

    elif self.soure_selection.selected_index == 2:
      try:
        self.loaded_df = self._load_data_from_bigquery(
            self.bq_project_id.value, self.bq_table_name.value
        )
      except Exception as e:
        self._apply_text_style('failure', '\n\nFailure!!')
        print('Error: {}'.format(e))
        print('Please check the following:')
        print('* There is something wrong with the bq-related settings.')
        print('* bq project id:{}'.format(self.bq_project_id.value))
        print('* bq table name :{}'.format(self.bq_table_name.value))
        raise Exception('Please check Failure')

    else:
      raise Exception('Please select a data souce at Step.1-2.')

    self._apply_text_style(
        'success',
        'Success! The target data has been loaded.')
    display(self.loaded_df.head(3))

  @staticmethod
  def _load_data_from_sheet(spreadsheet_url, sheet_name):
    """load_data_from_sheet load data from spreadsheet.

    Args:
      spreadsheet_url: Spreadsheet url with data.
      sheet_name: Sheet name with data.
    """
    auth.authenticate_user()
    creds, _ = default()
    gc = gspread.authorize(creds)
    _workbook = gc.open_by_url(spreadsheet_url)
    _worksheet = _workbook.worksheet(sheet_name)
    df_sheet = pd.DataFrame(_worksheet.get_all_values())
    df_sheet.columns = list(df_sheet.loc[0, :])
    df_sheet.drop(0, inplace=True)
    df_sheet.reset_index(drop=True, inplace=True)
    df_sheet.replace(',', '', regex=True, inplace=True)
    df_sheet.rename(columns=lambda x: x.replace(" ", ""), inplace=True)
    df_sheet = df_sheet.apply(pd.to_numeric, errors='ignore')
    return df_sheet

  @staticmethod
  def _load_data_from_csv(csv_name):
    """load_data_from_csv read data from csv.

    Args:
    csv_name: csv file name.
    """
    uploaded = files.upload()
    df_csv = pd.read_csv(io.BytesIO(uploaded[csv_name]))
    df_csv.replace(',', '', regex=True, inplace=True)
    df_csv.rename(columns=lambda x: x.replace(" ", ""), inplace=True)
    df_csv = df_csv.apply(pd.to_numeric, errors='ignore')
    return df_csv

  @staticmethod
  def _load_data_from_bigquery(bq_project_id, bq_table_name):
    """_load_data_from_bigquery load data from bigquery.

    Args:
    bq_project_id: bigquery project id.
    bq_table_name: bigquery table name
    """
    auth.authenticate_user()
    client = bigquery.Client(project=bq_project_id)
    query = 'SELECT * FROM `' + bq_table_name + '`;'
    df_bq = client.query(query).to_dataframe()
    df_bq.replace(',', '', regex=True, inplace=True)
    df_bq.rename(columns=lambda x: x.replace(" ", ""), inplace=True)
    df_bq = df_bq.apply(pd.to_numeric, errors='ignore')
    return df_bq

  def format_data(self):
    # Remove spaces from input data
    self.date_col_name = self.date_col.value.replace(' ', '')
    self.pivot_col_name = self.pivot_col.value.replace(' ', '')
    self.kpi_col_name = self.kpi_col.value.replace(' ', '')

    try:
      if self.data_type_selection.selected_index == 0:
        self.formatted_data = self.loaded_df.copy()
      elif self.data_type_selection.selected_index == 1:
        self.formatted_data = self._shape_wide(
            self.loaded_df,
            self.date_col_name,
            self.pivot_col_name,
            self.kpi_col_name,
        )

      self.formatted_data.drop(
          self.exclude_cols.value.replace(', ', ',').split(','),
          axis=1,
          errors='ignore',
          inplace=True,
      )
      self.formatted_data[self.date_col_name] = pd.to_datetime(
          self.formatted_data[self.date_col_name]
      )
      self.formatted_data = self.formatted_data.set_index(self.date_col_name)
      self.formatted_data = self.formatted_data.reindex(
          pd.date_range(
              start=self.formatted_data.index.min(),
              end=self.formatted_data.index.max(),
              name=self.formatted_data.index.name))
      self.tick_count = len(self.formatted_data.resample('M')) - 1
      self._apply_text_style(
          'success',
          '\nSuccess! The data was formatted for analysis.'
          )
      display(self.formatted_data.head(3))
      self._apply_text_style(
          'failure',
          '\nCheck! Here is an overview of the data.'
          )
      print(
          'Index name:{} | The earliest date: {} | The latest date: {}'.format(
              self.formatted_data.index.name,
              min(self.formatted_data.index),
              max(self.formatted_data.index)
              ))
      print('* Rows with missing values')
      self.missing_row = self.formatted_data[
          self.formatted_data.isnull().any(axis=1)]
      if len(self.missing_row) > 0:
        self.missing_row
      else:
        print('>> Does not include missing values')

      self._apply_text_style(
          'failure',
          '\nCheck! below [total_trend] / [each_trend] / [describe_data]'
          )
      self._trend_check(
          self.formatted_data,
          self.date_col_name,
          self.tick_count)

    except Exception as e:
      self._apply_text_style('failure', '\n\nFailure!!')
      print('Error: {}'.format(e))
      self._apply_text_style('failure', '\nPlease check the following:')
      if self.data_type_selection.selected_index == 0:
        print('* Your selected data format: Wide format at (2)')
        print('1. Check if the data source is wide.')
        print('2. Compare "date column"( {} ) and "data source"'.format(
            self.date_col.value))
        print('\n\n')
      else:
        print('* Your selected data format: Narrow format at (2)')
        print('1. Check if the data source is narrow.')
        print('2. Compare "your input" and "data source')
        print('>> date column: {}'.format(self.date_col.value))
        print('>> pivot column: {}'.format(self.pivot_col.value))
        print('>> kpi column: {}'.format(self.kpi_col.value))
        print('\n\n')
      raise Exception('Please check Failure')

  @staticmethod
  def _shape_wide(dataframe, date_column, pivot_column, kpi_column):
    """shape_wide pivots the data in the specified column.

    Converts long data to wide data suitable for experiment design.

    Args:
        dataframe: The DataFrame to be pivoted.
        date_column: The name of the column that contains the dates.
        pivot_column: The name of the column that contains the pivot keys.
        kpi_column: The name of the column that contains the KPI values.

    Returns:
        A DataFrame with the pivoted data.
    """
    # Check if the pivot_column is a single column or a list of columns.
    if ',' in pivot_column:
      group_cols = pivot_column.replace(', ', ',').split(',')
    else:
      group_cols = [pivot_column]

    pivoted_df = pd.pivot_table(
        (dataframe[[date_column] + [kpi_column] + group_cols])
        .groupby([date_column] + group_cols)
        .sum(),
        index=date_column,
        columns=group_cols,
        fill_value=0,
    )
    # Drop the first level of the column names.
    pivoted_df.columns = pivoted_df.columns.droplevel(0)
    # If there are multiple columns, convert the column names to a single string.
    if len(pivoted_df.columns.names) > 1:
      new_cols = [
          '_'.join([x.replace(',', '_') for x in y])
          for y in pivoted_df.columns.values
      ]
      pivoted_df.columns = new_cols
    pivoted_df = pivoted_df.reset_index()
    return pivoted_df

  @staticmethod
  def _trend_check(dataframe, date_col_name, tick_count):
    """trend_check visualize daily trend, 7-day moving average

    Args:
      dataframe: Wide data to check the trend
      date_col_name: xxx
    """
    df_each = pd.DataFrame(index=dataframe.index)
    col_list = list(dataframe.columns)
    for i in col_list:
      min_max = (
          dataframe[i] - dataframe[i].min()
          ) / (dataframe[i].max() - dataframe[i].min())
      df_each = pd.concat([df_each, min_max], axis = 1)

    metric = 'dtw'
    n_clusters = 5
    tskm_base = TimeSeriesKMeans(n_clusters=n_clusters, metric=metric,
                             max_iter=100, random_state=42)
    df_cluster = pd.DataFrame({
        "pivot": col_list,
        "cluster": tskm_base.fit_predict(df_each.T).tolist()})
    cluster_counts = (
        df_cluster["cluster"].value_counts().sort_values(ascending=True))

    cluster_text = []
    line_each = []
    for i in cluster_counts.index:
      clust_list = df_cluster.query("cluster == @i")["pivot"].to_list()
      source = df_each.filter(items=clust_list)
      cluster_text.append(str(clust_list).translate(
          str.maketrans({'[': '', ']': '',  "'": ''})))
      line_each.append(
          alt.Chart(source.reset_index())
          .transform_fold(fold=clust_list, as_=['pivot', 'kpi'])
          .mark_line()
          .encode(
              alt.X(
                  date_col_name + ':T',
                  title=None,
                  axis=alt.Axis(
                      grid=False, format='%Y %b', tickCount=tick_count
                      ),
                  ),
              alt.Y('kpi:Q', stack=None, axis=None),
              alt.Color(str(i) + ':N', title=None, legend=None),
              alt.Row(
                  'pivot:N',
                  title=None,
                  header=alt.Header(labelAngle=0, labelAlign='left'),
                  ),
              )
          .properties(bounds='flush', height=30)
          .configure_facet(spacing=0)
          .configure_view(stroke=None)
          .configure_title(anchor='end')
          )

    df_long = (
        pd.melt(dataframe.reset_index(), id_vars=date_col_name)
        .groupby(date_col_name)
        .sum(numeric_only=True)
        .reset_index()
    )
    line_total = (
        alt.Chart(df_long)
        .mark_line()
        .encode(
            x=alt.X(
                date_col_name + ':T',
                axis=alt.Axis(
                    title='', format='%Y %b', tickCount=tick_count
                ),
            ),
            y=alt.Y('value:Q', axis=alt.Axis(title='kpi')),
            color=alt.value('#4285F4'),
        )
    )
    moving_average = (
        alt.Chart(df_long)
        .transform_window(
            rolling_mean='mean(value)',
            frame=[-4, 3],
        )
        .mark_line()
        .encode(
            x=alt.X(date_col_name + ':T'),
            y=alt.Y('rolling_mean:Q'),
            color=alt.value('#DB4437'),
        )
    )
    tab_total_trend = ipywidgets.Output()
    tab_each_trend = ipywidgets.Output()
    tab_describe_data = ipywidgets.Output()
    tab_result = ipywidgets.Tab(children = [
        tab_total_trend,
        tab_each_trend,
        tab_describe_data,
        ])
    tab_result.set_title(0, '>> total_trend')
    tab_result.set_title(1, '>> each_trend')
    tab_result.set_title(2, '>> describe_data')
    display(tab_result)
    with tab_total_trend:
      display(
          (line_total + moving_average).properties(
              width=700,
              height=200,
              title={
                  'text': ['Daily Trend(blue) & 7days moving average(red)'],
              },
          )
      )
    with tab_each_trend:
      for i in range(len(cluster_text)):
          print('cluster {}:{}'.format(i, cluster_text[i]))
          display(line_each[i].properties(width=700))
    with tab_describe_data:
      display(dataframe.describe(include='all'))

  @staticmethod
  def saving_params(instance):
    params_dict = {
        # section for data source
        'soure_selection': instance.soure_selection.selected_index,
        'sheet_url': instance.sheet_url.value,
        'sheet_name': instance.sheet_name.value,
        'csv_name': instance.csv_name.value,
        'bq_project_id': instance.bq_project_id.value,
        'bq_table_name': instance.bq_table_name.value,

        # section for data format(narrow or wide)
        'data_type_selection': instance.data_type_selection.selected_index,
        'date_col': instance.date_col.value,
        'pivot_col': instance.pivot_col.value,
        'kpi_col': instance.kpi_col.value,

        # section for porpose(CausalImpact or Experimental Design)
        'purpose_selection': instance.purpose_selection.selected_index,
        'pre_period_start': instance.pre_period_start.value,
        'pre_period_end': instance.pre_period_end.value,
        'post_period_start': instance.post_period_start.value,
        'post_period_end': instance.post_period_end.value,
        'start_date': instance.start_date.value,
        'end_date': instance.end_date.value,
        'depend_data': instance.depend_data.value,

        'design_type': instance.design_type.selected_index,
        'num_of_split': instance.num_of_split.value,
        'target_columns': instance.target_columns.value,
        'control_columns': instance.control_columns.value,
        'num_of_pick_range': instance.num_of_pick_range.value,
        'num_of_covariate': instance.num_of_covariate.value,
        'target_share': instance.target_share.value,
        'exclude_cols': instance.exclude_cols.value,

        'num_of_seasons': instance.num_of_seasons.value,
        'estimate_icpa': instance.estimate_icpa.value,
        'confidence_interval': instance.confidence_interval.value,
        }
    return params_dict

  @staticmethod
  def set_params(instance, dict_params):
    # section for data source
    instance.soure_selection.selected_index = dict_params['soure_selection']
    instance.sheet_url.value = dict_params['sheet_url']
    instance.sheet_name.value = dict_params['sheet_name']
    instance.csv_name.value = dict_params['csv_name']
    instance.bq_project_id.value = dict_params['bq_project_id']
    instance.bq_table_name.value = dict_params['bq_table_name']

    # section for data format(narrow or wide)
    instance.data_type_selection.selected_index = dict_params['data_type_selection']
    instance.date_col.value = dict_params['date_col']
    instance.pivot_col.value = dict_params['pivot_col']
    instance.kpi_col.value = dict_params['kpi_col']

    # section for porpose(CausalImpact or Experimental Design)
    instance.purpose_selection.selected_index = dict_params['purpose_selection']
    instance.pre_period_start.value = dict_params['pre_period_start']
    instance.pre_period_end.value = dict_params['pre_period_end']
    instance.post_period_start.value = dict_params['post_period_start']
    instance.post_period_end.value = dict_params['post_period_end']
    instance.start_date.value = dict_params['start_date']
    instance.end_date.value = dict_params['end_date']
    instance.depend_data.value = dict_params['depend_data']

    instance.design_type.selected_index = dict_params['design_type']
    instance.num_of_split.value = dict_params['num_of_split']
    instance.target_columns.value = dict_params['target_columns']
    instance.control_columns.value = dict_params['control_columns']
    instance.num_of_pick_range.value = dict_params['num_of_pick_range']
    instance.num_of_covariate.value = dict_params['num_of_covariate']
    instance.target_share.value = dict_params['target_share']
    instance.exclude_cols.value = dict_params['exclude_cols']

    instance.num_of_seasons.value = dict_params['num_of_seasons']
    instance.estimate_icpa.value = dict_params['estimate_icpa']
    instance.confidence_interval.value = dict_params['confidence_interval']

# @title dev
class CausalImpact(PreProcess):
  """CausalImpact analysis and experimental design on CausalImpact.

  CausalImpact Analysis performs a CausalImpact analysis on the given data and
  outputs the results. The experimental design will be based on N partitions,
  similarity, or share, with 1000 iterations of random sampling, and will output
  the three candidate groups with the closest DTW distance. A combination of
  increments and periods will be used to simulate and return which combination
  will result in a significantly different validation.

  Attributes:
    run_causalImpact: Runs CausalImpact on the given case.
    create_causalimpact_object:
    display_causalimpact_result:
    plot_causalimpact:

  Returns:
    The CausalImpact object.
  """

  colors = [
      '#DB4437',
      '#AB47BC',
      '#4285F4',
      '#00ACC1',
      '#0F9D58',
      '#9E9D24',
      '#F4B400',
      '#FF7043',
  ]
  num_of_iteration = 1000
  combination_target = 10
  treat_duration = [14, 21, 28]
  treat_impact = [1, 1.01, 1.03, 1.05, 1.10, 1.15]
  max_string_length = 150

  def __init__(self):
    super().__init__()

  def run_causalImpact(self):
    self.ci_objs = []
    try:
      self.ci_obj = self.create_causalimpact_object(
          self.formatted_data,
          self.date_col_name,
          self.pre_period_start.value,
          self.pre_period_end.value,
          self.post_period_start.value,
          self.post_period_end.value,
          self.num_of_seasons.value,
          self.confidence_interval.value,
      )
      self.ci_objs.append(self.ci_obj)
      self._apply_text_style(
          'success',
          '\nSuccess! CausalImpact has been performed. Check the'
          ' results in the next cell.',
      )

    except Exception as e:
      self._apply_text_style('failure', '\n\nFailure!!')
      print('Error: {}'.format(e))
      print('Please check the following:')
      print('* Date source.')
      print('* Date Column Name.')
      print('* Duration of experiment (pre and post).')
      raise Exception('Please check Failure')

  @staticmethod
  def create_causalimpact_object(
      data,
      date_col,
      pre_start,
      pre_end,
      post_start,
      post_end,
      num_of_seasons,
      confidence_interval):
    if data.index.name != date_col: data.set_index(date_col, inplace=True)

    if num_of_seasons == 1:
      causalimpact_object = causalimpact.fit_causalimpact(
          data=data,
          pre_period=(str(pre_start), str(pre_end)),
          post_period=(str(post_start), str(post_end)),
          alpha= 1 - confidence_interval / 100,
      )
    else:
      causalimpact_object = causalimpact.fit_causalimpact(
          data=data,
          pre_period=(str(pre_start), str(pre_end)),
          post_period=(str(post_start), str(post_end)),
          alpha= 1 - confidence_interval / 100,
          model_options=causalimpact.ModelOptions(
              seasons=[
                  causalimpact.Seasons(num_seasons=num_of_seasons),
              ]
          ),
      )
    return causalimpact_object

  def display_causalimpact_result(self):
    print('Test & Control Time Series')
    line = (
        alt.Chart(self.formatted_data.reset_index())
        .transform_fold(list(self.formatted_data.columns))
        .mark_line()
        .encode(
            alt.X(
                self.date_col_name + ':T',
                title=None,
                axis=alt.Axis(format='%Y %b', tickCount=self.tick_count),
            ),
            y=alt.Y('value:Q', axis=alt.Axis(title='kpi')),
            color=alt.Color(
                'key:N',
                legend=alt.Legend(
                    title=None,
                    orient='none',
                    legendY=-20,
                    direction='horizontal',
                    titleAnchor='start',
                ),
                scale=alt.Scale(
                    domain=list(self.formatted_data.columns),
                    range=CausalImpact.colors,
                ),
            ),
        )
        .properties(height=200, width=600)
    )
    rule = (
        alt.Chart(
          pd.DataFrame({
            'Date': [
                str(self.post_period_start.value),
                str(self.post_period_end.value)
                ],
            'color': ['red', 'orange'],
            })
          )
        .mark_rule(strokeDash=[5, 5])
        .encode(x='Date:T', color=alt.Color('color:N', scale=None))
        )
    display((line+rule).properties(height=200, width=600))
    print('=' * 100)

    self.plot_causalimpact(
        self.ci_objs[0],
        self.pre_period_start.value,
        self.pre_period_end.value,
        self.post_period_start.value,
        self.post_period_end.value,
        self.confidence_interval.value,
        self.date_col_name,
        self.tick_count,
        self.purpose_selection.selected_index
    )

  @staticmethod
  def plot_causalimpact(
      causalimpact_object,
      pre_start,
      pre_end,
      tread_start,
      treat_end,
      confidence_interval,
      date_col_name,
      tick_count,
      purpose_selection
    ):
    causalimpact_df = causalimpact_object.series#.copy()
    mape = mean_absolute_percentage_error(
        causalimpact_df['observed'][str(pre_start) : str(pre_end)],
        causalimpact_df['posterior_mean'][str(pre_start) : str(pre_end)],
    )

    line_1 = (
        alt.Chart(causalimpact_df.reset_index())
        .transform_fold([
            'observed',
            'posterior_mean',
        ])
        .mark_line()
        .encode(
            x=alt.X(
                'yearmonthdate(' + date_col_name + ')',
                axis=alt.Axis(
                    title='',
                    labels=False,
                    ticks=False,
                    format='%Y %b',
                    tickCount=tick_count,
                ),
            ),
            y=alt.Y(
                'value:Q',
                scale=alt.Scale(zero=False),
                axis=alt.Axis(title=''),
            ),
            color=alt.Color(
                'key:N',
                legend=alt.Legend(
                    title=None,
                    orient='none',
                    legendY=-20,
                    direction='horizontal',
                    titleAnchor='start',
                ),
                sort=['posterior_mean', 'observed'],
            ),
            strokeDash=alt.condition(
                alt.datum.key == 'posterior_mean',
                alt.value([5, 5]),
                alt.value([0]),
            ),
        )
    )
    area_1 = (
        alt.Chart(causalimpact_df.reset_index())
        .mark_area(opacity=0.3)
        .encode(
            x=alt.X('yearmonthdate(' + date_col_name + ')'),
            y=alt.Y('posterior_lower:Q', scale=alt.Scale(zero=False)),
            y2=alt.Y2('posterior_upper:Q'),
        )
    )
    line_2 = (
        alt.Chart(causalimpact_df.reset_index())
        .mark_line(strokeDash=[5, 5])
        .encode(
            x=alt.X(
                'yearmonthdate(' + date_col_name + ')',
                axis=alt.Axis(
                    title='',
                    labels=False,
                    ticks=False,
                    format='%Y %b',
                    tickCount=tick_count,
                ),
            ),
            y=alt.Y(
                'point_effects_mean:Q',
                scale=alt.Scale(zero=False),
                axis=alt.Axis(title=''),
            ),
        )
    )
    area_2 = (
        alt.Chart(causalimpact_df.reset_index())
        .mark_area(opacity=0.3)
        .encode(
            x=alt.X('yearmonthdate(' + date_col_name + ')'),
            y=alt.Y('point_effects_lower:Q', scale=alt.Scale(zero=False)),
            y2=alt.Y2('point_effects_upper:Q'),
        )
    )
    line_3 = (
        alt.Chart(causalimpact_df.reset_index())
        .mark_line(strokeDash=[5, 5])
        .encode(
            x=alt.X(
                'yearmonthdate(' + date_col_name + ')',
                axis=alt.Axis(title='', format='%Y %b', tickCount=tick_count),
            ),
            y=alt.Y(
                'cumulative_effects_mean:Q',
                scale=alt.Scale(zero=False),
                axis=alt.Axis(title=''),
            ),
        )
    )
    area_3 = (
        alt.Chart(causalimpact_df.reset_index())
        .mark_area(opacity=0.3)
        .encode(
            x=alt.X('yearmonthdate(' + date_col_name + ')'),
            y=alt.Y('cumulative_effects_lower:Q', scale=alt.Scale(zero=False)),
            y2=alt.Y2('cumulative_effects_upper:Q'),
        )
    )
    zero_line = (
        alt.Chart(pd.DataFrame({'y': [0]}))
        .mark_rule()
        .encode(y='y', color=alt.value('gray'))
    )
    rules = (
        alt.Chart(
            pd.DataFrame({
                'Date': [str(tread_start), str(treat_end)],
                'color': ['red', 'orange'],
            })
        )
        .mark_rule(strokeDash=[5, 5])
        .encode(x='Date:T', color=alt.Color('color:N', scale=None))
    )
    watermark = alt.Chart(pd.DataFrame([1])).mark_text(
        align='center',
        dx=0,
        dy=0,
        fontSize=48,
        text='mock experiment',
        color='red'
      ).encode(
        opacity=alt.value(0.5)
    )
    if purpose_selection == 1:
      cumulative = line_3 + area_3 + rules + zero_line + watermark
    else:
      cumulative = line_3 + area_3 + rules + zero_line
    plot = alt.vconcat(
        (line_1 + area_1 + rules).properties(height=100, width=600),
        (line_2 + area_2 + rules + zero_line).properties(height=100, width=600),
        (cumulative).properties(height=100, width=600),
    )

    tab_data = ipywidgets.Output()
    tab_report = ipywidgets.Output()
    tab_summary = ipywidgets.Output()
    tab_result = ipywidgets.Tab(children = [tab_summary, tab_report, tab_data])
    tab_result.set_title(0, '>> summary')
    tab_result.set_title(1, '>> report')
    tab_result.set_title(2, '>> data')
    with tab_summary:
      print('Approximate model accuracy >> MAPE:{:.2%}\n'.format(mape))
      print(causalimpact.summary(
          causalimpact_object,
          output_format='summary',
          alpha= 1 - confidence_interval / 100))
      display(plot)
    with tab_report:
      print(causalimpact.summary(
          causalimpact_object,
          output_format="report",
          alpha= 1 - confidence_interval / 100))
    with tab_data:
      df = causalimpact_object.series
      df.insert(2, 'diff_percentage', df['point_effects_mean'] / df['observed'])
      display(df)
    display(tab_result)

  def run_experimental_design(self):
    if self.date_selection.selected_index == 0:
      self.start_date_value = min(self.formatted_data.index).date()
      self.end_date_value = max(self.formatted_data.index).date()
    else:
      self.start_date_value = self.start_date.value
      self.end_date_value = self.end_date.value

    if self.design_type.selected_index == 0:
      self.distance_data = self._n_part_split(
          self.formatted_data.query(
              '@self.start_date_value <= index <= @self.end_date_value'
              ),
          self.num_of_split.value,
          CausalImpact.num_of_iteration
      )
    elif self.design_type.selected_index == 1:
      self.distance_data = self._find_similar(
          self.formatted_data.query(
              '@self.start_date_value <= index <= @self.end_date_value'
              ),
          self.target_columns.value,
          self.num_of_pick_range.value,
          self.num_of_covariate.value
      )
    elif self.design_type.selected_index == 2:
      self.distance_data = self._from_share(
          self.formatted_data.query(
              '@self.start_date_value <= index <= @self.end_date_value'
              ),
          self.target_share.value,
      )
    elif self.design_type.selected_index == 3:
      self.distance_data = self._given_assignment(
          self.target_columns.value,
          self.control_columns.value,
      )
    else:
      self._apply_text_style('failure', '\n\nFailure!!')
      print('Please check the following:')
      print('* There is something wrong with design type.')
      raise Exception('Please check Failure')

    self._visualize_candidate(
        self.formatted_data,
        self.distance_data,
        self.start_date_value,
        self.end_date_value,
        self.date_col_name,
        self.tick_count
    )
    self._generate_choice()

  @staticmethod
  def _n_part_split(dataframe, num_of_split, num_of_iteration):
    """n_part_split

    Args:
      dataframe: xxx.
      num_of_split: xxx.
      num_of_iteration: xxx.
    """
    distance_data = pd.DataFrame(columns=['distance'])
    num_of_pick = len(dataframe.columns) // num_of_split

    for l in tqdm(range(num_of_iteration)):
      col_list = list(dataframe.columns)
      picked_data = pd.DataFrame()

      # random pick
      picks = []
      for s in range(num_of_split):
        random_pick = random.sample(col_list, num_of_pick)
        picks.append(random_pick)
        col_list = [i for i in col_list if i not in random_pick]
      picks[0].extend(col_list)

      for i in range(len(picks)):
        picked_data = pd.concat([
            picked_data,
            pd.DataFrame(dataframe[picks[i]].sum(axis=1), columns=[i])
            ], axis=1)

      # calculate distance
      distance = CausalImpact._calculate_distance(
          picked_data.reset_index(drop=True)
      )
      distance_data.loc[l, 'distance'] = float(distance)
      for j in range(len(picks)):
        distance_data.at[l, j] = str(sorted(picks[j]))

    distance_data = (
        distance_data.drop_duplicates()
        .sort_values('distance')
        .head(3)
        .reset_index(drop=True)
    )
    return distance_data

  @staticmethod
  def _find_similar(
      dataframe,
      target_columns,
      num_of_pick_range,
      num_of_covariate,
      ):
    distance_data = pd.DataFrame(columns=['distance'])
    target_cols = target_columns.replace(', ', ',').split(',')

    # An error occurs when the number of candidates (max num_of_range times
    # num_of_covariates) is greater than num_of_columns excluding target column.
    if (
        len(dataframe.columns) - len(target_cols)
          >= num_of_pick_range[1] * num_of_covariate):
      pass
    else:
      print('Please check the following:')
      print('* There is something wrong with similarity settings.')
      print('* Total number of columns ー the target = {}'.format(
          len(dataframe.columns) - len(target_cols)))
      print('* But your settings are {}(max pick#) × {}(covariate#)'.format(
          num_of_pick_range[1], num_of_covariate))
      print('* Please set it so that it does not exceed.')
      PreProcess.failure_text('▲▲▲▲▲▲\n\n')
      raise Exception('Please check Failure')

    for l in tqdm(range(CausalImpact.num_of_iteration)):
      picked_data = pd.DataFrame()
      remained_list = [
          i for i in list(dataframe.columns) if i not in target_cols
      ]
      picks = []
      for s in range(num_of_covariate):
        pick = random.sample(remained_list, random.randrange(
            num_of_pick_range[0], num_of_pick_range[1] + 1, 1
            )
        )
        picks.append(pick)
        remained_list = [
            ele for ele in remained_list if ele not in pick
        ]
      picks.insert(0, target_cols)
      for i in range(len(picks)):
        picked_data = pd.concat([
            picked_data,
            pd.DataFrame(dataframe[picks[i]].sum(axis=1), columns=[i])
            ], axis=1)

      # calculate distance
      distance = CausalImpact._calculate_distance(
          picked_data.reset_index(drop=True)
      )
      distance_data.loc[l, 'distance'] = float(distance)
      for j in range(len(picks)):
        distance_data.at[l, j] = str(sorted(picks[j]))

    distance_data = (
          distance_data.drop_duplicates()
          .sort_values('distance')
          .head(3)
          .reset_index(drop=True)
    )
    return distance_data

  @staticmethod
  def _from_share(
      dataframe,
      target_share
      ):
    distance_data = pd.DataFrame(columns=['distance'])
    combinations = []

    n = CausalImpact.num_of_iteration
    while len(combinations) < CausalImpact.combination_target:
      n -= 1
      picked_col = np.random.choice(
          dataframe.columns,
          # Shareは50%までなので列数を2分割
          random.randint(1, len(dataframe.columns)//2 + 1),
          replace=False)

      # (todo)@rhirota シェアを除外済みか全体か検討
      if float(Decimal(dataframe[picked_col].sum().sum() / dataframe.sum().sum()
                      ).quantize(Decimal('0.1'), ROUND_HALF_UP)) == target_share:
        combinations.append(sorted(set(picked_col)))
      if n == 1:
        PreProcess.failure_text('\n\nFailure!!')
        print('Please check the following:')
        print('* There is something wrong with design type C.')
        print('* Please re-set target share')
        PreProcess.failure_text('▲▲▲▲▲▲\n\n')
        raise Exception('Please check Failure')

    for comb in tqdm(combinations):
      for l in tqdm(
          range(
              CausalImpact.num_of_iteration // CausalImpact.combination_target),
          leave=False):
        picked_data = pd.DataFrame()
        remained_list = [
            i for i in list(dataframe.columns) if i not in comb
        ]
        picks = []
        picks.append(random.sample(remained_list, random.randrange(
            # (todo)@rhirota 最小Pickを検討
            1, len(remained_list), 1
            )
        ))
        picks.insert(0, comb)

        for i in range(len(picks)):
          picked_data = pd.concat([
              picked_data,
              pd.DataFrame(dataframe[picks[i]].sum(axis=1), columns=[i])
              ], axis=1)

      # calculate distance
      distance = CausalImpact._calculate_distance(
          picked_data.reset_index(drop=True)
      )
      distance_data.loc[l, 'distance'] = float(distance)
      for j in range(len(picks)):
        distance_data.at[l, j] = str(sorted(picks[j]))

    distance_data = (
          distance_data.drop_duplicates()
          .sort_values('distance')
          .head(3)
          .reset_index(drop=True)
    )
    return distance_data

  @staticmethod
  def _given_assignment(target_columns, control_columns):
    distance_data = pd.DataFrame(columns=['distance'])
    distance_data.loc[0, 'distance'] = 0
    distance_data.loc[0, 0] = str(target_columns.replace(', ', ',').split(','))
    distance_data.loc[0, 1] = str(control_columns.replace(', ', ',').split(','))
    return distance_data

  @staticmethod
  def _calculate_distance(dataframe):
    total_distance = 0
    scaled_data = pd.DataFrame()
    for col in dataframe:
      scaled_data[col] = (dataframe[col] - dataframe[col].min()) / (
          dataframe[col].max() - dataframe[col].min()
      )
    scaled_data = scaled_data.diff().reset_index().dropna()
    for v in itertools.combinations(list(scaled_data.columns), 2):
      distance, _ = fastdtw.fastdtw(
          scaled_data.loc[:, ['index', v[0]]],
          scaled_data.loc[:, ['index', v[1]]],
          dist=euclidean,
      )
      total_distance = total_distance + distance
    return total_distance

  @staticmethod
  def _visualize_candidate(
      dataframe,
      distance_data,
      start_date_value,
      end_date_value,
      date_col_name,
      tick_count
      ):
    PreProcess._apply_text_style(
          'failure',
          '\nCheck! Experimental Design Parameters.'
          )
    print('* start_date_value: ' + str(start_date_value))
    print('* end_date_value: ' + str(end_date_value))
    print('* columns:')
    l = []
    for i in range(len(dataframe.columns)):
      l.append(dataframe.columns[i])
      if len(str(l)) >= CausalImpact.max_string_length:
        print(str(l).translate(str.maketrans({'[': '', ']': '',  "'": ''})))
        l = []
    print('\n')

    sub_tab=[ipywidgets.Output() for i in distance_data.index.tolist()]
    tab_option = ipywidgets.Tab(sub_tab)
    for i in range (len(distance_data.index.tolist())):
        tab_option.set_title(i,"option_{}".format(i+1))
        with sub_tab[i]:
          candidate_df = pd.DataFrame(index=dataframe.index)
          for col in range(len(distance_data.columns) - 1):
            print(
                'col_' + str(col + 1) + ': '+ distance_data.at[i, col].replace(
                    "'", ""))
            candidate_df[col + 1] = list(
                dataframe.loc[:, eval(distance_data.at[i, col])].sum(axis=1)
            )
            print('\n')
          candidate_df = candidate_df.add_prefix('col_')

          candidate_share = pd.DataFrame(
              candidate_df.loc[str(start_date_value):str(end_date_value), :
                               ].sum(),
              columns=['total'])
          candidate_share['daily_average'] = candidate_share['total'] // (
              end_date_value - start_date_value).days
          candidate_share['share'] = candidate_share['total'] / (dataframe.query(
                '@start_date_value <= index <= @end_date_value'
                ).sum().sum())

          try:
            for i in candidate_df.columns:
              stl = STL(candidate_df[i], robust=True).fit()
              candidate_share.loc[i, 'std'] = np.std(stl.seasonal + stl.resid)
            display(
                candidate_share[['daily_average', 'share', 'std']].style.format(
                    {
                        'daily_average': '{:,.0f}',
                        'share': '{:.1%}',
                        'std': '{:,.0f}',
                        }))
          except Exception as e:
            print(e)
            display(
                candidate_share[['daily_average', 'share']].style.format({
                'daily_average': '{:,.0f}',
                'share': '{:.1%}',
                }))

          chart_line = (
              alt.Chart(candidate_df.reset_index())
              .transform_fold(
                  fold=list(candidate_df.columns), as_=['pivot', 'kpi']
              )
              .mark_line()
              .encode(
                  x=alt.X(
                      date_col_name + ':T',
                      title=None,
                      axis=alt.Axis(
                      grid=False, format='%Y %b', tickCount=tick_count
                      ),
                  ),
                  y=alt.Y('kpi:Q'),
                  color=alt.Color(
                    'pivot:N',
                    legend=alt.Legend(
                      title=None,
                      orient='none',
                      legendY=-20,
                      direction='horizontal',
                      titleAnchor='start'),
                    scale=alt.Scale(
                        domain=list(candidate_df.columns),
                        range=CausalImpact.colors)),
                  )
              .properties(width=600, height=200)
          )

          rules = alt.Chart(
              pd.DataFrame(
                  {
                      'Date': [str(start_date_value), str(end_date_value)],
                      'color': ['red', 'orange']
                      })
              ).mark_rule(strokeDash=[5, 5]).encode(
                  x='Date:T',
                  color=alt.Color('color:N', scale=None))

          df_scaled = candidate_df.copy()
          df_scaled[:] = MinMaxScaler().fit_transform(candidate_df)
          chart_line_scaled = (
              alt.Chart(df_scaled.reset_index())
              .transform_fold(
                  fold=list(candidate_df.columns),
                  as_=['pivot', 'kpi']
              )
              .mark_line()
              .encode(
                  x=alt.X(
                      date_col_name + ':T',
                      title=None,
                      axis=alt.Axis(
                      grid=False, format='%Y %b', tickCount=tick_count
                      ),
                  ),
                  y=alt.Y('kpi:Q'),
                  color=alt.Color(
                    'pivot:N',
                    legend=alt.Legend(
                      title=None,
                      orient='none',
                      legendY=-20,
                      direction='horizontal',
                      titleAnchor='start'),
                    scale=alt.Scale(
                        domain=list(candidate_df.columns),
                        range=CausalImpact.colors)),
                  )
              .properties(width=600, height=80)
          )

          df_diff = pd.DataFrame(
              np.diff(candidate_df, axis=0),
              columns=candidate_df.columns.values,
          )
          scatter = (
              alt.Chart(df_diff.reset_index())
              .mark_circle()
              .encode(
                  alt.X(alt.repeat('column'), type='quantitative'),
                  alt.Y(alt.repeat('row'), type='quantitative'),
              )
              .properties(width=80, height=80)
              .repeat(
                  row=df_diff.columns.values,
                  column=df_diff.columns.values,
              )
          )
          display(
              alt.vconcat(chart_line + rules, chart_line_scaled) | scatter)
    display(tab_option)

  def _generate_choice(self):
    self.your_choice = ipywidgets.Dropdown(
        options=['option_1', 'option_2', 'option_3'],
        description='your choice:',
    )
    self.target_col_to_simulate = ipywidgets.SelectMultiple(
        options=['col_1', 'col_2', 'col_3', 'col_4', 'col_5', 'col_6'],
        description='target col:',
        value=['col_1',],
    )
    self.covariate_col_to_simulate = ipywidgets.SelectMultiple(
        options=['col_1', 'col_2', 'col_3', 'col_4', 'col_5', 'col_6'],
        description='covatiate col:',
        value=['col_2',],
        style={'description_width': 'initial'},
    )
    display(
        PreProcess._apply_text_style(
            18,
            '⑷ Please select option, test column & control column(s).'),
        ipywidgets.HBox([
            self.your_choice,
            self.target_col_to_simulate,
            self.covariate_col_to_simulate,
        ]),
    )

  def generate_simulation(self):
    self.test_data = self._extract_data_from_choice(
        self.your_choice.value,
        self.target_col_to_simulate.value,
        self.covariate_col_to_simulate.value,
        self.formatted_data,
        self.distance_data,
    )
    self.simulation_params, self.ci_objs = self._execute_simulation(
        self.test_data,
        self.date_col_name,
        self.start_date_value,
        self.end_date_value,
        self.num_of_seasons.value,
        self.confidence_interval.value,
        CausalImpact.treat_duration,
        CausalImpact.treat_impact,
    )
    self._display_simulation_result(
        self.simulation_params,
        self.ci_objs,
        self.estimate_icpa.value,
    )
    self._plot_simulation_result(
        self.simulation_params,
        self.ci_objs,
        self.date_col_name,
        self.tick_count,
        self.purpose_selection.selected_index,
        self.confidence_interval.value,
    )

  @staticmethod
  def _extract_data_from_choice(
      your_choice,
      target_col_to_simulate,
      covariate_col_to_simulate,
      dataframe,
      distance
      ):
      selection_row = int(your_choice.replace('option_', '')) - 1
      selection_cols = [
          [int(t.replace('col_', '')) - 1 for t in list(target_col_to_simulate)],
          [int(t.replace('col_', '')) - 1 for t in list(covariate_col_to_simulate)
          ]]
      test_data = pd.DataFrame(index = dataframe.index)

      test_column = []
      for i in selection_cols[0]:
        test_column.extend(eval(distance.at[selection_row,i]))
      test_data['test'] = dataframe.loc[
                    :, test_column
                ].sum(axis=1)

      for col in selection_cols[1]:
        test_data['col_'+ str(col+1)] = dataframe.loc[
                :, eval(distance.at[selection_row, col])
            ].sum(axis=1)

      print('* test: {}\n'.format(str(test_column).replace("'", "")))
      print('* covariate')
      for x,i in zip(test_data.columns[1:],selection_cols[1]):
        print('> {}: {}'.format(
            x,
            str(eval(distance.at[selection_row, i]))).replace("'", "")
            )
      return test_data

  @staticmethod
  def _execute_simulation(
      dataframe,
      date_col_name,
      start_date_value,
      end_date_value,
      num_of_seasons,
      confidence_interval,
      treat_duration,
      treat_impact,
    ):
    ci_objs = []
    simulation_params = []
    adjusted_data = dataframe.copy()

    for duration in tqdm(treat_duration):
      for impact in tqdm(treat_impact, leave=False):
          pre_end_date = end_date_value + datetime.timedelta(days=-duration)
          post_start_date = pre_end_date + datetime.timedelta(days=1)
          adjusted_data.loc[
              np.datetime64(post_start_date) : np.datetime64(end_date_value),
              'test',] = (
                  dataframe.loc[
                  np.datetime64(post_start_date) : np.datetime64(end_date_value
                  ),
                  'test',
              ]
              * impact
          )

          ci_obj = CausalImpact.create_causalimpact_object(
              adjusted_data,
              date_col_name,
              start_date_value,
              pre_end_date,
              post_start_date,
              end_date_value,
              num_of_seasons,
              confidence_interval,
          )
          simulation_params.append([
              start_date_value,
              pre_end_date,
              post_start_date,
              end_date_value,
              impact,
              duration,
          ])
          ci_objs.append(ci_obj)
    return simulation_params, ci_objs

  @staticmethod
  def _display_simulation_result(simulation_params, ci_objs, estimate_icpa):
      simulation_df = pd.DataFrame(
          index=[],
          columns=[
              'mock_lift',
              'Days_simulated',
              'Pre_Period_MAPE',
              'Post_Period_MAPE',
              'Total_effect',
              'Average_effect',
              'Required_budget',
              'p_value',
              'predicted_lift'
          ],
      )
      for i in range(len(ci_objs)):
        impact_df = ci_objs[i].series
        impact_dict = {
            'test_period':'('+str(simulation_params[i][5])+'d) '+str(simulation_params[i][2])+'~'+str(simulation_params[i][3]),
            'mock_lift_rate': simulation_params[i][4] - 1,
            'predicted_lift_rate': ci_objs[i].summary.loc['average', 'rel_effect'],
            'Days_simulated': simulation_params[i][5],
            'Pre_Period_MAPE': [
                mean_absolute_percentage_error(
                    impact_df.loc[:, 'observed'][
                        str(simulation_params[i][0]) : str(
                            simulation_params[i][1]
                        )
                    ],
                    impact_df.loc[:, 'posterior_mean'][
                        str(simulation_params[i][0]) : str(
                            simulation_params[i][1]
                        )
                    ],
                )
            ],
            'Post_Period_MAPE': [
                mean_absolute_percentage_error(
                    impact_df.loc[:, 'observed'][
                        str(simulation_params[i][2]) : str(
                            simulation_params[i][3]
                        )
                    ],
                    impact_df.loc[:, 'posterior_mean'][
                        str(simulation_params[i][2]) : str(
                            simulation_params[i][3]
                        )
                    ],
                )
            ],
            'Total_effect': ci_objs[i].summary.loc['cumulative', 'abs_effect'],
            'Average_effect': ci_objs[i].summary.loc['average', 'abs_effect'],
            'Required_budget': [
                ci_objs[i].summary.loc['cumulative', 'abs_effect'] * estimate_icpa
            ],
            'p_value': ci_objs[i].summary.loc['average', 'p_value'],

        }
        simulation_df = pd.concat(
            [simulation_df, pd.DataFrame.from_dict(impact_dict)],
            ignore_index=True,
        )
      display(PreProcess._apply_text_style(
            18,
            'A/A Test: Check the error without intervention'))
      print('> If p_value < 0.05, please suspect "poor model accuracy"(See Pre_Period_MAPE) or "data drift"(See Time Series Chart).\n')
      display(
          simulation_df.query('mock_lift_rate == 0')[
              ['test_period','Pre_Period_MAPE','Post_Period_MAPE','p_value']
              ].style.format({
                  'Pre_Period_MAPE': '{:.2%}',
                  'Post_Period_MAPE': '{:.2%}',
                  'p_value': '{:,.2f}',
                  }).hide()
              )
      print('\n')
      display(PreProcess._apply_text_style(
            18,
            'Simulation with increments as a mock experiment'))
      for i in simulation_df.Days_simulated.unique():
        print('\n During the last {} days'.format(i))
        display(
            simulation_df.query('mock_lift_rate != 0 & Days_simulated == @i')[
                [
                    'mock_lift_rate',
                    'predicted_lift_rate',
                    'Pre_Period_MAPE',
                    'Total_effect',
                    'Average_effect',
                    'Required_budget',
                    'p_value',
                    ]
            ].style.format({
                'mock_lift_rate': '{:+.0%}',
                'predicted_lift_rate': '{:+.1%}',
                'Pre_Period_MAPE': '{:.2%}',
                'Total_effect': '{:,.2f}',
                'Average_effect': '{:,.2f}',
                'Required_budget': '{:,.0f}',
                'p_value': '{:,.2f}',
            }).hide()
        )

  @staticmethod
  def _plot_simulation_result(
      simulation_params,
      ci_objs,
      date_col_name,
      tick_count,
      purpose_selection,
      confidence_interval,
      ):

    mock_combinations = []
    for i in range(len(simulation_params)):
      mock_combinations.append(
            [
                '{}d:+{:.0%}'.format(
                    simulation_params[i][5],
                    simulation_params[i][4]-1)
            ])
    simulation_tb=[ipywidgets.Output() for tab in mock_combinations]
    tab_simulation = ipywidgets.Tab(simulation_tb)
    for id,name in enumerate(mock_combinations):
      tab_simulation.set_title(id,name)
      with simulation_tb[id]:
        print(
            'Pre Period:{} ~ {}\nPost Period:{} ~ {}'.format(
                simulation_params[id][0],
                simulation_params[id][1],
                simulation_params[id][2],
                simulation_params[id][3],
            )
        )
        CausalImpact.plot_causalimpact(
            ci_objs[id],
            simulation_params[id][0],
            simulation_params[id][1],
            simulation_params[id][2],
            simulation_params[id][3],
            confidence_interval,
            date_col_name,
            tick_count,
            purpose_selection
        )
    display(tab_simulation)

case_1 = CausalImpact()
case_1.generate_ui()
if 'dict_params' in globals():
  CausalImpact.set_params(case_1, dict_params)
print('\nExecution datetime(GMT):{}'.format(datetime.datetime.now()))

In [None]:
# @title Step.2
%%time
case_1.load_data()
case_1.format_data()
dict_params = PreProcess.saving_params(case_1)

if case_1.purpose_selection.selected_index == 0:
  case_1.run_causalImpact()
else:
  case_1.run_experimental_design()

print('\nExecution datetime(GMT):{}'.format(datetime.datetime.now()))

In [None]:
# @title Step.3
%%time
if case_1.purpose_selection.selected_index == 0:
  case_1.display_causalimpact_result()
else:
  case_1.generate_simulation()

print('\nExecution datetime(GMT):{}'.format(datetime.datetime.now()))

# (Optional) Case_2

In [None]:
# @title Case_2 Step.1
overwrite_pramas = True #@param {type:"boolean"}
case_2 = CausalImpact()
case_2.generate_ui()
if overwrite_pramas == True: PreProcess.set_params(case_2, dict_params)

In [None]:
# @title Case_2 Step.2
%%time
case_2.load_data()
case_2.format_data()

if case_2.purpose_selection.selected_index == 0:
  case_2.run_causalImpact()
else:
  case_2.run_experimental_design()

print('\nExecution datetime(GMT):{}'.format(datetime.datetime.now()))

In [None]:
# @title Case_2 Step.3
%%time
if case_2.purpose_selection.selected_index == 0:
  case_2.display_causalimpact_result()
else:
  case_2.generate_simulation()

print('\nExecution datetime(GMT):{}'.format(datetime.datetime.now()))

# (Optional) Case_3

In [None]:
# @title Case_3 Step.1
overwrite_pramas = False #@param {type:"boolean"}
case_3 = CausalImpact()
case_3.generate_ui()
if overwrite_pramas == True: PreProcess.set_params(case_3, dict_params)

In [None]:
# @title Case_3 Step.2
%%time
case_3.load_data()
case_3.format_data()

if case_3.purpose_selection.selected_index == 0:
  case_3.run_causalImpact()
else:
  case_3.run_experimental_design()

print('\nExecution datetime(GMT):{}'.format(datetime.datetime.now()))

In [None]:
# @title Case_3 Step.3
%%time
if case_3.purpose_selection.selected_index == 0:
  case_3.display_causalimpact_result()
else:
  case_3.generate_simulation()

print('\nExecution datetime(GMT):{}'.format(datetime.datetime.now()))