# Алгоритм Томпсона на батчах (без контекста)

1. На первом батче распределяем юзеров 50% на 50%.
2. Вероятность конверсии каждого варианта распределена по Beta распределению с $\alpha=1$,
$\beta=1$.
3. В конце каждого батча пересчитываем вероятности превосходства по точной формуле,
взятой отсюда https://www.johndcook.com//UTMDABTR-005-05.pdf
4. Распределяем трафик в пропорции вероятностей превосходства для каждого варианта
5. Останавливаем эксперимент при достижении определенной вероятности превосходства,
но не раньше определенного дня, чтобы учесть календарные факторы

Потенциальные проблемы:
- слишком рано отдаем трафик победителю
- из-за дисбаланса распределения трафика может быть больше успешных конверсий в этом варианте
(можно попробовать применить нормализацию)


***Реализуем алгоритм***

0. **Инициализация** - *BatchThompson(n_arms)*. Аргумент на вход: число вариантов сплита.
Здесь также инициализируются массивы для параметров Бета-распределений и вероятность превосходства = 0.5.

1. **Метод сплита** - *split_data()*. Исходя из вероятности превосходства вычисляем сплит по вариантам.
Возвращаем данные по конверсии на текущем батче для пересчета Бета-распределений.

2. **Метод изменения параметров распределения** - *.update_beta_params(data)*. Аргументы на вход: numpy массив со значениями конверсии по
каждому варианту. В случае неравномерного распределения по вариантам ставятся пропуски.
 - Проверяем, чтобы число столбцов совпадало с числом вариантов из инициализации.
 - Суммируем нули и единицы и обновляем параметры

3. **Метод пересчета** - *update_prob_super()*. Аргументов нет, так как учитывает измененные параметры
$\alpha$ (накопленное число успешных конверсий) и
$\beta$ (накопленное число неудачных конверсий) для всех вариантов.
 - Считаем по точной формуле
 - Выдаем массив из вероятностей превосходства

4. **Вероятность превосходства** - *prob_super_tuple()*. Аргументов нет, так как берем пересчитанные параметры.
5. **Критерий остановки** (*stopping_criterion*) - условия цикла while. Либо вероятность превосходства выше заданной
величины, либо закончились наблюдения.

In [3]:
from typing import List, Tuple
import os
import numpy as np
import pandas as pd
from numpy import ndarray
from scipy.stats import beta
from tqdm.notebook import tqdm
from AB_classic import get_size_zratio
from MAB import calc_prob_between

# Графики
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

import warnings
warnings.filterwarnings("ignore")



In [24]:
def create_directory_plot(folder: str, file_name: str):
    directory_plots = 'Plot/Thompson/' + folder + "/"
    try:
        os.mkdir(directory_plots)
    except:
        pass
    beta_distr_plot = PdfPages(directory_plots + file_name + ".pdf")
    return beta_distr_plot


