In [11]:
import numpy as np
import pandas as pd
import hashlib
import mmh3
from pandarallel import pandarallel
pandarallel.initialize()

"""hash_func  
各种hash函数，根据一个字符串获取hash结果
"""
hash_func_map = {
    "md5": lambda content: int(hashlib.md5(content.encode("utf8")).hexdigest(),16),
    "dwmd5": lambda content: int(str(int(hashlib.md5(content.encode("utf8")).hexdigest(), 16))[:16]),
    # "mmh3": lambda content: mmh3.hash(content),
    "blake2b": lambda content: int(hashlib.blake2b(content.encode("utf8")).hexdigest(),16)
}

"""ctx_func 
对hash内容做预处理
"""
ctx_func_map = {
    "default": lambda x: x,
    "salt": lambda x: x +"."+ x[::-1],
    "reverse": lambda x: x[::-1],
    "premd5": lambda x: hashlib.md5(x.encode("utf8")).hexdigest()
}

"""modfunc
自定义mod函数
"""
mod_func_map = {"default":lambda y: lambda x: (x % y)}

"""source
自定义数据源
"""
source_map = {
    "default":lambda n: np.random.choice(np.arange(100000, 100000+n), n, False).astype(str)
}

"""metrics
自定义指标计算方式
"""
metrics_map = {
    # 方差
    "var": lambda x: np.var(x), 
    # 标准差
    "std": lambda x: np.std(x,ddof=1),
    "max-min": lambda x: np.max(x) -np.min(x)
}

INFO: Pandarallel will run on 8 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.


In [12]:
import plotly.express as px

class Benchmark:
    """测试hash均匀性
    【param】
        size: 样本数量
        slots: 最终的业务槽数
    """

    def __init__(self):
        self.current = []
        self.hash_funcs = {}
        self.mod_funcs = {}
        self.ctx_funcs = {}
        self.sources = {}
        self.size = 20000000

    def load(self, sources, hash_fs, mod_fs, ctx_fs):
        self.sources = sources
        self.hash_funcs = hash_fs
        self.mod_funcs = mod_fs
        self.ctx_funcs = ctx_fs
        self.slots = [100,101]

    def simulate(self):
        report_table = {
            "source":[],
            "hash":[],
            "ctx":[],
            "mod":[],
            "slots":[]
        }

        for source_name, source_func in source_map.items():
            df = pd.DataFrame({"id":source_func(self.size)})
            for ctx_name, ctx_f in self.ctx_funcs.items():
                for mod_name, mod_f in self.mod_funcs.items():
                    for hash_name, hash_f in self.hash_funcs.items():
                        for slots in self.slots:
                            df["bucket"]= df["id"].parallel_apply(ctx_f).parallel_apply(hash_f).parallel_apply(mod_f(slots))
                            # 统计桶中的用户数量
                            report_table["source"].append(source_name)
                            report_table["ctx"].append(ctx_name)
                            report_table["mod"].append(mod_name)
                            report_table["hash"].append(hash_name)
                            report_table["slots"].append(slots)
                            bcount = df.groupby("bucket").count()["id"]

                            # 自定义的指标们
                            for metrics_name,mfunc in metrics_map.items():
                                if metrics_name not in report_table: report_table[metrics_name] = []
                                report_table[metrics_name].append(mfunc(bcount))
                
            self.report_table = report_table

    def report(self):
        df = pd.DataFrame(self.report_table)
        df.to_csv("report.csv")

    def getResult(self):
        return pd.DataFrame(self.report_table)

    def showReport(self):
        df = self.getResult()

        fig = px.bar(df,x = "hash", y = "var", color="ctx", barmode="group",facet_row= "slots", facet_col="source",title="方差情况(hash)")
        fig.show()

        fig = px.bar(df,x = "ctx", y = "var", color="hash", barmode="group",facet_row= "slots", facet_col="source",title="方差情况(ctx)")
        fig.show()
        
        fig = px.bar(df,x = "hash", y = "std", color="ctx", barmode="group",facet_row= "slots", facet_col="source",title="标准差情况(hash)")
        fig.show()

        fig = px.bar(df,x = "ctx", y = "std", color="hash", barmode="group",facet_row= "slots", facet_col="source",title="标准差情况(ctx)")
        fig.show()

        fig = px.bar(df,x = "hash", y = "max-min", color="ctx", barmode="group",facet_row= "slots", facet_col="source",title="max-min情况(hash)")
        fig.show()

        fig = px.bar(df,x = "ctx", y = "max-min", color="hash", barmode="group",facet_row= "slots", facet_col="source",title="max-min情况(ctx)")
        fig.show()

# 运行测试和报告
m = Benchmark()
m.load(source_map, hash_func_map, mod_func_map, ctx_func_map)
m.simulate()
m.report()
m.showReport()