<a href="https://colab.research.google.com/github/google/business_intelligence_group/blob/development/solutions/causal-impact/CausalImpact_with_Experimental_Design.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **CausalImpact with Experimental Design**

This Colab file contains *Experimental Design* and *CausalImpact Analysis*.

See [README.md](https://github.com/google/business_intelligence_group/tree/main/solutions/causal-impact) for details

---

Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

In [3]:
# @title Step.1 (~ 2min)
%%time

import sys
if 'fastdtw' not in sys.modules:
  !pip install 'fastdtw' --q
if 'tslearn' not in sys.modules:
  !pip install 'tslearn' --q
if 'tfp-causalimpact' not in sys.modules:
  !pip install 'tfp-causalimpact' --q

# Data Load
from google.colab import auth, files, widgets
from google.auth import default
from google.cloud import bigquery
import io
import os
import gspread
from oauth2client.client import GoogleCredentials

# Calculate
import altair as alt
import itertools
import random
import numpy as np
import pandas as pd
import fastdtw

from tslearn.clustering import TimeSeriesKMeans
from decimal import Decimal, ROUND_HALF_UP
from scipy.spatial.distance import euclidean
from sklearn.metrics import mean_absolute_percentage_error
from sklearn.preprocessing import MinMaxScaler
from statsmodels.tsa.seasonal import STL

# UI/UX
import datetime
from dateutil.relativedelta import relativedelta
import ipywidgets
from IPython.display import display, Markdown, HTML, Javascript
from tqdm.auto import tqdm
import warnings
warnings.simplefilter('ignore')

# causalimpact
import causalimpact
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

class UIUtils:
  """Utility class for UI styling and helpers."""

  @staticmethod
  def apply_text_style(type, text):
    """Applies a specific style to the text based on the type.

    Args:
        type (str or int): The type of style ('success', 'failure') or font size as an integer.
        text (str): The text content to display.

    Returns:
        None or ipywidgets.HTML: Prints styled text or returns an HTML widget for custom font size.
    """
    if type == 'success':
      return print(f"\033[38;2;15;157;88m " + text + "\033[0m")

    if type == 'failure':
      return print(f"\033[38;2;219;68;55m " + text + "\033[0m")

    if isinstance(type, int):
      span_style = ipywidgets.HTML(
        "<span style='font-size:" + str(type) + "px; background: "
        "linear-gradient(transparent 90%, #4285F4 0%);'>"
        + text
        + '</style>'
      )
      return span_style

class InteractiveUI:
  """Manages the interactive user interface widgets and their layout."""

  def __init__(self):
    """Initializes the InteractiveUI by defining all widgets."""
    self._define_widgets()

  def _define_widgets(self):
    """Defines all widgets used in the UI."""
    self._define_data_source_widgets()
    self._define_data_format_widgets()
    self._define_date_widgets()
    self._define_experimental_design_widgets()
    self._define_simulation_widgets()
    self.your_choice = ipywidgets.Dropdown(
        options=['option_1', 'option_2', 'option_3'],
        description='your choice:',
    )
    self.target_col_to_simulate = ipywidgets.SelectMultiple(
        options=['col_1', 'col_2', 'col_3', 'col_4', 'col_5', 'col_6'],
        description='target col:',
        value=['col_1',],
    )
    self.covariate_col_to_simulate = ipywidgets.SelectMultiple(
        options=['col_1', 'col_2', 'col_3', 'col_4', 'col_5', 'col_6'],
        description='covatiate col:',
        value=['col_2',],
        style={'description_width': 'initial'},
    )

  def _define_data_source_widgets(self):
    """Defines widgets for selecting data sources (Sheet, CSV, BigQuery)."""
    self.sheet_url = ipywidgets.Text(
        placeholder='Please enter google spreadsheet url',
        value='https://docs.google.com/spreadsheets/d/1dISrbX1mZHgzpsIct2QXFOWWRRJiCxDSmSzjuZz64Tw/edit#gid=0',
        description='spreadsheet url:',
        style={'description_width': 'initial'},
        layout=ipywidgets.Layout(width='800px'),
    )
    self.sheet_name = ipywidgets.Text(
        placeholder='Please enter sheet name',
        value='analysis_data',
        description='sheet name:',
    )
    self.csv_name = ipywidgets.Text(
        placeholder='Please enter csv name',
        description='csv name:',
        layout=ipywidgets.Layout(width='500px'),
    )
    self.bq_project_id = ipywidgets.Text(
        placeholder='Please enter project id',
        description='project id:',
        layout=ipywidgets.Layout(width='500px'),
    )
    self.bq_table_name = ipywidgets.Text(
        placeholder='Please enter table name',
        description='table name:',
        layout=ipywidgets.Layout(width='500px'),
    )

  def _define_data_format_widgets(self):
    """Defines widgets for specifying data format columns."""
    self.date_col = ipywidgets.Text(
        placeholder='Please enter date column name',
        value='Date',
        description='date column:',
    )
    self.pivot_col = ipywidgets.Text(
        placeholder='Please enter pivot column name',
        value='Geo',
        description='pivot column:',
    )
    self.kpi_col = ipywidgets.Text(
        placeholder='Please enter kpi column name',
        value='KPI',
        description='kpi column:',
    )

  def _define_experimental_design_widgets(self):
    """Defines widgets for experimental design parameters."""
    self.exclude_cols = ipywidgets.Text(
        placeholder=('Enter comma-separated columns if any columns are not used in the design.'),
        description='exclude cols:',
        layout=ipywidgets.Layout(width='1000px'),
    )
    self.num_of_split = ipywidgets.Dropdown(
        options=[2, 3, 4, 5],
        value=2,
        description='split#:',
        disabled=False,
    )
    self.target_columns = ipywidgets.Text(
        placeholder='Please enter comma-separated entries',
        value='Tokyo, Kanagawa',
        description='target_cols:',
        layout=ipywidgets.Layout(width='500px'),
    )
    self.num_of_pick_range = ipywidgets.IntRangeSlider(
        value=[5, 10],
        min=1,
        max=50,
        step=1,
        description='pick range:',
        orientation='horizontal',
        readout=True,
        readout_format='d',
    )
    self.num_of_covariate = ipywidgets.Dropdown(
        options=[1, 2, 3, 4, 5],
        value=1,
        description='covariate#:',
        layout=ipywidgets.Layout(width='192px'),
    )
    self.target_share = ipywidgets.FloatSlider(
        value=0.3,
        min=0.05,
        max=0.5,
        step=0.05,
        description='target share#:',
        orientation='horizontal',
        readout=True,
        readout_format='.1%',
    )
    self.control_columns = ipywidgets.Text(
        placeholder='Please enter comma-separated entries',
        value='Aomori, Akita',
        description='control_cols:',
        layout=ipywidgets.Layout(width='500px'),
    )

  def _define_simulation_widgets(self):
    """Defines widgets for simulation parameters."""
    self.num_of_seasons = ipywidgets.IntText(
        value=1,
        description='num_of_seasons:',
        disabled=False,
        style={'description_width': 'initial'},
    )
    self.estimate_icpa = ipywidgets.IntText(
        value=1000,
        description='Estimated iCPA:',
        style={'description_width': 'initial'},
    )
    self.credible_interval = ipywidgets.RadioButtons(
        options=[70, 80, 90, 95],
        value=90,
        description='Credible interval %:',
        style={'description_width': 'initial'},
    )

  def _define_date_widgets(self):
    """Defines widgets for date selection."""
    self.pre_period_start = ipywidgets.DatePicker(
        description='Pre Start:',
        value=datetime.date.today() - relativedelta(days=122),
    )
    self.pre_period_end = ipywidgets.DatePicker(
        description='Pre End:',
        value=datetime.date.today() - relativedelta(days=32),
    )
    self.post_period_start = ipywidgets.DatePicker(
        description='Post Start:',
        value=datetime.date.today() - relativedelta(days=31),
    )
    self.post_period_end = ipywidgets.DatePicker(
        description='Post End:',
        value=datetime.date.today(),
    )
    self.start_date = ipywidgets.DatePicker(
        description='Start Date:',
        value=datetime.date.today() - relativedelta(days=122),
    )
    self.end_date = ipywidgets.DatePicker(
        description='End Date:',
        value=datetime.date.today() - relativedelta(days=32),
    )
    self.depend_data = ipywidgets.ToggleButton(
        value=False,
        description='Click >> Use the beginning and end of data',
        disabled=False,
        button_style='info',
        tooltip='Description',
        layout=ipywidgets.Layout(width='300px'),
    )

  def generate_ui(self):
    """Constructs and displays the user interface tabs."""
    self._build_source_selection_tab()
    self._build_data_type_selection_tab()
    self._build_design_type_tab()
    self._build_purpose_selection_tab()

    display(
        UIUtils.apply_text_style(18, '⑴ Please select a data source.'),
        self.soure_selection,
        Markdown('<br>'),
        UIUtils.apply_text_style(
            18, '⑵ Please select wide or narrow data format.'
        ),
        self.data_type_selection,
        Markdown('<br>'),
        UIUtils.apply_text_style(
            18, '⑶ Please select the purpose and set conditions.'
        ),
        self.purpose_selection,
    )

  def _build_source_selection_tab(self):
    """Builds the tab for selecting data sources."""
    self.soure_selection = ipywidgets.Tab()
    self.soure_selection.children = [
        ipywidgets.VBox([self.sheet_url, self.sheet_name]),
        ipywidgets.VBox([self.csv_name]),
        ipywidgets.VBox([self.bq_project_id, self.bq_table_name]),
    ]
    self.soure_selection.set_title(0, 'Google_Spreadsheet')
    self.soure_selection.set_title(1, 'CSV_file')
    self.soure_selection.set_title(2, 'Big_Query')

  def _build_data_type_selection_tab(self):
    """Builds the tab for selecting data format (wide vs narrow)."""
    self.data_type_selection = ipywidgets.Tab()
    self.data_type_selection.children = [
        ipywidgets.VBox([
            ipywidgets.Label(
                'Wide, or unstacked data is presented with each different'
                ' data variable in a separate column.'
            ),
            self.date_col,
        ]),
        ipywidgets.VBox([
            ipywidgets.Label(
                'Narrow, stacked, or long data is presented with one column '
                'containing all the values and another column listing the '
                'context of the value'
            ),
            ipywidgets.HBox([self.date_col, self.pivot_col, self.kpi_col]),
        ]),
    ]
    self.data_type_selection.set_title(0, 'Wide_Format')
    self.data_type_selection.set_title(1, 'Narrow_Format')

  def _build_design_type_tab(self):
    """Builds the tab for selecting experimental design type."""
    self.design_type = ipywidgets.Tab(
        children=[
            ipywidgets.VBox([
                ipywidgets.HTML(
                    'divide_equally divides the time series data into N'
                    ' groups(split#) with similar movements.'
                ),
                self.num_of_split,
                self.exclude_cols,
            ]),
            ipywidgets.VBox([
                ipywidgets.HTML(
                    'similarity_selection extracts N groups(covariate#) that '
                    'move similarly to particular columns(target_cols).'
                ),
                ipywidgets.HBox([
                    self.target_columns,
                    self.num_of_covariate,
                    self.num_of_pick_range,
                ]),
                self.exclude_cols,
            ]),
            ipywidgets.VBox([
                ipywidgets.HTML(
                    'target share extracts targeted time series data from'
                    ' the proportion of interventions.'
                ),
                self.target_share,
                self.exclude_cols,
            ]),
            ipywidgets.VBox([
                ipywidgets.HTML(
                    'To improve reproducibility, it is important to create an'
                    ' accurate counterfactual model rather than a balanced'
                    ' assignment.'
                ),
                self.target_columns,
                self.control_columns,
            ]),
        ]
    )
    self.design_type.set_title(0, 'A: divide_equally')
    self.design_type.set_title(1, 'B: similarity_selection')
    self.design_type.set_title(2, 'C: target_share')
    self.design_type.set_title(3, 'D: pre-allocated')

  def _build_purpose_selection_tab(self):
    """Builds the tab for selecting the purpose (CausalImpact or Experimental Design)."""
    self.purpose_selection = ipywidgets.Tab()
    self.date_selection = ipywidgets.Tab()
    self.date_selection.children = [
        ipywidgets.VBox(
            [
                ipywidgets.HTML('The <b>minimum</b> date of the data is '
                'selected as the start date.'),
                ipywidgets.HTML('The <b>maximum</b> date in the data is '
                'selected as the end date.'),
            ]),
        ipywidgets.VBox(
            [
                self.start_date,
                self.end_date,
            ]
        )]
    self.date_selection.set_title(0, 'automatic selection')
    self.date_selection.set_title(1, 'manual input')

    self.purpose_selection.children = [
        # Causalimpact
        ipywidgets.VBox([
            UIUtils.apply_text_style(
                15, '⑶ - a: Enter the Pre and Post the intervention.'
            ),
            self.pre_period_start,
            self.pre_period_end,
            self.post_period_start,
            self.post_period_end,
            UIUtils.apply_text_style(
                15,
                '⑶ - b: Enter the number of periodicities in the'
                ' time series data.(default=1)',
            ),
            ipywidgets.VBox([self.num_of_seasons, self.credible_interval]),
        ]),
        # Experimental_Design
        ipywidgets.VBox([
            UIUtils.apply_text_style(
                15,
                '⑶ - a: Please select date for experimental design'
            ),
            self.date_selection,
            UIUtils.apply_text_style(
                15,
                '⑶ - b: Select the <b>experimental design method</b> and'
                ' enter the necessary items.',
            ),
            self.design_type,
            UIUtils.apply_text_style(
                15,
                '⑶ - c: (Optional) Enter <b>Estimated incremental CPA</b>(Cost'
                ' of intervention ÷ Lift from intervention without bias) & the '
                'number of periodicities in the time series data.',
            ),
            ipywidgets.VBox([
                self.estimate_icpa,
                self.num_of_seasons,
                self.credible_interval,
            ]),
        ]),
    ]
    self.purpose_selection.set_title(0, 'Causalimpact')
    self.purpose_selection.set_title(1, 'Experimental_Design')

  def display_simulation_choice(self):
      """Displays the simulation choice and simulated column selectors."""
      display(
        UIUtils.apply_text_style(
            18,
            '⑷ Please select option, test column & control column(s).'),
        ipywidgets.HBox([
            self.your_choice,
            self.target_col_to_simulate,
            self.covariate_col_to_simulate,
        ]),
    )

  def get_params(self):
    """Retrieves current values from all widgets.

    Returns:
        dict: A dictionary containing all widget values.
    """
    params_dict = {
        'soure_selection': self.soure_selection.selected_index,
        'sheet_url': self.sheet_url.value,
        'sheet_name': self.sheet_name.value,
        'csv_name': self.csv_name.value,
        'bq_project_id': self.bq_project_id.value,
        'bq_table_name': self.bq_table_name.value,
        'data_type_selection': self.data_type_selection.selected_index,
        'date_col': self.date_col.value,
        'pivot_col': self.pivot_col.value,
        'kpi_col': self.kpi_col.value,
        'purpose_selection': self.purpose_selection.selected_index,
        'pre_period_start': self.pre_period_start.value,
        'pre_period_end': self.pre_period_end.value,
        'post_period_start': self.post_period_start.value,
        'post_period_end': self.post_period_end.value,
        'start_date': self.start_date.value,
        'end_date': self.end_date.value,
        'depend_data': self.depend_data.value,
        'design_type': self.design_type.selected_index,
        'num_of_split': self.num_of_split.value,
        'target_columns': self.target_columns.value,
        'control_columns': self.control_columns.value,
        'num_of_pick_range': self.num_of_pick_range.value,
        'num_of_covariate': self.num_of_covariate.value,
        'target_share': self.target_share.value,
        'exclude_cols': self.exclude_cols.value,
        'num_of_seasons': self.num_of_seasons.value,
        'estimate_icpa': self.estimate_icpa.value,
        'credible_interval': self.credible_interval.value,
        'date_selection_index': self.date_selection.selected_index,
    }
    return params_dict

  def set_params(self, dict_params):
    """Sets widget values based on a parameter dictionary.

    Args:
        dict_params (dict): A dictionary of parameters to set.
    """
    # section for data source
    self.soure_selection.selected_index = dict_params['soure_selection']
    self.sheet_url.value = dict_params['sheet_url']
    self.sheet_name.value = dict_params['sheet_name']
    self.csv_name.value = dict_params['csv_name']
    self.bq_project_id.value = dict_params['bq_project_id']
    self.bq_table_name.value = dict_params['bq_table_name']

    # section for data format(narrow or wide)
    self.data_type_selection.selected_index = dict_params['data_type_selection']
    self.date_col.value = dict_params['date_col']
    self.pivot_col.value = dict_params['pivot_col']
    self.kpi_col.value = dict_params['kpi_col']

    # section for porpose(CausalImpact or Experimental Design)
    self.purpose_selection.selected_index = dict_params['purpose_selection']
    self.pre_period_start.value = dict_params['pre_period_start']
    self.pre_period_end.value = dict_params['pre_period_end']
    self.post_period_start.value = dict_params['post_period_start']
    self.post_period_end.value = dict_params['post_period_end']
    self.start_date.value = dict_params['start_date']
    self.end_date.value = dict_params['end_date']
    self.depend_data.value = dict_params['depend_data']

    self.design_type.selected_index = dict_params['design_type']
    self.num_of_split.value = dict_params['num_of_split']
    self.target_columns.value = dict_params['target_columns']
    self.control_columns.value = dict_params['control_columns']
    self.num_of_pick_range.value = dict_params['num_of_pick_range']
    self.num_of_covariate.value = dict_params['num_of_covariate']
    self.target_share.value = dict_params['target_share']
    self.exclude_cols.value = dict_params['exclude_cols']

    self.num_of_seasons.value = dict_params['num_of_seasons']
    self.estimate_icpa.value = dict_params['estimate_icpa']
    self.credible_interval.value = dict_params['credible_interval']

class DataLoader:
  """Handles data loading from various sources."""

  def load_data(self, params):
    """Loads data based on the selected source in params.

    Args:
        params (dict): Parameters containing source selection.

    Returns:
        pd.DataFrame: The loaded data as a DataFrame.

    Raises:
        Exception: If an invalid data source is selected.
    """
    source_index = params['soure_selection']
    if source_index == 0:
      return self._load_data_from_sheet(params['sheet_url'], params['sheet_name'])
    elif source_index == 1:
      return self._load_data_from_csv(params['csv_name'])
    elif source_index == 2:
      return self._load_data_from_bigquery(params['bq_project_id'], params['bq_table_name'])
    else:
        raise Exception('Please select a data souce.')

  @staticmethod
  def _load_data_from_sheet(spreadsheet_url, sheet_name):
    """Loads data from a Google Sheet.

    Args:
        spreadsheet_url (str): URL of the Google Sheet.
        sheet_name (str): Name of the sheet to load.

    Returns:
        pd.DataFrame: The loaded data.
    """
    auth.authenticate_user()
    creds, _ = default()
    gc = gspread.authorize(creds)
    _workbook = gc.open_by_url(spreadsheet_url)
    _worksheet = _workbook.worksheet(sheet_name)
    df_sheet = pd.DataFrame(_worksheet.get_all_values())
    df_sheet.columns = list(df_sheet.loc[0, :])
    df_sheet.drop(0, inplace=True)
    df_sheet.reset_index(drop=True, inplace=True)
    df_sheet.replace(',', '', regex=True, inplace=True)
    df_sheet.rename(columns=lambda x: x.replace(" ", ""), inplace=True)
    df_sheet = df_sheet.apply(pd.to_numeric, errors='ignore')
    return df_sheet

  @staticmethod
  def _load_data_from_csv(csv_name):
    """Loads data from an uploaded CSV file.

    Args:
        csv_name (str): Name of the CSV file.

    Returns:
        pd.DataFrame: The loaded data.
    """
    uploaded = files.upload()
    df_csv = pd.read_csv(io.BytesIO(uploaded[csv_name]))
    df_csv.replace(',', '', regex=True, inplace=True)
    df_csv.rename(columns=lambda x: x.replace(" ", ""), inplace=True)
    df_csv = df_csv.apply(pd.to_numeric, errors='ignore')
    return df_csv

  @staticmethod
  def _load_data_from_bigquery(bq_project_id, bq_table_name):
    """Loads data from a BigQuery table.

    Args:
        bq_project_id (str): Google Cloud Project ID.
        bq_table_name (str): BigQuery table name.

    Returns:
        pd.DataFrame: The loaded data.
    """
    auth.authenticate_user()
    client = bigquery.Client(project=bq_project_id)
    query = 'SELECT * FROM `' + bq_table_name + '`;'
    df_bq = client.query(query).to_dataframe()
    df_bq.replace(',', '', regex=True, inplace=True)
    df_bq.rename(columns=lambda x: x.replace(" ", ""), inplace=True)
    df_bq = df_bq.apply(pd.to_numeric, errors='ignore')
    return df_bq

class DataPreprocessor:
  """Handles data formatting and preprocessing."""

  def format_data(self, dataframe, params):
    """Formats the raw dataframe for analysis.

    Args:
        dataframe (pd.DataFrame): The raw input dataframe.
        params (dict): Parameters for formatting (date column, pivot, etc).

    Returns:
        tuple: (formatted_data (pd.DataFrame), date_col_name (str), tick_count (int))
    """
    date_col_name = params['date_col'].replace(' ', '')
    pivot_col_name = params['pivot_col'].replace(' ', '')
    kpi_col_name = params['kpi_col'].replace(' ', '')

    if params['data_type_selection'] == 0:
      formatted_data = dataframe.copy()
    elif params['data_type_selection'] == 1:
      formatted_data = self._shape_wide(
          dataframe,
          date_col_name,
          pivot_col_name,
          kpi_col_name,
      )

    formatted_data.drop(
        params['exclude_cols'].replace(', ', ',').split(','),
        axis=1,
        errors='ignore',
        inplace=True,
    )
    formatted_data[date_col_name] = pd.to_datetime(
        formatted_data[date_col_name]
    )
    formatted_data = formatted_data.set_index(date_col_name)
    formatted_data = formatted_data.reindex(
        pd.date_range(
            start=formatted_data.index.min(),
            end=formatted_data.index.max(),
            name=formatted_data.index.name))

    tick_count = len(formatted_data.resample('M')) - 1

    return formatted_data, date_col_name, tick_count

  @staticmethod
  def _shape_wide(dataframe, date_column, pivot_column, kpi_column):
    """Reshapes data from narrow to wide format.

    Args:
        dataframe (pd.DataFrame): The narrow dataframe.
        date_column (str): Name of the date column.
        pivot_column (str): Name of the column to pivot.
        kpi_column (str): Name of the KPI column.

    Returns:
        pd.DataFrame: The reshaped wide dataframe.
    """
    if ',' in pivot_column:
      group_cols = pivot_column.replace(', ', ',').split(',')
    else:
      group_cols = [pivot_column]

    pivoted_df = pd.pivot_table(
        (dataframe[[date_column] + [kpi_column] + group_cols])
        .groupby([date_column] + group_cols)
        .sum(),
        index=date_column,
        columns=group_cols,
        fill_value=0,
    )
    pivoted_df.columns = pivoted_df.columns.droplevel(0)
    if len(pivoted_df.columns.names) > 1:
      new_cols = [
          '_'.join([x.replace(',', '_') for x in y])
          for y in pivoted_df.columns.values
      ]
      pivoted_df.columns = new_cols
    pivoted_df = pivoted_df.reset_index()
    return pivoted_df

class ExploratoryDataAnalyzer:
  """Performs exploratory data analysis and quality checks."""

  def check_data_quality(self, formatted_data):
    """Checks and prints data quality summary (ranges, missing values).

    Args:
        formatted_data (pd.DataFrame): The dataframe to check.
    """
    UIUtils.apply_text_style(
        'failure',
        '\nCheck! Here is an overview of the data.'
    )
    print(
        'Index name:{} | The earliest date: {} | The latest date: {}'.format(
            formatted_data.index.name,
            min(formatted_data.index),
            max(formatted_data.index)
        ))
    print('* Rows with missing values')
    missing_row = formatted_data[
        formatted_data.isnull().any(axis=1)]
    if len(missing_row) > 0:
        display(missing_row)
    else:
        print('>> Does not include missing values')

  def trend_check(self, dataframe, date_col_name, tick_count):
    """Performs trend analysis and clustering of time series.

    Args:
        dataframe (pd.DataFrame): The dataframe containing time series.
        date_col_name (str): Name of the date column.
        tick_count (int): Number of ticks for the x-axis.
    """
    UIUtils.apply_text_style(
        'failure',
        '\nCheck! below [total_trend] / [each_trend] / [describe_data]'
    )

    df_each = pd.DataFrame(index=dataframe.index)
    col_list = list(dataframe.columns)
    for i in col_list:
      min_max = (
          dataframe[i] - dataframe[i].min()
          ) / (dataframe[i].max() - dataframe[i].min())
      df_each = pd.concat([df_each, min_max], axis = 1)

    metric = 'dtw'
    n_clusters = 5
    # Handling small data size for clusters
    if len(col_list) < n_clusters:
      n_clusters = len(col_list)

    tskm_base = TimeSeriesKMeans(n_clusters=n_clusters, metric=metric,
                              max_iter=100, random_state=42)
    df_cluster = pd.DataFrame({
        "pivot": col_list,
        "cluster": tskm_base.fit_predict(df_each.T).tolist()})
    cluster_counts = (
        df_cluster["cluster"].value_counts().sort_values(ascending=True))

    cluster_text = []
    line_each = []
    for i in cluster_counts.index:
      clust_list = df_cluster.query("cluster == @i")["pivot"].to_list()
      source = df_each.filter(items=clust_list)
      cluster_text.append(str(clust_list).translate(
          str.maketrans({'[': '', ']': '',  "'": ''})))
      line_each.append(
          alt.Chart(source.reset_index())
          .transform_fold(fold=clust_list, as_=['pivot', 'kpi'])
          .mark_line()
          .encode(
              alt.X(
                  date_col_name + ':T',
                  title=None,
                  axis=alt.Axis(
                      grid=False, format='%Y %b', tickCount=tick_count
                      ),
                  ),
              alt.Y('kpi:Q', stack=None, axis=None),
              alt.Color(str(i) + ':N', title=None, legend=None),
              alt.Row(
                  'pivot:N',
                  title=None,
                  header=alt.Header(labelAngle=0, labelAlign='left'),
                  ),
              )
          .properties(bounds='flush', height=30)
          .configure_facet(spacing=0)
          .configure_view(stroke=None)
          .configure_title(anchor='end')
          )

    df_long = (
        pd.melt(dataframe.reset_index(), id_vars=date_col_name)
        .groupby(date_col_name)
        .sum(numeric_only=True)
        .reset_index()
    )
    line_total = (
        alt.Chart(df_long)
        .mark_line()
        .encode(
            x=alt.X(
                date_col_name + ':T',
                axis=alt.Axis(
                    title='', format='%Y %b', tickCount=tick_count
                ),
            ),
            y=alt.Y('value:Q', axis=alt.Axis(title='kpi')),
            color=alt.value('#4285F4'),
        )
    )
    moving_average = (
        alt.Chart(df_long)
        .transform_window(
            rolling_mean='mean(value)',
            frame=[-4, 3],
        )
        .mark_line()
        .encode(
            x=alt.X(date_col_name + ':T'),
            y=alt.Y('rolling_mean:Q'),
            color=alt.value('#DB4437'),
        )
    )
    tab_total_trend = ipywidgets.Output()
    tab_each_trend = ipywidgets.Output()
    tab_describe_data = ipywidgets.Output()
    tab_result = ipywidgets.Tab(children = [
        tab_total_trend,
        tab_each_trend,
        tab_describe_data,
        ])
    tab_result.set_title(0, '>> total_trend')
    tab_result.set_title(1, '>> each_trend')
    tab_result.set_title(2, '>> describe_data')
    display(tab_result)
    with tab_total_trend:
      display(
          (line_total + moving_average).properties(
              width=700,
              height=200,
              title={
                  'text': ['Daily Trend(blue) & 7days moving average(red)'],
              },
          )
      )
    with tab_each_trend:
      for i in range(len(cluster_text)):
          print('cluster {}:{}'.format(i, cluster_text[i]))
          display(line_each[i].properties(width=700))
    with tab_describe_data:
      display(dataframe.describe(include='all'))

class CausalImpactEstimator:
  """Wraps the CausalImpact analysis logic."""
  colors = [
      '#DB4437', '#AB47BC', '#4285F4', '#00ACC1',
      '#0F9D58', '#9E9D24', '#F4B400', '#FF7043',
  ]

  def create_causalimpact_object(self, data, date_col, pre_start, pre_end, post_start, post_end, num_of_seasons, credible_interval):
    """Fits the CausalImpact model to the data.

    Args:
        data (pd.DataFrame): Input dataframe.
        date_col (str): Name of the date column.
        pre_start (datetime): Start of pre-intervention period.
        pre_end (datetime): End of pre-intervention period.
        post_start (datetime): Start of post-intervention period.
        post_end (datetime): End of post-intervention period.
        num_of_seasons (int): Number of seasons in the data.
        credible_interval (int): Credible interval percentage.

    Returns:
        causalimpact.CausalImpact: The fitted CausalImpact object.
    """
    if data.index.name != date_col: data.set_index(date_col, inplace=True)

    if num_of_seasons == 1:
      causalimpact_object = causalimpact.fit_causalimpact(
          data=data,
          pre_period=(str(pre_start), str(pre_end)),
          post_period=(str(post_start), str(post_end)),
          alpha= 1 - credible_interval / 100,
      )
    else:
      causalimpact_object = causalimpact.fit_causalimpact(
          data=data,
          pre_period=(str(pre_start), str(pre_end)),
          post_period=(str(post_start), str(post_end)),
          alpha= 1 - credible_interval / 100,
          model_options=causalimpact.ModelOptions(
              seasons=[
                  causalimpact.Seasons(num_seasons=num_of_seasons),
              ]
          ),
      )
    return causalimpact_object

  def plot_causalimpact(self, causalimpact_object, pre_start, pre_end, tread_start, treat_end, credible_interval, date_col_name, tick_count, purpose_selection):
    """Plots the results of CausalImpact analysis.

    Args:
        causalimpact_object (causalimpact.CausalImpact): The fitted object.
        pre_start (datetime): Start of pre-period.
        pre_end (datetime): End of pre-period.
        tread_start (datetime): Start of treatment.
        treat_end (datetime): End of treatment.
        credible_interval (int): Credible interval percentage.
        date_col_name (str): Date column name.
        tick_count (int): Number of ticks for x-axis.
        purpose_selection (int): The purpose of analysis.
    """
    causalimpact_df = causalimpact_object.series
    mape = mean_absolute_percentage_error(
        causalimpact_df['observed'][str(pre_start) : str(pre_end)],
        causalimpact_df['posterior_mean'][str(pre_start) : str(pre_end)],
    )
    threshold = round(1 - credible_interval / 100, 2)

    line_1 = (
        alt.Chart(causalimpact_df.reset_index())
        .transform_fold([
            'observed',
            'posterior_mean',
        ])
        .mark_line()
        .encode(
            x=alt.X(
                'yearmonthdate(' + date_col_name + ')',
                axis=alt.Axis(
                    title='',
                    labels=False,
                    ticks=False,
                    format='%Y %b',
                    tickCount=tick_count,
                ),
            ),
            y=alt.Y(
                'value:Q',
                scale=alt.Scale(zero=False),
                axis=alt.Axis(title=''),
            ),
            color=alt.Color(
                'key:N',
                legend=alt.Legend(
                    title=None,
                    orient='none',
                    legendY=-20,
                    direction='horizontal',
                    titleAnchor='start',
                ),
                sort=['posterior_mean', 'observed'],
            ),
            strokeDash=alt.condition(
                alt.datum.key == 'posterior_mean',
                alt.value([5, 5]),
                alt.value([0]),
            ),
        )
    )
    area_1 = (
        alt.Chart(causalimpact_df.reset_index())
        .mark_area(opacity=0.3)
        .encode(
            x=alt.X('yearmonthdate(' + date_col_name + ')'),
            y=alt.Y('posterior_lower:Q', scale=alt.Scale(zero=False)),
            y2=alt.Y2('posterior_upper:Q'),
        )
    )
    line_2 = (
        alt.Chart(causalimpact_df.reset_index())
        .mark_line(strokeDash=[5, 5])
        .encode(
            x=alt.X(
                'yearmonthdate(' + date_col_name + ')',
                axis=alt.Axis(
                    title='',
                    labels=False,
                    ticks=False,
                    format='%Y %b',
                    tickCount=tick_count,
                ),
            ),
            y=alt.Y(
                'point_effects_mean:Q',
                scale=alt.Scale(zero=False),
                axis=alt.Axis(title=''),
            ),
        )
    )
    area_2 = (
        alt.Chart(causalimpact_df.reset_index())
        .mark_area(opacity=0.3)
        .encode(
            x=alt.X('yearmonthdate(' + date_col_name + ')'),
            y=alt.Y('point_effects_lower:Q', scale=alt.Scale(zero=False)),
            y2=alt.Y2('point_effects_upper:Q'),
        )
    )
    line_3 = (
        alt.Chart(causalimpact_df.reset_index())
        .mark_line(strokeDash=[5, 5])
        .encode(
            x=alt.X(
                'yearmonthdate(' + date_col_name + ')',
                axis=alt.Axis(title='', format='%Y %b', tickCount=tick_count),
            ),
            y=alt.Y(
                'cumulative_effects_mean:Q',
                scale=alt.Scale(zero=False),
                axis=alt.Axis(title=''),
            ),
        )
    )
    area_3 = (
        alt.Chart(causalimpact_df.reset_index())
        .mark_area(opacity=0.3)
        .encode(
            x=alt.X(
                'yearmonthdate(' + date_col_name + ')',
                axis=alt.Axis(title='')),
            y=alt.Y('cumulative_effects_lower:Q', scale=alt.Scale(zero=False),
                    axis=alt.Axis(title='')),
            y2=alt.Y2('cumulative_effects_upper:Q'),
        )
    )
    zero_line = (
        alt.Chart(pd.DataFrame({'y': [0]}))
        .mark_rule()
        .encode(y='y', color=alt.value('gray'))
    )
    rules = (
        alt.Chart(
            pd.DataFrame({
                'Date': [str(tread_start), str(treat_end)],
                'color': ['red', 'orange'],
            })
        )
        .mark_rule(strokeDash=[5, 5])
        .encode(x='Date:T', color=alt.Color('color:N', scale=None))
    )
    watermark = alt.Chart(pd.DataFrame([1])).mark_text(
        align='center',
        dx=0,
        dy=0,
        fontSize=48,
        text='mock experiment',
        color='red'
      ).encode(
        opacity=alt.value(0.5)
    )
    if purpose_selection == 1:
      cumulative = line_3 + area_3 + rules + zero_line + watermark
    elif causalimpact_object.summary.p_value.average >= threshold:
      cumulative = area_3 + rules + zero_line
    else:
      cumulative = line_3 + area_3 + rules + zero_line
    plot = alt.vconcat(
        (line_1 + area_1 + rules).properties(height=100, width=600),
        (line_2 + area_2 + rules + zero_line).properties(height=100, width=600),
        (cumulative).properties(height=100, width=600),
    )

    tab_data = ipywidgets.Output()
    tab_report = ipywidgets.Output()
    tab_summary = ipywidgets.Output()
    tab_result = ipywidgets.Tab(children = [tab_summary, tab_report, tab_data])
    tab_result.set_title(0, '>> summary')
    tab_result.set_title(1, '>> report')
    tab_result.set_title(2, '>> data')
    with tab_summary:
      print('Approximate model accuracy >> MAPE:{:.2%}'.format(mape))
      if mape <= 0.05:
          UIUtils.apply_text_style(
              'success',
              'Very Good: The difference between actual and predicted values \u200b\u200bis slight.')
      elif mape <= 0.10:
          UIUtils.apply_text_style(
              'success',
              'Good: The difference between the actual and predicted values \u200b\u200bis within the acceptable range.')
      elif mape <= 0.15:
          UIUtils.apply_text_style(
              'failure',
              'Medium: he difference between the actual and predicted values \u200b\u200bismoderate, so this is only a reference value.')
      else:
          UIUtils.apply_text_style(
              'failure',
              'Bad: The difference between actual and predicted values \u200b\u200bis large, so we do not recommend using it.')
      if causalimpact_object.summary.p_value.average <= threshold:
          UIUtils.apply_text_style('success', f'\nP-Value is under {threshold}. There is a statistically significant difference.')
      else:
          UIUtils.apply_text_style('failure', f'\nP-Value is over {threshold}. There is not a statistically significant difference.')

      print(causalimpact.summary(
          causalimpact_object,
          output_format='summary',
          alpha= 1 - credible_interval / 100))
      display(plot)
    with tab_report:
      print(causalimpact.summary(
          causalimpact_object,
          output_format="report",
          alpha= 1 - credible_interval / 100))
    with tab_data:
      df = causalimpact_object.series
      df.insert(2, 'diff_percentage', df['point_effects_mean'] / df['observed'])
      display(df)
    display(tab_result)

  def display_causalimpact_result(self, formatted_data, date_col_name, tick_count, post_period_start, post_period_end):
    """Displays the comparison between test and control time series.

    Args:
        formatted_data (pd.DataFrame): The formatted data.
        date_col_name (str): Date column name.
        tick_count (int): Number of ticks.
        post_period_start (datetime): Start of post period.
        post_period_end (datetime): End of post period.
    """
    print('Test & Control Time Series')
    line = (
        alt.Chart(formatted_data.reset_index())
        .transform_fold(list(formatted_data.columns))
        .mark_line()
        .encode(
            alt.X(
                date_col_name + ':T',
                title=None,
                axis=alt.Axis(format='%Y %b', tickCount=tick_count),
            ),
            y=alt.Y('value:Q', axis=alt.Axis(title='kpi')),
            color=alt.Color(
                'key:N',
                legend=alt.Legend(
                    title=None,
                    orient='none',
                    legendY=-20,
                    direction='horizontal',
                    titleAnchor='start',
                ),
                scale=alt.Scale(
                    domain=list(formatted_data.columns),
                    range=self.colors,
                ),
            ),
        )
        .properties(height=200, width=600)
    )
    rule = (
        alt.Chart(
          pd.DataFrame({
            'Date': [
                str(post_period_start),
                str(post_period_end)
                ],
            'color': ['red', 'orange'],
            })
          )
        .mark_rule(strokeDash=[5, 5])
        .encode(x='Date:T', color=alt.Color('color:N', scale=None))
        )
    display((line+rule).properties(height=200, width=600))
    print('=' * 100)

class ExperimentalDesigner:
  """Performs experimental design calculations and visualization."""
  NUM_OF_ITERATION = 1000
  COMBINATION_TARGET = 10
  MAX_STRING_LENGTH = 150
  colors = CausalImpactEstimator.colors

  def run_design(self, dataframe, params, start_date_value, end_date_value):
    """Executes the experimental design process based on selected type.

    Args:
        dataframe (pd.DataFrame): Input data.
        params (dict): Design parameters.
        start_date_value (datetime): Start date for design.
        end_date_value (datetime): End date for design.

    Returns:
        pd.DataFrame: A dataframe containing design candidates and their distances.

    Raises:
        Exception: If an invalid design type is selected.
    """
    design_type_index = params['design_type']

    # Use string slicing to handle @variable query in original code if using query(), but here use loc.
    df_sliced = dataframe.loc[str(start_date_value):str(end_date_value)]

    if design_type_index == 0:
      distance_data = self._n_part_split(
          df_sliced,
          params['num_of_split'],
          self.NUM_OF_ITERATION
      )
    elif design_type_index == 1:
      distance_data = self._find_similar(
          df_sliced,
          params['target_columns'],
          params['num_of_pick_range'],
          params['num_of_covariate']
      )
    elif design_type_index == 2:
      distance_data = self._from_share(
          df_sliced,
          params['target_share'],
      )
    elif design_type_index == 3:
      distance_data = self._given_assignment(
          params['target_columns'],
          params['control_columns'],
      )
    else:
      raise Exception('Invalid design type.')

    return distance_data

  @staticmethod
  def _calculate_distance(dataframe):
    """Calculates the Euclidean distance between time series.

    Args:
        dataframe (pd.DataFrame): Dataframe containing the time series.

    Returns:
        float: The total distance.
    """
    total_distance = 0
    scaled_data = pd.DataFrame()
    for col in dataframe:
      scaled_data[col] = (dataframe[col] - dataframe[col].min()) / (
          dataframe[col].max() - dataframe[col].min()
      )
    scaled_data = scaled_data.diff().reset_index().dropna()
    for v in itertools.combinations(list(scaled_data.columns), 2):
      distance, _ = fastdtw.fastdtw(
          scaled_data.loc[:, ['index', v[0]]],
          scaled_data.loc[:, ['index', v[1]]],
          dist=euclidean,
      )
      total_distance = total_distance + distance
    return total_distance

  def _n_part_split(self, dataframe, num_of_split, NUM_OF_ITERATION):
    """Splits the data into N groups with similar movements.

    Args:
        dataframe (pd.DataFrame): Input data.
        num_of_split (int): Number of splits.
        NUM_OF_ITERATION (int): Number of iterations for finding best split.

    Returns:
        pd.DataFrame: Top 3 splits with smallest distances.
    """
    distance_data = pd.DataFrame(columns=['distance'])
    num_of_pick = len(dataframe.columns) // num_of_split

    for l in tqdm(range(NUM_OF_ITERATION)):
      col_list = list(dataframe.columns)
      picked_data = pd.DataFrame()
      picks = []
      for s in range(num_of_split):
        random_pick = random.sample(col_list, num_of_pick)
        picks.append(random_pick)
        col_list = [i for i in col_list if i not in random_pick]
      picks[0].extend(col_list)

      for i in range(len(picks)):
        picked_data = pd.concat([
            picked_data,
            pd.DataFrame(dataframe[picks[i]].sum(axis=1), columns=[i])
            ], axis=1)

      distance = self._calculate_distance(
          picked_data.reset_index(drop=True)
      )
      distance_data.loc[l, 'distance'] = float(distance)
      for j in range(len(picks)):
        distance_data.at[l, j] = str(sorted(picks[j]))

    distance_data = (
        distance_data.drop_duplicates()
        .sort_values('distance')
        .head(3)
        .reset_index(drop=True)
    )
    return distance_data

  def _find_similar(
      self,
      dataframe,
      target_columns,
      num_of_pick_range,
      num_of_covariate,
      ):
    """Finds covariate groups that move similarly to target columns.

    Args:
        dataframe (pd.DataFrame): Input data.
        target_columns (str): Comma-separated target column names.
        num_of_pick_range (tuple): Range of number of columns to pick.
        num_of_covariate (int): Number of covariate groups.

    Returns:
        pd.DataFrame: Top 3 candidates.

    Raises:
        Exception: If settings for similarity are invalid.
    """
    distance_data = pd.DataFrame(columns=['distance'])
    target_cols = target_columns.replace(', ', ',').split(',')

    if (
        len(dataframe.columns) - len(target_cols)
          >= num_of_pick_range[1] * num_of_covariate):
      pass
    else:
      print('Please check the following:')
      print('* There is something wrong with similarity settings.')
      print('* Total number of columns \u30fc the target = {}'.format(
          len(dataframe.columns) - len(target_cols)))
      print('* But your settings are {}(max pick#) \u00d7 {}(covariate#)'.format(
          num_of_pick_range[1], num_of_covariate))
      print('* Please set it so that it does not exceed.')
      UIUtils.apply_text_style('failure', '\u25b2\u25b2\u25b2\u25b2\u25b2\u25b2\n\n')
      raise Exception('Please check Failure')

    for l in tqdm(range(self.NUM_OF_ITERATION)):
      picked_data = pd.DataFrame()
      remained_list = [
          i for i in list(dataframe.columns) if i not in target_cols
      ]
      picks = []
      for s in range(num_of_covariate):
        pick = random.sample(remained_list, random.randrange(
            num_of_pick_range[0], num_of_pick_range[1] + 1, 1
            )
        )
        picks.append(pick)
        remained_list = [
            ele for ele in remained_list if ele not in pick
        ]
      picks.insert(0, target_cols)
      for i in range(len(picks)):
        picked_data = pd.concat([
            picked_data,
            pd.DataFrame(dataframe[picks[i]].sum(axis=1), columns=[i])
            ], axis=1)

      distance = self._calculate_distance(
          picked_data.reset_index(drop=True)
      )
      distance_data.loc[l, 'distance'] = float(distance)
      for j in range(len(picks)):
        distance_data.at[l, j] = str(sorted(picks[j]))

    distance_data = (
          distance_data.drop_duplicates()
          .sort_values('distance')
          .head(3)
          .reset_index(drop=True)
    )
    return distance_data

  def _from_share(
      self,
      dataframe,
      target_share
      ):
    """Finds groups that match a specific target share.

    Args:
        dataframe (pd.DataFrame): Input data.
        target_share (float): The desired share of the target group.

    Returns:
        pd.DataFrame: Top 3 candidates.

    Raises:
        Exception: If a combination matching the target share is not found.
    """
    distance_data = pd.DataFrame(columns=['distance'])
    combinations = []

    n = self.NUM_OF_ITERATION
    while len(combinations) < self.COMBINATION_TARGET:
      n -= 1
      picked_col = np.random.choice(
          dataframe.columns,
          random.randint(1, len(dataframe.columns)//2 + 1),
          replace=False)

      if float(Decimal(dataframe[picked_col].sum().sum() / dataframe.sum().sum()
                      ).quantize(Decimal('0.1'), ROUND_HALF_UP)) == target_share:
        combinations.append(sorted(set(picked_col)))
      if n == 1:
        UIUtils.apply_text_style('failure', '\n\nFailure!!')
        print('Please check the following:')
        print('* There is something wrong with design type C.')
        print("* You couldn't find the right combination in the repetitions.")
        print('* Please re-try or re-set target share')
        UIUtils.apply_text_style('failure', '\u25b2\u25b2\u25b2\u25b2\u25b2\u25b2\n\n')
        raise Exception('Please check Failure')

    for comb in tqdm(combinations):
      for l in tqdm(
          range(
              self.NUM_OF_ITERATION // self.COMBINATION_TARGET),
          leave=False):
        picked_data = pd.DataFrame()
        remained_list = [
            i for i in list(dataframe.columns) if i not in comb
        ]
        picks = []
        picks.append(random.sample(remained_list, random.randrange(
            1, len(remained_list), 1
            )
        ))
        picks.insert(0, comb)

        for i in range(len(picks)):
          picked_data = pd.concat([
              picked_data,
              pd.DataFrame(dataframe[picks[i]].sum(axis=1), columns=[i])
              ], axis=1)

      distance = self._calculate_distance(
          picked_data.reset_index(drop=True)
      )
      distance_data.loc[l, 'distance'] = float(distance)
      for j in range(len(picks)):
        distance_data.at[l, j] = str(sorted(picks[j]))

    distance_data = (
          distance_data.drop_duplicates()
          .sort_values('distance')
          .head(3)
          .reset_index(drop=True)
    )
    return distance_data

  def _given_assignment(self, target_columns, control_columns):
    """Creates a design with pre-assigned target and control columns.

    Args:
        target_columns (str): Target column names.
        control_columns (str): Control column names.

    Returns:
        pd.DataFrame: A dataframe representing the assignment.
    """
    distance_data = pd.DataFrame(columns=['distance'])
    distance_data.loc[0, 'distance'] = 0
    distance_data.loc[0, 0] = str(target_columns.replace(', ', ',').split(','))
    distance_data.loc[0, 1] = str(control_columns.replace(', ', ',').split(','))
    return distance_data

  def visualize_candidate(
      self,
      dataframe,
      distance_data,
      start_date_value,
      end_date_value,
      date_col_name,
      tick_count
      ):
    """Visualizes the candidate experimental designs.

    Args:
        dataframe (pd.DataFrame): Input data.
        distance_data (pd.DataFrame): Design candidates.
        start_date_value (datetime): Start date.
        end_date_value (datetime): End date.
        date_col_name (str): Date column name.
        tick_count (int): Tick count for plots.
    """
    UIUtils.apply_text_style(
          'failure',
          '\nCheck! Experimental Design Parameters.'
          )
    print('* start_date_value: ' + str(start_date_value))
    print('* end_date_value: ' + str(end_date_value))
    print('* columns:')
    l = []
    for i in range(len(dataframe.columns)):
      l.append(dataframe.columns[i])
      if len(str(l)) >= self.MAX_STRING_LENGTH:
        print(str(l).translate(str.maketrans({'[': '', ']': '',  "'": ''})))
        l = []
    print('\n')

    sub_tab=[ipywidgets.Output() for i in distance_data.index.tolist()]
    tab_option = ipywidgets.Tab(sub_tab)
    for i in range (len(distance_data.index.tolist())):
        tab_option.set_title(i,"option_{}".format(i+1))
        with sub_tab[i]:
          candidate_df = pd.DataFrame(index=dataframe.index)
          for col in range(len(distance_data.columns) - 1):
            print(
                'col_' + str(col + 1) + ': '+ distance_data.at[i, col].replace(
                    "'", ""))
            candidate_df[col + 1] = list(
                dataframe.loc[:, eval(distance_data.at[i, col])].sum(axis=1)
            )
            print('\n')
          candidate_df = candidate_df.add_prefix('col_')

          candidate_share = pd.DataFrame(
              candidate_df.loc[str(start_date_value):str(end_date_value), :
                               ].sum(),
              columns=['total'])
          candidate_share['daily_average'] = candidate_share['total'] // (
              end_date_value - start_date_value).days
          candidate_share['share'] = candidate_share['total'] / (dataframe.query(
                '@start_date_value <= index <= @end_date_value'
                ).sum().sum())

          try:
            for i in candidate_df.columns:
              stl = STL(candidate_df[i], robust=True).fit()
              candidate_share.loc[i, 'std'] = np.std(stl.seasonal + stl.resid)
            display(
                candidate_share[['daily_average', 'share', 'std']].style.format(
                    {
                        'daily_average': '{:,.0f}',
                        'share': '{:.1%}',
                        'std': '{:,.0f}',
                        }))
          except Exception as e:
            print(e)
            display(
                candidate_share[['daily_average', 'share']].style.format({
                'daily_average': '{:,.0f}',
                'share': '{:.1%}',
                }))

          chart_line = (
              alt.Chart(candidate_df.reset_index())
              .transform_fold(
                  fold=list(candidate_df.columns), as_=['pivot', 'kpi']
              )
              .mark_line()
              .encode(
                  x=alt.X(
                      date_col_name + ':T',
                      title=None,
                      axis=alt.Axis(
                      grid=False, format='%Y %b', tickCount=tick_count
                      ),
                  ),
                  y=alt.Y('kpi:Q'),
                  color=alt.Color(
                    'pivot:N',
                    legend=alt.Legend(
                      title=None,
                      orient='none',
                      legendY=-20,
                      direction='horizontal',
                      titleAnchor='start'),
                    scale=alt.Scale(
                        domain=list(candidate_df.columns),
                        range=self.colors)),
                  )
              .properties(width=600, height=200)
          )

          rules = alt.Chart(
              pd.DataFrame(
                  {
                      'Date': [str(start_date_value), str(end_date_value)],
                      'color': ['red', 'orange']
                      })
              ).mark_rule(strokeDash=[5, 5]).encode(
                  x='Date:T',
                  color=alt.Color('color:N', scale=None))

          df_scaled = candidate_df.copy()
          df_scaled[:] = MinMaxScaler().fit_transform(candidate_df)
          chart_line_scaled = (
              alt.Chart(df_scaled.reset_index())
              .transform_fold(
                  fold=list(candidate_df.columns),
                  as_=['pivot', 'kpi']
              )
              .mark_line()
              .encode(
                  x=alt.X(
                      date_col_name + ':T',
                      title=None,
                      axis=alt.Axis(
                      grid=False, format='%Y %b', tickCount=tick_count
                      ),
                  ),
                  y=alt.Y('kpi:Q'),
                  color=alt.Color(
                    'pivot:N',
                    legend=alt.Legend(
                      title=None,
                      orient='none',
                      legendY=-20,
                      direction='horizontal',
                      titleAnchor='start'),
                    scale=alt.Scale(
                        domain=list(candidate_df.columns),
                        range=self.colors)),
                  )
              .properties(width=600, height=80)
          )

          df_diff = pd.DataFrame(
              np.diff(candidate_df, axis=0),
              columns=candidate_df.columns.values,
          )
          scatter = (
              alt.Chart(df_diff.reset_index())
              .mark_circle()
              .encode(
                  alt.X(alt.repeat('column'), type='quantitative'),
                  alt.Y(alt.repeat('row'), type='quantitative'),
              )
              .properties(width=80, height=80)
              .repeat(
                  row=df_diff.columns.values,
                  column=df_diff.columns.values,
              )
          )
          display(
              alt.vconcat(chart_line + rules, chart_line_scaled) | scatter)
    display(tab_option)

class SimulationOptimizer:
  """Optimizes and runs simulations for causal impact."""
  TREAT_DURATION = [14, 21, 28]
  TREAT_IMPACT = [1, 1.01, 1.03, 1.05, 1.10, 1.15]

  def __init__(self, estimator):
    """Initializes the optimizer with an estimator.

    Args:
        estimator (CausalImpactEstimator): The estimator instance.
    """
    self.estimator = estimator

  def generate_simulation(self, choice_value, target_cols, covariate_cols, formatted_data, distance_data, date_col_name, start_date, end_date, num_of_seasons, credible_interval, estimate_icpa, purpose_selection_index):
    """Runs a simulation based on selected design and displays results.

    Args:
        choice_value (str): The selected option value.
        target_cols (list): List of target columns.
        covariate_cols (list): List of covariate columns.
        formatted_data (pd.DataFrame): Formatted data.
        distance_data (pd.DataFrame): Design candidates.
        date_col_name (str): Date column name.
        start_date (datetime): Start date.
        end_date (datetime): End date.
        num_of_seasons (int): Number of seasons.
        credible_interval (int): Credible interval percentage.
        estimate_icpa (int): Estimated iCPA.
        purpose_selection_index (int): Purpose selection index.
    """
    test_data = self._extract_data_from_choice(
        choice_value,
        target_cols,
        covariate_cols,
        formatted_data,
        distance_data,
    )
    simulation_params, ci_objs = self._execute_simulation(
        test_data,
        date_col_name,
        start_date,
        end_date,
        num_of_seasons,
        credible_interval
    )
    self._display_simulation_result(simulation_params, ci_objs, estimate_icpa)
    # tick_count needs to be passed or recalculated.
    # For now, let's recalculate it as it depends on data
    tick_count = len(test_data.resample('M')) - 1
    self._plot_simulation_result(
          simulation_params, ci_objs, date_col_name, tick_count,
          purpose_selection=purpose_selection_index,
          credible_interval=credible_interval
    )

  def _extract_data_from_choice(self, your_choice, target_col_to_simulate, covariate_col_to_simulate, dataframe, distance):
    """Extracts test data based on the user's choice.

    Args:
        your_choice (str): Selected choice.
        target_col_to_simulate (list): Target columns.
        covariate_col_to_simulate (list): Covariate columns.
        dataframe (pd.DataFrame): Input data.
        distance (pd.DataFrame): Distance/Design data.

    Returns:
        pd.DataFrame: Dataframe with 'test' and covariate columns.
    """
    selection_row = int(your_choice.replace('option_', '')) - 1
    selection_cols = [
        [int(t.replace('col_', '')) - 1 for t in list(target_col_to_simulate)],
        [int(t.replace('col_', '')) - 1 for t in list(covariate_col_to_simulate)
        ]]
    test_data = pd.DataFrame(index = dataframe.index)

    test_column = []
    for i in selection_cols[0]:
      test_column.extend(eval(distance.at[selection_row,i]))
    test_data['test'] = dataframe.loc[
                  :, test_column
              ].sum(axis=1)

    for col in selection_cols[1]:
      test_data['col_'+ str(col+1)] = dataframe.loc[
              :, eval(distance.at[selection_row, col])
          ].sum(axis=1)

    print('* test: {}\n'.format(str(test_column).replace("'", "")))
    print('* covariate')
    for x,i in zip(test_data.columns[1:],selection_cols[1]):
      print('> {}: {}'.format(
          x,
          str(eval(distance.at[selection_row, i]))).replace("'", "")
          )
    return test_data

  def _execute_simulation(self, dataframe, date_col_name, start_date_value, end_date_value, num_of_seasons, credible_interval):
    """Executes CausalImpact simulation for various durations and impacts.

    Args:
        dataframe (pd.DataFrame): Input data.
        date_col_name (str): Date column name.
        start_date_value (datetime): Start date.
        end_date_value (datetime): End date.
        num_of_seasons (int): Number of seasons.
        credible_interval (int): Credible interval.

    Returns:
        tuple: (simulation_params, ci_objs)
    """
    ci_objs = []
    simulation_params = []
    adjusted_data = dataframe.copy()

    for duration in tqdm(self.TREAT_DURATION):
      for impact in tqdm(self.TREAT_IMPACT, leave=False):
          pre_end_date = end_date_value + datetime.timedelta(days=-duration)
          post_start_date = pre_end_date + datetime.timedelta(days=1)
          adjusted_data.loc[
              np.datetime64(post_start_date) : np.datetime64(end_date_value),
              'test',] = (
                  dataframe.loc[
                  np.datetime64(post_start_date) : np.datetime64(end_date_value
                  ),
                  'test',
              ]
              * impact
          )

          ci_obj = self.estimator.create_causalimpact_object(
              adjusted_data,
              date_col_name,
              start_date_value,
              pre_end_date,
              post_start_date,
              end_date_value,
              num_of_seasons,
              credible_interval,
          )
          simulation_params.append([
              start_date_value,
              pre_end_date,
              post_start_date,
              end_date_value,
              impact,
              duration,
          ])
          ci_objs.append(ci_obj)
    return simulation_params, ci_objs

  def _display_simulation_result(self, simulation_params, ci_objs, estimate_icpa):
    """Displays the simulation results in a tabular format.

    Args:
        simulation_params (list): List of simulation parameters.
        ci_objs (list): List of CausalImpact objects.
        estimate_icpa (int): Estimated iCPA.
    """
    simulation_df = pd.DataFrame(
        index=[],
        columns=[
            'mock_lift',
            'Days_simulated',
            'Pre_Period_MAPE',
            'Post_Period_MAPE',
            'Total_effect',
            'Average_effect',
            'Required_budget',
            'p_value',
            'predicted_lift'
        ],
    )
    for i in range(len(ci_objs)):
      impact_df = ci_objs[i].series
      impact_dict = {
          'test_period':'('+str(simulation_params[i][5])+'d) '+str(simulation_params[i][2])+'~'+str(simulation_params[i][3]),
          'mock_lift_rate': simulation_params[i][4] - 1,
          'predicted_lift_rate': ci_objs[i].summary.loc['average', 'rel_effect'],
          'Days_simulated': simulation_params[i][5],
          'Pre_Period_MAPE': [
              mean_absolute_percentage_error(
                  impact_df.loc[:, 'observed'][
                      str(simulation_params[i][0]) : str(
                          simulation_params[i][1]
                      )
                  ],
                  impact_df.loc[:, 'posterior_mean'][
                      str(simulation_params[i][0]) : str(
                          simulation_params[i][1]
                      )
                  ],
              )
          ],
          'Post_Period_MAPE': [
              mean_absolute_percentage_error(
                  impact_df.loc[:, 'observed'][
                      str(simulation_params[i][2]) : str(
                          simulation_params[i][3]
                      )
                  ],
                  impact_df.loc[:, 'posterior_mean'][
                      str(simulation_params[i][2]) : str(
                          simulation_params[i][3]
                      )
                  ],
              )
          ],
          'Total_effect': ci_objs[i].summary.loc['cumulative', 'abs_effect'],
          'Average_effect': ci_objs[i].summary.loc['average', 'abs_effect'],
          'Required_budget': [
              ci_objs[i].summary.loc['cumulative', 'abs_effect'] * estimate_icpa
          ],
          'p_value': ci_objs[i].summary.loc['average', 'p_value'],

      }
      simulation_df = pd.concat(
          [simulation_df, pd.DataFrame.from_dict(impact_dict)],
          ignore_index=True,
      )
    display(UIUtils.apply_text_style(
          18,
          'A/A Test: Check the error without intervention'))
    print('> If p_value < 0.05, please suspect "poor model accuracy"(See Pre_Period_MAPE) or "data drift"(See Time Series Chart).\n')
    display(
        simulation_df.query('mock_lift_rate == 0')[
            ['test_period','Pre_Period_MAPE','Post_Period_MAPE','p_value']
            ].style.format({
                'Pre_Period_MAPE': '{:.2%}',
                'Post_Period_MAPE': '{:.2%}',
                'p_value': '{:,.2f}',
                }).hide()
            )
    print('\n')
    display(UIUtils.apply_text_style(
          18,
          'Simulation with increments as a mock experiment'))
    for i in simulation_df.Days_simulated.unique():
      print('\n During the last {} days'.format(i))
      display(
          simulation_df.query('mock_lift_rate != 0 & Days_simulated == @i')[
              [
                  'mock_lift_rate',
                  'predicted_lift_rate',
                  'Pre_Period_MAPE',
                  'Total_effect',
                  'Average_effect',
                  'Required_budget',
                  'p_value',
                  ]
          ].style.format({
              'mock_lift_rate': '{:+.0%}',
              'predicted_lift_rate': '{:+.1%}',
              'Pre_Period_MAPE': '{:.2%}',
              'Total_effect': '{:,.2f}',
              'Average_effect': '{:,.2f}',
              'Required_budget': '{:,.0f}',
              'p_value': '{:,.2f}',
          }).hide()
      )

  def _plot_simulation_result(
      self,
      simulation_params,
      ci_objs,
      date_col_name,
      tick_count,
      purpose_selection,
      credible_interval,
      ):
    """Plots the simulation results using CausalImpact plots.

    Args:
        simulation_params (list): Simulation parameters.
        ci_objs (list): CausalImpact objects.
        date_col_name (str): Date column name.
        tick_count (int): Tick count.
        purpose_selection (int): Purpose selection index.
        credible_interval (int): Credible interval.
    """
    mock_combinations = []
    for i in range(len(simulation_params)):
      mock_combinations.append(
            [
                '{}d:+{:.0%}'.format(
                    simulation_params[i][5],
                    simulation_params[i][4]-1)
            ])
    simulation_tb=[ipywidgets.Output() for tab in mock_combinations]
    tab_simulation = ipywidgets.Tab(simulation_tb)
    for id,name in enumerate(mock_combinations):
      tab_simulation.set_title(id,name)
      with simulation_tb[id]:
        print(
            'Pre Period:{} ~ {}\nPost Period:{} ~ {}'.format(
                simulation_params[id][0],
                simulation_params[id][1],
                simulation_params[id][2],
                simulation_params[id][3],
            )
        )
        self.estimator.plot_causalimpact(
            ci_objs[id],
            simulation_params[id][0],
            simulation_params[id][1],
            simulation_params[id][2],
            simulation_params[id][3],
            credible_interval,
            date_col_name,
            tick_count,
            purpose_selection
        )
    display(tab_simulation)

class CausalImpactAnalysis:
  """Orchestrator class for Causal Impact Analysis and Experimental Design."""

  def __init__(self):
    """Initializes the analysis environment and helper classes."""
    self.ui = InteractiveUI()
    self.loader = DataLoader()
    self.preprocessor = DataPreprocessor()
    self.analyzer = ExploratoryDataAnalyzer()
    self.estimator = CausalImpactEstimator()
    self.designer = ExperimentalDesigner()
    self.optimizer = SimulationOptimizer(self.estimator)

    self.loaded_df = None
    self.formatted_data = None
    self.date_col_name = None
    self.tick_count = None
    self.distance_data = None
    self.ci_obj = None

  def generate_ui(self):
    """Generates and displays the main user interface."""
    self.ui.generate_ui()

  def set_params(self, params):
    """Sets the parameters of the analysis.

    Args:
        params (dict): A dictionary of parameters.
    """
    self.ui.set_params(params)

  def get_params(self):
      """Retrieves current parameters from the UI.

      Returns:
          dict: Current parameters.
      """
      return self.ui.get_params()

  def load_data(self):
    """Loads the data using the current parameters.

    Returns:
        pd.DataFrame: The loaded data.

    Raises:
        Exception: If loading fails.
    """
    params = self.ui.get_params()
    try:
        self.loaded_df = self.loader.load_data(params)
        UIUtils.apply_text_style(
            'success',
            'Success! The target data has been loaded.')
        display(self.loaded_df.head(3))
    except Exception as e:
        UIUtils.apply_text_style('failure', '\n\nFailure!!')
        print('Error: {}'.format(e))
        raise e

  def format_data(self):
    """Formats the loaded data.

    Returns:
        pd.DataFrame: The formatted data.

    Raises:
        Exception: If formatting fails.
    """
    params = self.ui.get_params()
    try:
        self.formatted_data, self.date_col_name, self.tick_count = self.preprocessor.format_data(
            self.loaded_df, params
        )
        UIUtils.apply_text_style(
            'success',
            '\nSuccess! The data was formatted for analysis.'
            )
        display(self.formatted_data.head(3))

        self.analyzer.check_data_quality(self.formatted_data)
        self.analyzer.trend_check(self.formatted_data, self.date_col_name, self.tick_count)

    except Exception as e:
        UIUtils.apply_text_style('failure', '\n\nFailure!!')
        print('Error: {}'.format(e))
        raise e

  def run_causalImpact(self):
    """Runs CausalImpact analysis."""
    params = self.ui.get_params()
    try:
        self.ci_obj = self.estimator.create_causalimpact_object(
            self.formatted_data,
            self.date_col_name,
            params['pre_period_start'],
            params['pre_period_end'],
            params['post_period_start'],
            params['post_period_end'],
            params['num_of_seasons'],
            params['credible_interval'],
        )
        UIUtils.apply_text_style(
            'success',
            '\nSuccess! CausalImpact has been performed. Check the'
            ' results in the next cell.',
        )
    except Exception as e:
          UIUtils.apply_text_style('failure', '\n\nFailure!!')
          print('Error: {}'.format(e))
          raise e

  def display_causalimpact_result(self):
    """Displays the results of the CausalImpact analysis."""
    params = self.ui.get_params()
    self.estimator.display_causalimpact_result(
        self.formatted_data,
        self.date_col_name,
        self.tick_count,
        params['post_period_start'],
        params['post_period_end']
    )
    self.estimator.plot_causalimpact(
        self.ci_obj,
        params['pre_period_start'],
        params['pre_period_end'],
        params['post_period_start'],
        params['post_period_end'],
        params['credible_interval'],
        self.date_col_name,
        self.tick_count,
        params['purpose_selection']
    )

  def run_experimental_design(self):
    """Runs the Experimental Design logic."""
    params = self.ui.get_params()
    if params['date_selection_index'] == 0:
        start_date_value = min(self.formatted_data.index).date()
        end_date_value = max(self.formatted_data.index).date()
    else:
        start_date_value = params['start_date']
        end_date_value = params['end_date']

    self.distance_data = self.designer.run_design(
        self.formatted_data, params, start_date_value, end_date_value
    )

    self.designer.visualize_candidate(
        self.formatted_data,
        self.distance_data,
        start_date_value,
        end_date_value,
        self.date_col_name,
        self.tick_count
    )
    self.ui.display_simulation_choice()

  def generate_simulation(self):
    """Generates the simulation."""
    params = self.ui.get_params()

    if params['date_selection_index'] == 0:
        start_date_value = min(self.formatted_data.index).date()
        end_date_value = max(self.formatted_data.index).date()
    else:
        start_date_value = params['start_date']
        end_date_value = params['end_date']

    self.optimizer.generate_simulation(
        self.ui.your_choice.value,
        self.ui.target_col_to_simulate.value,
        self.ui.covariate_col_to_simulate.value,
        self.formatted_data,
        self.distance_data,
        self.date_col_name,
        start_date_value,
        end_date_value,
        params['num_of_seasons'],
        params['credible_interval'],
        params['estimate_icpa'],
        params['purpose_selection']
    )

# Instantiate the orchestrator
case_1 = CausalImpactAnalysis()
case_1.generate_ui()
if 'dict_params' in globals():
  case_1.set_params(dict_params)
print('\nExecution datetime(GMT):{}'.format(datetime.datetime.now()))

HTML(value="<span style='font-size:18px; background: linear-gradient(transparent 90%, #4285F4 0%);'>⑴ Please s…

Tab(children=(VBox(children=(Text(value='https://docs.google.com/spreadsheets/d/1dISrbX1mZHgzpsIct2QXFOWWRRJiC…

<br>

HTML(value="<span style='font-size:18px; background: linear-gradient(transparent 90%, #4285F4 0%);'>⑵ Please s…

Tab(children=(VBox(children=(Label(value='Wide, or unstacked data is presented with each different data variab…

<br>

HTML(value="<span style='font-size:18px; background: linear-gradient(transparent 90%, #4285F4 0%);'>⑶ Please s…

Tab(children=(VBox(children=(HTML(value="<span style='font-size:15px; background: linear-gradient(transparent …


Execution datetime(GMT):2025-12-24 07:09:27.844535
CPU times: user 1.36 s, sys: 213 ms, total: 1.57 s
Wall time: 4.73 s


In [4]:
# @title Step.2
%%time
case_1.load_data()
case_1.format_data()
dict_params = case_1.get_params()

if dict_params['purpose_selection'] == 0:
  case_1.run_causalImpact()
else:
  case_1.run_experimental_design()

print('\nExecution datetime(GMT):{}'.format(datetime.datetime.now()))

[38;2;15;157;88m Success! The target data has been loaded.[0m


Unnamed: 0,Date,Test,Control
0,2025-08-24,115133,70628
1,2025-08-25,108481,65315
2,2025-08-26,106624,66472


[38;2;15;157;88m 
Success! The data was formatted for analysis.[0m


Unnamed: 0_level_0,Test,Control
Date,Unnamed: 1_level_1,Unnamed: 2_level_1
2025-08-24,115133,70628
2025-08-25,108481,65315
2025-08-26,106624,66472


[38;2;219;68;55m 
Check! Here is an overview of the data.[0m
Index name:Date | The earliest date: 2025-08-24 00:00:00 | The latest date: 2025-12-24 00:00:00
* Rows with missing values
>> Does not include missing values
[38;2;219;68;55m 
Check! below [total_trend] / [each_trend] / [describe_data][0m


Tab(children=(Output(), Output(), Output()), _titles={'0': '>> total_trend', '1': '>> each_trend', '2': '>> de…

[38;2;15;157;88m 
Success! CausalImpact has been performed. Check the results in the next cell.[0m

Execution datetime(GMT):2025-12-24 07:12:56.017356
CPU times: user 11.6 s, sys: 705 ms, total: 12.3 s
Wall time: 33.9 s


In [5]:
# @title Step.3
%%time
if case_1.get_params()['purpose_selection'] == 0:
  case_1.display_causalimpact_result()
else:
  case_1.generate_simulation()

print('\nExecution datetime(GMT):{}'.format(datetime.datetime.now()))

Test & Control Time Series




Tab(children=(Output(), Output(), Output()), _titles={'0': '>> summary', '1': '>> report', '2': '>> data'})


Execution datetime(GMT):2025-12-24 07:12:58.775235
CPU times: user 363 ms, sys: 3.52 ms, total: 366 ms
Wall time: 383 ms


# (Optional) Case_2

In [None]:
# @title Case_2 Step.1
overwrite_pramas = True #@param {type:"boolean"}
case_2 = CausalImpactAnalysis()
case_2.generate_ui()
if overwrite_pramas == True: case_2.set_params(dict_params)

In [None]:
# @title Case_2 Step.2
%%time
case_2.load_data()
case_2.format_data()

if case_2.get_params()['purpose_selection'] == 0:
  case_2.run_causalImpact()
else:
  case_2.run_experimental_design()

print('\nExecution datetime(GMT):{}'.format(datetime.datetime.now()))

In [None]:
# @title Case_2 Step.3
%%time
if case_2.get_params()['purpose_selection'] == 0:
  case_2.display_causalimpact_result()
else:
  case_2.generate_simulation()

print('\nExecution datetime(GMT):{}'.format(datetime.datetime.now()))

# (Optional) Case_3

In [None]:
# @title Case_3 Step.1
overwrite_pramas = False #@param {type:"boolean"}
case_3 = CausalImpactAnalysis()
case_3.generate_ui()
if overwrite_pramas == True: case_3.set_params(dict_params)

In [None]:
# @title Case_3 Step.2
%%time
case_3.load_data()
case_3.format_data()

if case_3.get_params()['purpose_selection'] == 0:
  case_3.run_causalImpact()
else:
  case_3.run_experimental_design()

print('\nExecution datetime(GMT):{}'.format(datetime.datetime.now()))

In [None]:
# @title Case_3 Step.3
%%time
if case_3.get_params()['purpose_selection'] == 0:
  case_3.display_causalimpact_result()
else:
  case_3.generate_simulation()

print('\nExecution datetime(GMT):{}'.format(datetime.datetime.now()))