class BatchThompson:
    def __init__(self, p_list: List[float], batch_size_share: np.float):
        self.p_list = p_list
        self.n_arms = len(p_list)
        self.n_obs_every_arm = get_size_zratio(p_list[0], p_list[1], alpha=0.05, beta=0.2)
        self.batch_size_share = batch_size_share
        self.batch_size = np.uint16((self.batch_size_share * self.n_arms * self.n_obs_every_arm).item())


        self.alphas = np.repeat(1.0, self.n_arms)
        self.bethas = np.repeat(1.0, self.n_arms)
        self.probability_superiority_tuple = (0.5, 0.5)

       # Generating data
        np.random.seed(np.uint16(np.random.random(size=1) * 100).item())
        self.data = np.random.binomial(n=[1,1], p=self.p_list, size=(self.n_obs_every_arm, self.n_arms))

        # print(f"Нужно наблюдений в каждую руку для выявления эффекта в классическом АБ-тесте: "
        #       f"{self.n_obs_every_arm}")


    def split_data_historic(self, cumulative_observations: List, batch_split_obs: List):
        """
        Split data in every batch iteration
        :param cumulative_observations: list with cumulative observations for every arm
        :param batch_split_obs: how many observation we must extract this iter
        :return:
        """
        n_rows, n_cols = np.max(batch_split_obs), self.n_arms
        data_split = np.empty((n_rows, n_cols))
        data_split[:] = np.nan
        for i in range(self.n_arms):
            data_split[:batch_split_obs[i], i] = \
                self.data[cumulative_observations[i] : cumulative_observations[i] + batch_split_obs[i], i]
        return data_split


    def split_data_random(self, batch_split_obs: np.array):
        """

        :param batch_split_obs: size for every arm
        :param probs: probability for conversion rate
        :return:
        """
        data_split = np.empty((np.max(batch_split_obs), self.n_arms))
        data_split[:] = np.nan
        for i in range(self.n_arms):
            data_split[:batch_split_obs[i], i] = np.random.binomial(n=1, p=self.p_list[i],
                                                                    size=batch_split_obs[i])
        return data_split


    def update_beta_params(self, batch_data: np.array, method:str):
        if method == "summation":
            self.alphas += np.nansum(batch_data, axis=0)
            self.bethas += np.sum(batch_data == 0, axis=0)
        elif method == "normalization":
            S_list =  np.nansum(batch_data, axis=0)  # number of successes in within batch
            F_list = np.sum(batch_data == 0, axis=0)
            M = batch_data.shape[0]
            K = self.n_arms

            adding_alphas = (M / K ) * (np.array(S_list) / (np.array(S_list) + np.array(F_list)))
            adding_bethas = (M / K ) * (1 - np.array(S_list) / (np.array(S_list) + np.array(F_list)))

            adding_alphas = np.nan_to_num(adding_alphas)
            adding_bethas = np.nan_to_num(adding_bethas)

            self.alphas += adding_alphas
            self.bethas += adding_bethas
        return self.alphas, self.bethas


    def update_prob_super(self, method_calc) -> Tuple:
        if method_calc == 'integrating':
            prob_superiority =  calc_prob_between(self.alphas, self.bethas)
            self.probability_superiority_tuple = (prob_superiority, 1 - prob_superiority)


    # def create_plots(self, beta_distr_plot):
    #     x = np.linspace(0, 1, 100)
    #     rv1 = beta(self.alphas[0], self.bethas[0])
    #     rv2 = beta(self.alphas[1], self.bethas[1])
    #     fix, ax = plt.subplots()
    #     ax.plot(x, rv1.pdf(x), label='control')
    #     ax.plot(x, rv2.pdf(x), label='testing')
    #     leg = ax.legend();
    #     plt.title(f"Вероятность превосходства в %: "
    #               f"{np.round(tuple(map(lambda x: x * 100, self.prob_superiority_tuple)), 1)}")
    #     beta_distr_plot.savefig()
    #     plt.close()


    def start_experiment(self):

        probability_superiority_step_list: List[ndarray] = []  # how share of traffic changes across experiment
        observations_step_list: List[ndarray] = []  # how many observations is cumulated in every step

        # Plots
        # folder, file_name = self.experiment_name, str(self.p1) + "_" + str(self.p2)
        cumulative_observations = np.repeat(0, self.n_arms)  # how many observations we extract every iter for every arm

        for i in tqdm(range(0, np.uint16(1 / (self.batch_size_share / self.n_arms)))):
            batch_split_obs = np.round(np.array(self.batch_size) * self.probability_superiority_tuple).astype(np.uint16)  # get number of observations every arm
            cumulative_observations += batch_split_obs
            # batch_data = batchT.split_data_historic(cumulative_observations=cumulative_observations,
            #                                         batch_split_obs=batch_split_obs) # based on earlier generated distr
            batch_data = self.split_data_random(batch_split_obs)  # based on generate batch online

            # Updating all
            self.update_beta_params(batch_data, method="normalization")  # update beta distributions parameters
            self.update_prob_super(method_calc="integrating") # update probability superiority

            # Append for resulting
            probability_superiority_step_list.append(self.probability_superiority_tuple)
            observations_step_list.append(batch_split_obs)

            stopping_criterion = (np.max(self.probability_superiority_tuple) >= 0.99) | \
                                 (np.max(cumulative_observations) >  self.n_obs_every_arm)
            if stopping_criterion:
                break

        return np.round(probability_superiority_step_list, 3), observations_step_list

