In [1]:
import time
import uuid

from collections import defaultdict, Counter
from math import ceil, sqrt
from pprint import pprint
from random import choice, gauss, seed
from statistics import stdev


##########
# Entities
##########

class Worker:
    def __init__(self, identity=None):
        self.identity = identity or uuid.uuid1().hex
        self.devices = set()
        self.databases = set()
        self.load_index = 0
     
    def __repr__(self):
        return 'Worker(ID: {} Load index: {}, Devices: {})'.format(
            self.identity, self.load_index, sorted(self.devices, reverse=True))
    
    def __lt__(self, other):
        if self.load_index == other.load_index:
            return len(self.devices) < len(other.devices)
        else:
            return self.load_index < other.load_index
    
    def __contains__(self, device):
        return device in self.devices
        
    def __len__(self):
        return len(self.devices)
    
    
class Device:
    def __init__(self, id_, load_index=0):
        self.id_ = id_
        self.load_index = load_index
        # TODO reprocessing will be memory hog, need mechanism for effective dealing with this
        # some out of order data is ok, since scheduler will just reasign properly
        # good idea would be to increase memory treshold on those workers, and decrease on others
        # someone should suggest this good idea (because it really rocks)
        self.reprocessing = False
        
    def __repr__(self):
        return '(ID: {id_}, Load index: {load_index})'.format(**self.__dict__)
    
    def __lt__(self, other):
        return self.load_index < other.load_index
    
    def __eq__(self, other):
        return self.id_ == other.id_
    
    def __hash__(self):
        return hash(self.id_)

    
class Database:
    def __init__(self, name):
        self.name = name
        self.devices = set()
        self.load_index = 0
        
    def __repr__(self):
        return 'DB(Name: {name}, Load Index)'.format(**self.__dict__)

    
##########
# Settings
##########

_statistics = defaultdict(Counter)  # Redis simulator
_device_num = 10  # device count
_worker_num = 2  # worker count
_devices = [Device(i) for i in range(_device_num)]
_workers = [Worker(i+1) for i in range(_worker_num)]
int_time = 10  # interval in which we should collect

decimal_points = len([c for c in str(_device_num)])
decimal_points = ceil(decimal_points + (decimal_points*5/4))

device_load = round((1/_device_num), decimal_points)  # average load time per worker
gauss_deviation =device_load * sqrt(_device_num)  # deviation per load
worker_deviation = 0.1  # index 1 how much % more load worker can have to preserve devices

load_per_worker = round((1/_worker_num), decimal_points)  # load per worker, index 1 
deviation_per_worker_load = load_per_worker+(load_per_worker*worker_deviation)  # load with deviation


###########
# Calculate
###########

def load_simulator():
    device = choice(_devices)
    calc_time = abs(round(gauss(device_load, gauss_deviation), 5))
    _statistics[device]['count'] += 1
    _statistics[device]['calc_time'] += calc_time 
    return calc_time

def generate_data(seed_num):
    seed(seed_num)
    total_messages = 0
    global _statistics
    del _statistics
    _statistics = defaultdict(Counter)  # Redis simulator
    sleep_time = 0
    while sleep_time < int_time:
        total_messages += 1
        sleep_time += load_simulator()
        
    
    return total_messages

def _load_index(int_time, system_msgs, device_calc, device_msgs):
    calc_time_index = .7
    count_index = 0.3    
    return round((device_calc*calc_time_index/int_time +
            device_msgs*count_index/system_msgs)/(calc_time_index + count_index), decimal_points)


def run_analytics(total_messages):
    print('Total msgs:', total_messages)
    for device, value in _statistics.items():
        # TODO need exception handling here for ZeroDivisionError
        proc_time = value['calc_time']/value['count']
        load_index = _load_index(int_time, total_messages, value['calc_time'], value['count'])
        device.load_index = load_index
#         print('Device(id: {}, proc_time: {}, count: {}, load_index: {})'.format(
#                 device, proc_time, value['count'], load_index))     


def sort_devices(devices, workers, load_per_worker, deviation_per_worker_load):
    for worker in sorted(workers, reverse=True):
        worker_new_devices = set()
        worker.load_index = 0
        for device in sorted(devices, reverse=True):
            if device in worker:
                if (device.load_index + worker.load_index) < deviation_per_worker_load:
                    worker_new_devices.add(device)
                    worker.load_index += device.load_index
        devices -= worker_new_devices
        worker.devices = worker_new_devices
    
    # TODO: This part needs to be smarter, not just random 
    # Heavy workers(like device or two with heavy load) shouldn't get small load devices
    # One worker should always be elected as 'reprocessing' worker which would just deal with 
    # devices that are in reprocess state
    
    # existing coord service works great with this! victory!
    for device in sorted(devices, reverse=True):
        worker = sorted(workers)[0]
        worker.devices.add(device)
        worker.load_index += device.load_index
    
    return workers    
    
def _print_devices():
    print('-'*113)
    for device in _devices:
        print(device)
    
def _print_workers(workers):
    print('='*113)
    for worker in workers:
        print(worker)
    
#####################
# Ride the unicorn 🦄
#####################

def main():
    _print_devices()
    _workers_1 = sort_devices(set(_devices), _workers, load_per_worker, deviation_per_worker_load)
    _print_workers(_workers_1)
    msg_count = generate_data(2)
    run_analytics(msg_count)
    _print_devices()
    _workers_2 = sort_devices(set(_devices), _workers_1, load_per_worker, deviation_per_worker_load)
    _print_workers(_workers_2)
    msg_count = generate_data(8)
    run_analytics(msg_count)
    _print_devices()
    _workers_3 = sort_devices(set(_devices), _workers_2, load_per_worker, deviation_per_worker_load)
    _print_workers(_workers_3)

    
main()
        

-----------------------------------------------------------------------------------------------------------------
(ID: 0, Load index: 0)
(ID: 1, Load index: 0)
(ID: 2, Load index: 0)
(ID: 3, Load index: 0)
(ID: 4, Load index: 0)
(ID: 5, Load index: 0)
(ID: 6, Load index: 0)
(ID: 7, Load index: 0)
(ID: 8, Load index: 0)
(ID: 9, Load index: 0)
Worker(ID: 1 Load index: 0, Devices: [(ID: 0, Load index: 0), (ID: 8, Load index: 0), (ID: 2, Load index: 0), (ID: 4, Load index: 0), (ID: 6, Load index: 0)])
Worker(ID: 2 Load index: 0, Devices: [(ID: 1, Load index: 0), (ID: 3, Load index: 0), (ID: 5, Load index: 0), (ID: 9, Load index: 0), (ID: 7, Load index: 0)])
Total msgs: 33
-----------------------------------------------------------------------------------------------------------------
(ID: 0, Load index: 0.08709)
(ID: 1, Load index: 0)
(ID: 2, Load index: 0.11495)
(ID: 3, Load index: 0)
(ID: 4, Load index: 0.01594)
(ID: 5, Load index: 0.07236)
(ID: 6, Load index: 0.19955)
(ID: 7, Load index

In [5]:
def test(*args):
    print (len(args))
    
a = 5-3 > 1

print(a)

True
