# A/B Testing Course

## Lesson 11. Traffic Splitting

### Homework

#### Task 1. 

Code `SplittingService` class `add_experiment` method. 

In [1]:
from pydantic import BaseModel


class Experiment(BaseModel):
    """
    id - experiment identifier.
    buckets_count - desired number of buckets.
    conflicts - a list of experiment identifiers that cannot be run simultaneously on the same users.
    """
    id: int
    buckets_count: int
    conflicts: list = []


class SplittingService:

    def __init__(self, buckets_count):
        """Class for distributing experiments and users into buckets.
        
        :param buckets_count (int): number of buckets.
        """
        self.buckets_count = buckets_count
        self.buckets = [[] for _ in range(buckets_count)]

    def add_experiment(self, experiment):
        """Checks if an experiment can be added, and adds it if possible.

        :param experiment (Experiment): experiment parameters to be added
        :return success, buckets:
            success (boolean) - whether the experiment can be added, True if possible, False otherwise
            buckets (list[list[int]]]) - list of buckets, where each bucket contains the identifiers of experiments being conducted in it.
        """
        # YOUR_CODE_HERE
        exp_id = experiment.id
        exp_buckets = experiment.buckets_count
        exp_conflicts = experiment.conflicts
        if exp_conflicts == []:
            exp_conflicts = [0]
        
        cnt = 0
        for bucket in self.buckets:
            for exp_conflict in exp_conflicts:
                if exp_conflict in bucket:
                    cnt += 1
                    break
                    
        if self.buckets_count - cnt < exp_buckets or self.buckets_count < exp_buckets:
            success = False
        else:
            success = True
            
        if success == True:
            for bucket in range(self.buckets_count):
                for exp_conflict in exp_conflicts:
                    if exp_conflict in self.buckets[bucket]:
                        continue 
                    elif exp_buckets > 0:
                        if exp_id in self.buckets[bucket]:
                            continue
                        else:
                            self.buckets[bucket].append(exp_id)
                            exp_buckets -= 1
            
        return success, self.buckets

        
def check_correct_buckets(buckets, experiments):
    for experiment in experiments:
        buckets_with_exp = [b for b in buckets if experiment.id in b]
        assert experiment.buckets_count == len(buckets_with_exp), 'Wrong number of buckets with experiments'
        parallel_experiments = set(sum(buckets_with_exp, []))
        err_msg = 'Incompatible experiments in one bucket'
        assert len(set(experiment.conflicts) & parallel_experiments) == 0, err_msg


if __name__ == '__main__':
    experiments = [
        Experiment(id=1, buckets_count=4, conflicts=[4]),
        Experiment(id=2, buckets_count=2, conflicts=[3]),
        Experiment(id=3, buckets_count=2, conflicts=[2]),
        Experiment(id=4, buckets_count=1, conflicts=[1]),
    ]
    ideal_answers = [True, True, True, False]

    splitting_service = SplittingService(buckets_count=4)
    added_experiments = []
    for index, (experiment, ideal_answer) in enumerate(zip(experiments, ideal_answers)):
        success, buckets = splitting_service.add_experiment(experiment)
        print(success, buckets)
        assert success == ideal_answer, 'The split system is functioning suboptimally or incorrectly.'
        if success:
            added_experiments.append(experiment)
        check_correct_buckets(buckets, added_experiments)
    print('simple test passed')

True [[1], [1], [1], [1]]
True [[1, 2], [1, 2], [1], [1]]
True [[1, 2], [1, 2], [1, 3], [1, 3]]
False [[1, 2], [1, 2], [1, 3], [1, 3]]
simple test passed


#### Task 2. 

Code `SplittingService` class `process_user` method.

In [2]:
from pydantic import BaseModel
from hashlib import md5


class Experiment(BaseModel):
    """
    id - experiment identifier.
    salt - experiment salt (for random assignment of users to control/pilot groups).
    """
    id: int
    salt: str


class SplittingService:

    def __init__(self, buckets_count, bucket_salt, buckets=None, id2experiment=None):
        """Class for distributing experiments and users into buckets.

        :param buckets_count (int): number of buckets.
        :param bucket_salt (str): salt for user bucketing.
        With the same salt, each user should always be assigned to the same bucket.
        If the salt is changed, the distribution of users into buckets should change as well.
        :param buckets (list[list[int]]) - list of buckets, where each bucket contains the identifiers of experiments conducted in it.
        :param id2experiment (dict[int, Experiment]) - dictionary of experiment identifiers to experiments.
        """
        self.buckets_count = buckets_count
        self.bucket_salt = bucket_salt
        if buckets:
            self.buckets = buckets
        else:
            self.buckets = [[] for _ in range(buckets_count)]
        if id2experiment:
            self.id2experiment = id2experiment
        else:
            self.id2experiment = {}

    def process_user(self, user_id):
        """Determines which experiments the user belongs to.

        First, the user's bucket needs to be determined.
        Then, for each experiment in that bucket, select the pilot or control group.
        :param user_id (str): user identifier
        :return bucket_id, experiment_groups:
            bucket_id (int) - bucket number (index in self.buckets)
            experiment_groups (list[tuple]) - list of pairs: experiment id, group. Groups: 'A', 'B'.
            Example: (8, [(194, 'A'), (73, 'B')])
        """
        # YOUR_CODE_HERE
        def get_bucket(value: str, n: int, salt: str=''):
            """Determines the bucket based on the id.
            
            value - unique identifier of the object.
            n - number of buckets.
            salt - salt for shuffling.
            """
            hash_value = int(md5((value + salt).encode()).hexdigest(), 16)
            return hash_value % n
        
        bucket_no = get_bucket(user_id, self.buckets_count, self.bucket_salt)
        
        lst = []
        groups = {0: 'A', 1: 'B'}
        for exp in self.buckets[bucket_no]:
            exp_salt = self.id2experiment[exp].salt
            group_no = get_bucket(user_id, 2, exp_salt)
            lst.append((exp, groups[group_no]))
        
        return bucket_no, lst
        

if __name__ == '__main__':
    id2experiment = {
        0: Experiment(id=0, salt='0'),
        1: Experiment(id=1, salt='1')
    }
    buckets = [[0, 1], [1], [], []]
    buckets_count = len(buckets)
    bucket_salt = 'a2N4'

    splitting_service = SplittingService(buckets_count, bucket_salt, buckets, id2experiment)
    user_ids = [str(x) for x in range(1000)]
    for user_id in user_ids:
        bucket_id, experiment_groups = splitting_service.process_user(user_id)
        assert bucket_id in [0, 1, 2, 3], 'Wrong bucket_id'
        assert len(experiment_groups) == len(buckets[bucket_id]), 'Wrong number of experiments in a bucket'
        for exp_id, group in experiment_groups:
            assert exp_id in id2experiment, 'Wrong experiment_id'
            assert group in ['A', 'B'], 'Wrong group'
    print('simple test passed')

simple test passed