bts = BatchThompson(p_list=[0.4, 0.5], batch_size_share=0.05)
probability_superiority_steps, observations_step_list = bts.start_experiment()
print(f"Наблюдений в каждую руку во время эксперимента: "
      f"\n {np.cumsum(observations_step_list, axis=0)}")

print(f"Вероятности превосходства каждой руки во время эксперимента: "
      f"\n {probability_superiority_steps}")

  0%|          | 0/40 [00:00<?, ?it/s]

Наблюдений в каждую руку во время эксперимента: 
 [[ 19  19]
 [ 47  29]
 [ 72  42]
 [ 90  62]
 [110  80]
 [126 102]
 [140 126]
 [150 154]
 [159 183]
 [169 211]
 [176 242]
 [183 273]
 [195 299]
 [203 329]
 [213 357]
 [226 382]
 [243 403]]
Вероятности превосходства каждой руки во время эксперимента: 
 [[0.747 0.253]
 [0.666 0.334]
 [0.482 0.518]
 [0.525 0.475]
 [0.416 0.584]
 [0.365 0.635]
 [0.259 0.741]
 [0.229 0.771]
 [0.274 0.726]
 [0.178 0.822]
 [0.174 0.826]
 [0.326 0.674]
 [0.223 0.777]
 [0.267 0.733]
 [0.342 0.658]
 [0.446 0.554]
 [0.46  0.54 ]]


# Эксперименты на разных теоретических конверсиях и размерах батча (summation vs normalization)

In [6]:
from itertools import combinations, product
p1_control = np.round(np.linspace(0.01, 0.5, 10), 3)
mde_test_effect = np.round(np.linspace(0, 0.15, 3), 3)
tuple_batch_size_share = np.round(np.linspace(0.001, 0.1, 10), 3)

result_experiments_df = pd.DataFrame(index=pd.MultiIndex.from_product([p1_control, mde_test_effect, tuple_batch_size_share],
                                                                      names=["p1", "mde", "batch_size_share"]),
                                     columns=['probability_superiority_steps', 'observations_step_list',
                                              'n_obs_per_every_arm'])
for index, row in  tqdm(result_experiments_df.iterrows()):
    p_list = [index[0], index[0] * (1 + index[1])]
    batch_size_share = index[2]
    bts = BatchThompson(p_list=p_list, batch_size_share=batch_size_share)
    probability_superiority_steps, observations_step_list = bts.start_experiment()
    result_experiments_df.loc[index, "n_obs_per_every_arm"] = bts.n_obs_every_arm
    result_experiments_df.loc[index, "probability_superiority_steps"] = probability_superiority_steps
    result_experiments_df.loc[index, "observations_step_list"] = observations_step_list

0it [00:00, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/166 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Попробуем реализовать метод, описанный в статье
https://www.researchgate.net/publication/352117401_Parallelizing_Thompson_Sampling

Авторы предлагают следующий алгоритм. Пусть $k_a$ - число раз выбора руки $a$,
$l_a = 1$ - число раз выбора руки подряд.

Для каждого батча $t = 1, 2, .. T$:

- смотрим на вероятности бета распределений
- выбираем руку с наибольшей вероятностью
- присваиваем $k_a = k_a + 1$
- ЕСЛИ $k_a < 2^{l_a}$, то кидаем ВЕСЬ трафик в эту руку
- ИНАЧЕ: присваиваем $l_a = l_a + 1$ и распределяем трафик в ОБЕ руки ПОЛНОСТЬЮ (без долей)
- обновляем параметры и по новой

Утверждается, что он довольно хорошо работает и для динамических батчей - когда размер заранее нам неизвестен

In [64]:
def all_equal(lst):
    for arr in lst[1:]:
        if not np.array_equal(lst[0], arr, equal_nan=True):
            return False
    return True


class BatchThompson1(BatchThompson):


    def split_data_random(self, best_arms: np.array):
        data_split = np.empty((self.batch_size, self.n_arms))
        data_split[:] = np.nan
        for i in range(self.n_arms):
            if i in best_arms:
                data_split[:, i] = np.random.binomial(n=1, p=self.p_list[i], size=self.batch_size)
            else:
                data_split[:, i] = np.nan
        return data_split


    def start_experiment(self):
        cumulative_observations = np.repeat(0, self.n_arms)
        probability_superiority_step_list: List[ndarray] = []  # how share of traffic changes across experiment
        observations_step_list = []
        k_list = [0] * self.n_arms
        l_list = [0] * self.n_arms

        for i in tqdm(range(0, np.uint16(1 / (self.batch_size_share / self.n_arms)))):

            # Determine argmax arm
            if all_equal(self.probability_superiority_tuple):
                best_arm = np.random.choice(len(self.probability_superiority_tuple[:-1]),
                                            size=len(self.probability_superiority_tuple[:-1]))[0]
            else:
                best_arm = np.argmax(self.probability_superiority_tuple)
            k_list[best_arm] += 1
            if k_list[best_arm] == 2 ** l_list[best_arm]:
                l_list[best_arm] += 1
                batch_data = self.split_data_random(best_arms=np.arange(self.n_arms))  # based on generate batch online
            elif k_list[best_arm] < 2 ** l_list[best_arm]:
                batch_data = self.split_data_random(best_arms=np.array(best_arm))

            batch_non_zero_observations_step = batch_data.shape[0] - np.isnan(batch_data).sum(axis=0)
            cumulative_observations += batch_non_zero_observations_step
            observations_step_list.append(batch_non_zero_observations_step)
            # Updating all
            self.update_beta_params(batch_data, method="summation")  # update beta distributions parameters
            self.update_prob_super(method_calc="integrating") # update probability superiority

            # Append for resulting
            probability_superiority_step_list.append(self.probability_superiority_tuple)

            stopping_criterion = (np.max(self.probability_superiority_tuple) >= 0.99) | \
                                 (np.max(cumulative_observations) >  self.n_obs_every_arm)
            if stopping_criterion:
                break

        return np.round(probability_superiority_step_list, 3), observations_step_list,\
               k_list, l_list
bt1 = BatchThompson1(p_list=[0.4, 0.42], batch_size_share=0.1)
probability_superiority_experiment, observations_step_list, k_list, l_list = bt1.start_experiment()
print(probability_superiority_experiment)
print(observations_step_list)
print(k_list)
print(l_list)

  0%|          | 0/20 [00:00<?, ?it/s]

[[0.083 0.917]
 [0.013 0.987]
 [0.001 0.999]]
[array([1897, 1897]), array([1897, 1897]), array([1897, 1897])]
[1, 2]
[1, 2]


In [35]:
observations_step_list

[array([61, 61]),
 array([61, 61]),
 array([61,  0]),
 array([61, 61]),
 array([61,  0]),
 array([61,  0]),
 array([61,  0]),
 array([61, 61]),
 array([61,  0])]

In [61]:
bt1.n_obs_every_arm

9489

In [60]:
np.cumsum(observations_step_list, axis=0)

array([[ 189,  189],
       [ 378,  378],
       [ 567,  567],
       [ 756,  756],
       [ 756,  945],
       [ 945, 1134],
       [ 945, 1323],
       [ 945, 1512],
       [ 945, 1701],
       [1134, 1890],
       [1134, 2079],
       [1134, 2268],
       [1134, 2457],
       [1134, 2646],
       [1134, 2835],
       [1134, 3024],
       [1134, 3213],
       [1323, 3402],
       [1323, 3591],
       [1323, 3780],
       [1323, 3969],
       [1323, 4158],
       [1323, 4347],
       [1323, 4536],
       [1323, 4725],
       [1323, 4914],
       [1323, 5103],
       [1323, 5292],
       [1323, 5481],
       [1323, 5670],
       [1323, 5859],
       [1323, 6048],
       [1323, 6237],
       [1512, 6426],
       [1512, 6615],
       [1512, 6804],
       [1512, 6993],
       [1512, 7182],
       [1512, 7371],
       [1512, 7560],
       [1512, 7749],
       [1512, 7938],
       [1512, 8127],
       [1512, 8316],
       [1512, 8505],
       [1512, 8694],
       [1512, 8883],
       [1512,